├── docs └── intro.png ├── requirements.txt ├── utils.py ├── app_gradio.py ├── README.md ├── .gitignore ├── twin.py └── panorama.py /docs/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0606zt/TwinDiffusion/HEAD/docs/intro.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | diffusers>=0.26.0 4 | transformers 5 | tqdm 6 | gradio -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | # randomly corp a 512*512 block from a panorama 9 | def crop_img(input_folder, output_folder): 10 | png_imgs = [file for file in os.listdir(input_folder) if file.endswith('.png')] 11 | for i, png_img in tqdm(enumerate(png_imgs)): 12 | img_path = os.path.join(input_folder, png_img) 13 | img = Image.open(img_path) 14 | width, height = img.size 15 | 16 | x = random.randint(0, width - 512) 17 | y = random.randint(0, height - 512) 18 | 19 | cropped_img = img.crop((x, y, x + 512, y + 512)) 20 | 21 | output_path = os.path.join(output_folder, png_img) 22 | cropped_img.save(output_path) 23 | 24 | 25 | # arrange multiple images into a large one 26 | def view_images(images, num_rows=1, offset_ratio=0.02): 27 | if type(images) is list: 28 | num_empty = len(images) % num_rows 29 | elif images.ndim == 4: 30 | num_empty = images.shape[0] % num_rows 31 | else: 32 | images = [images] 33 | num_empty = 0 34 | 35 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 36 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 37 | num_items = len(images) 38 | 39 | h, w, c = images[0].shape 40 | offset = int(h * offset_ratio) 41 | num_cols = num_items // num_rows 42 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 43 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 44 | for i in range(num_rows): 45 | for j in range(num_cols): 46 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 47 | i * num_cols + j] 48 | 49 | pil_img = Image.fromarray(image_) 50 | return pil_img 51 | -------------------------------------------------------------------------------- /app_gradio.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torchvision.transforms as TT 3 | from panorama import TwinDiffusion, seed_everything 4 | 5 | seed_everything(-1) 6 | td = TwinDiffusion('cuda', '2.0') 7 | 8 | 9 | def generate_panorama_fn(prompt, width, lam, view_stride, cross_time): 10 | image = td.text2panorama_optm( 11 | prompts=prompt, 12 | negative_prompts="", 13 | width=width, 14 | lam=lam, 15 | view_stride=view_stride, 16 | cross_time=cross_time 17 | ) 18 | image = TT.ToPILImage()(image[0]) 19 | return image 20 | 21 | 22 | description = """ 23 |

Generate Panoramic Images with TwinDiffusion

24 |

[Code] 25 | [Paper]

26 | """ 27 | 28 | prompt_exp = [ 29 | ["A photo of the dolomites"], 30 | ["A photo of the mountain, lake, people and boats"], 31 | ["A landscape ink painting"], 32 | ["Natural landscape in anime style illustration"], 33 | ["A graphite sketch of a majestic mountain range"], 34 | ["A surrealistic artwork of urban park at dawn"] 35 | ] 36 | 37 | with gr.Blocks() as demo: 38 | gr.Markdown(description) 39 | output = gr.Image(label="Generated Image") 40 | with gr.Row(): 41 | with gr.Column(scale=3): 42 | prompt = gr.Textbox(label="Enter your prompt") 43 | with gr.Column(scale=1): 44 | generate_btn = gr.Button("Generate") 45 | gr.Examples(examples=prompt_exp, inputs=prompt) 46 | 47 | with gr.Accordion(label="⚙️ More Settings", open=False): 48 | width = gr.Slider( 49 | label="Image Width", 50 | minimum=512, 51 | maximum=4608, 52 | step=128, 53 | value=2048) 54 | lam = gr.Slider( 55 | label="Lambda", 56 | minimum=0, 57 | maximum=100, 58 | step=1, 59 | value=1) 60 | view_stride = gr.Slider( 61 | label="View Stride", 62 | minimum=8, 63 | maximum=48, 64 | step=8, 65 | value=16) 66 | cross_time = gr.Slider( 67 | label="Cross Time", 68 | minimum=2, 69 | maximum=10, 70 | step=1, 71 | value=2) 72 | 73 | generate_btn.click( 74 | fn=generate_panorama_fn, 75 | inputs=[prompt, width, lam, view_stride, cross_time], 76 | outputs=output 77 | ) 78 | demo.launch(share=True) 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ♊TwinDiffusion: Enhancing Coherence and Efficiency in Panoramic Image Generation with Diffusion Models 2 | 3 | ![intro](docs/intro.png) 4 | 5 | ## Brief Look 6 | 7 | **TwinDiffusion** is a MultiDiffusion-based framework that integrates two straightforward but effective methods to generate panoramic images with 8 | improved quality and efficiency. 9 | 10 | > **Abstract**   Diffusion models have emerged as effective tools for generating diverse and high-quality content. However, their capability in high-resolution image generation, particularly for panoramic images, still faces challenges such as visible seams and incoherent transitions. In this paper, we propose TwinDiffusion, an optimized framework designed to address these challenges through two key innovations: the Crop Fusion for quality enhancement and the Cross Sampling for efficiency optimization. We introduce a training-free optimizing stage to refine the similarity of adjacent image areas, as well as an interleaving sampling strategy to yield dynamic patches during the cropping process. A comprehensive evaluation is conducted to compare TwinDiffusion with the prior works, considering factors including coherence, fidelity, compatibility, and efficiency. The results demonstrate the superior performance of our approach in generating seamless and coherent panoramas, setting a new standard in quality and efficiency for panoramic image generation. 11 | 12 | For more details, please visit our [paper page](https://ebooks.iospress.nl/doi/10.3233/FAIA240512). 13 | 14 | ## Quick Start 15 | 16 | **Installation**   Set up and configure the environment by installing the required packages: 17 | 18 | ```bash 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | **Generation**   We support SDXL and batch generation, you can generate twin images and panoramic images with the following command: 23 | 24 | ```bash 25 | python twin.py --prompt "A photo of dolomites" --seed 5 --n 3 26 | ``` 27 | ```bash 28 | python panorama.py \ 29 | --prompt "Landscape ink painting" \ 30 | --sd_version 2.0 \ 31 | --H 512 \ 32 | --W 4096 \ 33 | --seed -1 \ 34 | --lam 1 \ 35 | --view_stride 16 \ 36 | --cross_time 2 \ 37 | --n 1 38 | ``` 39 | 40 | **App**   We also provide a gradio app for interactive testing: 41 | 42 | ```bash 43 | python app_gradio.py 44 | ``` 45 | 46 | ## Citation 47 | 48 | If you find our work helpful, please consider citing: 49 | 50 | ```bibtex 51 | @incollection{zhou2024twindiffusion, 52 | title={TwinDiffusion: Enhancing Coherence and Efficiency in Panoramic Image Generation with Diffusion Models}, 53 | author={Zhou, Teng and Tang, Yongchuan}, 54 | booktitle={ECAI 2024}, 55 | pages={386--393}, 56 | year={2024}, 57 | publisher={IOS Press} 58 | } 59 | ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /twin.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection 2 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms as TT 6 | import torch.nn.functional as nnf 7 | import numpy as np 8 | import argparse 9 | import os 10 | from tqdm import tqdm 11 | from utils import view_images 12 | 13 | 14 | def seed_everything(seed): 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | # torch.backends.cudnn.deterministic = True 18 | # torch.backends.cudnn.benchmark = True 19 | 20 | 21 | class TwinDiffusion(nn.Module): 22 | def __init__(self, device, sd_version='2.0', hf_key=None): 23 | super().__init__() 24 | 25 | self.device = device 26 | self.sd_version = sd_version 27 | 28 | if hf_key is not None: 29 | print(f'Using hugging face custom model key: {hf_key}') 30 | model_key = hf_key 31 | elif self.sd_version == '2.1': 32 | model_key = "stabilityai/stable-diffusion-2-1-base" 33 | elif self.sd_version == '2.0': 34 | model_key = "stabilityai/stable-diffusion-2-base" 35 | elif self.sd_version == '1.5': 36 | model_key = "runwayml/stable-diffusion-v1-5" 37 | elif self.sd_version == 'xl-1.0': 38 | model_key = "stabilityai/stable-diffusion-xl-base-1.0" 39 | else: 40 | raise ValueError(f'Stable Diffusion Version {self.sd_version} NOT Supported.') 41 | 42 | print('Loading stable diffusion...') 43 | 44 | self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device) 45 | self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") 46 | self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device) 47 | self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device) 48 | self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") 49 | if 'xl' in self.sd_version: 50 | self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer_2") 51 | self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_key, subfolder="text_encoder_2").to(self.device) 52 | 53 | print('Loaded stable diffusion!') 54 | 55 | @torch.no_grad() 56 | def get_text_embeds(self, prompts, negative_prompts): 57 | if 'xl' in self.sd_version: 58 | tokenizers = [self.tokenizer, self.tokenizer_2] 59 | text_encoders = [self.text_encoder, self.text_encoder_2] 60 | text_embeddings_list = [] 61 | uncond_text_embeddings_list = [] 62 | for tokenizer, text_encoder in zip(tokenizers, text_encoders): 63 | text_input = tokenizer(prompts, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt') 64 | text_embeddings = text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True) 65 | pooled_text_embeddings = text_embeddings[0] 66 | text_embeddings = text_embeddings.hidden_states[-2] 67 | text_embeddings_list.append(text_embeddings) 68 | 69 | uncond_input = tokenizer(negative_prompts, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt') 70 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(self.device), output_hidden_states=True) 71 | pooled_uncond_embeddings = uncond_embeddings[0] 72 | uncond_embeddings = uncond_embeddings.hidden_states[-2] 73 | uncond_text_embeddings_list.append(uncond_embeddings) 74 | 75 | text_embeddings = torch.cat(text_embeddings_list, dim=-1) 76 | uncond_embeddings = torch.cat(uncond_text_embeddings_list, dim=-1) 77 | 78 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 79 | add_text_embeds = torch.cat([pooled_uncond_embeddings, pooled_text_embeddings]) 80 | return text_embeddings, add_text_embeds 81 | else: 82 | text_input = self.tokenizer(prompts, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt') 83 | text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] 84 | 85 | uncond_input = self.tokenizer(negative_prompts, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') 86 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 87 | 88 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 89 | return text_embeddings 90 | 91 | @torch.no_grad() 92 | def denoise_single_step(self, latents, t, text_embeds, guidance_scale, added_cond_kwargs=None): 93 | latent_model_input = torch.cat([latents] * 2) 94 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep=t) 95 | if 'xl' in self.sd_version: 96 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds, added_cond_kwargs=added_cond_kwargs).sample 97 | else: 98 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds).sample 99 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 100 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 101 | latents = self.scheduler.step(noise_pred, t, latents).prev_sample 102 | return latents 103 | 104 | @torch.no_grad() 105 | def decode_latents(self, latents): 106 | imgs = self.vae.decode(1 / self.vae.config.scaling_factor * latents).sample 107 | imgs = (imgs / 2 + 0.5).clamp(0, 1) 108 | imgs = imgs.cpu().permute(0, 2, 3, 1).numpy() 109 | imgs = (imgs * 255).astype(np.uint8) 110 | return imgs 111 | 112 | def generate_twin_images(self, prompts, negative_prompts, lam=1.0, view_stride=32, num_inference_steps=50): 113 | height = width = 1024 if 'xl' in self.sd_version else 512 114 | guidance_scale = 5.0 if 'xl' in self.sd_version else 7.5 115 | window_size = 128 if 'xl' in self.sd_version else 64 116 | 117 | prompts = [prompts] if isinstance(prompts, str) else prompts 118 | negative_prompts = [negative_prompts] if isinstance(negative_prompts, str) else negative_prompts 119 | 120 | batch_size = len(prompts) 121 | 122 | if 'xl' in self.sd_version: 123 | text_embeds_1_2, add_text_embeds_1_2 = self.get_text_embeds(prompts * 2, negative_prompts * 2) 124 | text_embeds, add_text_embeds = self.get_text_embeds(prompts, negative_prompts) 125 | 126 | add_time_ids = torch.tensor([[height, width, 0, 0, height, width]], dtype=text_embeds.dtype, device=self.device) 127 | negative_add_time_ids = add_time_ids 128 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids]) 129 | add_time_ids_1_2 = add_time_ids.repeat(2 * batch_size, 1) 130 | add_time_ids = add_time_ids.repeat(batch_size, 1) 131 | 132 | added_cond_kwargs_1_2 = {"text_embeds": add_text_embeds_1_2, "time_ids": add_time_ids_1_2} 133 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 134 | else: 135 | text_embeds_1_2 = self.get_text_embeds(prompts * 2, negative_prompts * 2) 136 | text_embeds = self.get_text_embeds(prompts, negative_prompts) 137 | 138 | self.scheduler.set_timesteps(num_inference_steps) 139 | 140 | latent_1_2 = torch.randn((batch_size, 2, self.unet.config.in_channels, height // 8, width // 8), device=self.device) # z1 z2 141 | latent_1_2 = latent_1_2 * self.scheduler.init_noise_sigma 142 | latent_1_2[:, 1, :, :, :window_size - view_stride] = latent_1_2[:, 0, :, :, view_stride:] # initialize [z2]l=[z1]r 143 | latent_2_optm = latent_1_2[:, 1].clone() # initialize z2*=z2 144 | latent_1_2 = latent_1_2.reshape(2 * batch_size, *latent_1_2.shape[2:]) 145 | 146 | with tqdm(self.scheduler.timesteps, desc='Generating images') as pbar: 147 | for i, t in enumerate(pbar): 148 | if 'xl' in self.sd_version: 149 | latent_1_2 = self.denoise_single_step(latent_1_2, t, text_embeds_1_2, guidance_scale, added_cond_kwargs_1_2) 150 | latent_2_optm = self.denoise_single_step(latent_2_optm, t, text_embeds, guidance_scale, added_cond_kwargs) 151 | else: 152 | latent_1_2 = self.denoise_single_step(latent_1_2, t, text_embeds_1_2, guidance_scale) 153 | latent_2_optm = self.denoise_single_step(latent_2_optm, t, text_embeds, guidance_scale) 154 | 155 | # Crop Fusion 156 | if i < num_inference_steps // 2: # see Ablation.1 157 | latent_2_optm_pre = latent_2_optm.clone().detach() 158 | latent_1_2 = latent_1_2.reshape(batch_size, 2, *latent_1_2.shape[1:]) 159 | 160 | # training-based optimization 161 | # latent_2_optm.requires_grad = True 162 | # optimizer = torch.optim.Adam([latent_2_optm], lr=1e-3 * (1. - i / 100.)) 163 | # 164 | # for epoch in range(train_epochs): 165 | # loss = nnf.mse_loss(latent_1_2[:, 0, :, :, view_stride:], latent_2_optm[:, :, :, :window_size - view_stride]) + \ 166 | # lam * nnf.mse_loss(latent_2_optm_pre, latent_2_optm) 167 | # # limit test of the function 168 | # # loss = nnf.mse_loss(latent_1_2[:, 0, :, :, view_stride:], latent_2_optm[:, :, :, :window_size - view_stride]) + \ 169 | # # lam * nnf.mse_loss(latent_1_2[:, 1], latent_2_optm) 170 | # 171 | # pbar.set_postfix({'epoch': epoch, 'loss': loss.item() / batch_size}) 172 | # 173 | # optimizer.zero_grad() 174 | # loss.backward() 175 | # optimizer.step() 176 | 177 | # training-free optimization 178 | latent_2_optm[:, :, :, :window_size - view_stride] = (latent_1_2[:, 0, :, :, view_stride:] + 179 | lam * latent_2_optm_pre[:, :, :, :window_size - view_stride]) / (1 + lam) 180 | # limit test 181 | # latent_2_optm[:, :, :, :window_size - view_stride] = (latent_1_2[:, 0, :, :, view_stride:] + 182 | # lam * latent_1_2[:, 1, :, :, :window_size - view_stride]) / (1 + lam) 183 | 184 | latent_1_2 = latent_1_2.reshape(2 * batch_size, *latent_1_2.shape[2:]) 185 | 186 | latents = torch.cat([latent_1_2.reshape(batch_size, 2, *latent_1_2.shape[1:]), latent_2_optm.unsqueeze(1)], dim=1) 187 | imgs = self.decode_latents(latents.reshape(3 * batch_size, *latents.shape[2:])) 188 | imgs = imgs.reshape(batch_size, 3, *imgs.shape[1:]) # return I1、I2、I2*, I1 and I2* are twin images 189 | return imgs 190 | 191 | 192 | if __name__ == '__main__': 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument('--prompt', type=str, default="A photo of the dolomites") 195 | parser.add_argument('--negative', type=str, default="") 196 | parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0', '2.1', 'xl-1.0']) 197 | parser.add_argument('--seed', type=int, default=-1) 198 | parser.add_argument('--lam', type=float, default=1.0) 199 | parser.add_argument('--n', type=int, default=1) 200 | opt = parser.parse_args() 201 | 202 | if opt.seed != -1: 203 | seed_everything(opt.seed) 204 | 205 | device = torch.device('cuda:1') 206 | td = TwinDiffusion(device, opt.sd_version) 207 | 208 | imgs = td.generate_twin_images([opt.prompt] * opt.n, [opt.negative] * opt.n, lam=opt.lam) # [n,3,height,width,3] 209 | 210 | for i in tqdm(range(opt.n), desc='Saving images'): 211 | img = view_images(imgs[i]) 212 | img.save(f"out{i}.png") 213 | 214 | -------------------------------------------------------------------------------- /panorama.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection 2 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as nnf 6 | import torchvision.transforms as TT 7 | import numpy as np 8 | import argparse 9 | import os 10 | import time 11 | from tqdm import tqdm 12 | 13 | 14 | def seed_everything(seed): 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | # torch.backends.cudnn.deterministic = True 18 | # torch.backends.cudnn.benchmark = True 19 | 20 | 21 | def get_views(height, width, window_size, stride): 22 | num_blocks_height = (height - window_size) // stride + 1 23 | num_blocks_width = (width - window_size) // stride + 1 24 | total_num_blocks = int(num_blocks_height * num_blocks_width) 25 | views = [] 26 | for i in range(total_num_blocks): 27 | h_start = int((i // num_blocks_width) * stride) 28 | h_end = h_start + window_size 29 | w_start = int((i % num_blocks_width) * stride) 30 | w_end = w_start + window_size 31 | views.append((h_start, h_end, w_start, w_end)) 32 | return views 33 | 34 | 35 | class TwinDiffusion(nn.Module): 36 | def __init__(self, device, sd_version='2.0', hf_key=None): 37 | super().__init__() 38 | 39 | self.device = device 40 | self.sd_version = sd_version 41 | 42 | if hf_key is not None: 43 | print(f'Using hugging face custom model key: {hf_key}') 44 | model_key = hf_key 45 | elif self.sd_version == '2.1': 46 | model_key = "stabilityai/stable-diffusion-2-1-base" 47 | elif self.sd_version == '2.0': 48 | model_key = "stabilityai/stable-diffusion-2-base" 49 | elif self.sd_version == '1.5': 50 | model_key = "runwayml/stable-diffusion-v1-5" 51 | elif self.sd_version == 'xl-1.0': 52 | model_key = "stabilityai/stable-diffusion-xl-base-1.0" 53 | else: 54 | raise ValueError(f'Stable Diffusion Version {self.sd_version} NOT Supported.') 55 | 56 | print('Loading stable diffusion...') 57 | 58 | self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device) 59 | self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") 60 | self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device) 61 | self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device) 62 | self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") 63 | if 'xl' in self.sd_version: 64 | self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer_2") 65 | self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_key, subfolder="text_encoder_2").to(self.device) 66 | 67 | print('Loaded stable diffusion!') 68 | 69 | @torch.no_grad() 70 | def get_text_embeds(self, prompts, negative_prompts): 71 | if 'xl' in self.sd_version: 72 | tokenizers = [self.tokenizer, self.tokenizer_2] 73 | text_encoders = [self.text_encoder, self.text_encoder_2] 74 | text_embeddings_list = [] 75 | uncond_text_embeddings_list = [] 76 | for tokenizer, text_encoder in zip(tokenizers, text_encoders): 77 | text_input = tokenizer(prompts, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt') 78 | text_embeddings = text_encoder(text_input.input_ids.to(self.device), output_hidden_states=True) 79 | pooled_text_embeddings = text_embeddings[0] 80 | text_embeddings = text_embeddings.hidden_states[-2] 81 | text_embeddings_list.append(text_embeddings) 82 | 83 | uncond_input = tokenizer(negative_prompts, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt') 84 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(self.device), output_hidden_states=True) 85 | pooled_uncond_embeddings = uncond_embeddings[0] 86 | uncond_embeddings = uncond_embeddings.hidden_states[-2] 87 | uncond_text_embeddings_list.append(uncond_embeddings) 88 | 89 | text_embeddings = torch.cat(text_embeddings_list, dim=-1) 90 | uncond_embeddings = torch.cat(uncond_text_embeddings_list, dim=-1) 91 | 92 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 93 | add_text_embeds = torch.cat([pooled_uncond_embeddings, pooled_text_embeddings]) 94 | return text_embeddings, add_text_embeds 95 | else: 96 | text_input = self.tokenizer(prompts, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt') 97 | text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] 98 | 99 | uncond_input = self.tokenizer(negative_prompts, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') 100 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 101 | 102 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 103 | return text_embeddings 104 | 105 | @torch.no_grad() 106 | def decode_latents(self, latents): 107 | imgs = self.vae.decode(1 / self.vae.config.scaling_factor * latents).sample 108 | imgs = (imgs / 2 + 0.5).clamp(0, 1) 109 | imgs = imgs.cpu().permute(0, 2, 3, 1).numpy() 110 | imgs = (imgs * 255).astype(np.uint8) 111 | return imgs 112 | 113 | @torch.no_grad() 114 | def denoise_single_step(self, latents, t, text_embeds, guidance_scale, added_cond_kwargs=None): 115 | latent_model_input = torch.cat([latents] * 2) 116 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep=t) 117 | if 'xl' in self.sd_version: 118 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds, added_cond_kwargs=added_cond_kwargs).sample 119 | else: 120 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds).sample 121 | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) 122 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 123 | latents = self.scheduler.step(noise_pred, t, latents).prev_sample 124 | return latents 125 | 126 | def text2panorama_optm(self, prompts, negative_prompts, height=512, width=2048, lam=1.0, view_stride=16, cross_time=2, num_inference_steps=50): 127 | """ 128 | height, width:the size of panoramas 129 | lam:the Lagrange multiplier of Crop Fusion function 130 | view_stride:the step size when cropping panoramas 131 | cross_time:the frequency of Cross Sampling 132 | """ 133 | guidance_scale = 5.0 if 'xl' in self.sd_version else 7.5 134 | window_size = 128 if 'xl' in self.sd_version else 64 135 | 136 | prompts = [prompts] if isinstance(prompts, str) else prompts 137 | negative_prompts = [negative_prompts] if isinstance(negative_prompts, str) else negative_prompts 138 | 139 | batch_size = len(prompts) 140 | 141 | if 'xl' in self.sd_version: 142 | text_embeds, add_text_embeds = self.get_text_embeds(prompts, negative_prompts) 143 | 144 | add_time_ids = torch.tensor([[height, width, 0, 0, height, width]], dtype=text_embeds.dtype, device=self.device) 145 | negative_add_time_ids = add_time_ids 146 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids]) 147 | add_time_ids = add_time_ids.repeat(batch_size, 1) 148 | 149 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 150 | else: 151 | text_embeds = self.get_text_embeds(prompts, negative_prompts) 152 | 153 | self.scheduler.set_timesteps(num_inference_steps) 154 | 155 | latents = torch.randn((batch_size, self.unet.config.in_channels, height // 8, width // 8), device=self.device) 156 | latents = latents * self.scheduler.init_noise_sigma 157 | 158 | views = get_views(height // 8, width // 8, window_size=window_size, stride=view_stride) 159 | count = torch.zeros_like(latents) 160 | value = torch.zeros_like(latents) 161 | 162 | # Cross Sampling 163 | cross_stride = 0 164 | all_views = [views] 165 | all_cross_strides = [cross_stride] 166 | for _ in range(cross_time - 1): 167 | cross_stride += view_stride // cross_time 168 | views_cross = [views[0]] + [(hs, he, ws + cross_stride, we + cross_stride) for hs, he, ws, we in views[1:-1]] + [views[-1]] 169 | all_views.append(views_cross) 170 | all_cross_strides.append(cross_stride) 171 | 172 | with tqdm(self.scheduler.timesteps, desc='Generating images') as pbar: 173 | for i, t in enumerate(pbar): 174 | count.zero_(), value.zero_() 175 | 176 | for idx, view in enumerate(all_views[i % cross_time]): 177 | h_start, h_end, w_start, w_end = view 178 | latents_view = latents[:, :, h_start:h_end, w_start:w_end] 179 | 180 | if 'xl' in self.sd_version: 181 | latents_view = self.denoise_single_step(latents_view, t, text_embeds, guidance_scale, added_cond_kwargs=added_cond_kwargs) 182 | else: 183 | latents_view = self.denoise_single_step(latents_view, t, text_embeds, guidance_scale) 184 | 185 | # Crop Fusion 186 | if idx > 0 and i < num_inference_steps // 2: 187 | latents_view_pre = latents_view.clone().detach() 188 | 189 | # training-based optimization 190 | # latents_view.requires_grad = True 191 | # optimizer = torch.optim.Adam([latents_view], lr=1e-5 * (1. - i / 100.)) 192 | # 193 | # for epoch in range(train_epochs): 194 | # if idx == 1: 195 | # loss = nnf.mse_loss(nbr_views_optm[:, :, :, view_stride + all_cross_strides[i % cross_time]:], latents_view[:, :, :, :window_size - view_stride - all_cross_strides[i % cross_time]]) + \ 196 | # lam * nnf.mse_loss(latents_view_pre, latents_view) 197 | # elif idx == len(views) - 1: 198 | # loss = nnf.mse_loss(nbr_views_optm[:, :, :, view_stride - all_cross_strides[i % cross_time]:], latents_view[:, :, :, :window_size - view_stride + all_cross_strides[i % cross_time]]) + \ 199 | # lam * nnf.mse_loss(latents_view_pre, latents_view) 200 | # else: 201 | # loss = nnf.mse_loss(nbr_views_optm[:, :, :, view_stride:], latents_view[:, :, :, :window_size - view_stride]) + \ 202 | # lam * nnf.mse_loss(latents_view_pre, latents_view) 203 | # 204 | # optimizer.zero_grad() 205 | # loss.backward() 206 | # optimizer.step() 207 | 208 | # training-free optimization 209 | if idx == 1: 210 | latents_view[:, :, :, :window_size - view_stride - all_cross_strides[i % cross_time]] = (nbr_views_optm[:, :, :, view_stride + all_cross_strides[i % cross_time]:] + 211 | lam * latents_view_pre[:, :, :, :window_size - view_stride - all_cross_strides[i % cross_time]]) / (1 + lam) 212 | elif idx == len(views) - 1: 213 | latents_view[:, :, :, :window_size - view_stride + all_cross_strides[i % cross_time]] = (nbr_views_optm[:, :, :, view_stride - all_cross_strides[i % cross_time]:] + 214 | lam * latents_view_pre[:, :, :, :window_size - view_stride + all_cross_strides[i % cross_time]]) / (1 + lam) 215 | else: 216 | latents_view[:, :, :, :window_size - view_stride] = (nbr_views_optm[:, :, :, view_stride:] + 217 | lam * latents_view_pre[:, :, :, :window_size - view_stride]) / (1 + lam) 218 | 219 | value[:, :, h_start:h_end, w_start:w_end] += latents_view 220 | count[:, :, h_start:h_end, w_start:w_end] += 1 221 | 222 | nbr_views_optm = latents_view.clone() # reserving the left neighbor z^{i-1} for z^i 223 | 224 | latents = torch.where(count > 0, value / count, value) 225 | 226 | imgs = self.decode_latents(latents) 227 | return imgs 228 | 229 | 230 | if __name__ == '__main__': 231 | # sd: 232 | # H, W = 512, 2048 233 | # view_stride = 16 234 | # sdxl: 235 | # H, W = 1024, 4096 236 | # view_stride = 32 237 | 238 | parser = argparse.ArgumentParser() 239 | parser.add_argument('--prompt', type=str, default="A photo of the dolomites") 240 | parser.add_argument('--negative', type=str, default="") 241 | parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0', '2.1', 'xl-1.0']) 242 | parser.add_argument('--H', type=int, default=512) 243 | parser.add_argument('--W', type=int, default=2048) 244 | parser.add_argument('--seed', type=int, default=-1) 245 | parser.add_argument('--lam', type=float, default=1.0) 246 | parser.add_argument('--view_stride', type=int, default=16) 247 | parser.add_argument('--cross_time', type=int, default=2) 248 | parser.add_argument('--n', type=int, default=1) 249 | opt = parser.parse_args() 250 | 251 | if opt.seed != -1: 252 | seed_everything(opt.seed) 253 | 254 | device = torch.device('cuda') 255 | td = TwinDiffusion(device, opt.sd_version) 256 | 257 | start = time.time() 258 | imgs = td.text2panorama_optm([opt.prompt] * opt.n, [opt.negative] * opt.n, opt.H, opt.W, lam=opt.lam, view_stride=opt.view_stride) 259 | print(f"time: {time.time() - start} s") 260 | 261 | for i in tqdm(range(opt.n), desc='Saving images'): 262 | TT.ToPILImage()(imgs[i]).save(f"out{i}.png") 263 | 264 | --------------------------------------------------------------------------------