├── vivid123 ├── configs │ ├── __init__.py │ └── base_schema.py ├── metrics │ ├── __init__.py │ └── utils.py ├── models │ ├── __init__.py │ └── clip_camera_projection.py ├── pipelines │ ├── __init__.py │ ├── vivid123_pipeline.py │ └── zero123_pipeline.py ├── __init__.py └── generation_utils.py ├── .gitignore ├── scripts ├── task_example.yaml ├── job_config_yaml_generation.py └── gso_metadata_object_prompt_100.csv ├── run_generation.py ├── run_zero123.py ├── run_batch_generation.py ├── README.md ├── run_evaluation.py ├── run_calculate_stats.py └── LICENSE /vivid123/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_schema import ViVid123BaseSchema 2 | -------------------------------------------------------------------------------- /vivid123/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import LPIPSMeter, PSNRMeter, SSIM, FOR -------------------------------------------------------------------------------- /vivid123/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_camera_projection import CLIPCameraProjection -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | *.npy 4 | *.png 5 | *.json 6 | *.yaml 7 | *.zip 8 | *.mp4 9 | 10 | *.sh 11 | *.out 12 | *.csv -------------------------------------------------------------------------------- /vivid123/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .vivid123_pipeline import ViVid123Pipeline 2 | from .zero123_pipeline import Zero1to3StableDiffusionPipeline -------------------------------------------------------------------------------- /vivid123/__init__.py: -------------------------------------------------------------------------------- 1 | from .configs import * 2 | from .models import * 3 | from .pipelines import * 4 | 5 | from .generation_utils import generation_vivid123, prepare_vivid123_pipeline -------------------------------------------------------------------------------- /scripts/task_example.yaml: -------------------------------------------------------------------------------- 1 | delta_azimuth_end: 45.0 2 | delta_azimuth_start: -45.0 3 | delta_elevation_end: 0.0 4 | delta_elevation_start: 0.0 5 | delta_radius_end: 0.0 6 | delta_radius_start: 0.0 7 | eta: 1.0 8 | guidance_scale_video: 1.0 9 | guidance_scale_zero123: 3.0 10 | height: 256 11 | input_image_path: tmp/racoon/img/012.png 12 | name: racoon 13 | noise_identical_accross_frames: false 14 | num_frames: 25 15 | num_inference_steps: 50 16 | prompt: '' 17 | refiner_guidance_scale: 12.0 18 | refiner_strength: 0.3 19 | video_end_step_percentage: 1.0 20 | video_linear_end_weight: 0.5 21 | video_linear_start_weight: 1.0 22 | video_start_step_percentage: 0.0 23 | width: 256 24 | zero123_end_step_percentage: 1.0 25 | zero123_linear_end_weight: 1.0 26 | zero123_linear_start_weight: 1.0 27 | zero123_start_step_percentage: 0.0 28 | -------------------------------------------------------------------------------- /run_generation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from vivid123 import generation_vivid123, prepare_vivid123_pipeline 4 | 5 | ZERO123_MODEL_ID = "bennyguo/zero123-xl-diffusers" 6 | VIDEO_MODEL_ID = "cerspense/zeroscope_v2_576w" 7 | VIDEO_XL_MODEL_ID = "cerspense/zeroscope_v2_XL" 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser(description='ViVid123 Generation') 12 | parser.add_argument('--task_yaml_path', type=str, required=True, help='The path for the task yaml') 13 | args = parser.parse_args() 14 | 15 | vivid123_pipe, xl_pipe = prepare_vivid123_pipeline( 16 | ZERO123_MODEL_ID=ZERO123_MODEL_ID, 17 | VIDEO_MODEL_ID=VIDEO_MODEL_ID, 18 | VIDEO_XL_MODEL_ID=VIDEO_XL_MODEL_ID 19 | ) 20 | 21 | generation_vivid123(config_path=args.task_yaml_path, vivid123_pipe=vivid123_pipe, xl_pipe=xl_pipe) -------------------------------------------------------------------------------- /vivid123/configs/base_schema.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | class ViVid123BaseSchema(BaseModel): 4 | # Disable aliasing underscore to hyphen 5 | class Config: 6 | alias_generator = lambda string: string 7 | 8 | num_frames: int = 25 9 | delta_elevation_start: float = 0.0 10 | delta_elevation_end: float = 0.0 11 | delta_azimuth_start: float = -45.0 12 | delta_azimuth_end: float = 45.0 13 | delta_radius_start: float = 0.0 14 | delta_radius_end: float = 0.0 15 | height: int = 256 16 | width: int = 256 17 | # num_videos_per_image_prompt: int = 1 # Only support 1 for running on < 24G memory GPU 18 | num_inference_steps: int = 50 19 | guidance_scale_zero123: float = 3.0 20 | guidance_scale_video: float = 1.0 21 | eta: float = 0.0 # 0.0 for purely deterministic, 1.0 for purely stochastic 22 | noise_identical_accross_frames: bool = False 23 | prompt: str = "" 24 | 25 | video_linear_start_weight: float = 1.0 26 | video_linear_end_weight: float = 0.5 27 | video_start_step_percentage: float = 0.0 28 | video_end_step_percentage: float = 1.0 29 | zero123_linear_start_weight: float = 1.0 30 | zero123_linear_end_weight: float = 1.0 31 | zero123_start_step_percentage: float = 0.0 32 | zero123_end_step_percentage: float = 1.0 33 | 34 | skip_refiner: bool = False 35 | refiner_strength: float = 0.3 36 | refiner_guidance_scale: float = 12.0 37 | 38 | obj_name: str = "new_balance_used" 39 | input_image_path: str = "tmp/new_balance_used/012.png" 40 | exp_name: str = "test_exp" 41 | -------------------------------------------------------------------------------- /scripts/job_config_yaml_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import yaml 4 | import csv 5 | import argparse 6 | from vivid123.configs import ViVid123BaseSchema 7 | 8 | 9 | SLURM_TMPDIR = os.getenv("SLURM_TMPDIR") if os.getenv("SLURM_TMPDIR") else "/home/erqun/vivid123/tmp" 10 | 11 | job_specs = [ 12 | # {"num_frames": 24, "delta_azimuth_start": 15, "delta_azimuth_end": 360, "exp_name": "num_frames_24"}, 13 | {} # default job specified by default schema in vivid123/configs/base_schema.py 14 | ] 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser(description='ViVid123 Generation') 18 | parser.add_argument('--run_on_slurm', action='store_true', help="whether to run on a slurm cluster") 19 | args = parser.parse_args() 20 | 21 | for job_spec in job_specs: 22 | with open("scripts/gso_metadata_object_prompt_100.csv", 'r') as f_metadata: 23 | csv_lines = csv.reader(f_metadata, delimiter=',', quotechar='"') 24 | my_model = ViVid123BaseSchema() 25 | for fieldname, value in job_spec.items(): 26 | if hasattr(my_model, fieldname): 27 | setattr(my_model, fieldname, value) 28 | else: 29 | raise ValueError(f"No field {fieldname}") 30 | 31 | task_yamls_output_dir = f"exps/task_yamls/{my_model.exp_name}" 32 | os.makedirs(task_yamls_output_dir, exist_ok=True) 33 | for i, csv_line in enumerate(csv_lines): 34 | my_model.obj_name = csv_line[0] 35 | if args.run_on_slurm: 36 | my_model.input_image_path = r"${SLURM_TMPDIR}/" + f"{my_model.obj_name}/img/012.png" 37 | else: 38 | my_model.input_image_path = f"./tmp/{my_model.obj_name}/img/012.png" 39 | with open(os.path.join(task_yamls_output_dir, f"{my_model.obj_name}.yaml"), "w") as f_job: 40 | print(f"dumping yaml to ", os.path.join(task_yamls_output_dir, f"{my_model.obj_name}.yaml")) 41 | yaml.dump(my_model.model_dump(), f_job) 42 | -------------------------------------------------------------------------------- /run_zero123.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | 5 | from diffusers.models import UNet2DConditionModel, AutoencoderKL 6 | from diffusers.schedulers import DDIMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler 7 | from transformers import CLIPVisionModelWithProjection 8 | from vivid123.models import CLIPCameraProjection 9 | from vivid123.pipelines import Zero1to3StableDiffusionPipeline 10 | 11 | from diffusers.utils import export_to_video 12 | 13 | model_id = "bennyguo/zero123-xl-diffusers" 14 | 15 | zero123_unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", cache_dir="/scratch/.cache") 16 | zero123_cam_proj = CLIPCameraProjection.from_pretrained(model_id, subfolder="clip_camera_projection", cache_dir="/scratch/.cache") 17 | zero123_img_enc = CLIPVisionModelWithProjection.from_pretrained(model_id, subfolder="image_encoder", cache_dir="/scratch/.cache") 18 | vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", cache_dir="/scratch/.cache") 19 | scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler", cache_dir="/scratch/.cache") 20 | zero123_pipe = Zero1to3StableDiffusionPipeline( 21 | vae=vae, 22 | image_encoder=zero123_img_enc, 23 | unet=zero123_unet, 24 | scheduler=scheduler, 25 | cc_projection=zero123_cam_proj, 26 | requires_safety_checker=False, 27 | safety_checker=None, 28 | feature_extractor=None, 29 | ) 30 | 31 | # zero123_pipe.enable_xformers_memory_efficient_attention() 32 | # zero123_pipe.enable_vae_tiling() 33 | # zero123_pipe.enable_attention_slicing() 34 | zero123_pipe = zero123_pipe.to("cuda") 35 | 36 | query_pose = [0, 45.0, 0.0] 37 | 38 | # for single input 39 | H, W = (256, 256) 40 | num_images_per_prompt = 1 41 | 42 | input_image = Image.open("data/Squirrel/img/012.png").convert("RGBA").resize((H, W), Image.BICUBIC) 43 | background = Image.new("RGBA", input_image.size, (255, 255, 255)) 44 | alpha_composite = Image.alpha_composite(background, input_image) 45 | 46 | input_images = [alpha_composite] 47 | query_poses = [query_pose] 48 | 49 | images = zero123_pipe( 50 | input_imgs=input_images, 51 | prompt_imgs=input_images, 52 | poses=query_poses, 53 | height=H, 54 | width=W, 55 | guidance_scale=3.0, 56 | num_inference_steps=50, 57 | ).images 58 | 59 | # save imgs 60 | log_dir = "logs" 61 | os.makedirs(log_dir, exist_ok=True) 62 | bs = len(input_images) 63 | i = 0 64 | for obj in range(bs): 65 | for idx in range(num_images_per_prompt): 66 | images[i].save(os.path.join(log_dir, f"obj{obj}_{idx}.jpg")) 67 | i += 1 -------------------------------------------------------------------------------- /run_batch_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import csv 5 | 6 | from vivid123 import generation_vivid123, prepare_vivid123_pipeline 7 | 8 | 9 | ZERO123_MODEL_ID = "bennyguo/zero123-xl-diffusers" 10 | VIDEO_MODEL_ID = "cerspense/zeroscope_v2_576w" 11 | VIDEO_XL_MODEL_ID = "cerspense/zeroscope_v2_XL" 12 | 13 | SLURM_TMPDIR = os.getenv("SLURM_TMPDIR") if os.getenv("SLURM_TMPDIR") else "./tmp" 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser(description='ViVid123 Generation') 18 | parser.add_argument('--task_yamls_dir', type=str, required=True, help='The directory for all configs') 19 | parser.add_argument('--dataset_dir', type=str, required=True, help='The directory for all groundtruth renderings, each object being a zip file') 20 | parser.add_argument('--output_dir', type=str, required=True, help='The root directory for all outputs') 21 | parser.add_argument('--obj_csv_file', type=str, required=True, help='The csv file for all objects') 22 | parser.add_argument('--run_from_obj_index', type=int, default=0, help='The index of object to start with') 23 | parser.add_argument('--run_to_obj_index', type=int, default=999, help='The index of object to end with') 24 | args = parser.parse_args() 25 | 26 | vivid123_pipe, xl_pipe = prepare_vivid123_pipeline( 27 | ZERO123_MODEL_ID=ZERO123_MODEL_ID, 28 | VIDEO_MODEL_ID=VIDEO_MODEL_ID, 29 | VIDEO_XL_MODEL_ID=VIDEO_XL_MODEL_ID 30 | ) 31 | 32 | with open(args.obj_csv_file, 'r') as csv_file: 33 | csv_lines = csv.reader(csv_file, delimiter=',', quotechar='"') 34 | for i, csv_line in enumerate(csv_lines): 35 | if i < args.run_from_obj_index: 36 | continue 37 | if i > args.run_to_obj_index: 38 | break 39 | 40 | obj_name = csv_line[0] 41 | if os.path.isfile(f"{args.output_dir}/{obj_name}/xl.mp4"): 42 | print(f"{obj_name} has already been generated, skipping...") 43 | continue 44 | 45 | print(f"Processing {obj_name}") 46 | if not os.path.exists(f"{SLURM_TMPDIR}/{obj_name}"): 47 | print(f"unpacking {args.dataset_dir}/{obj_name}.zip to {SLURM_TMPDIR}/{obj_name}") 48 | shutil.unpack_archive(f"{args.dataset_dir}/{obj_name}.zip", f"{SLURM_TMPDIR}/{obj_name}") 49 | 50 | generation_vivid123( 51 | vivid123_pipe=vivid123_pipe, 52 | xl_pipe=xl_pipe, 53 | config_path=f"{args.task_yamls_dir}/{obj_name}.yaml", 54 | output_root_dir=args.output_dir, 55 | ) -------------------------------------------------------------------------------- /vivid123/models/clip_camera_projection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import inspect 15 | import math 16 | import warnings 17 | from typing import Any, Callable, Dict, List, Optional, Union 18 | 19 | import PIL 20 | import torch 21 | import torchvision.transforms.functional as TF 22 | from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config 23 | from diffusers.image_processor import VaeImageProcessor 24 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 25 | from diffusers.models.modeling_utils import ModelMixin 26 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 27 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 28 | from diffusers.pipelines.stable_diffusion.safety_checker import ( 29 | StableDiffusionSafetyChecker, 30 | ) 31 | from diffusers.schedulers import KarrasDiffusionSchedulers 32 | from diffusers.utils import deprecate, is_accelerate_available, logging 33 | from diffusers.utils.torch_utils import randn_tensor 34 | from packaging import version 35 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 36 | 37 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 38 | 39 | 40 | class CLIPCameraProjection(ModelMixin, ConfigMixin): 41 | """ 42 | A Projection layer for CLIP embedding and camera embedding. 43 | Parameters: 44 | embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `clip_embed` 45 | additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the 46 | projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + 47 | additional_embeddings`. 48 | """ 49 | 50 | @register_to_config 51 | def __init__(self, embedding_dim: int = 768, additional_embeddings: int = 4): 52 | super().__init__() 53 | self.embedding_dim = embedding_dim 54 | self.additional_embeddings = additional_embeddings 55 | 56 | self.input_dim = self.embedding_dim + self.additional_embeddings 57 | self.output_dim = self.embedding_dim 58 | 59 | self.proj = torch.nn.Linear(self.input_dim, self.output_dim) 60 | 61 | def forward( 62 | self, 63 | embedding: torch.FloatTensor, 64 | ): 65 | """ 66 | The [`PriorTransformer`] forward method. 67 | Args: 68 | hidden_states (`torch.FloatTensor` of shape `(batch_size, input_dim)`): 69 | The currently input embeddings. 70 | Returns: 71 | The output embedding projection (`torch.FloatTensor` of shape `(batch_size, output_dim)`). 72 | """ 73 | proj_embedding = self.proj(embedding) 74 | return proj_embedding -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ViVid-1-to-3: Novel View Synthesis with Video Diffusion Models 2 | 3 | This repository is a reference implementation for ViVid-1-to-3. It combines video diffusion with novel-view synthesis diffusion models for increased pose and appearace consistency. 4 | 5 | [[arXiv]](https://arxiv.org/abs/2312.01305), [[project page]](https://ubc-vision.github.io/vivid123/) 6 | 7 | ## Requirements 8 | ```bash 9 | pip install torch "diffusers==0.24" transformers accelerate einops kornia imageio[ffmpeg] opencv-python pydantic scikit-image lpips 10 | ``` 11 | 12 | ## Run single generation task 13 | Put the reference image to $IMAGE_PATH, and set the `input_image_path` in `scripts/task_example.yaml` to it. Then run 14 | ```bash 15 | python run_generation.py --task_yaml_path=scripts/task_example.yaml 16 | ``` 17 | 18 | ## Run batch generation tasks 19 | We have supported running batch generation tasks on both PC and SLURM clusters. 20 | ### Prepare batch generation config yaml file 21 | We tested our method on 100 [GSO](https://app.gazebosim.org/GoogleResearch/fuel/collections/Scanned%20Objects%20by%20Google%20Research) objects. The list of the objects is in `scripts/gso_metadata_object_prompt_100.csv`, along with our labeled text prompts if you would like to test prompt-based generation yourself. We have rendered the 100 objects beforehand. It can be downloaded [here](https://drive.google.com/file/d/1A9PJDRD27igX5p88slWVF_QSDKxaZDCZ/view?usp=sharing). You can decompress the content into `gso-100`. Then simply run the following line to prepare a batch generation job on a PC: 22 | ```bash 23 | python -m scripts.job_config_yaml_generation 24 | ``` 25 | Or run the following line to prepare a batch generation job on a SLURM cluster, which will move temporary stuff to `$SLURM_TMPDIR` of your cluster: 26 | ``` 27 | python -m scripts.job_config_yaml_generation --run_on_slurm 28 | ``` 29 | All the yaml files will be generated in a new folder called `tasks_gso`. 30 | 31 | If you want to run customized batch generation, simply add an entry in the `job_specs` list in the beginning of `scripts/job_config_yaml_generation.py` and run it with the same bash command. An example has been commented out in it. 32 | 33 | 34 | ### Batch generation 35 | For batch generation, run 36 | ```bash 37 | python run_batch_generation.py --task_yamls_dir=tasks_gso --dataset_dir=gso-100 --output_dir=outputs --obj_csv_file=scripts/gso_metadata_object_prompt_100.csv 38 | ``` 39 | 40 | ### Tips for scheduling batch generation on SLURM clusters 41 | It takes about 1min30s to run one generation on a v100 gpu. If the number of generations is too large for each job you can schedule on a SLURM cluster, 42 | you can split the dataset for each job using the `--run_from_obj_index` and `--run_to_obj_index` options. For example 43 | ```bash 44 | python run_batch_generation.py --task_yamls_dir=tasks_gso --dataset_dir=gso-100 --output_dir=outputs --obj_csv_file=scripts/gso_metadata_object_prompt_100.csv --run_from_obj_index=0 --run_to_obj_index=50 45 | ``` 46 | 47 | ### Run evaluation 48 | #### Get metrics for each object 49 | To run evaluation for a batch generation, put the experiments you want to evaluate in the `eval_specs` list in `run_evaluation.py`. Make sure the `exp_name` key has the same value as that of your batch generation. Also, you should modify the `expdir` and `savedir` in `run_evaluation.py`. Suppose you want to run the $EXP_ID-th experiment in the list, then do the following: 50 | ```bash 51 | python run_evaluation.py --exp_id $EXP_ID 52 | ``` 53 | After the evaluation is run, intermediate results on PSNR, SSIM, LPIPS, FOR_8, FOR_16 for each object will be put to `savedir`. 54 | #### Get stats for this experiment 55 | Finally, you can use `run_calculate_stats.py` to get the PSNR, SSIM, LPIPS, FOR_8, FOR_16 stats for this experiment on your whole dataset. Make sure to modify the `psnr_save_dir`, `lpips_save_dir`, `ssim_save_dir`, `for_8_save_dir`, `for_16_save_dir` in `run_calculate_stats.py` to match the folder storing the intermediate results from the last step. 56 | ```bash 57 | python run_calculate_stats.py 58 | ``` 59 | 60 | 61 | 62 | ## Acknowledgement 63 | This repo is based on the Huggingface community [implementation](https://github.com/huggingface/diffusers/blob/main/examples/community/pipeline_zero1to3.py) and [converted weights](https://huggingface.co/bennyguo/zero123-xl-diffusers) of [Zero-1-to-3](https://github.com/cvlab-columbia/zero123), as well as the Huggingface community text-to-video model [Zeroscope v2](https://huggingface.co/cerspense/zeroscope_v2_576w). Thanks for their awesome works. 64 | 65 | ## Citation 66 | 67 | If you use this code in your research, please cite our paper: 68 | ``` 69 | @inproceedings{kwak2024vivid, 70 | title={Vivid-1-to-3: Novel view synthesis with video diffusion models}, 71 | author={Kwak, Jeong-gi and Dong, Erqun and Jin, Yuhe and Ko, Hanseok and Mahajan, Shweta and Yi, Kwang Moo}, 72 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 73 | pages={6775--6785}, 74 | year={2024} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /run_evaluation.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import argparse 3 | import datetime 4 | import os 5 | 6 | import glob 7 | import numpy as np 8 | 9 | import shutil 10 | import csv 11 | from vivid123.metrics import LPIPSMeter, PSNRMeter, SSIM, FOR 12 | 13 | SLURM_TMPDIR = ( 14 | os.getenv("SLURM_TMPDIR") 15 | if os.getenv("SLURM_TMPDIR") 16 | else "/scratch/rendering-360/" # the dir where the gt images are decompressed to, if it exists on your local machine 17 | ) 18 | 19 | # should specify the indeces of the frames to be evaluated in both the generation dir and the gt dir, like the example below 20 | eval_specs = [ 21 | { 22 | "exp_name": "num_frames_24", 23 | "vid_frame_indeces": [2 * i for i in range(12)], 24 | "gt_indeces": [3 + 6 * i for i in range(12)], 25 | }, 26 | ] 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | # For batch running on Compute Canada 31 | parser.add_argument("--exp_id", type=int) 32 | parser.add_argument( 33 | "--metadata", 34 | type=str, 35 | default="scripts/gso_metadata_object_prompt_100.csv", 36 | ) 37 | parser.add_argument( 38 | "--gt_dir", 39 | type=str, 40 | default="/scratch/rendering-360-zips", 41 | help="Directory containing the ground truth images, each object in a separate zip file, the zip file contains a folder named 'img' with the images.", 42 | ) 43 | 44 | args = parser.parse_args() 45 | 46 | lpips_scorer = LPIPSMeter(device="cuda:0", size=512, net="vgg") 47 | psnr_scorer = PSNRMeter(size=512) 48 | ssim_scorer = SSIM(size=512) 49 | for_scorer = FOR( 50 | size=512, 51 | ) 52 | 53 | exp = eval_specs[args.exp_id]["exp_name"] 54 | expdir = f"/scratch-ssd/vivid123/exps/samples/{exp}" 55 | savedir = f"/scratch-ssd/vivid123/exps/evaluations/{exp}" 56 | os.makedirs(savedir, exist_ok=True) 57 | 58 | vid_frame_indeces = eval_specs[args.exp_id]["vid_frame_indeces"] 59 | gt_indeces = eval_specs[args.exp_id]["gt_indeces"] 60 | num_views = len(vid_frame_indeces) 61 | 62 | csv_columns = ( 63 | ["obj", "psnr", "lpips", "ssim", "for_8", "for_16"] 64 | + [f"psnr_{i}" for i in range(num_views)] 65 | + [f"lpips_{i}" for i in range(num_views)] 66 | + [f"ssim_{i}" for i in range(num_views)] 67 | + [f"for_8_{i}" for i in range(num_views)] 68 | + [f"for_16_{i}" for i in range(num_views)] 69 | ) 70 | 71 | with open(args.metadata, newline="") as csvmetadatafile: 72 | csv_lines = csv.reader(csvmetadatafile, delimiter=",", quotechar='"') 73 | for csv_line in csv_lines: 74 | object_identifier = csv_line[0] 75 | csv_exp_file = f"{savedir}/{object_identifier}.csv" 76 | if os.path.isfile(csv_exp_file): 77 | continue 78 | 79 | if not os.path.isfile(f"{SLURM_TMPDIR}/{object_identifier}/img/000.png"): 80 | shutil.unpack_archive( 81 | f"{args.gt_dir}/{object_identifier}.zip", 82 | f"{SLURM_TMPDIR}/{object_identifier}", 83 | ) 84 | 85 | result_dict = {} 86 | 87 | gt_paths_sorted = sorted( 88 | glob.glob(f"{SLURM_TMPDIR}/{object_identifier}/img/*.png") 89 | ) 90 | pred_paths_sorted = sorted( 91 | glob.glob(f"{expdir}/{object_identifier}/xl_frames/*.png") 92 | ) 93 | print(f"object_identifier: {object_identifier}") 94 | gt_paths = [gt_paths_sorted[i] for i in gt_indeces] 95 | pred_paths = [pred_paths_sorted[i] for i in vid_frame_indeces] 96 | if ( 97 | len(gt_paths) == 0 98 | or len(pred_paths) == 0 99 | or len(gt_paths) != len(pred_paths) 100 | ): 101 | print(f"gt_path_list: {gt_paths}") 102 | print(f"pred_path_list: {pred_paths}") 103 | print( 104 | f"\n\n{object_identifier} doesn't have data or the rendering wasn't complete in {expdir}! Skipping this object...\n\n" 105 | ) 106 | continue 107 | 108 | result_dict["obj"] = object_identifier 109 | result_dict["psnr"], psnrs = psnr_scorer.score_gt(gt_paths, pred_paths) 110 | result_dict["lpips"], lpips = lpips_scorer.score_gt(gt_paths, pred_paths) 111 | result_dict["ssim"], ssims = ssim_scorer.score_gt(gt_paths, pred_paths) 112 | masked_flow_error_and_mask = for_scorer.raft_predict( 113 | gt_paths, 114 | pred_paths, 115 | results_path=f"exps/optical_flow_tmp/{exp}", 116 | obj_id=object_identifier, 117 | ) 118 | result_dict["for_8"], for_8s = for_scorer.score_gt( 119 | masked_flow_error_and_mask, threshold=8 120 | ) 121 | result_dict["for_16"], for_16s = for_scorer.score_gt( 122 | masked_flow_error_and_mask, threshold=16 123 | ) 124 | 125 | for i in range(num_views): 126 | result_dict[f"psnr_{i}"] = psnrs[i] 127 | result_dict[f"lpips_{i}"] = lpips[i] 128 | result_dict[f"ssim_{i}"] = ssims[i] 129 | result_dict[f"for_8_{i}"] = for_8s[i] 130 | result_dict[f"for_16_{i}"] = for_16s[i] 131 | 132 | print(f"PSNR for {object_identifier}: {result_dict['psnr']}") 133 | print(f"LPIPS for {object_identifier}: {result_dict['lpips']}") 134 | print(f"SSIM for {object_identifier}: {result_dict['ssim']}") 135 | print(f"FOR_8 for {object_identifier}: {result_dict['for_8']}") 136 | print(f"FOR_16 for {object_identifier}: {result_dict['for_16']}") 137 | 138 | with open(csv_exp_file, "a") as csvexpfile: 139 | writer = csv.DictWriter(csvexpfile, fieldnames=csv_columns) 140 | writer.writeheader() 141 | writer.writerow(result_dict) 142 | -------------------------------------------------------------------------------- /scripts/gso_metadata_object_prompt_100.csv: -------------------------------------------------------------------------------- 1 | 3D_Dollhouse_Sink,"a sink toy in dollhouse" 2 | 3D_Dollhouse_Sofa,"a sofa toy in dollhouse" 3 | 3D_Dollhouse_Swing,"a swing toy in dollhouse" 4 | 3D_Dollhouse_TablePurple,"a purple table toy in dollhouse" 5 | adiZero_Slide_2_SC,"a slide in white and yellow color" 6 | adistar_boost_m,"gray sneaker with green bottom" 7 | ASICS_GELLinksmaster_WhiteRasberryGunmetal,"sneaker in white and black color" 8 | Android_Lego,"Android 3D lego" 9 | Circo_Fish_Toothbrush_Holder_14995988,"a green and yellow turtle toy with a ground" 10 | Crosley_Alarm_Clock_Vintage_Metal,"vintag ealarm clock" 11 | DOLL_FAMILY,"doll family" 12 | DPC_Handmade_Hat_Brown,"a brown hat" 13 | Dino_3,"dinosaur toy" 14 | Dino_4,"dinosaur toy" 15 | Dino_5,"dinosaur toy" 16 | Down_To_Earth_Orchid_Pot_Ceramic_Lime,"ceramic lime orchid pot" 17 | Great_Dinos_Triceratops_Toy,"Android mini figure" 18 | Mens_Santa_Cruz_Thong_in_Chocolate_lvxYW7lek6B,"thong sandals" 19 | Mens_Striper_Sneaker_in_White_rnp8HUli59Y,"" 20 | My_First_Wiggle_Crocodile,"wiggle crocodile toy" 21 | My_Little_Pony_Princess_Celestia,"pony princess celestia toy" 22 | Nickelodeon_Teenage_Mutant_Ninja_Turtles_Leonardo,"teenage mutant ninja turtles Leonardo" 23 | Nickelodeon_Teenage_Mutant_Ninja_Turtles_Michelangelo,"teenage mutant ninja turtles Michelangelo" 24 | Nickelodeon_Teenage_Mutant_Ninja_Turtles_Raphael,"teenage mutant ninja turtles Raphael" 25 | Nintendo_Mario_Action_Figure,"Nintendo Mario action figure" 26 | Nintendo_Yoshi_Action_Figure,"Nintendo Yoshi action figure" 27 | Olive_Kids_Game_On_Pack_n_Snack,"kids backpack" 28 | Olive_Kids_Trains_Planes_Trucks_Bogo_Backpack,"kids backpack" 29 | Ortho_Forward_Facing,"lion toy" 30 | Ortho_Forward_Facing_3Q6J2oKJD92,"hippo toy" 31 | Ortho_Forward_Facing_QCaor9ImJ2G,"koala toy" 32 | Pepsi_Caffeine_Free_Diet_12_CT,"a pepsi box" 33 | Pepsi_Cola_Wild_Cherry_Diet_12_12_fl_oz_355_ml_cans_144_fl_oz_426_lt,"pepsi red and white box" 34 | Poppin_File_Sorter_Blue,"file sorter blue" 35 | Predito_LZ_TRX_FG_W,"a white soccer shoe with pink and blue stripes" 36 | RJ_Rabbit_Easter_Basket_Blue,"a basket with flowers on it is shown" 37 | Racoon,"a stuffed animal raccoon sitting" 38 | Razer_Kraken_Pro_headset_Full_size_Black,"razer kraken surround sound gaming headset" 39 | Razer_Taipan_Black_Ambidextrous_Gaming_Mouse,"a black computer mouse" 40 | Razer_Taipan_White_Ambidextrous_Gaming_Mouse,"a white computer mouse" 41 | Reebok_ALLYLYNN,"a blue high top sneaker with yellow accents" 42 | Reebok_CL_RAYEN,"a sneaker crossfit nano blue" 43 | Reebok_FS_HI_INT_R12,"a white and red high top sneaker with a red sole" 44 | Reebok_SMOOTHFLEX_CUSHRUN_20,"a running shoe with a blue and green design" 45 | Reebok_TURBO_RC,"a running shoe of grey and yellow color" 46 | Reebok_ZIGSTORM,"a blue and yellow sandal with a strap" 47 | Retail_Leadership_Summit,"a hat on a ground with a white band" 48 | Room_Essentials_Kitchen_Towels_16_x_26_2_count,"two folded blankets on a ground" 49 | Rubbermaid_Large_Drainer,"a black rack with several baskets on it" 50 | SAMOA,"adidas originals nike sb trainers" 51 | SORTING_TRAIN,"a toy train with a man on it" 52 | STACKING_BEAR,"a toy bear with a rainbow colored stack of toys" 53 | Schleich_African_Black_Rhino,"a rhino standing" 54 | Schleich_Allosaurus,"a small dinosaur is walking on a ground" 55 | Schleich_Bald_Eagle,"a bald eagle is flying in the air" 56 | Schleich_S_Bayala_Unicorn_70432,"a white horse with green flowers on its back" 57 | Schleich_Spinosaurus_Action_Figure,"a small dinosaur toy" 58 | Schleich_Therizinosaurus_ln9cruulPqc,"a dinosaur toy with its head up and its legs spread out" 59 | Shark,"a stuffed shark toy on the ground" 60 | Smith_Hawken_Woven_BasketTray_Organizer_with_3_Compartments_95_x_9_x_13,"a wicker basket with four compartments on a ground" 61 | Sonny_School_Bus,"a toy truck with a driver" 62 | Sootheze_Cold_Therapy_Elephant,"a stuffed elephant toy sitting on the ground" 63 | Sootheze_Toasty_Orca,"a stuffed killer whale toy on the ground" 64 | Sperry_TopSider_pSUFPWQXPp3,"a brown ankle boot with a zipper on the side" 65 | SpiderMan_Titan_Hero_12Inch_Action_Figure_5Hnn4mtkFsP,"a toy spiderman standing on the ground" 66 | SpiderMan_Titan_Hero_12Inch_Action_Figure_oo1qph4wwiW,"a toy spider man standing on the ground" 67 | Squirrel,"a stuffed squirrel toy sitting on its hind legs" 68 | Squirt_Strain_Fruit_Basket,"a green basket with fruit in it" 69 | Target_Basket_Medium,"a white basket with a white cover on it" 70 | Thomas_Friends_Wooden_Railway_Porter_5JzRhMm3a9o,"thomas the tank engine" 71 | Thomas_Friends_Wooden_Railway_Talking_Thomas_z7yi7UFHJRj,"thomas the tank engine wooden train" 72 | Threshold_Basket_Natural_Finish_Fabric_Liner_Small,"a wicker basket with a white lid" 73 | Timberland_Mens_Earthkeepers_Newmarket_6Inch_Cupsole_Boot,"a red high top shoe with laces on the side" 74 | Timberland_Mens_Earthkeepers_Stormbuck_Lite_Plain_Toe_Oxford,"a tan shoe on the ground" 75 | Toysmith_Windem_Up_Flippin_Animals_Dog,"a toy dog with a tongue sticking out on a ground" 76 | Transformers_Age_of_Extinction_Mega_1Step_Bumblebee_Figure,"transformers bumblebee" 77 | Transformers_Age_of_Extinction_Stomp_and_Chomp_Grimlock_Figure,"transformers robot dinosaur" 78 | TriStar_Products_PPC_Power_Pressure_Cooker_XL_in_Black,"a large stainless steel pressure cooker on a ground" 79 | UGG_Classic_Tall_Womens_Boots_Chestnut_7,"a tan boot on the ground" 80 | UGG_Jena_Womens_Java_7,"a brown boot shoe on the ground" 81 | US_Army_Stash_Lunch_Bag,"a military style lunch bag with a badge on it" 82 | Vtech_Roll_Learn_Turtle,"a toy turtle with a green shell and colorful toys" 83 | Vtech_Stack_Sing_Rings_636_Months,"a stack of colorful toys on the ground" 84 | Weisshai_Great_White_Shark,"a shark swimming in the dark with no light" 85 | Whey_Protein_3_Flavor_Variety_Pack_12_Packets,"whey protein powder in a box" 86 | Whey_Protein_Vanilla,"a jar of whey protein on the ground" 87 | Wishbone_Pencil_Case,"a purple and white pencil case" 88 | Womens_Betty_Chukka_Boot_in_Navy_aEE8OqvMII4,"a blue sneaker shoe on the ground" 89 | Womens_Betty_Chukka_Boot_in_Salt_Washed_Red_AL2YrOt9CRy,"a pink shoe with gold laces on it" 90 | Womens_Bluefish_2Eye_Boat_Shoe_in_Linen_Natural_Sparkle_Suede_kqi81aojcOR,"a white boat shoe on the ground" 91 | Womens_Bluefish_2Eye_Boat_Shoe_in_White_Tumbled_YG44xIePRHw,"a white boat shoe on the ground" 92 | Womens_Canvas_Bahama_in_White_4UyOhP6rYGO,"a white boat shoe on the ground" 93 | Womens_Cloud_Logo_Authentic_Original_Boat_Shoe_in_Black_Supersoft_8LigQYwf4gr,"a black boat shoe with white soles" 94 | Womens_Hikerfish_Boot_in_Black_Leopard_bVSNY1Le1sm,"a black boot shoe with leopard print on the side" 95 | Womens_Sequin_Bahama_in_White_Sequin_V9K1hf24Oxe,"a white boat shoe on the ground" 96 | Womens_Sequin_Bahama_in_White_Sequin_yGVsSA4tOwJ,"a white boat shoe with laces on the side" 97 | Womens_Sparkle_Suede_Bahama_in_Silver_Sparkle_Suede_Grey_Patent_tYrIBLMhSTN,"a white shoe on the ground" 98 | Womens_Teva_Capistrano_Bootie_ldjRT9yZ5Ht,"a grey boot shoe on a ground" 99 | ZX700_lYiwcTIekXk,"adidas shoe white green" 100 | ZX700_mzGbdP3u6JB,"adidas shoe black orange" 101 | -------------------------------------------------------------------------------- /run_calculate_stats.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import itertools 3 | import argparse 4 | import numpy as np 5 | import os 6 | 7 | eval_specs = [ 8 | {"exp_name": "num_frames_12", "num_frames": 12}, 9 | ] 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument( 15 | "--metadata", 16 | type=str, 17 | default="/scratch-ssd/vivid123/scripts/gso_metadata_object_prompt_100.csv", 18 | ) 19 | 20 | args = parser.parse_args() 21 | 22 | psnr_save_dir = f"exps/stats/total_result_metrics_psnr" 23 | lpips_save_dir = f"exps/stats/total_result_metrics_lpips" 24 | ssim_save_dir = f"exps/stats/total_result_metrics_ssim" 25 | for_8_save_dir = f"exps/stats/total_result_metrics_for_8" 26 | for_16_save_dir = f"exps/stats/total_result_metrics_for_16" 27 | os.makedirs(psnr_save_dir, exist_ok=True) 28 | os.makedirs(lpips_save_dir, exist_ok=True) 29 | os.makedirs(ssim_save_dir, exist_ok=True) 30 | os.makedirs(for_8_save_dir, exist_ok=True) 31 | os.makedirs(for_16_save_dir, exist_ok=True) 32 | 33 | for eval_spec in eval_specs: 34 | exp_name = eval_spec["exp_name"] 35 | num_views = eval_spec["num_frames"] 36 | eval_dir = f"exps/evaluations/{exp_name}" 37 | 38 | # Aggregate the evaluation results to the final stats 39 | csv_total_file_psnr = f"{psnr_save_dir}/{exp_name}.csv" 40 | csv_total_columns_psnr = ["exp_name", "psnr"] + [ 41 | f"psnr_{i}" for i in range(num_views) 42 | ] 43 | with open(csv_total_file_psnr, "w") as csvtotalfile: 44 | writer = csv.DictWriter(csvtotalfile, fieldnames=csv_total_columns_psnr) 45 | writer.writeheader() 46 | 47 | csv_total_file_lpips = f"{lpips_save_dir}/{exp_name}.csv" 48 | csv_total_columns_lpips = ["exp_name", "lpips"] + [ 49 | f"lpips_{i}" for i in range(num_views) 50 | ] 51 | with open(csv_total_file_lpips, "w") as csvtotalfile: 52 | writer = csv.DictWriter(csvtotalfile, fieldnames=csv_total_columns_lpips) 53 | writer.writeheader() 54 | 55 | csv_total_file_ssim = f"{ssim_save_dir}/{exp_name}.csv" 56 | csv_total_columns_ssim = ["exp_name", "ssim"] + [ 57 | f"ssim_{i}" for i in range(num_views) 58 | ] 59 | with open(csv_total_file_ssim, "w") as csvtotalfile: 60 | writer = csv.DictWriter(csvtotalfile, fieldnames=csv_total_columns_ssim) 61 | writer.writeheader() 62 | 63 | csv_total_file_for_8 = f"{for_8_save_dir}/{exp_name}.csv" 64 | csv_total_columns_for_8 = ["exp_name", "for_8"] + [ 65 | f"for_8_{i}" for i in range(num_views) 66 | ] 67 | with open(csv_total_file_for_8, "w") as csvtotalfile: 68 | writer = csv.DictWriter(csvtotalfile, fieldnames=csv_total_columns_for_8) 69 | writer.writeheader() 70 | 71 | csv_total_file_for_16 = f"{for_16_save_dir}/{exp_name}.csv" 72 | csv_total_columns_for_16 = ["exp_name", "for_16"] + [ 73 | f"for_16_{i}" for i in range(num_views) 74 | ] 75 | with open(csv_total_file_for_16, "w") as csvtotalfile: 76 | writer = csv.DictWriter(csvtotalfile, fieldnames=csv_total_columns_for_16) 77 | writer.writeheader() 78 | 79 | results_psnr = {"psnr": []} 80 | results_lpips = {"lpips": []} 81 | results_ssim = {"ssim": []} 82 | results_for_8 = {"for_8": []} 83 | results_for_16 = {"for_16": []} 84 | for i in range(num_views): 85 | results_psnr[f"psnr_{i}"] = [] 86 | results_lpips[f"lpips_{i}"] = [] 87 | results_ssim[f"ssim_{i}"] = [] 88 | results_for_8[f"for_8_{i}"] = [] 89 | results_for_16[f"for_16_{i}"] = [] 90 | 91 | count = 0 92 | with open(args.metadata, newline="") as csvmetadatafile: 93 | csv_lines = csv.reader(csvmetadatafile, delimiter=",", quotechar='"') 94 | for csv_line in csv_lines: 95 | object_identifier = csv_line[0] 96 | if not os.path.isfile(f"{eval_dir}/{object_identifier}.csv"): 97 | print( 98 | f"WARNING: {exp_name} doesn't have {object_identifier}! Skipping this object..." 99 | ) 100 | continue 101 | 102 | count += 1 103 | with open( 104 | f"{eval_dir}/{object_identifier}.csv", newline="" 105 | ) as csv_object_metric_file: 106 | reader = csv.DictReader(csv_object_metric_file, delimiter=",") 107 | row = reader.__next__() 108 | print(row) 109 | results_psnr["psnr"].append(float(row["psnr"])) 110 | results_lpips["lpips"].append(float(row["lpips"])) 111 | results_ssim["ssim"].append(float(row["ssim"])) 112 | results_for_8["for_8"].append(float(row["for_8"])) 113 | results_for_16["for_16"].append(float(row["for_16"])) 114 | for i in range(num_views): 115 | results_psnr[f"psnr_{i}"].append(float(row[f"psnr_{i}"])) 116 | results_lpips[f"lpips_{i}"].append(float(row[f"lpips_{i}"])) 117 | results_ssim[f"ssim_{i}"].append(float(row[f"ssim_{i}"])) 118 | results_for_8[f"for_8_{i}"].append(float(row[f"for_8_{i}"])) 119 | results_for_16[f"for_16_{i}"].append(float(row[f"for_16_{i}"])) 120 | 121 | print(f"{exp_name} has {count} objects finished") 122 | 123 | # write into csv file 124 | with open(csv_total_file_psnr, "a") as csvtotalfile: 125 | writer = csv.DictWriter(csvtotalfile, fieldnames=csv_total_columns_psnr) 126 | row_to_write = {} 127 | row_to_write["exp_name"] = exp_name 128 | row_to_write["psnr"] = np.mean(results_psnr["psnr"]) 129 | for i in range(num_views): 130 | row_to_write[f"psnr_{i}"] = np.mean(results_psnr[f"psnr_{i}"]) 131 | writer.writerow(row_to_write) 132 | 133 | with open(csv_total_file_lpips, "a") as csvtotalfile: 134 | writer = csv.DictWriter( 135 | csvtotalfile, fieldnames=csv_total_columns_lpips 136 | ) 137 | row_to_write = {} 138 | row_to_write["exp_name"] = exp_name 139 | row_to_write["lpips"] = np.mean(results_lpips["lpips"]) 140 | for i in range(num_views): 141 | row_to_write[f"lpips_{i}"] = np.mean(results_lpips[f"lpips_{i}"]) 142 | writer.writerow(row_to_write) 143 | 144 | with open(csv_total_file_ssim, "a") as csvtotalfile: 145 | writer = csv.DictWriter(csvtotalfile, fieldnames=csv_total_columns_ssim) 146 | row_to_write = {} 147 | row_to_write["exp_name"] = exp_name 148 | row_to_write["ssim"] = np.mean(results_ssim["ssim"]) 149 | for i in range(num_views): 150 | row_to_write[f"ssim_{i}"] = np.mean(results_ssim[f"ssim_{i}"]) 151 | writer.writerow(row_to_write) 152 | 153 | with open(csv_total_file_for_8, "a") as csvtotalfile: 154 | writer = csv.DictWriter( 155 | csvtotalfile, fieldnames=csv_total_columns_for_8 156 | ) 157 | row_to_write = {} 158 | row_to_write["exp_name"] = exp_name 159 | row_to_write["for_8"] = np.mean(results_for_8["for_8"]) 160 | for i in range(num_views): 161 | row_to_write[f"for_8_{i}"] = np.mean(results_for_8[f"for_8_{i}"]) 162 | writer.writerow(row_to_write) 163 | 164 | with open(csv_total_file_for_16, "a") as csvtotalfile: 165 | writer = csv.DictWriter( 166 | csvtotalfile, fieldnames=csv_total_columns_for_16 167 | ) 168 | row_to_write = {} 169 | row_to_write["exp_name"] = exp_name 170 | row_to_write["for_16"] = np.mean(results_for_16["for_16"]) 171 | for i in range(num_views): 172 | row_to_write[f"for_16_{i}"] = np.mean(results_for_16[f"for_16_{i}"]) 173 | writer.writerow(row_to_write) 174 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /vivid123/metrics/utils.py: -------------------------------------------------------------------------------- 1 | # code adapted from https://github.com/guochengqian/Magic123/blob/main/all_metrics/metric_utils.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | import torchvision.transforms as T 7 | import torchvision.transforms.functional as TF 8 | from torchvision.models.optical_flow import raft_large 9 | import matplotlib.pyplot as plt 10 | 11 | # import clip 12 | from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor 13 | from torchvision import transforms 14 | import numpy as np 15 | import torch.nn.functional as F 16 | from tqdm import tqdm 17 | import cv2 18 | from PIL import Image 19 | 20 | # import torchvision.transforms as transforms 21 | import glob 22 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 23 | import lpips 24 | from os.path import join as osp 25 | 26 | 27 | def numpy_to_torch(images): 28 | images = images * 2.0 - 1.0 29 | images = torch.from_numpy(images.transpose((0, 3, 1, 2))).float() 30 | return images.cuda() 31 | 32 | 33 | class LPIPSMeter: 34 | def __init__( 35 | self, net="alex", device=None, size=224 36 | ): # or we can use 'alex', 'vgg' as network 37 | self.size = size 38 | self.net = net 39 | self.results = [] 40 | self.device = ( 41 | device 42 | if device is not None 43 | else torch.device("cuda" if torch.cuda.is_available() else "cpu") 44 | ) 45 | self.fn = lpips.LPIPS(net=net).eval().to(self.device) 46 | 47 | def measure(self): 48 | return np.mean(self.results) 49 | 50 | def report(self): 51 | return f"LPIPS ({self.net}) = {self.measure():.6f}" 52 | 53 | def read_img_list(self, img_list): 54 | size = self.size 55 | images = [] 56 | 57 | for img_path in img_list: 58 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 59 | 60 | if img.shape[2] == 4: # Handle BGRA images 61 | alpha = img[:, :, 3] # Extract alpha channel 62 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR 63 | 64 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB 65 | img[np.where(alpha == 0)] = [ 66 | 255, 67 | 255, 68 | 255, 69 | ] # Set transparent pixels to white 70 | else: # Handle other image formats like JPG and PNG 71 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 72 | 73 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) 74 | images.append(img) 75 | 76 | images = np.stack(images, axis=0) 77 | images = images.astype(np.float32) / 255.0 78 | 79 | return images 80 | 81 | # * recommend to use this function for evaluation 82 | @torch.no_grad() 83 | def score_gt(self, ref_paths, novel_paths): 84 | self.results = [] 85 | # Load images 86 | img0, img1 = self.read_img_list(ref_paths), self.read_img_list(novel_paths) 87 | img0, img1 = numpy_to_torch(img0), numpy_to_torch(img1) 88 | img0 = F.interpolate(img0, size=(self.size, self.size), mode="area") 89 | img1 = F.interpolate(img1, size=(self.size, self.size), mode="area") 90 | 91 | self.results.append(self.fn.forward(img0, img1).cpu().squeeze().numpy()) 92 | 93 | return self.measure(), self.results[0] 94 | 95 | 96 | class PSNRMeter: 97 | def __init__(self, size=800): 98 | self.results = [] 99 | self.size = size 100 | 101 | def read_img_list(self, img_list): 102 | size = self.size 103 | images = [] 104 | for img_path in img_list: 105 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 106 | 107 | if img.shape[2] == 4: # Handle BGRA images 108 | alpha = img[:, :, 3] # Extract alpha channel 109 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB 112 | img[np.where(alpha == 0)] = [ 113 | 255, 114 | 255, 115 | 255, 116 | ] # Set transparent pixels to white 117 | else: # Handle other image formats like JPG and PNG 118 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 119 | 120 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) 121 | images.append(img) 122 | 123 | images = np.stack(images, axis=0) 124 | images = images.astype(np.float32) / 255.0 125 | return images 126 | 127 | def update(self, preds, truths): 128 | psnr_values = [] 129 | # For each pair of images in the batches 130 | for img1, img2 in zip(preds, truths): 131 | # Compute the PSNR and add it to the list 132 | psnr = compare_psnr( 133 | img1, img2, data_range=1.0 134 | ) # assuming your images are scaled to [0,1] 135 | psnr_values.append(psnr) 136 | 137 | # Convert the list of PSNR values to a numpy array 138 | self.results = psnr_values 139 | 140 | def measure(self): 141 | return np.mean(self.results) 142 | 143 | def report(self): 144 | return f"PSNR = {self.measure():.6f}" 145 | 146 | # * recommend to use this function for evaluation 147 | def score_gt(self, ref_paths, novel_paths): 148 | self.results = [] 149 | print(f"ref_paths: {ref_paths}") 150 | print(f"novel_pahts: {novel_paths}") 151 | # [B, N, 3] or [B, H, W, 3], range[0, 1] 152 | preds = self.read_img_list(novel_paths) 153 | truths = self.read_img_list(ref_paths) 154 | self.update(preds, truths) 155 | return self.measure(), self.results 156 | 157 | 158 | def rgb_ssim( 159 | img0, 160 | img1, 161 | max_val, 162 | filter_size=11, 163 | filter_sigma=1.5, 164 | k1=0.01, 165 | k2=0.03, 166 | return_map=False, 167 | ): 168 | """Evaluation metrics ssim""" 169 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 170 | assert len(img0.shape) == 3 171 | assert img0.shape[-1] == 3 172 | assert img0.shape == img1.shape 173 | 174 | # Construct a 1D Gaussian blur filter. 175 | hw = filter_size // 2 176 | shift = (2 * hw - filter_size + 1) / 2 177 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma) ** 2 178 | filt = np.exp(-0.5 * f_i) 179 | filt /= np.sum(filt) 180 | 181 | # Blur in x and y (faster than the 2D convolution). 182 | def convolve2d(z, f): 183 | import scipy 184 | 185 | return scipy.signal.convolve2d(z, f, mode="valid") 186 | 187 | filt_fn = lambda z: np.stack( 188 | [ 189 | convolve2d(convolve2d(z[..., i], filt[:, None]), filt[None, :]) 190 | for i in range(z.shape[-1]) 191 | ], 192 | -1, 193 | ) 194 | mu0 = filt_fn(img0) 195 | mu1 = filt_fn(img1) 196 | mu00 = mu0 * mu0 197 | mu11 = mu1 * mu1 198 | mu01 = mu0 * mu1 199 | sigma00 = filt_fn(img0**2) - mu00 200 | sigma11 = filt_fn(img1**2) - mu11 201 | sigma01 = filt_fn(img0 * img1) - mu01 202 | 203 | # Clip the variances and covariances to valid values. 204 | # Variance must be non-negative: 205 | sigma00 = np.maximum(0.0, sigma00) 206 | sigma11 = np.maximum(0.0, sigma11) 207 | sigma01 = np.sign(sigma01) * np.minimum(np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 208 | c1 = (k1 * max_val) ** 2 209 | c2 = (k2 * max_val) ** 2 210 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 211 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 212 | ssim_map = numer / denom 213 | ssim = np.mean(ssim_map) 214 | return ssim_map if return_map else ssim 215 | 216 | 217 | class SSIM: 218 | def __init__(self, use_gpu=True, size=512): 219 | super().__init__() 220 | self.use_gpu = use_gpu 221 | self.size = size 222 | 223 | def measure(self): 224 | return np.mean(self.results) 225 | 226 | def read_img_list(self, img_list): 227 | size = self.size 228 | images = [] 229 | for img_path in img_list: 230 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 231 | 232 | if img.shape[2] == 4: # Handle BGRA images 233 | alpha = img[:, :, 3] # Extract alpha channel 234 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR 235 | 236 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB 237 | img[np.where(alpha == 0)] = [ 238 | 255, 239 | 255, 240 | 255, 241 | ] # Set transparent pixels to white 242 | else: # Handle other image formats like JPG and PNG 243 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 244 | 245 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) 246 | img = img.astype(np.float32) / 255.0 247 | images.append(img) 248 | 249 | return images 250 | 251 | def update(self, preds, truths): 252 | ssim_values = [] 253 | for img1, img2 in zip(preds, truths): 254 | val = rgb_ssim(img1, img2, max_val=1.0) 255 | ssim_values.append(val) 256 | 257 | self.results = ssim_values 258 | 259 | @torch.no_grad() 260 | def score_gt(self, ref_paths, novel_paths): 261 | self.results = [] 262 | # [B, N, 3] or [B, H, W, 3], range[0, 1] 263 | preds = self.read_img_list(novel_paths) 264 | truths = self.read_img_list(ref_paths) 265 | self.update(preds, truths) 266 | return self.measure(), self.results 267 | 268 | 269 | class FOR: 270 | def __init__( 271 | self, 272 | use_gpu=True, 273 | size=512, 274 | ): 275 | super().__init__() 276 | self.use_gpu = use_gpu 277 | self.size = size 278 | self.raft = raft_large(pretrained=True, progress=False).to("cuda") 279 | self.raft = self.raft.eval() 280 | 281 | def read_img_list(self, img_list): 282 | size = self.size 283 | images = [] 284 | masks = [] 285 | for img_path in img_list: 286 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 287 | 288 | if img.shape[2] == 4: # Handle BGRA images 289 | alpha = img[:, :, 3] # Extract alpha channel 290 | mask = alpha != 0 291 | masks.append(mask) 292 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR 293 | 294 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB 295 | img[np.where(alpha == 0)] = [ 296 | 255, 297 | 255, 298 | 255, 299 | ] # Set transparent pixels to white 300 | else: # Handle other image formats like JPG and PNG 301 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 302 | # For RGB image, do nothing to masks, as a placeholder 303 | pass 304 | 305 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) 306 | img = img.astype(np.float32) / 255.0 307 | images.append(img) 308 | 309 | return images, masks 310 | 311 | def compute_optical_flow(self, gt_imgs, pred_imgs, gt_masks, results_path, obj_id): 312 | optical_flow_results_path = os.path.join( 313 | results_path, obj_id, "optical_flow.npy" 314 | ) 315 | 316 | if os.path.exists(optical_flow_results_path): 317 | return np.load(optical_flow_results_path) 318 | 319 | # compute raft 320 | gt_batch = (((np.stack(gt_imgs, axis=0) / 255) - 0.5) * 2).astype( 321 | np.float32 322 | ) # np array BxHxWx3 323 | rendered_batch = (((np.stack(pred_imgs, axis=0) / 255) - 0.5) * 2).astype( 324 | np.float32 325 | ) # np array BxHxWx3 326 | mask_batch = np.stack(gt_masks)[..., None] # np array BxHxWx1 327 | 328 | gt_batch = torch.tensor(gt_batch).permute( 329 | 0, 3, 1, 2 330 | ) # convert to pytorch tensor and change shape to Bx3xHxW 331 | rendered_batch = torch.tensor(rendered_batch).permute( 332 | 0, 3, 1, 2 333 | ) # convert to pytorch tensor and change shape to Bx3xHxW 334 | 335 | # flows_batch: torch.Tensor, shape=Bx2xHxW, B is num_views 336 | flows_batch = ( 337 | self.raft(gt_batch.to("cuda"), rendered_batch.to("cuda"))[-1] 338 | .detach() 339 | .cpu() 340 | .permute(0, 2, 3, 1) 341 | .numpy() 342 | ) 343 | # convert to BxHxWx2 344 | 345 | optical_flow_error = np.sqrt( 346 | np.sum(flows_batch**2, axis=-1, keepdims=True) 347 | ) # BxHxWx1 348 | masked_error = optical_flow_error * mask_batch 349 | if masked_error.shape != (gt_batch.shape[0], 512, 512, 1): 350 | raise ValueError("masked error has wrong shape") 351 | 352 | masked_error_and_mask = np.concatenate( 353 | (masked_error, mask_batch), axis=-1 354 | ) # BxHxWx2 355 | 356 | os.makedirs(os.path.join(results_path, obj_id), exist_ok=True) 357 | np.save(optical_flow_results_path, masked_error_and_mask) 358 | return masked_error_and_mask 359 | 360 | @torch.no_grad() 361 | def raft_predict(self, ref_paths, novel_paths, results_path, obj_id): 362 | self.results = [] 363 | # [B, N, 3] or [B, H, W, 3], range[0, 1] 364 | preds, _ = self.read_img_list(novel_paths) 365 | truths, masks = self.read_img_list(ref_paths) 366 | masked_error_and_mask = self.compute_optical_flow( 367 | truths, preds, masks, results_path, obj_id 368 | ) 369 | return masked_error_and_mask 370 | 371 | def score_gt(self, masked_error_and_mask, threshold): 372 | optical_flow_error = masked_error_and_mask[..., 0] 373 | mask = masked_error_and_mask[..., 1].astype(bool) 374 | optical_flow_error[~mask] = 100 375 | number_valid_pixels = np.sum(mask.reshape(mask.shape[0], -1), axis=-1) 376 | 377 | # compute optical flow outlier ratio for each view 378 | view_ids = range(int(optical_flow_error.shape[0])) 379 | outlier_ratios = [] 380 | for view_id in view_ids: 381 | outlier_ratio = ( 382 | 1 383 | - np.sum(optical_flow_error[view_id] < threshold) 384 | / number_valid_pixels[view_id] 385 | ) 386 | outlier_ratios.append(outlier_ratio) 387 | 388 | return np.mean(outlier_ratios), outlier_ratios 389 | -------------------------------------------------------------------------------- /vivid123/generation_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Any 3 | import yaml 4 | from yaml.parser import ParserError 5 | import re 6 | 7 | import torch 8 | from PIL import Image 9 | import numpy as np 10 | import imageio.v3 as imageio 11 | 12 | from diffusers.pipelines import DiffusionPipeline 13 | from diffusers.models import UNet2DConditionModel, AutoencoderKL 14 | from diffusers.schedulers import DDIMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler 15 | from diffusers.pipelines import DiffusionPipeline 16 | from transformers import CLIPVisionModelWithProjection 17 | 18 | from .models import CLIPCameraProjection 19 | from .pipelines import ViVid123Pipeline, Zero1to3StableDiffusionPipeline 20 | from .configs import ViVid123BaseSchema 21 | 22 | 23 | XDG_CACHE_HOME = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) 24 | 25 | 26 | def prepare_cam_pose_input( 27 | num_frames: int = 25, 28 | delta_elevation_start: float = 0.0, 29 | delta_elevation_end: float = 0.0, 30 | delta_azimuth_start: float = -45.0, 31 | delta_azimuth_end: float = 45.0, 32 | delta_radius_start: float = 0.0, 33 | delta_radius_end: float = 0.0, 34 | ): 35 | r""" 36 | The function to prepare the input to the vivid123 pipeline 37 | Args: 38 | delta_elevation_start (`float`, *optional*, defaults to 0.0): 39 | The starting relative elevation angle of the camera, in degree. Relative to the elevation of the reference image. 40 | The camera is facing towards the origin. 41 | delta_elevation_end (`float`, *optional*, defaults to 0.0): 42 | The ending relative elevation angle of the camera, in degree. Relative to the elevation of the reference image. 43 | The camera is facing towards the origin. 44 | delta_azimuth_start (`float`, *optional*, defaults to -45.0): 45 | The starting relative azimuth angle of the camera, in degree. Relative to the elevation of the reference image. 46 | The camera is facing towards the origin. 47 | delta_azimuth_end (`float`, *optional*, defaults to 45.0): 48 | The ending relative azimuth angle of the camera, in degree. Relative to the elevation of the reference image. 49 | The camera is facing towards the origin. 50 | 51 | Returns: 52 | """ 53 | cam_elevation = np.radians(np.linspace(delta_elevation_start, delta_elevation_end, num_frames))[..., None] 54 | cam_azimuth = np.radians(np.linspace(delta_azimuth_start, delta_azimuth_end, num_frames)) 55 | cam_azimuth_sin_cos = np.stack([np.sin(cam_azimuth), np.cos(cam_azimuth)], axis=-1) 56 | cam_radius = np.linspace(delta_radius_start, delta_radius_end, num_frames)[..., None] 57 | 58 | cam_pose_np = np.concatenate([cam_elevation, cam_azimuth_sin_cos, cam_radius], axis=-1) 59 | cam_pose_torch = torch.from_numpy(cam_pose_np) 60 | 61 | return cam_pose_torch 62 | 63 | 64 | # refer to https://stackoverflow.com/a/33507138/6257375 65 | def conver_rgba_to_rgb_white_bg( 66 | image: Image, 67 | H: int = 256, 68 | W: int = 256, 69 | ): 70 | input_image = image.convert("RGBA").resize((H, W), Image.BICUBIC) 71 | background = Image.new("RGBA", input_image.size, (255, 255, 255)) 72 | alpha_composite = Image.alpha_composite(background, input_image) 73 | 74 | return alpha_composite 75 | 76 | 77 | def prepare_fusion_schedule_linear( 78 | num_inference_steps: int = 50, 79 | video_linear_start_weight: float = 1.0, 80 | video_linear_end_weight: float = 0.5, 81 | video_start_step_percentage: float = 0.0, 82 | video_end_step_percentage: float = 1.0, 83 | zero123_linear_start_weight: float = 1.0, 84 | zero123_linear_end_weight: float = 1.0, 85 | zero123_start_step_percentage: float = 0.0, 86 | zero123_end_step_percentage: float = 1.0, 87 | ): 88 | """ 89 | Prepare the fusion schedule of video diffusion and zero123 at all the denoising steps 90 | Args: 91 | video_linear_start_weight (`float`, *optional*, defaults to 1.0): 92 | The weight of the video diffusion at the start of the video. The weight is linearly increased from 93 | `video_linear_start_weight` to `video_linear_end_weight` during the video diffusion. 94 | video_linear_end_weight (`float`, *optional*, defaults to 0.5): 95 | The weight of the video diffusion at the end of the video. The weight is linearly increased from 96 | `video_linear_start_weight` to `video_linear_end_weight` during the video diffusion. 97 | video_start_step_percentage (`float`, *optional*, defaults to 0.0): 98 | The percentage of the total number of inference steps at which the video diffusion starts. The video 99 | diffusion is linearly increased from `video_linear_start_weight` to `video_linear_end_weight` between 100 | `video_start_step_percentage` and `video_end_step_percentage`. 101 | video_end_step_percentage (`float`, *optional*, defaults to 1.0): 102 | The percentage of the total number of inference steps at which the video diffusion ends. The video 103 | diffusion is linearly increased from `video_linear_start_weight` to `video_linear_end_weight` between 104 | `video_start_step_percentage` and `video_end_step_percentage`. 105 | zero123_linear_start_weight (`float`, *optional*, defaults to 1.0): 106 | The weight of the zero123 diffusion at the start of the video. The weight is linearly increased from 107 | `zero123_linear_start_weight` to `zero123_linear_end_weight` during the zero123 diffusion. 108 | zero123_linear_end_weight (`float`, *optional*, defaults to 1.0): 109 | The weight of the zero123 diffusion at the end of the video. The weight is linearly increased from 110 | `zero123_linear_start_weight` to `zero123_linear_end_weight` during the zero123 diffusion. 111 | zero123_start_step_percentage (`float`, *optional*, defaults to 0.0): 112 | The percentage of the total number of inference steps at which the zero123 diffusion starts. The 113 | zero123 diffusion is linearly increased from `zero123_linear_start_weight` to 114 | `zero123_linear_end_weight` between `zero123_start_step_percentage` and `zero123_end_step_percentage`. 115 | zero123_end_step_percentage (`float`, *optional*, defaults to 1.0): 116 | The percentage of the total number of inference steps at which the zero123 diffusion ends. The 117 | zero123 diffusion is linearly increased from `zero123_linear_start_weight` to 118 | `zero123_linear_end_weight` between `zero123_start_step_percentage` and `zero123_end_step_percentage`. 119 | 120 | Return: 121 | A tuple of two tensors, 122 | video_schedule (`torch.Tensor`): The schedule of the video diffusion weighting, with shape `[num_inference_steps]`. 123 | zero123_schedule (`torch.Tensor`): The schedule of the zero123 diffusion weighting, with shape `[num_inference_steps]`. 124 | """ 125 | 126 | assert ( 127 | video_linear_start_weight >= 0.0 and video_linear_start_weight <= 1.0 128 | ), "video_linear_start_weight must be between 0.0 and 1.0" 129 | assert ( 130 | video_linear_end_weight >= 0.0 and video_linear_end_weight <= 1.0 131 | ), "video_linear_end_weight must be between 0.0 and 1.0" 132 | assert ( 133 | video_start_step_percentage >= 0.0 and video_start_step_percentage <= 1.0 134 | ), "video_start_step_percentage must be between 0.0 and 1.0" 135 | assert ( 136 | video_end_step_percentage >= 0.0 and video_end_step_percentage <= 1.0 137 | ), "video_end_step_percentage must be between 0.0 and 1.0" 138 | assert ( 139 | zero123_linear_start_weight >= 0.0 and zero123_linear_start_weight <= 1.0 140 | ), "zero123_linear_start_weight must be between 0.0 and 1.0" 141 | assert ( 142 | zero123_linear_end_weight >= 0.0 and zero123_linear_end_weight <= 1.0 143 | ), "zero123_linear_end_weight must be between 0.0 and 1.0" 144 | assert ( 145 | zero123_start_step_percentage >= 0.0 and zero123_start_step_percentage <= 1.0 146 | ), "zero123_start_step_percentage must be between 0.0 and 1.0" 147 | assert ( 148 | zero123_end_step_percentage >= 0.0 and zero123_end_step_percentage <= 1.0 149 | ), "zero123_end_step_percentage must be between 0.0 and 1.0" 150 | 151 | video_schedule = torch.linspace( 152 | start=video_linear_start_weight, 153 | end=video_linear_end_weight, 154 | steps=int((video_end_step_percentage - video_start_step_percentage) * num_inference_steps), 155 | ) 156 | zero123_schedule = torch.linspace( 157 | start=zero123_linear_start_weight, 158 | end=zero123_linear_end_weight, 159 | steps=int((zero123_end_step_percentage - zero123_start_step_percentage) * num_inference_steps), 160 | ) 161 | if video_schedule.shape[0] < num_inference_steps: 162 | video_schedule = torch.cat( 163 | [ 164 | video_linear_start_weight * torch.ones([video_start_step_percentage * num_inference_steps]), 165 | video_schedule, 166 | video_linear_end_weight 167 | * torch.ones([num_inference_steps - video_end_step_percentage * num_inference_steps]), 168 | ] 169 | ) 170 | if zero123_schedule.shape[0] < num_inference_steps: 171 | zero123_schedule = torch.cat( 172 | [ 173 | zero123_linear_start_weight * torch.ones([zero123_start_step_percentage * num_inference_steps]), 174 | zero123_schedule, 175 | zero123_linear_end_weight 176 | * torch.ones([num_inference_steps - zero123_end_step_percentage * num_inference_steps]), 177 | ] 178 | ) 179 | 180 | return (video_schedule, zero123_schedule) 181 | 182 | 183 | def save_videos_grid_zeroscope_nplist(video_frames: List[np.ndarray], path: str, n_rows=6, fps=8, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): 184 | # fourcc = cv2.VideoWriter_fourcc(*"mp4v") 185 | f = len(video_frames) 186 | h, w, c = video_frames[0].shape 187 | #images = [(image).astype("uint8") for image in video_frames] 188 | 189 | os.makedirs(os.path.dirname(path), exist_ok=True) 190 | imageio.imwrite(path, video_frames, fps=fps) 191 | 192 | 193 | def prepare_vivid123_pipeline( 194 | ZERO123_MODEL_ID: str = "bennyguo/zero123-xl-diffusers", 195 | VIDEO_MODEL_ID: str = "cerspense/zeroscope_v2_576w", 196 | VIDEO_XL_MODEL_ID: str = "cerspense/zeroscope_v2_XL" 197 | ): 198 | zero123_unet = UNet2DConditionModel.from_pretrained(ZERO123_MODEL_ID, subfolder="unet", cache_dir=XDG_CACHE_HOME) 199 | zero123_cam_proj = CLIPCameraProjection.from_pretrained(ZERO123_MODEL_ID, subfolder="clip_camera_projection", cache_dir=XDG_CACHE_HOME) 200 | zero123_img_enc = CLIPVisionModelWithProjection.from_pretrained(ZERO123_MODEL_ID, subfolder="image_encoder", cache_dir=XDG_CACHE_HOME) 201 | vivid123_pipe = ViVid123Pipeline.from_pretrained( 202 | VIDEO_MODEL_ID, 203 | cache_dir=XDG_CACHE_HOME, 204 | novel_view_unet=zero123_unet, 205 | image_encoder=zero123_img_enc, 206 | cc_projection=zero123_cam_proj, 207 | # torch_dtype=torch.float16, 208 | ) 209 | vivid123_pipe.scheduler = DPMSolverMultistepScheduler.from_config(vivid123_pipe.scheduler.config, cache_dir=XDG_CACHE_HOME) 210 | vivid123_pipe.enable_model_cpu_offload() 211 | 212 | xl_pipe = DiffusionPipeline.from_pretrained(VIDEO_XL_MODEL_ID, torch_dtype=torch.float16, cache_dir=XDG_CACHE_HOME) 213 | xl_pipe.scheduler = DPMSolverMultistepScheduler.from_config(xl_pipe.scheduler.config, cache_dir=XDG_CACHE_HOME) 214 | xl_pipe.enable_model_cpu_offload() 215 | 216 | return vivid123_pipe, xl_pipe 217 | 218 | 219 | def generation_vivid123( 220 | vivid123_pipe: ViVid123Pipeline, 221 | xl_pipe: DiffusionPipeline, 222 | config_path: str, 223 | output_root_dir: str = "./outputs", 224 | ): 225 | # loading yaml config 226 | _var_matcher = re.compile(r"\${([^}^{]+)}") 227 | _tag_matcher = re.compile(r"[^$]*\${([^}^{]+)}.*") 228 | 229 | def _path_constructor(_loader: Any, node: Any): 230 | def replace_fn(match): 231 | envparts = f"{match.group(1)}:".split(":") 232 | return os.environ.get(envparts[0], envparts[1]) 233 | return _var_matcher.sub(replace_fn, node.value) 234 | 235 | def load_yaml(filename: str) -> dict: 236 | yaml.add_implicit_resolver("!envvar", _tag_matcher, None, yaml.SafeLoader) 237 | yaml.add_constructor("!envvar", _path_constructor, yaml.SafeLoader) 238 | with open(filename, "r") as f: 239 | return yaml.safe_load(f.read()) 240 | 241 | yaml_loaded = load_yaml(config_path) 242 | print(f"input_image_path is: ", yaml_loaded["input_image_path"]) 243 | cfg = ViVid123BaseSchema.model_validate(yaml_loaded) 244 | 245 | # get reference image 246 | print(f"input_image_path is: {cfg.input_image_path}") 247 | input_image = Image.open(cfg.input_image_path) 248 | input_image = conver_rgba_to_rgb_white_bg(input_image, H=cfg.height, W=cfg.width) 249 | 250 | cam_pose = prepare_cam_pose_input( 251 | num_frames=cfg.num_frames, 252 | delta_elevation_start=cfg.delta_elevation_start, 253 | delta_elevation_end=cfg.delta_elevation_end, 254 | delta_azimuth_start=cfg.delta_azimuth_start, 255 | delta_azimuth_end=cfg.delta_azimuth_end, 256 | delta_radius_start=cfg.delta_radius_start, 257 | delta_radius_end=cfg.delta_radius_end, 258 | ) 259 | 260 | fusion_schedule = prepare_fusion_schedule_linear( 261 | num_inference_steps=cfg.num_inference_steps, 262 | video_linear_start_weight=cfg.video_linear_start_weight, 263 | video_linear_end_weight=cfg.video_linear_end_weight, 264 | video_start_step_percentage=cfg.video_start_step_percentage, 265 | video_end_step_percentage=cfg.video_end_step_percentage, 266 | zero123_linear_start_weight=cfg.zero123_linear_start_weight, 267 | zero123_linear_end_weight=cfg.zero123_linear_end_weight, 268 | zero123_start_step_percentage=cfg.zero123_start_step_percentage, 269 | zero123_end_step_percentage=cfg.zero123_end_step_percentage, 270 | ) 271 | 272 | vid_base_frames = vivid123_pipe( 273 | image=input_image, 274 | cam_pose_torch=cam_pose, 275 | fusion_schedule=fusion_schedule, 276 | height=cfg.height, 277 | width=cfg.width, 278 | num_frames=cfg.num_frames, 279 | prompt=cfg.prompt, 280 | guidance_scale_video=cfg.guidance_scale_video, 281 | guidance_scale_zero123=cfg.guidance_scale_zero123, 282 | num_inference_steps=cfg.num_inference_steps, 283 | noise_identical_accross_frames=cfg.noise_identical_accross_frames, 284 | eta=cfg.eta, 285 | ).frames 286 | 287 | # save imgs 288 | os.makedirs(os.path.join(output_root_dir, cfg.obj_name), exist_ok=True) 289 | input_image.save(f"{output_root_dir}/{cfg.obj_name}/input.png") 290 | os.makedirs(os.path.join(output_root_dir, cfg.obj_name, "base_frames"), exist_ok=True) 291 | for i in range(len(vid_base_frames)): 292 | Image.fromarray(vid_base_frames[i]).save(f"{output_root_dir}/{cfg.obj_name}/base_frames/{str(i).zfill(3)}.png") 293 | 294 | save_videos_grid_zeroscope_nplist(vid_base_frames, f"{output_root_dir}/{cfg.obj_name}/base.mp4") 295 | 296 | if cfg.skip_refiner: 297 | return 298 | 299 | video_xl_input = [Image.fromarray(frame).resize((576, 576)) for frame in vid_base_frames] 300 | 301 | video_xl_frames = xl_pipe( 302 | prompt=cfg.prompt, video=video_xl_input, strength=cfg.refiner_strength, guidance_scale=cfg.refiner_guidance_scale 303 | ).frames 304 | 305 | os.makedirs(os.path.join(output_root_dir, cfg.obj_name, "xl_frames"), exist_ok=True) 306 | for i in range(len(vid_base_frames)): 307 | Image.fromarray(vid_base_frames[i]).save(f"{output_root_dir}/{cfg.obj_name}/xl_frames/{str(i).zfill(3)}.png") 308 | save_videos_grid_zeroscope_nplist(video_xl_frames, f"{output_root_dir}/{cfg.obj_name}/xl.mp4") 309 | 310 | 311 | def prepare_zero123_pipeline( 312 | ZERO123_MODEL_ID: str = "bennyguo/zero123-xl-diffusers", 313 | VIDEO_XL_MODEL_ID: str = "cerspense/zeroscope_v2_XL" 314 | ): 315 | zero123_unet = UNet2DConditionModel.from_pretrained(ZERO123_MODEL_ID, subfolder="unet", cache_dir=XDG_CACHE_HOME) 316 | zero123_cam_proj = CLIPCameraProjection.from_pretrained(ZERO123_MODEL_ID, subfolder="clip_camera_projection", cache_dir=XDG_CACHE_HOME) 317 | zero123_img_enc = CLIPVisionModelWithProjection.from_pretrained(ZERO123_MODEL_ID, subfolder="image_encoder", cache_dir=XDG_CACHE_HOME) 318 | vae = AutoencoderKL.from_pretrained(ZERO123_MODEL_ID, subfolder="vae", cache_dir=XDG_CACHE_HOME) 319 | scheduler = DDIMScheduler.from_pretrained(ZERO123_MODEL_ID, subfolder="scheduler", cache_dir=XDG_CACHE_HOME) 320 | zero123_pipe = Zero1to3StableDiffusionPipeline( 321 | vae=vae, 322 | image_encoder=zero123_img_enc, 323 | unet=zero123_unet, 324 | scheduler=scheduler, 325 | cc_projection=zero123_cam_proj, 326 | requires_safety_checker=False, 327 | safety_checker=None, 328 | feature_extractor=None, 329 | ) 330 | 331 | xl_pipe = DiffusionPipeline.from_pretrained(VIDEO_XL_MODEL_ID, torch_dtype=torch.float16, cache_dir=XDG_CACHE_HOME) 332 | xl_pipe.scheduler = DPMSolverMultistepScheduler.from_config(xl_pipe.scheduler.config, cache_dir=XDG_CACHE_HOME) 333 | xl_pipe.enable_model_cpu_offload() 334 | 335 | return zero123_pipe, xl_pipe 336 | -------------------------------------------------------------------------------- /vivid123/pipelines/vivid123_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Callable, Dict, List, Optional, Union 16 | from dataclasses import dataclass 17 | 18 | import numpy as np 19 | import torch 20 | from einops import rearrange 21 | import PIL 22 | 23 | from diffusers.pipelines import TextToVideoSDPipeline 24 | from diffusers.pipelines.text_to_video_synthesis import TextToVideoSDPipelineOutput 25 | from diffusers.models import AutoencoderKL, UNet3DConditionModel, UNet2DConditionModel 26 | from diffusers.schedulers import KarrasDiffusionSchedulers 27 | from diffusers.utils import ( 28 | is_accelerate_available, 29 | is_accelerate_version, 30 | logging, 31 | replace_example_docstring, 32 | BaseOutput, 33 | ) 34 | from diffusers.utils.torch_utils import randn_tensor 35 | from diffusers.image_processor import VaeImageProcessor 36 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg 37 | from transformers import CLIPVisionModelWithProjection, CLIPTextModel, CLIPTokenizer 38 | from ..models import CLIPCameraProjection 39 | import kornia 40 | 41 | 42 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 43 | 44 | EXAMPLE_DOC_STRING = """ 45 | Examples: 46 | ```py 47 | >>> import torch 48 | >>> from diffusers import TextToVideoSDPipeline 49 | >>> from diffusers.utils import export_to_video 50 | 51 | >>> pipe = TextToVideoSDPipeline.from_pretrained( 52 | ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" 53 | ... ) 54 | >>> pipe.enable_model_cpu_offload() 55 | 56 | >>> prompt = "Spiderman is surfing" 57 | >>> video_frames = pipe(prompt).frames 58 | >>> video_path = export_to_video(video_frames) 59 | >>> video_path 60 | ``` 61 | """ 62 | 63 | 64 | def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: 65 | # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 66 | # reshape to ncfhw 67 | mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) 68 | std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) 69 | # unnormalize back to [0,1] 70 | video = video.mul_(std).add_(mean) 71 | video.clamp_(0, 1) 72 | # prepare the final outputs 73 | i, c, f, h, w = video.shape 74 | images = video.permute(2, 3, 0, 4, 1).reshape( 75 | f, h, i * w, c 76 | ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) 77 | images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) 78 | images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c 79 | return images 80 | 81 | 82 | class ViVid123Pipeline(TextToVideoSDPipeline): 83 | r""" 84 | Pipeline for text-to-video generation. 85 | 86 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 87 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 88 | 89 | Args: 90 | vae ([`AutoencoderKL`]): 91 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 92 | text_encoder ([`CLIPTextModel`]): 93 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 94 | tokenizer (`CLIPTokenizer`): 95 | A [`~transformers.CLIPTokenizer`] to tokenize text. 96 | unet ([`UNet3DConditionModel`]): 97 | A [`UNet3DConditionModel`] to denoise the encoded video latents. 98 | scheduler ([`SchedulerMixin`]): 99 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 100 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 101 | """ 102 | 103 | def __init__( 104 | self, 105 | vae: AutoencoderKL, 106 | text_encoder: CLIPTextModel, 107 | tokenizer: CLIPTokenizer, 108 | unet: UNet3DConditionModel, 109 | scheduler: KarrasDiffusionSchedulers, 110 | novel_view_unet: UNet2DConditionModel, 111 | image_encoder: CLIPVisionModelWithProjection, 112 | cc_projection: CLIPCameraProjection, 113 | ): 114 | super().__init__(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) 115 | 116 | self.register_modules( 117 | novel_view_unet=novel_view_unet, 118 | image_encoder=image_encoder, 119 | cc_projection=cc_projection, 120 | ) 121 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 122 | 123 | self.image_processor = VaeImageProcessor( 124 | vae_scale_factor=self.vae_scale_factor, 125 | do_convert_rgb=True, 126 | do_normalize=True, 127 | ) 128 | 129 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs 130 | def check_inputs( 131 | self, 132 | prompt, 133 | height, 134 | width, 135 | callback_steps, 136 | negative_prompt=None, 137 | prompt_embeds=None, 138 | negative_prompt_embeds=None, 139 | num_inference_steps=50, 140 | fusion_schedule=None, 141 | ): 142 | if height % 8 != 0 or width % 8 != 0: 143 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 144 | 145 | if (callback_steps is None) or ( 146 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 147 | ): 148 | raise ValueError( 149 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 150 | f" {type(callback_steps)}." 151 | ) 152 | 153 | if prompt is not None and prompt_embeds is not None: 154 | raise ValueError( 155 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 156 | " only forward one of the two." 157 | ) 158 | elif prompt is None and prompt_embeds is None: 159 | raise ValueError( 160 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 161 | ) 162 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 163 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 164 | 165 | if negative_prompt is not None and negative_prompt_embeds is not None: 166 | raise ValueError( 167 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 168 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 169 | ) 170 | 171 | if prompt_embeds is not None and negative_prompt_embeds is not None: 172 | if prompt_embeds.shape != negative_prompt_embeds.shape: 173 | raise ValueError( 174 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 175 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 176 | f" {negative_prompt_embeds.shape}." 177 | ) 178 | 179 | if fusion_schedule is None: 180 | raise ValueError( 181 | "Fusion schedule is not provided." 182 | ) 183 | 184 | if len(fusion_schedule[0]) != num_inference_steps or len(fusion_schedule[1]) != num_inference_steps: 185 | raise ValueError( 186 | "Fusion schedule length does not match the number of timesteps." 187 | ) 188 | 189 | def prepare_latents( 190 | self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, noise_identical_accross_frames=False 191 | ): 192 | shape = ( 193 | batch_size, 194 | num_channels_latents, 195 | num_frames if not noise_identical_accross_frames else 1, 196 | height // self.vae_scale_factor, 197 | width // self.vae_scale_factor, 198 | ) 199 | if isinstance(generator, list) and len(generator) != batch_size: 200 | raise ValueError( 201 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 202 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 203 | ) 204 | 205 | if latents is None: 206 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 207 | else: 208 | if latents.shape != shape: 209 | raise ValueError( 210 | f"User-prepared `latents` must have shape {shape}, when noise_identical_accross_frames={noise_identical_accross_frames} but got {latents.shape}." 211 | ) 212 | latents = latents.to(device) 213 | 214 | if noise_identical_accross_frames: 215 | latents = latents.repeat(1, 1, num_frames, 1, 1) 216 | 217 | # scale the initial noise by the standard deviation required by the scheduler 218 | latents = latents * self.scheduler.init_noise_sigma 219 | return latents 220 | 221 | def prepare_img_latents( 222 | self, image, batch_size, dtype, device, generator=None, do_zero123_classifier_free_guidance=False 223 | ): 224 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 225 | raise ValueError( 226 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 227 | ) 228 | 229 | if isinstance(image, torch.Tensor): 230 | # Batch single image 231 | if image.ndim == 3: 232 | assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" 233 | image = image.unsqueeze(0) 234 | 235 | assert image.ndim == 4, "Image must have 4 dimensions" 236 | 237 | # Check image is in [-1, 1] 238 | if image.min() < -1 or image.max() > 1: 239 | raise ValueError("Image should be in [-1, 1] range") 240 | else: 241 | # preprocess image 242 | if isinstance(image, (PIL.Image.Image, np.ndarray)): 243 | image = [image] 244 | 245 | if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): 246 | image = [np.array(i.convert("RGB"))[None, :] for i in image] 247 | image = np.concatenate(image, axis=0) 248 | elif isinstance(image, list) and isinstance(image[0], np.ndarray): 249 | image = np.concatenate([i[None, :] for i in image], axis=0) 250 | 251 | image = image.transpose(0, 3, 1, 2) 252 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 253 | 254 | image = image.to(device=device, dtype=dtype) 255 | 256 | if isinstance(generator, list) and len(generator) != batch_size: 257 | raise ValueError( 258 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 259 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 260 | ) 261 | 262 | if isinstance(generator, list): 263 | init_latents = [ 264 | self.vae.encode(image[i : i + 1]).latent_dist.mode(generator[i]) for i in range(batch_size) # sample 265 | ] 266 | init_latents = torch.cat(init_latents, dim=0) 267 | else: 268 | init_latents = self.vae.encode(image).latent_dist.mode() 269 | 270 | # init_latents = self.vae.config.scaling_factor * init_latents # todo in original zero123's inference gradio_new.py, model.encode_first_stage() is not scaled by scaling_factor 271 | if batch_size > init_latents.shape[0]: 272 | # init_latents = init_latents.repeat(batch_size // init_latents.shape[0], 1, 1, 1) 273 | num_images_per_prompt = batch_size // init_latents.shape[0] 274 | # duplicate image latents for each generation per prompt, using mps friendly method 275 | bs_embed, emb_c, emb_h, emb_w = init_latents.shape 276 | init_latents = init_latents.unsqueeze(1) 277 | init_latents = init_latents.repeat(1, num_images_per_prompt, 1, 1, 1) 278 | init_latents = init_latents.view(bs_embed * num_images_per_prompt, emb_c, emb_h, emb_w) 279 | 280 | # init_latents = torch.cat([init_latents]*2) if do_zero123_classifier_free_guidance else init_latents # follow zero123 281 | init_latents = ( 282 | torch.cat([torch.zeros_like(init_latents), init_latents]) 283 | if do_zero123_classifier_free_guidance 284 | else init_latents 285 | ) 286 | 287 | init_latents = init_latents.to(device=device, dtype=dtype) 288 | return init_latents 289 | 290 | def CLIP_preprocess(self, x): 291 | dtype = x.dtype 292 | # following openai's implementation 293 | # TODO HF OpenAI CLIP preprocessing issue https://github.com/huggingface/transformers/issues/22505#issuecomment-1650170741 294 | # follow openai preprocessing to keep exact same, input tensor [-1, 1], otherwise the preprocessing will be different, https://github.com/huggingface/transformers/pull/22608 295 | if isinstance(x, torch.Tensor): 296 | if x.min() < -1.0 or x.max() > 1.0: 297 | raise ValueError("Expected input tensor to have values in the range [-1, 1]") 298 | x = kornia.geometry.resize( 299 | x.to(torch.float32), (224, 224), interpolation="bicubic", align_corners=True, antialias=False 300 | ).to(dtype=dtype) 301 | x = (x + 1.0) / 2.0 302 | # renormalize according to clip 303 | x = kornia.enhance.normalize( 304 | x, torch.Tensor([0.48145466, 0.4578275, 0.40821073]), torch.Tensor([0.26862954, 0.26130258, 0.27577711]) 305 | ) 306 | return x 307 | 308 | # from stable_diffusion_image_variation 309 | def _encode_image(self, image, device, num_images_per_prompt, do_video_classifier_free_guidance): 310 | dtype = next(self.image_encoder.parameters()).dtype 311 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 312 | raise ValueError( 313 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 314 | ) 315 | 316 | if isinstance(image, torch.Tensor): 317 | # Batch single image 318 | if image.ndim == 3: 319 | assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" 320 | image = image.unsqueeze(0) 321 | 322 | assert image.ndim == 4, "Image must have 4 dimensions" 323 | 324 | # Check image is in [-1, 1] 325 | if image.min() < -1 or image.max() > 1: 326 | raise ValueError("Image should be in [-1, 1] range") 327 | else: 328 | # preprocess image 329 | if isinstance(image, (PIL.Image.Image, np.ndarray)): 330 | image = [image] 331 | 332 | if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): 333 | image = [np.array(i.convert("RGB"))[None, :] for i in image] 334 | image = np.concatenate(image, axis=0) 335 | elif isinstance(image, list) and isinstance(image[0], np.ndarray): 336 | image = np.concatenate([i[None, :] for i in image], axis=0) 337 | 338 | image = image.transpose(0, 3, 1, 2) 339 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 340 | 341 | image = image.to(device=device, dtype=dtype) 342 | 343 | image = self.CLIP_preprocess(image) 344 | # if not isinstance(image, torch.Tensor): 345 | # # 0-255 346 | # print("Warning: image is processed by hf's preprocess, which is different from openai original's.") 347 | # image = self.feature_extractor(images=image, return_tensors="pt").pixel_values 348 | image_embeddings = self.image_encoder(image).image_embeds.to(dtype=dtype) 349 | image_embeddings = image_embeddings.unsqueeze(1) 350 | 351 | # duplicate image embeddings for each generation per prompt, using mps friendly method 352 | bs_embed, seq_len, _ = image_embeddings.shape 353 | image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) 354 | image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) 355 | 356 | if do_video_classifier_free_guidance: 357 | negative_prompt_embeds = torch.zeros_like(image_embeddings) 358 | 359 | # For classifier free guidance, we need to do two forward passes. 360 | # Here we concatenate the unconditional and text embeddings into a single batch 361 | # to avoid doing two forward passes 362 | image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) 363 | 364 | return image_embeddings 365 | 366 | def _encode_pose(self, pose, device, num_images_per_prompt, do_video_classifier_free_guidance): 367 | dtype = next(self.cc_projection.parameters()).dtype 368 | if isinstance(pose, torch.Tensor): 369 | pose_embeddings = pose.unsqueeze(1).to(device=device, dtype=dtype) 370 | else: 371 | if isinstance(pose[0], list): 372 | pose = torch.Tensor(pose) 373 | else: 374 | pose = torch.Tensor([pose]) 375 | x, y, z = pose[:, 0].unsqueeze(1), pose[:, 1].unsqueeze(1), pose[:, 2].unsqueeze(1) 376 | pose_embeddings = ( 377 | torch.cat([torch.deg2rad(x), torch.sin(torch.deg2rad(y)), torch.cos(torch.deg2rad(y)), z], dim=-1) 378 | .unsqueeze(1) 379 | .to(device=device, dtype=dtype) 380 | ) # B, 1, 4 381 | # duplicate pose embeddings for each generation per prompt, using mps friendly method 382 | bs_embed, seq_len, _ = pose_embeddings.shape 383 | pose_embeddings = pose_embeddings.repeat(1, num_images_per_prompt, 1) 384 | pose_embeddings = pose_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) 385 | if do_video_classifier_free_guidance: 386 | negative_prompt_embeds = torch.zeros_like(pose_embeddings) 387 | 388 | # For classifier free guidance, we need to do two forward passes. 389 | # Here we concatenate the unconditional and text embeddings into a single batch 390 | # to avoid doing two forward passes 391 | pose_embeddings = torch.cat([negative_prompt_embeds, pose_embeddings]) 392 | return pose_embeddings 393 | 394 | def _encode_image_with_pose(self, image, pose, device, num_images_per_prompt, do_video_classifier_free_guidance): 395 | img_prompt_embeds = self._encode_image(image, device, num_images_per_prompt, False) 396 | pose_prompt_embeds = self._encode_pose(pose, device, num_images_per_prompt, False) 397 | prompt_embeds = torch.cat([img_prompt_embeds, pose_prompt_embeds], dim=-1) 398 | prompt_embeds = self.cc_projection(prompt_embeds) 399 | # prompt_embeds = img_prompt_embeds 400 | # follow 0123, add negative prompt, after projection 401 | if do_video_classifier_free_guidance: 402 | negative_prompt = torch.zeros_like(prompt_embeds) 403 | prompt_embeds = torch.cat([negative_prompt, prompt_embeds]) 404 | return prompt_embeds 405 | 406 | @torch.no_grad() 407 | @replace_example_docstring(EXAMPLE_DOC_STRING) 408 | def __call__( 409 | self, 410 | prompt: Union[str, List[str]] = None, 411 | height: Optional[int] = None, 412 | width: Optional[int] = None, 413 | num_frames: int = 16, 414 | num_inference_steps: int = 50, 415 | guidance_scale_video: float = 9.0, 416 | negative_prompt: Optional[Union[str, List[str]]] = None, 417 | eta: float = 0.0, 418 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 419 | latents: Optional[torch.FloatTensor] = None, 420 | prompt_embeds: Optional[torch.FloatTensor] = None, 421 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 422 | output_type: Optional[str] = "np", 423 | return_dict: bool = True, 424 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 425 | callback_steps: int = 1, 426 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 427 | guidance_rescale: float = 0.0, 428 | # vivid123 params below 429 | image: Optional[ 430 | Union[ 431 | torch.FloatTensor, 432 | PIL.Image.Image, 433 | np.ndarray, 434 | List[torch.FloatTensor], 435 | List[PIL.Image.Image], 436 | List[np.ndarray], 437 | ] 438 | ] = None, 439 | cam_pose_torch: Optional[torch.FloatTensor] = None, 440 | fusion_schedule: Optional[tuple[float]] = None, 441 | ddim_eta_0123: float = 1.0, 442 | guidance_scale_zero123: float = 3.0, 443 | noise_identical_accross_frames: bool = False, 444 | ): 445 | r""" 446 | The call function to the pipeline for generation. 447 | 448 | Args: 449 | prompt (`str` or `List[str]`, *optional*): 450 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. 451 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 452 | The height in pixels of the generated video. 453 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 454 | The width in pixels of the generated video. 455 | num_frames (`int`, *optional*, defaults to 16): 456 | The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds 457 | amounts to 2 seconds of video. 458 | num_inference_steps (`int`, *optional*, defaults to 50): 459 | The number of denoising steps. More denoising steps usually lead to a higher quality videos at the 460 | expense of slower inference. 461 | guidance_scale (`float`, *optional*, defaults to 7.5): 462 | A higher guidance scale value encourages the model to generate images closely linked to the text 463 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 464 | negative_prompt (`str` or `List[str]`, *optional*): 465 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to 466 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). 467 | num_images_per_prompt (`int`, *optional*, defaults to 1): 468 | The number of images to generate per prompt. 469 | eta (`float`, *optional*, defaults to 0.0): 470 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 471 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 472 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 473 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 474 | generation deterministic. 475 | latents (`torch.FloatTensor`, *optional*): 476 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video 477 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 478 | tensor is generated by sampling using the supplied random `generator`. Latents should be of shape 479 | `(batch_size, num_channel, num_frames, height, width)`. 480 | prompt_embeds (`torch.FloatTensor`, *optional*): 481 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 482 | provided, text embeddings are generated from the `prompt` input argument. 483 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 484 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 485 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 486 | output_type (`str`, *optional*, defaults to `"np"`): 487 | The output format of the generated video. Choose between `torch.FloatTensor` or `np.array`. 488 | return_dict (`bool`, *optional*, defaults to `True`): 489 | Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead 490 | of a plain tuple. 491 | callback (`Callable`, *optional*): 492 | A function that calls every `callback_steps` steps during inference. The function is called with the 493 | following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 494 | callback_steps (`int`, *optional*, defaults to 1): 495 | The frequency at which the `callback` function is called. If not specified, the callback is called at 496 | every step. 497 | cross_attention_kwargs (`dict`, *optional*): 498 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in 499 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 500 | guidance_rescale (`float`, *optional*, defaults to 0.0): 501 | Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are 502 | Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when 503 | using zero terminal SNR. 504 | guidance_scale_zero123 (`float`, *optional*, defaults to 3.0): 505 | A higher guidance scale value encourages the model to generate images closely linked to the text 506 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 507 | cam_pose_torch: (`torch.FloatTensor`, *optional*): 508 | Camera pose in torch tensor, shape (4,). The elements mean (el, sin(az), cos(az), radius) 509 | fusion_schedule (`tuple[float]`, *optional*): 510 | Fusion schedule for video diffusion and zero123. The first element is the schedule for video diffusion, and the 511 | second element is the schedule for zero123. The length of each schedule should be the same as the number 512 | of timesteps. 513 | ddim_eta_0123 (`float`, *optional*, defaults to 1.0): 514 | The eta value for the 0123 diffusion steps. Only applies to the [`~schedulers.DDIMScheduler`], and is 515 | ignored in other schedulers. 516 | 517 | Example: 518 | 519 | 520 | Returns: 521 | [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: 522 | If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is 523 | returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. 524 | """ 525 | # 0. Default height and width to unet 526 | height = height or self.unet.config.sample_size * self.vae_scale_factor 527 | width = width or self.unet.config.sample_size * self.vae_scale_factor 528 | 529 | num_videos_per_image_prompt = 1 530 | 531 | # 1. Check inputs. Raise error if not correct 532 | self.check_inputs( 533 | prompt, 534 | height, 535 | width, 536 | callback_steps, 537 | negative_prompt, 538 | prompt_embeds, 539 | negative_prompt_embeds, 540 | num_inference_steps, 541 | fusion_schedule 542 | ) 543 | 544 | # 2. Define call parameters 545 | if prompt is not None and isinstance(prompt, str): 546 | batch_size = 1 547 | elif prompt is not None and isinstance(prompt, list): 548 | batch_size = len(prompt) 549 | else: 550 | batch_size = prompt_embeds.shape[0] 551 | 552 | device = self._execution_device 553 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 554 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 555 | # corresponds to doing no classifier free guidance. 556 | do_video_classifier_free_guidance = guidance_scale_video > 1.0 557 | do_zero123_classifier_free_guidance = guidance_scale_zero123 > 1.0 558 | 559 | # 3.1 Encode input prompt for video diffusion 560 | text_encoder_lora_scale = ( 561 | cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None 562 | ) 563 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 564 | prompt=prompt, 565 | device=device, 566 | # by diffusers v0.23.1, the naming of diffusers.pipelines.TextToVideoSDPipeline is still "num_images_per_prompt", 567 | # where it should be "num_videos_per_prompt" 568 | num_images_per_prompt=num_videos_per_image_prompt, 569 | do_classifier_free_guidance=do_video_classifier_free_guidance, 570 | negative_prompt=negative_prompt, 571 | prompt_embeds=prompt_embeds, 572 | negative_prompt_embeds=negative_prompt_embeds, 573 | lora_scale=text_encoder_lora_scale, 574 | ) 575 | # For classifier free guidance, we need to do two forward passes. 576 | # Here we concatenate the unconditional and text embeddings into a single batch 577 | # to avoid doing two forward passes 578 | if do_video_classifier_free_guidance: 579 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 580 | 581 | # 3.2 Encode input image for zero123 582 | zero123_cond_images = [image for _ in range(num_frames)] 583 | zero123_embeds = self._encode_image_with_pose( 584 | zero123_cond_images, 585 | cam_pose_torch, 586 | device, 587 | num_videos_per_image_prompt, 588 | do_zero123_classifier_free_guidance, 589 | ) # (2xF) x 1 x 768 590 | 591 | # 4. Prepare timesteps 592 | self.scheduler.set_timesteps(num_inference_steps, device=device) 593 | timesteps = self.scheduler.timesteps 594 | 595 | # 5. Prepare latent variables 596 | num_channels_latents = self.unet.config.in_channels 597 | latents = self.prepare_latents( 598 | batch_size * num_videos_per_image_prompt, 599 | num_channels_latents, 600 | num_frames, 601 | height, 602 | width, 603 | prompt_embeds.dtype, 604 | device, 605 | generator, 606 | latents, 607 | noise_identical_accross_frames, 608 | ) 609 | 610 | # 6. Prepare Zero123 image latents 611 | img_latents = self.prepare_img_latents( 612 | zero123_cond_images, 613 | batch_size=num_frames, 614 | dtype=zero123_embeds.dtype, 615 | device=device, 616 | generator=generator, 617 | do_zero123_classifier_free_guidance=True, 618 | ) 619 | 620 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 621 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 622 | 623 | # 8. Denoising loop 624 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 625 | with self.progress_bar(total=num_inference_steps) as progress_bar: 626 | for i, t in enumerate(timesteps): 627 | # expand the latents if we are doing classifier free guidance 628 | latent_model_input = torch.cat([latents] * 2) if do_video_classifier_free_guidance else latents 629 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 630 | 631 | # predict the noise residual with video diffusion 632 | noise_pred_video = self.unet( 633 | latent_model_input, 634 | t, 635 | encoder_hidden_states=prompt_embeds, 636 | cross_attention_kwargs=cross_attention_kwargs, 637 | return_dict=False, 638 | )[0] 639 | 640 | # perform classifier-free guidance for video diffusion 641 | if do_video_classifier_free_guidance: 642 | noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) 643 | noise_pred_video = noise_pred_video_uncond + guidance_scale_video * ( 644 | noise_pred_video_text - noise_pred_video_uncond 645 | ) 646 | # if do_video_classifier_free_guidance and guidance_rescale > 0.0: 647 | # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 648 | # noise_pred_video = rescale_noise_cfg( 649 | # noise_pred_video, noise_pred_video_text, guidance_rescale=guidance_rescale 650 | # ) 651 | 652 | # zero123 denoising 653 | latent_model_input_zero123 = torch.cat([latents] * 2) if do_zero123_classifier_free_guidance else latents 654 | augmented_latent_model_input_zero123 = torch.cat( 655 | [rearrange(latent_model_input_zero123, "B C F H W -> (B F) C H W"), img_latents], 656 | dim=1, 657 | ).to(self.novel_view_unet.dtype) 658 | noise_pred_zero123 = self.novel_view_unet( 659 | augmented_latent_model_input_zero123, 660 | t, 661 | encoder_hidden_states=zero123_embeds, 662 | return_dict=True, 663 | ).sample 664 | noise_pred_zero123 = rearrange(noise_pred_zero123, "(B F) C H W -> B C F H W", F=num_frames) 665 | 666 | if do_zero123_classifier_free_guidance: 667 | noise_pred_zero123_uncond, noise_pred_zero123_text = noise_pred_zero123.chunk(2) 668 | noise_pred_zero123 = noise_pred_zero123_uncond + guidance_scale_zero123 * ( 669 | noise_pred_zero123_text - noise_pred_zero123_uncond 670 | ) 671 | 672 | # fusing video diffusion with zero123 673 | noise_pred = fusion_schedule[0][i] * noise_pred_video + fusion_schedule[1][i] * noise_pred_zero123 674 | 675 | # reshape latents 676 | bsz, channel, frames, width, height = latents.shape 677 | latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) 678 | noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) 679 | 680 | # compute the previous noisy sample x_t -> x_t-1 681 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 682 | 683 | # reshape latents back 684 | latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) 685 | 686 | # call the callback, if provided 687 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 688 | progress_bar.update() 689 | if callback is not None and i % callback_steps == 0: 690 | callback(i, t, latents) 691 | 692 | if output_type == "latent": 693 | return TextToVideoSDPipelineOutput(frames=latents) 694 | 695 | video_tensor = self.decode_latents(latents) 696 | 697 | if output_type == "pt": 698 | video = video_tensor 699 | else: 700 | video = tensor2vid(video_tensor) 701 | 702 | # Offload last model to CPU 703 | self.maybe_free_model_hooks() 704 | 705 | if not return_dict: 706 | return (video,) 707 | 708 | return TextToVideoSDPipelineOutput(frames=video) 709 | -------------------------------------------------------------------------------- /vivid123/pipelines/zero123_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | # A diffuser version implementation of Zero1to3 (https://github.com/cvlab-columbia/zero123), ICCV 2023 3 | # by Xin Kong 4 | 5 | import inspect 6 | from typing import Any, Callable, Dict, List, Optional, Union 7 | 8 | import kornia 9 | import numpy as np 10 | import PIL.Image 11 | import torch 12 | from packaging import version 13 | from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection 14 | 15 | # from ...configuration_utils import FrozenDict 16 | # from ...models import AutoencoderKL, UNet2DConditionModel 17 | # from ...schedulers import KarrasDiffusionSchedulers 18 | # from ...utils import ( 19 | # deprecate, 20 | # is_accelerate_available, 21 | # is_accelerate_version, 22 | # logging, 23 | # randn_tensor, 24 | # replace_example_docstring, 25 | # ) 26 | # from ..pipeline_utils import DiffusionPipeline 27 | # from . import StableDiffusionPipelineOutput 28 | # from .safety_checker import StableDiffusionSafetyChecker 29 | from diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel 30 | from diffusers.configuration_utils import ConfigMixin, FrozenDict 31 | from diffusers.models.modeling_utils import ModelMixin 32 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker 33 | from diffusers.schedulers import KarrasDiffusionSchedulers 34 | from diffusers.utils import ( 35 | deprecate, 36 | is_accelerate_available, 37 | is_accelerate_version, 38 | logging, 39 | replace_example_docstring, 40 | ) 41 | from diffusers.utils.torch_utils import randn_tensor 42 | 43 | from ..models import CLIPCameraProjection 44 | 45 | 46 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 47 | # todo 48 | EXAMPLE_DOC_STRING = """ 49 | Examples: 50 | ```py 51 | >>> import torch 52 | >>> from diffusers import StableDiffusionPipeline 53 | >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) 54 | >>> pipe = pipe.to("cuda") 55 | >>> prompt = "a photo of an astronaut riding a horse on mars" 56 | >>> image = pipe(prompt).images[0] 57 | ``` 58 | """ 59 | 60 | class Zero1to3StableDiffusionPipeline(DiffusionPipeline): 61 | r""" 62 | Pipeline for single view conditioned novel view generation using Zero1to3. 63 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 64 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 65 | Args: 66 | vae ([`AutoencoderKL`]): 67 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 68 | image_encoder ([`CLIPVisionModelWithProjection`]): 69 | Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of 70 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), 71 | specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 72 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 73 | scheduler ([`SchedulerMixin`]): 74 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 75 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 76 | safety_checker ([`StableDiffusionSafetyChecker`]): 77 | Classification module that estimates whether generated images could be considered offensive or harmful. 78 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. 79 | feature_extractor ([`CLIPFeatureExtractor`]): 80 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 81 | cc_projection ([`CCProjection`]): 82 | Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size. 83 | """ 84 | 85 | _optional_components = ["safety_checker", "feature_extractor"] 86 | 87 | def __init__( 88 | self, 89 | vae: AutoencoderKL, 90 | image_encoder: CLIPVisionModelWithProjection, 91 | unet: UNet2DConditionModel, 92 | scheduler: KarrasDiffusionSchedulers, 93 | safety_checker: StableDiffusionSafetyChecker, 94 | feature_extractor: CLIPFeatureExtractor, 95 | cc_projection: CLIPCameraProjection, 96 | requires_safety_checker: bool = True, 97 | ): 98 | super().__init__() 99 | 100 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 101 | deprecation_message = ( 102 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 103 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 104 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 105 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 106 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 107 | " file" 108 | ) 109 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 110 | new_config = dict(scheduler.config) 111 | new_config["steps_offset"] = 1 112 | scheduler._internal_dict = FrozenDict(new_config) 113 | 114 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 115 | deprecation_message = ( 116 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 117 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 118 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 119 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 120 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 121 | ) 122 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 123 | new_config = dict(scheduler.config) 124 | new_config["clip_sample"] = False 125 | scheduler._internal_dict = FrozenDict(new_config) 126 | 127 | if safety_checker is None and requires_safety_checker: 128 | logger.warning( 129 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 130 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" 131 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 132 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 133 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 134 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 135 | ) 136 | 137 | if safety_checker is not None and feature_extractor is None: 138 | raise ValueError( 139 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" 140 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." 141 | ) 142 | 143 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 144 | version.parse(unet.config._diffusers_version).base_version 145 | ) < version.parse("0.9.0.dev0") 146 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 147 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 148 | deprecation_message = ( 149 | "The configuration file of the unet has set the default `sample_size` to smaller than" 150 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 151 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 152 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 153 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 154 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 155 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 156 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 157 | " the `unet/config.json` file" 158 | ) 159 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 160 | new_config = dict(unet.config) 161 | new_config["sample_size"] = 64 162 | unet._internal_dict = FrozenDict(new_config) 163 | 164 | self.register_modules( 165 | vae=vae, 166 | image_encoder=image_encoder, 167 | unet=unet, 168 | scheduler=scheduler, 169 | safety_checker=safety_checker, 170 | feature_extractor=feature_extractor, 171 | cc_projection=cc_projection, 172 | ) 173 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 174 | self.register_to_config(requires_safety_checker=requires_safety_checker) 175 | # self.model_mode = None 176 | 177 | def enable_vae_slicing(self): 178 | r""" 179 | Enable sliced VAE decoding. 180 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several 181 | steps. This is useful to save some memory and allow larger batch sizes. 182 | """ 183 | self.vae.enable_slicing() 184 | 185 | def disable_vae_slicing(self): 186 | r""" 187 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to 188 | computing decoding in one step. 189 | """ 190 | self.vae.disable_slicing() 191 | 192 | def enable_vae_tiling(self): 193 | r""" 194 | Enable tiled VAE decoding. 195 | When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in 196 | several steps. This is useful to save a large amount of memory and to allow the processing of larger images. 197 | """ 198 | self.vae.enable_tiling() 199 | 200 | def disable_vae_tiling(self): 201 | r""" 202 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to 203 | computing decoding in one step. 204 | """ 205 | self.vae.disable_tiling() 206 | 207 | def enable_sequential_cpu_offload(self, gpu_id=0): 208 | r""" 209 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, 210 | text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a 211 | `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. 212 | Note that offloading happens on a submodule basis. Memory savings are higher than with 213 | `enable_model_cpu_offload`, but performance is lower. 214 | """ 215 | if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): 216 | from accelerate import cpu_offload 217 | else: 218 | raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") 219 | 220 | device = torch.device(f"cuda:{gpu_id}") 221 | 222 | if self.device.type != "cpu": 223 | self.to("cpu", silence_dtype_warnings=True) 224 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) 225 | 226 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 227 | cpu_offload(cpu_offloaded_model, device) 228 | 229 | if self.safety_checker is not None: 230 | cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) 231 | 232 | def enable_model_cpu_offload(self, gpu_id=0): 233 | r""" 234 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared 235 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` 236 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with 237 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. 238 | """ 239 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): 240 | from accelerate import cpu_offload_with_hook 241 | else: 242 | raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") 243 | 244 | device = torch.device(f"cuda:{gpu_id}") 245 | 246 | if self.device.type != "cpu": 247 | self.to("cpu", silence_dtype_warnings=True) 248 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) 249 | 250 | hook = None 251 | for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: 252 | _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) 253 | 254 | if self.safety_checker is not None: 255 | _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) 256 | 257 | # We'll offload the last model manually. 258 | self.final_offload_hook = hook 259 | 260 | @property 261 | def _execution_device(self): 262 | r""" 263 | Returns the device on which the pipeline's models will be executed. After calling 264 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module 265 | hooks. 266 | """ 267 | if not hasattr(self.unet, "_hf_hook"): 268 | return self.device 269 | for module in self.unet.modules(): 270 | if ( 271 | hasattr(module, "_hf_hook") 272 | and hasattr(module._hf_hook, "execution_device") 273 | and module._hf_hook.execution_device is not None 274 | ): 275 | return torch.device(module._hf_hook.execution_device) 276 | return self.device 277 | 278 | def _encode_prompt( 279 | self, 280 | prompt, 281 | device, 282 | num_images_per_prompt, 283 | do_classifier_free_guidance, 284 | negative_prompt=None, 285 | prompt_embeds: Optional[torch.FloatTensor] = None, 286 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 287 | ): 288 | r""" 289 | Encodes the prompt into text encoder hidden states. 290 | Args: 291 | prompt (`str` or `List[str]`, *optional*): 292 | prompt to be encoded 293 | device: (`torch.device`): 294 | torch device 295 | num_images_per_prompt (`int`): 296 | number of images that should be generated per prompt 297 | do_classifier_free_guidance (`bool`): 298 | whether to use classifier free guidance or not 299 | negative_prompt (`str` or `List[str]`, *optional*): 300 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 301 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 302 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 303 | prompt_embeds (`torch.FloatTensor`, *optional*): 304 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 305 | provided, text embeddings will be generated from `prompt` input argument. 306 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 307 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 308 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 309 | argument. 310 | """ 311 | if prompt is not None and isinstance(prompt, str): 312 | batch_size = 1 313 | elif prompt is not None and isinstance(prompt, list): 314 | batch_size = len(prompt) 315 | else: 316 | batch_size = prompt_embeds.shape[0] 317 | 318 | if prompt_embeds is None: 319 | text_inputs = self.tokenizer( 320 | prompt, 321 | padding="max_length", 322 | max_length=self.tokenizer.model_max_length, 323 | truncation=True, 324 | return_tensors="pt", 325 | ) 326 | text_input_ids = text_inputs.input_ids 327 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 328 | 329 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 330 | text_input_ids, untruncated_ids 331 | ): 332 | removed_text = self.tokenizer.batch_decode( 333 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 334 | ) 335 | logger.warning( 336 | "The following part of your input was truncated because CLIP can only handle sequences up to" 337 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 338 | ) 339 | 340 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 341 | attention_mask = text_inputs.attention_mask.to(device) 342 | else: 343 | attention_mask = None 344 | 345 | prompt_embeds = self.text_encoder( 346 | text_input_ids.to(device), 347 | attention_mask=attention_mask, 348 | ) 349 | prompt_embeds = prompt_embeds[0] 350 | 351 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 352 | 353 | bs_embed, seq_len, _ = prompt_embeds.shape 354 | # duplicate text embeddings for each generation per prompt, using mps friendly method 355 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 356 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 357 | 358 | # get unconditional embeddings for classifier free guidance 359 | if do_classifier_free_guidance and negative_prompt_embeds is None: 360 | uncond_tokens: List[str] 361 | if negative_prompt is None: 362 | uncond_tokens = [""] * batch_size 363 | elif type(prompt) is not type(negative_prompt): 364 | raise TypeError( 365 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 366 | f" {type(prompt)}." 367 | ) 368 | elif isinstance(negative_prompt, str): 369 | uncond_tokens = [negative_prompt] 370 | elif batch_size != len(negative_prompt): 371 | raise ValueError( 372 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 373 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 374 | " the batch size of `prompt`." 375 | ) 376 | else: 377 | uncond_tokens = negative_prompt 378 | 379 | max_length = prompt_embeds.shape[1] 380 | uncond_input = self.tokenizer( 381 | uncond_tokens, 382 | padding="max_length", 383 | max_length=max_length, 384 | truncation=True, 385 | return_tensors="pt", 386 | ) 387 | 388 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 389 | attention_mask = uncond_input.attention_mask.to(device) 390 | else: 391 | attention_mask = None 392 | 393 | negative_prompt_embeds = self.text_encoder( 394 | uncond_input.input_ids.to(device), 395 | attention_mask=attention_mask, 396 | ) 397 | negative_prompt_embeds = negative_prompt_embeds[0] 398 | 399 | if do_classifier_free_guidance: 400 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 401 | seq_len = negative_prompt_embeds.shape[1] 402 | 403 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 404 | 405 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 406 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 407 | 408 | # For classifier free guidance, we need to do two forward passes. 409 | # Here we concatenate the unconditional and text embeddings into a single batch 410 | # to avoid doing two forward passes 411 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 412 | 413 | return prompt_embeds 414 | 415 | def CLIP_preprocess(self, x): 416 | dtype = x.dtype 417 | # following openai's implementation 418 | # TODO HF OpenAI CLIP preprocessing issue https://github.com/huggingface/transformers/issues/22505#issuecomment-1650170741 419 | # follow openai preprocessing to keep exact same, input tensor [-1, 1], otherwise the preprocessing will be different, https://github.com/huggingface/transformers/pull/22608 420 | if isinstance(x, torch.Tensor): 421 | if x.min() < -1.0 or x.max() > 1.0: 422 | raise ValueError("Expected input tensor to have values in the range [-1, 1]") 423 | x = kornia.geometry.resize( 424 | x.to(torch.float32), (224, 224), interpolation="bicubic", align_corners=True, antialias=False 425 | ).to(dtype=dtype) 426 | x = (x + 1.0) / 2.0 427 | # renormalize according to clip 428 | x = kornia.enhance.normalize( 429 | x, torch.Tensor([0.48145466, 0.4578275, 0.40821073]), torch.Tensor([0.26862954, 0.26130258, 0.27577711]) 430 | ) 431 | return x 432 | 433 | # from image_variation 434 | def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): 435 | dtype = next(self.image_encoder.parameters()).dtype 436 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 437 | raise ValueError( 438 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 439 | ) 440 | 441 | if isinstance(image, torch.Tensor): 442 | # Batch single image 443 | if image.ndim == 3: 444 | assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" 445 | image = image.unsqueeze(0) 446 | 447 | assert image.ndim == 4, "Image must have 4 dimensions" 448 | 449 | # Check image is in [-1, 1] 450 | if image.min() < -1 or image.max() > 1: 451 | raise ValueError("Image should be in [-1, 1] range") 452 | else: 453 | # preprocess image 454 | if isinstance(image, (PIL.Image.Image, np.ndarray)): 455 | image = [image] 456 | 457 | if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): 458 | image = [np.array(i.convert("RGB"))[None, :] for i in image] 459 | image = np.concatenate(image, axis=0) 460 | elif isinstance(image, list) and isinstance(image[0], np.ndarray): 461 | image = np.concatenate([i[None, :] for i in image], axis=0) 462 | 463 | image = image.transpose(0, 3, 1, 2) 464 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 465 | 466 | image = image.to(device=device, dtype=dtype) 467 | 468 | image = self.CLIP_preprocess(image) 469 | # if not isinstance(image, torch.Tensor): 470 | # # 0-255 471 | # print("Warning: image is processed by hf's preprocess, which is different from openai original's.") 472 | # image = self.feature_extractor(images=image, return_tensors="pt").pixel_values 473 | image_embeddings = self.image_encoder(image).image_embeds.to(dtype=dtype) 474 | image_embeddings = image_embeddings.unsqueeze(1) 475 | 476 | # duplicate image embeddings for each generation per prompt, using mps friendly method 477 | bs_embed, seq_len, _ = image_embeddings.shape 478 | image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) 479 | image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) 480 | 481 | if do_classifier_free_guidance: 482 | negative_prompt_embeds = torch.zeros_like(image_embeddings) 483 | 484 | # For classifier free guidance, we need to do two forward passes. 485 | # Here we concatenate the unconditional and text embeddings into a single batch 486 | # to avoid doing two forward passes 487 | image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) 488 | 489 | return image_embeddings 490 | 491 | def _encode_pose(self, pose, device, num_images_per_prompt, do_classifier_free_guidance): 492 | dtype = next(self.cc_projection.parameters()).dtype 493 | if isinstance(pose, torch.Tensor): 494 | pose_embeddings = pose.unsqueeze(1).to(device=device, dtype=dtype) 495 | else: 496 | if isinstance(pose[0], list): 497 | pose = torch.Tensor(pose) 498 | else: 499 | pose = torch.Tensor([pose]) 500 | x, y, z = pose[:, 0].unsqueeze(1), pose[:, 1].unsqueeze(1), pose[:, 2].unsqueeze(1) 501 | pose_embeddings = ( 502 | torch.cat([torch.deg2rad(x), torch.sin(torch.deg2rad(y)), torch.cos(torch.deg2rad(y)), z], dim=-1) 503 | .unsqueeze(1) 504 | .to(device=device, dtype=dtype) 505 | ) # B, 1, 4 506 | # duplicate pose embeddings for each generation per prompt, using mps friendly method 507 | bs_embed, seq_len, _ = pose_embeddings.shape 508 | pose_embeddings = pose_embeddings.repeat(1, num_images_per_prompt, 1) 509 | pose_embeddings = pose_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) 510 | if do_classifier_free_guidance: 511 | negative_prompt_embeds = torch.zeros_like(pose_embeddings) 512 | 513 | # For classifier free guidance, we need to do two forward passes. 514 | # Here we concatenate the unconditional and text embeddings into a single batch 515 | # to avoid doing two forward passes 516 | pose_embeddings = torch.cat([negative_prompt_embeds, pose_embeddings]) 517 | return pose_embeddings 518 | 519 | def _encode_image_with_pose(self, image, pose, device, num_images_per_prompt, do_classifier_free_guidance): 520 | img_prompt_embeds = self._encode_image(image, device, num_images_per_prompt, False) 521 | pose_prompt_embeds = self._encode_pose(pose, device, num_images_per_prompt, False) 522 | prompt_embeds = torch.cat([img_prompt_embeds, pose_prompt_embeds], dim=-1) 523 | prompt_embeds = self.cc_projection(prompt_embeds) 524 | # prompt_embeds = img_prompt_embeds 525 | # follow 0123, add negative prompt, after projection 526 | if do_classifier_free_guidance: 527 | negative_prompt = torch.zeros_like(prompt_embeds) 528 | prompt_embeds = torch.cat([negative_prompt, prompt_embeds]) 529 | return prompt_embeds 530 | 531 | def run_safety_checker(self, image, device, dtype): 532 | if self.safety_checker is not None: 533 | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) 534 | image, has_nsfw_concept = self.safety_checker( 535 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) 536 | ) 537 | else: 538 | has_nsfw_concept = None 539 | return image, has_nsfw_concept 540 | 541 | def decode_latents(self, latents): 542 | latents = 1 / self.vae.config.scaling_factor * latents 543 | image = self.vae.decode(latents).sample 544 | image = (image / 2 + 0.5).clamp(0, 1) 545 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 546 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 547 | return image 548 | 549 | def prepare_extra_step_kwargs(self, generator, eta): 550 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 551 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 552 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 553 | # and should be between [0, 1] 554 | 555 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 556 | extra_step_kwargs = {} 557 | if accepts_eta: 558 | extra_step_kwargs["eta"] = eta 559 | 560 | # check if the scheduler accepts generator 561 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 562 | if accepts_generator: 563 | extra_step_kwargs["generator"] = generator 564 | return extra_step_kwargs 565 | 566 | def check_inputs(self, image, height, width, callback_steps): 567 | if ( 568 | not isinstance(image, torch.Tensor) 569 | and not isinstance(image, PIL.Image.Image) 570 | and not isinstance(image, list) 571 | ): 572 | raise ValueError( 573 | "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" 574 | f" {type(image)}" 575 | ) 576 | 577 | if height % 8 != 0 or width % 8 != 0: 578 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 579 | 580 | if (callback_steps is None) or ( 581 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 582 | ): 583 | raise ValueError( 584 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 585 | f" {type(callback_steps)}." 586 | ) 587 | 588 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 589 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) 590 | if isinstance(generator, list) and len(generator) != batch_size: 591 | raise ValueError( 592 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 593 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 594 | ) 595 | 596 | if latents is None: 597 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 598 | else: 599 | latents = latents.to(device) 600 | 601 | # scale the initial noise by the standard deviation required by the scheduler 602 | latents = latents * self.scheduler.init_noise_sigma 603 | return latents 604 | 605 | def prepare_img_latents(self, image, batch_size, dtype, device, generator=None, do_classifier_free_guidance=False): 606 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 607 | raise ValueError( 608 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 609 | ) 610 | 611 | if isinstance(image, torch.Tensor): 612 | # Batch single image 613 | if image.ndim == 3: 614 | assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" 615 | image = image.unsqueeze(0) 616 | 617 | assert image.ndim == 4, "Image must have 4 dimensions" 618 | 619 | # Check image is in [-1, 1] 620 | if image.min() < -1 or image.max() > 1: 621 | raise ValueError("Image should be in [-1, 1] range") 622 | else: 623 | # preprocess image 624 | if isinstance(image, (PIL.Image.Image, np.ndarray)): 625 | image = [image] 626 | 627 | if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): 628 | image = [np.array(i.convert("RGB"))[None, :] for i in image] 629 | image = np.concatenate(image, axis=0) 630 | elif isinstance(image, list) and isinstance(image[0], np.ndarray): 631 | image = np.concatenate([i[None, :] for i in image], axis=0) 632 | 633 | image = image.transpose(0, 3, 1, 2) 634 | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 635 | 636 | image = image.to(device=device, dtype=dtype) 637 | 638 | if isinstance(generator, list) and len(generator) != batch_size: 639 | raise ValueError( 640 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 641 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 642 | ) 643 | 644 | if isinstance(generator, list): 645 | init_latents = [ 646 | self.vae.encode(image[i : i + 1]).latent_dist.mode(generator[i]) 647 | for i in range(batch_size) # sample 648 | ] 649 | init_latents = torch.cat(init_latents, dim=0) 650 | else: 651 | init_latents = self.vae.encode(image).latent_dist.mode() 652 | 653 | # init_latents = self.vae.config.scaling_factor * init_latents # todo in original zero123's inference gradio_new.py, model.encode_first_stage() is not scaled by scaling_factor 654 | if batch_size > init_latents.shape[0]: 655 | # init_latents = init_latents.repeat(batch_size // init_latents.shape[0], 1, 1, 1) 656 | num_images_per_prompt = batch_size // init_latents.shape[0] 657 | # duplicate image latents for each generation per prompt, using mps friendly method 658 | bs_embed, emb_c, emb_h, emb_w = init_latents.shape 659 | init_latents = init_latents.unsqueeze(1) 660 | init_latents = init_latents.repeat(1, num_images_per_prompt, 1, 1, 1) 661 | init_latents = init_latents.view(bs_embed * num_images_per_prompt, emb_c, emb_h, emb_w) 662 | 663 | # init_latents = torch.cat([init_latents]*2) if do_classifier_free_guidance else init_latents # follow zero123 664 | init_latents = ( 665 | torch.cat([torch.zeros_like(init_latents), init_latents]) if do_classifier_free_guidance else init_latents 666 | ) 667 | 668 | init_latents = init_latents.to(device=device, dtype=dtype) 669 | return init_latents 670 | 671 | @torch.no_grad() 672 | @replace_example_docstring(EXAMPLE_DOC_STRING) 673 | def __call__( 674 | self, 675 | input_imgs: Union[torch.FloatTensor, PIL.Image.Image] = None, 676 | prompt_imgs: Union[torch.FloatTensor, PIL.Image.Image] = None, 677 | poses: Union[List[float], List[List[float]]] = None, 678 | torch_dtype=torch.float32, 679 | height: Optional[int] = None, 680 | width: Optional[int] = None, 681 | num_inference_steps: int = 50, 682 | guidance_scale: float = 3.0, 683 | negative_prompt: Optional[Union[str, List[str]]] = None, 684 | num_images_per_prompt: Optional[int] = 1, 685 | eta: float = 0.0, 686 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 687 | latents: Optional[torch.FloatTensor] = None, 688 | prompt_embeds: Optional[torch.FloatTensor] = None, 689 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 690 | output_type: Optional[str] = "pil", 691 | return_dict: bool = True, 692 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 693 | callback_steps: int = 1, 694 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 695 | controlnet_conditioning_scale: float = 1.0, 696 | ): 697 | r""" 698 | Function invoked when calling the pipeline for generation. 699 | Args: 700 | input_imgs (`PIL` or `List[PIL]`, *optional*): 701 | The single input image for each 3D object 702 | prompt_imgs (`PIL` or `List[PIL]`, *optional*): 703 | Same as input_imgs, but will be used later as an image prompt condition, encoded by CLIP feature 704 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 705 | The height in pixels of the generated image. 706 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 707 | The width in pixels of the generated image. 708 | num_inference_steps (`int`, *optional*, defaults to 50): 709 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 710 | expense of slower inference. 711 | guidance_scale (`float`, *optional*, defaults to 7.5): 712 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 713 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 714 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 715 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 716 | usually at the expense of lower image quality. 717 | negative_prompt (`str` or `List[str]`, *optional*): 718 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 719 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 720 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 721 | num_images_per_prompt (`int`, *optional*, defaults to 1): 722 | The number of images to generate per prompt. 723 | eta (`float`, *optional*, defaults to 0.0): 724 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 725 | [`schedulers.DDIMScheduler`], will be ignored for others. 726 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 727 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 728 | to make generation deterministic. 729 | latents (`torch.FloatTensor`, *optional*): 730 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 731 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 732 | tensor will ge generated by sampling using the supplied random `generator`. 733 | prompt_embeds (`torch.FloatTensor`, *optional*): 734 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 735 | provided, text embeddings will be generated from `prompt` input argument. 736 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 737 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 738 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 739 | argument. 740 | output_type (`str`, *optional*, defaults to `"pil"`): 741 | The output format of the generate image. Choose between 742 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 743 | return_dict (`bool`, *optional*, defaults to `True`): 744 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 745 | plain tuple. 746 | callback (`Callable`, *optional*): 747 | A function that will be called every `callback_steps` steps during inference. The function will be 748 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 749 | callback_steps (`int`, *optional*, defaults to 1): 750 | The frequency at which the `callback` function will be called. If not specified, the callback will be 751 | called at every step. 752 | cross_attention_kwargs (`dict`, *optional*): 753 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under 754 | `self.processor` in 755 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 756 | Examples: 757 | Returns: 758 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 759 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 760 | When returning a tuple, the first element is a list with the generated images, and the second element is a 761 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 762 | (nsfw) content, according to the `safety_checker`. 763 | """ 764 | # 0. Default height and width to unet 765 | height = height or self.unet.config.sample_size * self.vae_scale_factor 766 | width = width or self.unet.config.sample_size * self.vae_scale_factor 767 | 768 | # 1. Check inputs. Raise error if not correct 769 | # input_image = hint_imgs 770 | self.check_inputs(input_imgs, height, width, callback_steps) 771 | 772 | # 2. Define call parameters 773 | if isinstance(input_imgs, PIL.Image.Image): 774 | batch_size = 1 775 | elif isinstance(input_imgs, list): 776 | batch_size = len(input_imgs) 777 | else: 778 | batch_size = input_imgs.shape[0] 779 | device = self._execution_device 780 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 781 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 782 | # corresponds to doing no classifier free guidance. 783 | do_classifier_free_guidance = guidance_scale > 1.0 784 | 785 | # 3. Encode input image with pose as prompt 786 | prompt_embeds = self._encode_image_with_pose( 787 | prompt_imgs, poses, device, num_images_per_prompt, do_classifier_free_guidance 788 | ) 789 | 790 | # 4. Prepare timesteps 791 | self.scheduler.set_timesteps(num_inference_steps, device=device) 792 | timesteps = self.scheduler.timesteps 793 | 794 | # 5. Prepare latent variables 795 | latents = self.prepare_latents( 796 | batch_size * num_images_per_prompt, 797 | 4, 798 | height, 799 | width, 800 | prompt_embeds.dtype, 801 | device, 802 | generator, 803 | latents, 804 | ) 805 | 806 | # 6. Prepare image latents 807 | img_latents = self.prepare_img_latents( 808 | input_imgs, 809 | batch_size * num_images_per_prompt, 810 | prompt_embeds.dtype, 811 | device, 812 | generator, 813 | do_classifier_free_guidance, 814 | ) 815 | 816 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 817 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 818 | 819 | # 7. Denoising loop 820 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 821 | with self.progress_bar(total=num_inference_steps) as progress_bar: 822 | for i, t in enumerate(timesteps): 823 | # expand the latents if we are doing classifier free guidance 824 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 825 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 826 | latent_model_input = torch.cat([latent_model_input, img_latents], dim=1) 827 | 828 | # predict the noise residual 829 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample 830 | 831 | # perform guidance 832 | if do_classifier_free_guidance: 833 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 834 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 835 | 836 | # compute the previous noisy sample x_t -> x_t-1 837 | # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype) 838 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 839 | 840 | # call the callback, if provided 841 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 842 | progress_bar.update() 843 | if callback is not None and i % callback_steps == 0: 844 | step_idx = i // getattr(self.scheduler, "order", 1) 845 | callback(step_idx, t, latents) 846 | 847 | # 8. Post-processing 848 | has_nsfw_concept = None 849 | if output_type == "latent": 850 | image = latents 851 | elif output_type == "pil": 852 | # 8. Post-processing 853 | image = self.decode_latents(latents) 854 | # 10. Convert to PIL 855 | image = self.numpy_to_pil(image) 856 | else: 857 | # 8. Post-processing 858 | image = self.decode_latents(latents) 859 | 860 | # Offload last model to CPU 861 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 862 | self.final_offload_hook.offload() 863 | 864 | if not return_dict: 865 | return (image, has_nsfw_concept) 866 | 867 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) --------------------------------------------------------------------------------