├── LICENSE ├── README.md ├── app.py ├── demo.ipynb ├── magic_mix.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Partho 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MagicMix 2 | [![Generic badge](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue.svg)](https://huggingface.co/spaces/daspartho/MagicMix) 3 | 4 | Implementation of [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/pdf/2210.16056.pdf) paper. 5 | 6 | ![magicmix](https://user-images.githubusercontent.com/59410571/206903603-6c8da6ef-69c4-4400-b4a3-aef9206ff396.png) 7 | 8 | The aim of the method is to mix two different concepts in a semantic manner to synthesize a new concept while preserving the spatial layout and geometry. 9 | 10 | The method takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process. 11 | 12 | There are 3 parameters for the method- 13 | - `v`: It is the interpolation constant used in the layout generation phase. The greater the value of v, the greater the influence of the prompt on the layout generation process. 14 | - `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process. 15 | 16 | ### Usage 17 | 18 | ```python 19 | from PIL import Image 20 | from magic_mix import magic_mix 21 | 22 | img = Image.open('phone.jpg') 23 | out_img = magic_mix(img, 'bed', kmax=0.5) 24 | out_img.save("mix.jpg") 25 | ``` 26 | ``` 27 | python3 magic_mix.py \ 28 | "phone.jpg" \ 29 | "bed" \ 30 | "mix.jpg" \ 31 | --kmin 0.3 \ 32 | --kmax 0.6 \ 33 | --v 0.5 \ 34 | --steps 50 \ 35 | --seed 42 \ 36 | --guidance_scale 7.5 37 | ``` 38 | Also, check out the [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb) for example usage of the implementation to reproduce examples from the paper. 39 | 40 | You can also use the community pipeline on the diffusers libary. 41 | 42 | ```python 43 | from diffusers import DiffusionPipeline, DDIMScheduler 44 | from PIL import Image 45 | 46 | pipe = DiffusionPipeline.from_pretrained( 47 | "CompVis/stable-diffusion-v1-4", 48 | custom_pipeline="magic_mix", 49 | scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"), 50 | ).to('cuda') 51 | 52 | img = Image.open('phone.jpg') 53 | mix_img = pipe( 54 | img, 55 | prompt = 'bed', 56 | kmin = 0.3, 57 | kmax = 0.5, 58 | mix_factor = 0.5, 59 | ) 60 | mix_img.save('mix.jpg') 61 | ``` 62 | 63 | ### Some examples reproduced from the paper: 64 | 65 | ##### Input Image: 66 | 67 | ![telephone](https://user-images.githubusercontent.com/59410571/206903102-34e79b9f-9ed2-4fac-bb38-82871343c655.jpg) 68 | 69 | ##### Prompt: "Bed" 70 | 71 | ##### Output Image: 72 | 73 | ![telephone-bed](https://user-images.githubusercontent.com/59410571/206903104-913a671d-ef53-4ae4-919d-64c3059c8f67.jpg) 74 | 75 | ##### Input Image: 76 | 77 | ![sign](https://user-images.githubusercontent.com/59410571/206903307-b066dddd-8aaf-4104-9d5c-8427a51f37a7.jpg) 78 | 79 | ##### Prompt: "Family" 80 | 81 | ##### Output Image: 82 | 83 | ![sign-family](https://user-images.githubusercontent.com/59410571/206903320-7530a8ac-6594-4449-8328-bbc31befd9e8.jpg) 84 | 85 | ##### Input Image: 86 | 87 | ![sushi](https://user-images.githubusercontent.com/59410571/206903325-a06268ef-903e-434b-8365-68fb8b003d1e.jpg) 88 | 89 | ##### Prompt: "ice-cream" 90 | 91 | ##### Output Image: 92 | 93 | ![sushi-ice-cream](https://user-images.githubusercontent.com/59410571/206903341-e66d5c27-1543-489f-833b-dc8afc6c68e6.jpg) 94 | 95 | ##### Input Image: 96 | 97 | ![pineapple](https://user-images.githubusercontent.com/59410571/206903362-7c0464a7-ace4-4810-8fe3-37cab3d929a6.jpg) 98 | 99 | ##### Prompt: "Cake" 100 | 101 | ##### Output Image: 102 | 103 | ![pineapple-cake](https://user-images.githubusercontent.com/59410571/206903377-3b0fb63c-061e-4070-a8d1-eaca5738ae36.jpg) 104 | 105 | ### Note 106 | **I'm not the author of the paper, and this is not an official implementation** 107 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from magic_mix import magic_mix 3 | 4 | iface = gr.Interface( 5 | description="Implementation of MagicMix: Semantic Mixing with Diffusion Models paper", 6 | article="

Github

", 7 | fn=magic_mix, 8 | inputs=[ 9 | gr.Image(shape=(512, 512), type="pil"), 10 | gr.Text(), 11 | gr.Slider(value=0.3, minimum=0, maximum=1, step=0.1), 12 | gr.Slider(value=0.6, minimum=0, maximum=1, step=0.1), 13 | gr.Slider(value=0.5, minimum=0, maximum=1, step=0.1), 14 | gr.Number(value=42, maximum=2**64 - 1), 15 | gr.Slider(value=50), 16 | gr.Slider(value=7.5, minimum=1, maximum=15, step=0.1), 17 | ], 18 | outputs=gr.Image(), 19 | title="MagicMix", 20 | ) 21 | 22 | iface.launch() 23 | -------------------------------------------------------------------------------- /magic_mix.py: -------------------------------------------------------------------------------- 1 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler 2 | from transformers import CLIPTextModel, CLIPTokenizer, logging 3 | import torch 4 | from torchvision import transforms as tfms 5 | from tqdm.auto import tqdm 6 | from PIL import Image 7 | 8 | # Supress some unnecessary warnings when loading the CLIPTextModel 9 | logging.set_verbosity_error() 10 | 11 | # Set device 12 | device = "cuda" if torch.cuda.is_available() else "cpu" 13 | 14 | # Loading components we'll use 15 | 16 | tokenizer = CLIPTokenizer.from_pretrained( 17 | "openai/clip-vit-large-patch14", 18 | ) 19 | 20 | text_encoder = CLIPTextModel.from_pretrained( 21 | "openai/clip-vit-large-patch14", 22 | ).to(device) 23 | 24 | vae = AutoencoderKL.from_pretrained( 25 | "CompVis/stable-diffusion-v1-4", 26 | subfolder="vae", 27 | ).to(device) 28 | 29 | unet = UNet2DConditionModel.from_pretrained( 30 | "CompVis/stable-diffusion-v1-4", 31 | subfolder="unet", 32 | ).to(device) 33 | 34 | beta_start, beta_end = 0.00085, 0.012 35 | scheduler = DDIMScheduler( 36 | beta_start=beta_start, 37 | beta_end=beta_end, 38 | beta_schedule="scaled_linear", 39 | num_train_timesteps=1000, 40 | clip_sample=False, 41 | set_alpha_to_one=False, 42 | ) 43 | 44 | 45 | # convert PIL image to latents 46 | def encode(img): 47 | with torch.no_grad(): 48 | latent = vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(device) * 2 - 1) 49 | latent = 0.18215 * latent.latent_dist.sample() 50 | return latent 51 | 52 | 53 | # convert latents to PIL image 54 | def decode(latent): 55 | latent = (1 / 0.18215) * latent 56 | with torch.no_grad(): 57 | img = vae.decode(latent).sample 58 | img = (img / 2 + 0.5).clamp(0, 1) 59 | img = img.detach().cpu().permute(0, 2, 3, 1).numpy() 60 | img = (img * 255).round().astype("uint8") 61 | return Image.fromarray(img[0]) 62 | 63 | 64 | # convert prompt into text embeddings, also unconditional embeddings 65 | def prep_text(prompt): 66 | text_input = tokenizer( 67 | prompt, 68 | padding="max_length", 69 | max_length=tokenizer.model_max_length, 70 | truncation=True, 71 | return_tensors="pt", 72 | ) 73 | 74 | text_embedding = text_encoder(text_input.input_ids.to(device))[0] 75 | 76 | uncond_input = tokenizer( 77 | "", 78 | padding="max_length", 79 | max_length=tokenizer.model_max_length, 80 | truncation=True, 81 | return_tensors="pt", 82 | ) 83 | 84 | uncond_embedding = text_encoder(uncond_input.input_ids.to(device))[0] 85 | 86 | return torch.cat([uncond_embedding, text_embedding]) 87 | 88 | 89 | def magic_mix( 90 | img, # specifies the layout semantics 91 | prompt, # specifies the content semantics 92 | kmin=0.3, 93 | kmax=0.6, 94 | v=0.5, # interpolation constant 95 | seed=42, 96 | steps=50, 97 | guidance_scale=7.5, 98 | ): 99 | tmin = steps - int(kmin * steps) 100 | tmax = steps - int(kmax * steps) 101 | 102 | text_embeddings = prep_text(prompt) 103 | 104 | scheduler.set_timesteps(steps) 105 | 106 | width, height = img.size 107 | encoded = encode(img) 108 | 109 | torch.manual_seed(seed) 110 | noise = torch.randn( 111 | (1, unet.in_channels, height // 8, width // 8), 112 | ).to(device) 113 | 114 | latents = scheduler.add_noise(encoded, noise, timesteps=scheduler.timesteps[tmax]) 115 | 116 | input = torch.cat([latents] * 2) 117 | 118 | input = scheduler.scale_model_input(input, scheduler.timesteps[tmax]) 119 | 120 | with torch.no_grad(): 121 | pred = unet( 122 | input, 123 | scheduler.timesteps[tmax], 124 | encoder_hidden_states=text_embeddings, 125 | ).sample 126 | 127 | pred_uncond, pred_text = pred.chunk(2) 128 | pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) 129 | 130 | latents = scheduler.step(pred, scheduler.timesteps[tmax], latents).prev_sample 131 | 132 | for i, t in enumerate(tqdm(scheduler.timesteps)): 133 | if i > tmax: 134 | if i < tmin: # layout generation phase 135 | orig_latents = scheduler.add_noise(encoded, noise, timesteps=t) 136 | 137 | input = (v * latents) + ( 138 | 1 - v 139 | ) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics 140 | input = torch.cat([input] * 2) 141 | 142 | else: # content generation phase 143 | input = torch.cat([latents] * 2) 144 | 145 | input = scheduler.scale_model_input(input, t) 146 | 147 | with torch.no_grad(): 148 | pred = unet( 149 | input, 150 | t, 151 | encoder_hidden_states=text_embeddings, 152 | ).sample 153 | 154 | pred_uncond, pred_text = pred.chunk(2) 155 | pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) 156 | 157 | latents = scheduler.step(pred, t, latents).prev_sample 158 | 159 | return decode(latents) 160 | 161 | 162 | if __name__ == "__main__": 163 | import argparse 164 | 165 | parser = argparse.ArgumentParser() 166 | 167 | parser.add_argument( 168 | "img_file", 169 | type=str, 170 | help="image file to provide the layout semantics for the mixing process", 171 | ) 172 | parser.add_argument( 173 | "prompt", 174 | type=str, 175 | help="prompt to provide the content semantics for the mixing process", 176 | ) 177 | parser.add_argument("out_file", type=str, help="filename to save the generation to") 178 | parser.add_argument("--kmin", type=float, default=0.3) 179 | parser.add_argument("--kmax", type=float, default=0.6) 180 | parser.add_argument("--v", type=float, default=0.5) 181 | parser.add_argument("--seed", type=int, default=42) 182 | parser.add_argument("--steps", type=int, default=50) 183 | parser.add_argument("--guidance_scale", type=float, default=7.5) 184 | 185 | args = parser.parse_args() 186 | 187 | img = Image.open(args.img_file) 188 | out_img = magic_mix( 189 | img, 190 | args.prompt, 191 | args.kmin, 192 | args.kmax, 193 | args.v, 194 | args.seed, 195 | args.steps, 196 | args.guidance_scale, 197 | ) 198 | out_img.save(args.out_file) 199 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | diffusers 4 | transformers 5 | accelerate 6 | tqdm 7 | pillow 8 | gradio --------------------------------------------------------------------------------