├── models ├── __pycache__ │ └── net.cpython-39.pyc └── net.py ├── utils ├── __pycache__ │ ├── plot.cpython-39.pyc │ ├── helper.cpython-39.pyc │ ├── plots.cpython-39.pyc │ └── helper_plot.cpython-39.pyc ├── helper.py └── plots.py ├── .ipynb_checkpoints ├── diffusion_07_DDIM-checkpoint.ipynb ├── diffusion_08_IDDPM-checkpoint.ipynb ├── diffusion_09_SDE-checkpoint.ipynb ├── diffusion_11_disco diffusion-checkpoint.ipynb └── diffusion_10_guided diffusion-checkpoint.ipynb ├── README.md ├── diffusion_11_Classifier Free Diffusion.ipynb ├── diffusion_13_CLIP guided diffusion.ipynb └── diffusion_14_Augmented CLIP Guided Diffusion.ipynb /models/__pycache__/net.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlin-ai/diffusion_tutorial/HEAD/models/__pycache__/net.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlin-ai/diffusion_tutorial/HEAD/utils/__pycache__/plot.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/helper.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlin-ai/diffusion_tutorial/HEAD/utils/__pycache__/helper.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plots.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlin-ai/diffusion_tutorial/HEAD/utils/__pycache__/plots.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/helper_plot.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlin-ai/diffusion_tutorial/HEAD/utils/__pycache__/helper_plot.cpython-39.pyc -------------------------------------------------------------------------------- /models/net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class ConditionalLinear(nn.Module): 5 | def __init__(self, num_in, num_out, n_steps): 6 | super(ConditionalLinear,self).__init__() 7 | self.num_out = num_out 8 | self.lin = nn.Linear(num_in, num_out) 9 | self.embed = nn.Embedding(n_steps, num_out) 10 | self.embed.weight.data.uniform_() 11 | 12 | def forward(self, x, y): 13 | out = self.lin(x) 14 | gamma = self.embed(y) 15 | out = gamma.view(-1, self.num_out) * out 16 | return out 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/diffusion_07_DDIM-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3c56be64", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [] 10 | } 11 | ], 12 | "metadata": { 13 | "kernelspec": { 14 | "display_name": "Python 3 (ipykernel)", 15 | "language": "python", 16 | "name": "python3" 17 | }, 18 | "language_info": { 19 | "codemirror_mode": { 20 | "name": "ipython", 21 | "version": 3 22 | }, 23 | "file_extension": ".py", 24 | "mimetype": "text/x-python", 25 | "name": "python", 26 | "nbconvert_exporter": "python", 27 | "pygments_lexer": "ipython3", 28 | "version": "3.9.7" 29 | } 30 | }, 31 | "nbformat": 4, 32 | "nbformat_minor": 5 33 | } 34 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/diffusion_08_IDDPM-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3c56be64", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [] 10 | } 11 | ], 12 | "metadata": { 13 | "kernelspec": { 14 | "display_name": "Python 3 (ipykernel)", 15 | "language": "python", 16 | "name": "python3" 17 | }, 18 | "language_info": { 19 | "codemirror_mode": { 20 | "name": "ipython", 21 | "version": 3 22 | }, 23 | "file_extension": ".py", 24 | "mimetype": "text/x-python", 25 | "name": "python", 26 | "nbconvert_exporter": "python", 27 | "pygments_lexer": "ipython3", 28 | "version": "3.9.7" 29 | } 30 | }, 31 | "nbformat": 4, 32 | "nbformat_minor": 5 33 | } 34 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/diffusion_09_SDE-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3c56be64", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [] 10 | } 11 | ], 12 | "metadata": { 13 | "kernelspec": { 14 | "display_name": "Python 3 (ipykernel)", 15 | "language": "python", 16 | "name": "python3" 17 | }, 18 | "language_info": { 19 | "codemirror_mode": { 20 | "name": "ipython", 21 | "version": 3 22 | }, 23 | "file_extension": ".py", 24 | "mimetype": "text/x-python", 25 | "name": "python", 26 | "nbconvert_exporter": "python", 27 | "pygments_lexer": "ipython3", 28 | "version": "3.9.7" 29 | } 30 | }, 31 | "nbformat": 4, 32 | "nbformat_minor": 5 33 | } 34 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/diffusion_11_disco diffusion-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3c56be64", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [] 10 | } 11 | ], 12 | "metadata": { 13 | "kernelspec": { 14 | "display_name": "Python 3 (ipykernel)", 15 | "language": "python", 16 | "name": "python3" 17 | }, 18 | "language_info": { 19 | "codemirror_mode": { 20 | "name": "ipython", 21 | "version": 3 22 | }, 23 | "file_extension": ".py", 24 | "mimetype": "text/x-python", 25 | "name": "python", 26 | "nbconvert_exporter": "python", 27 | "pygments_lexer": "ipython3", 28 | "version": "3.9.7" 29 | } 30 | }, 31 | "nbformat": 4, 32 | "nbformat_minor": 5 33 | } 34 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/diffusion_10_guided diffusion-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3c56be64", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [] 10 | } 11 | ], 12 | "metadata": { 13 | "kernelspec": { 14 | "display_name": "Python 3 (ipykernel)", 15 | "language": "python", 16 | "name": "python3" 17 | }, 18 | "language_info": { 19 | "codemirror_mode": { 20 | "name": "ipython", 21 | "version": 3 22 | }, 23 | "file_extension": ".py", 24 | "mimetype": "text/x-python", 25 | "name": "python", 26 | "nbconvert_exporter": "python", 27 | "pygments_lexer": "ipython3", 28 | "version": "3.9.7" 29 | } 30 | }, 31 | "nbformat": 4, 32 | "nbformat_minor": 5 33 | } 34 | -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.datasets import make_swiss_roll 3 | 4 | def sample_batch(size, noise=1.0): 5 | x, _= make_swiss_roll(size, noise=noise) 6 | return x[:, [0, 2]] / 10.0 7 | 8 | def extract(input, t, x): 9 | shape = x.shape 10 | out = torch.gather(input, 0, t.to(input.device)) 11 | reshape = [t.shape[0]] + [1] * (len(shape) - 1) 12 | return out.reshape(*reshape) 13 | 14 | def make_beta_schedule(schedule='linear', n_timesteps=1000, start=1e-5, end=1e-2): 15 | if schedule == 'linear': 16 | betas = torch.linspace(start, end, n_timesteps) 17 | elif schedule == "quad": 18 | betas = torch.linspace(start ** 0.5, end ** 0.5, n_timesteps) ** 2 19 | elif schedule == "sigmoid": 20 | betas = torch.linspace(-6, 6, n_timesteps) 21 | betas = torch.sigmoid(betas) * (end - start) + start 22 | return betas 23 | 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Diffusion 扩散模型 4 | 5 | 本notebook系列将介绍一种新的生成模型: diffusion 扩散概率模型。 6 | 7 | 我们将探讨一系列最新 diffusion 的模型以及相关酷炫的应用,将包括如下内容(其中 [x] 表示已经完成): 8 | 9 | - [x] 1 Score matching 10 | - [x] 2 Langevin dynamics 11 | - [x] 3 DPM(2015):Deep unsupervised learning using nonequilibrium thermodynamics 12 | - [x] 4 NCSN(2020): Noise conditional score networks 13 | - [x] 5 DDPM(2020): Denoising Diffusion Probabilistic Models 14 | - [ ] 6 WAVEGRAD(2020): ESTIMATING GRADIENTS FOR WAVEFORM GENERATION 15 | - [x] 7 DDIM(2021): DENOISING DIFFUSION IMPLICIT MODELS 16 | - [ ] 8 IDDPM(2021): Improved Denoising Diffusion Probabilistic Models 17 | - [ ] 9 SDE(2021): SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS 18 | - [ ] 10 Guided Diffusion(2021): Diffusion Models Beat GANs on Image Synthesis 19 | - [x] 11 Classifier Free Diffusion(2021): Classifier-Free Diffusion Guidance 20 | - [x] 12 Latent Diffusion (2022): High-Resolution Image Synthesis with Latent Diffusion Models 21 | - [x] 13 CLIP guided diffusion 22 | - [x] 14 Augmented CLIP Guided Diffusion 23 | - [x] 15 Disco Diffusion 24 | - [x] 16 High Resolution Image Synthesis with Latent Diffusion -------------------------------------------------------------------------------- /diffusion_11_Classifier Free Diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Run this line in Colab to install the package if it is\n", 10 | "# not already installed.\n", 11 | "!pip install git+https://github.com/openai/glide-text2im" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from PIL import Image\n", 21 | "from IPython.display import display\n", 22 | "import torch as th\n", 23 | "\n", 24 | "from glide_text2im.download import load_checkpoint\n", 25 | "from glide_text2im.model_creation import (\n", 26 | " create_model_and_diffusion,\n", 27 | " model_and_diffusion_defaults,\n", 28 | " model_and_diffusion_defaults_upsampler\n", 29 | ")" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# This notebook supports both CPU and GPU.\n", 39 | "# On CPU, generating one sample may take on the order of 20 minutes.\n", 40 | "# On a GPU, it should be under a minute.\n", 41 | "\n", 42 | "has_cuda = th.cuda.is_available()\n", 43 | "device = th.device('cpu' if not has_cuda else 'cuda')" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "# Create base model.\n", 53 | "options = model_and_diffusion_defaults()\n", 54 | "options['use_fp16'] = has_cuda\n", 55 | "options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n", 56 | "model, diffusion = create_model_and_diffusion(**options)\n", 57 | "model.eval()\n", 58 | "if has_cuda:\n", 59 | " model.convert_to_fp16()\n", 60 | "model.to(device)\n", 61 | "model.load_state_dict(load_checkpoint('base', device))\n", 62 | "print('total base parameters', sum(x.numel() for x in model.parameters()))" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "# Create upsampler model.\n", 72 | "options_up = model_and_diffusion_defaults_upsampler()\n", 73 | "options_up['use_fp16'] = has_cuda\n", 74 | "options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling\n", 75 | "model_up, diffusion_up = create_model_and_diffusion(**options_up)\n", 76 | "model_up.eval()\n", 77 | "if has_cuda:\n", 78 | " model_up.convert_to_fp16()\n", 79 | "model_up.to(device)\n", 80 | "model_up.load_state_dict(load_checkpoint('upsample', device))\n", 81 | "print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "def show_images(batch: th.Tensor):\n", 91 | " \"\"\" Display a batch of images inline. \"\"\"\n", 92 | " scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n", 93 | " reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n", 94 | " display(Image.fromarray(reshaped.numpy()))" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# Sampling parameters\n", 104 | "prompt = \"an oil painting of a corgi\"\n", 105 | "batch_size = 1\n", 106 | "guidance_scale = 3.0\n", 107 | "\n", 108 | "# Tune this parameter to control the sharpness of 256x256 images.\n", 109 | "# A value of 1.0 is sharper, but sometimes results in grainy artifacts.\n", 110 | "upsample_temp = 0.997" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "##############################\n", 120 | "# Sample from the base model #\n", 121 | "##############################\n", 122 | "\n", 123 | "# Create the text tokens to feed to the model.\n", 124 | "tokens = model.tokenizer.encode(prompt)\n", 125 | "tokens, mask = model.tokenizer.padded_tokens_and_mask(\n", 126 | " tokens, options['text_ctx']\n", 127 | ")\n", 128 | "\n", 129 | "# Create the classifier-free guidance tokens (empty)\n", 130 | "full_batch_size = batch_size * 2\n", 131 | "uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(\n", 132 | " [], options['text_ctx']\n", 133 | ")\n", 134 | "\n", 135 | "# Pack the tokens together into model kwargs.\n", 136 | "model_kwargs = dict(\n", 137 | " tokens=th.tensor(\n", 138 | " [tokens] * batch_size + [uncond_tokens] * batch_size, device=device\n", 139 | " ),\n", 140 | " mask=th.tensor(\n", 141 | " [mask] * batch_size + [uncond_mask] * batch_size,\n", 142 | " dtype=th.bool,\n", 143 | " device=device,\n", 144 | " ),\n", 145 | ")\n", 146 | "\n", 147 | "# Create a classifier-free guidance sampling function\n", 148 | "def model_fn(x_t, ts, **kwargs):\n", 149 | " half = x_t[: len(x_t) // 2]\n", 150 | " combined = th.cat([half, half], dim=0)\n", 151 | " model_out = model(combined, ts, **kwargs)\n", 152 | " eps, rest = model_out[:, :3], model_out[:, 3:]\n", 153 | " cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)\n", 154 | " half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)\n", 155 | " eps = th.cat([half_eps, half_eps], dim=0)\n", 156 | " return th.cat([eps, rest], dim=1)\n", 157 | "\n", 158 | "# Sample from the base model.\n", 159 | "model.del_cache()\n", 160 | "samples = diffusion.p_sample_loop(\n", 161 | " model_fn,\n", 162 | " (full_batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n", 163 | " device=device,\n", 164 | " clip_denoised=True,\n", 165 | " progress=True,\n", 166 | " model_kwargs=model_kwargs,\n", 167 | " cond_fn=None,\n", 168 | ")[:batch_size]\n", 169 | "model.del_cache()\n", 170 | "\n", 171 | "# Show the output\n", 172 | "show_images(samples)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "##############################\n", 182 | "# Upsample the 64x64 samples #\n", 183 | "##############################\n", 184 | "\n", 185 | "tokens = model_up.tokenizer.encode(prompt)\n", 186 | "tokens, mask = model_up.tokenizer.padded_tokens_and_mask(\n", 187 | " tokens, options_up['text_ctx']\n", 188 | ")\n", 189 | "\n", 190 | "# Create the model conditioning dict.\n", 191 | "model_kwargs = dict(\n", 192 | " # Low-res image to upsample.\n", 193 | " low_res=((samples+1)*127.5).round()/127.5 - 1,\n", 194 | "\n", 195 | " # Text tokens\n", 196 | " tokens=th.tensor(\n", 197 | " [tokens] * batch_size, device=device\n", 198 | " ),\n", 199 | " mask=th.tensor(\n", 200 | " [mask] * batch_size,\n", 201 | " dtype=th.bool,\n", 202 | " device=device,\n", 203 | " ),\n", 204 | ")\n", 205 | "\n", 206 | "# Sample from the base model.\n", 207 | "model_up.del_cache()\n", 208 | "up_shape = (batch_size, 3, options_up[\"image_size\"], options_up[\"image_size\"])\n", 209 | "up_samples = diffusion_up.ddim_sample_loop(\n", 210 | " model_up,\n", 211 | " up_shape,\n", 212 | " noise=th.randn(up_shape, device=device) * upsample_temp,\n", 213 | " device=device,\n", 214 | " clip_denoised=True,\n", 215 | " progress=True,\n", 216 | " model_kwargs=model_kwargs,\n", 217 | " cond_fn=None,\n", 218 | ")[:batch_size]\n", 219 | "model_up.del_cache()\n", 220 | "\n", 221 | "# Show the output\n", 222 | "show_images(up_samples)" 223 | ] 224 | } 225 | ], 226 | "metadata": { 227 | "interpreter": { 228 | "hash": "e7d6e62d90e7e85f9a0faa7f0b1d576302d7ae6108e9fe361594f8e1c8b05781" 229 | }, 230 | "kernelspec": { 231 | "display_name": "Python 3", 232 | "language": "python", 233 | "name": "python3" 234 | }, 235 | "language_info": { 236 | "codemirror_mode": { 237 | "name": "ipython", 238 | "version": 3 239 | }, 240 | "file_extension": ".py", 241 | "mimetype": "text/x-python", 242 | "name": "python", 243 | "nbconvert_exporter": "python", 244 | "pygments_lexer": "ipython3", 245 | "version": "3.7.3" 246 | }, 247 | "accelerator": "GPU" 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /utils/plots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib as mpl 3 | import matplotlib.pyplot as plt 4 | import torch 5 | from sklearn.mixture import GaussianMixture 6 | import torch.distributions as distribution 7 | from matplotlib.patches import Ellipse 8 | 9 | def hdr_plot_style(): 10 | plt.style.use('dark_background') 11 | mpl.rcParams.update({'font.size': 18, 'lines.linewidth': 3, 'lines.markersize': 15}) 12 | # avoid type 3 (i.e. bitmap) fonts in figures 13 | mpl.rcParams['ps.useafm'] = True 14 | mpl.rcParams['pdf.use14corefonts'] = True 15 | mpl.rcParams['text.usetex'] = False 16 | mpl.rcParams['font.family'] = 'sans-serif' 17 | mpl.rcParams['font.sans-serif'] = 'Courier New' 18 | # mpl.rcParams['text.hinting'] = False 19 | # Set colors cycle 20 | colors = mpl.cycler('color', ['#3388BB', '#EE6666', '#9988DD', '#EECC55', '#88BB44', '#FFBBBB']) 21 | #plt.rc('figure', facecolor='#00000000', edgecolor='black') 22 | #plt.rc('axes', facecolor='#FFFFFF88', edgecolor='white', axisbelow=True, grid=True, prop_cycle=colors) 23 | plt.rc('legend', facecolor='#666666EE', edgecolor='white', fontsize=16) 24 | plt.rc('grid', color='white', linestyle='solid') 25 | plt.rc('text', color='white') 26 | plt.rc('xtick', direction='out', color='white') 27 | plt.rc('ytick', direction='out', color='white') 28 | plt.rc('patch', edgecolor='#E6E6E6') 29 | 30 | # define function that allows to generate a number of sub plots in a single line with the given titles 31 | def prep_plots(titles, fig_size, fig_num=1): 32 | """ 33 | create a figure with the number of sub_plots given by the number of totles, and return all generated subplot axis 34 | as a list 35 | """ 36 | # first close possibly existing old figures, if you dont' do this Juyter Lab will coplain after a while when it collects more than 20 existing ficgires for the same cell 37 | # plt.close(fig_num) 38 | # create a new figure 39 | hdr_plot_style() 40 | fig=plt.figure(fig_num, figsize=fig_size) 41 | ax_list = [] 42 | for ind, title in enumerate(titles, start=1): 43 | ax=fig.add_subplot(1, len(titles), ind) 44 | ax.set_title(title) 45 | ax_list.append(ax) 46 | return ax_list 47 | 48 | def finalize_plots(axes_list, legend=True, fig_title=None): 49 | """ 50 | adds grid and legend to all axes in the given list 51 | """ 52 | if fig_title: 53 | fig = axes_list[0].figure 54 | fig.suptitle(fig_title, y=1) 55 | for ax in axes_list: 56 | ax.grid(True) 57 | if legend: 58 | ax.legend() 59 | 60 | def plot_patterns(P,D): 61 | """ Plots the decision boundary of a single neuron with 2-dimensional inputs """ 62 | hdr_plot_style() 63 | nPats = P.shape[1] 64 | nUnits = D.shape[0] 65 | if nUnits < 2: 66 | D = np.concatenate(D, np.zeros(1,nPats)) 67 | # Create the figure 68 | fig = plt.figure(figsize=(10, 8)) 69 | ax = plt.gca() 70 | # Calculate the bounds for the plot and cause axes to be drawn. 71 | xmin, xmax = np.min(P[0, :]), np.max(P[0, :]) 72 | xb = (xmax - xmin) * 0.2 73 | ymin, ymax = np.min(P[1, :]), np.max(P[1, :]) 74 | yb = (ymax-ymin) * 0.2 75 | ax.set(xlim=[xmin-xb, xmax+xb], ylim=[ymin-yb, ymax+yb]) 76 | plt.title('Input Classification') 77 | plt.xlabel('x1'); plt.ylabel('x2') 78 | # classVal = 1 + D[1,:] + 2 * D[2,:]; 79 | colors = [[0, 0.2, 0.9], [0, 0.9, 0.2], [0, 0, 1], [0, 1, 0]] 80 | symbols = 'ooo*+x'; Dcopy = D[:] 81 | #Dcopy[Dcopy == 0] = 1 82 | for i in range(nPats): 83 | c = Dcopy[i] 84 | ax.scatter(P[0,i], P[1,i], marker=symbols[c], c=colors[c], s=50, linewidths=2, edgecolor='w') 85 | #ax.legend() 86 | ax.grid(True) 87 | return fig 88 | 89 | def plot_boundary(W,iVal,style,fig): 90 | """ Plots (bi-dimensionnal) input patterns """ 91 | nUnits = W.shape[0] 92 | colors = plt.cm.inferno_r.colors[1::3] 93 | xLims = plt.gca().get_xlim() 94 | for i in range(nUnits): 95 | if len(style) == 1: 96 | color = [1, 1, 1]; 97 | else: 98 | color = colors[int((3 * iVal + 9) % len(colors))] 99 | plt.plot(xLims,(-np.dot(W[i, 1], xLims) - W[i, 0]) / W[i, 2], linestyle=style, color=color, linewidth=1.5); 100 | fig.canvas.draw() 101 | 102 | def visualize_boundary_linear(X, y, model): 103 | # VISUALIZEBOUNDARYLINEAR plots a linear decision boundary learned by the SVM 104 | # VISUALIZEBOUNDARYLINEAR(X, y, model) plots a linear decision boundary 105 | # learned by the SVM and overlays the data on it 106 | hdr_plot_style() 107 | w = model["w"] 108 | b = model["b"] 109 | xp = np.linspace(np.min(X[:, 0]), np.max(X[:, 0]), 100).transpose() 110 | yp = - (w[0] * xp + b) / w[1] 111 | plt.figure(figsize=(12, 8)) 112 | pos = (y == 1)[:, 0] 113 | neg = (y == -1)[:, 0] 114 | plt.scatter(X[pos, 0], X[pos, 1], marker='x', linewidths=2, s=23, c=[0, 0.5, 0]) 115 | plt.scatter(X[neg, 0], X[neg, 1], marker='o', linewidths=2, s=23, c=[1, 0, 0]) 116 | plt.plot(xp, yp, '-b') 117 | plt.scatter(model["X"][:, 0], model["X"][:, 1], marker='o', linewidths=4, s=40, c=None, edgecolors=[0.1, 0.1, 0.1]) 118 | 119 | def plot_data(X, y): 120 | #PLOTDATA Plots the data points X and y into a new figure 121 | # PLOTDATA(x,y) plots the data points with + for the positive examples 122 | # and o for the negative examples. X is assumed to be a Mx2 matrix. 123 | # 124 | # Note: This was slightly modified such that it expects y = 1 or y = 0 125 | hdr_plot_style() 126 | # Find Indices of Positive and Negative Examples 127 | pos = (y == 1)[:, 0] 128 | neg = (y == 0)[:, 0] 129 | # Plot Examples 130 | fig = plt.figure(figsize=(12, 8)) 131 | plt.scatter(X[pos, 0], X[pos, 1], marker='x', edgecolor='k', linewidths=2, s=50, c=[0, 0.5, 0]) 132 | plt.scatter(X[neg, 0], X[neg, 1], marker='o', edgecolor='k', linewidths=2, s=50, c=[1, 0, 0]) 133 | return fig 134 | 135 | def visualize_boundary(X, y, model): 136 | #VISUALIZEBOUNDARY plots a non-linear decision boundary learned by the SVM 137 | # VISUALIZEBOUNDARYLINEAR(X, y, model) plots a non-linear decision 138 | # boundary learned by the SVM and overlays the data on it 139 | hdr_plot_style() 140 | # Plot the training data on top of the boundary 141 | plot_data(X, y) 142 | # Make classification predictions over a grid of values 143 | x1plot = np.linspace(np.min(X[:, 0]), np.max(X[:, 0]), 100).transpose() 144 | x2plot = np.linspace(np.min(X[:, 1]), np.max(X[:, 1]), 100).transpose() 145 | [X1, X2] = np.meshgrid(x1plot, x2plot) 146 | vals = np.zeros(X1.shape) 147 | for i in range(X1.shape[1]): 148 | this_X = np.vstack((X1[:, i], X2[:, i])) 149 | vals[:, i] = svmPredict(model, this_X) 150 | # Plot the SVM boundary 151 | plt.contour(X1, X2, vals, [1, 1], c='b') 152 | # Plot the support vectors 153 | plt.scatter(model["X"][:, 0], model["X"][:, 1], marker='o', linewidths=4, s=10, c=[0.1, 0.1, 0.1]) 154 | 155 | def plot_svc_decision_function(model, ax=None, plot_support=True): 156 | """Plot the decision function for a 2D SVC""" 157 | if ax is None: 158 | ax = plt.gca() 159 | xlim = ax.get_xlim() 160 | ylim = ax.get_ylim() 161 | 162 | # create grid to evaluate model 163 | x = np.linspace(xlim[0], xlim[1], 30) 164 | y = np.linspace(ylim[0], ylim[1], 30) 165 | Y, X = np.meshgrid(y, x) 166 | xy = np.vstack([X.ravel(), Y.ravel()]).T 167 | P = model.decision_function(xy).reshape(X.shape) 168 | 169 | # plot decision boundary and margins 170 | ax.contour(X, Y, P, colors='w', 171 | levels=[-1, 0, 1], alpha=0.9, 172 | linestyles=['--', '-', '--']) 173 | 174 | # plot support vectors 175 | if plot_support: 176 | ax.scatter(model.support_vectors_[:, 0], 177 | model.support_vectors_[:, 1], 178 | s=300, linewidth=2, edgecolor='w', facecolors='none'); 179 | ax.set_xlim(xlim) 180 | ax.set_ylim(ylim) 181 | 182 | 183 | def plot_gaussian_ellipsoid(m, C, sdwidth=1, npts=None, axh=None, color='r'): 184 | # PLOT_GAUSSIAN_ELLIPSOIDS plots 2-d and 3-d Gaussian distributions 185 | # 186 | # H = PLOT_GAUSSIAN_ELLIPSOIDS(M, C) plots the distribution specified by 187 | # mean M and covariance C. The distribution is plotted as an ellipse (in 188 | # 2-d) or an ellipsoid (in 3-d). By default, the distributions are 189 | # plotted in the current axes. 190 | 191 | # PLOT_GAUSSIAN_ELLIPSOIDS(M, C, SD) uses SD as the standard deviation 192 | # along the major and minor axes (larger SD => larger ellipse). By 193 | # default, SD = 1. 194 | # PLOT_GAUSSIAN_ELLIPSOIDS(M, C, SD, NPTS) plots the ellipse or 195 | # ellipsoid with a resolution of NPTS 196 | # 197 | # PLOT_GAUSSIAN_ELLIPSOIDS(M, C, SD, NPTS, AX) adds the plot to the 198 | # axes specified by the axis handle AX. 199 | # 200 | # Examples: 201 | # ------------------------------------------- 202 | # # Plot three 2-d Gaussians 203 | # figure; 204 | # h1 = plot_gaussian_ellipsoid([1 1], [1 0.5; 0.5 1]); 205 | # h2 = plot_gaussian_ellipsoid([2 1.5], [1 -0.7; -0.7 1]); 206 | # h3 = plot_gaussian_ellipsoid([0 0], [1 0; 0 1]); 207 | # set(h2,'color','r'); 208 | # set(h3,'color','g'); 209 | # 210 | # # "Contour map" of a 2-d Gaussian 211 | # figure; 212 | # for sd = [0.3:0.4:4], 213 | # h = plot_gaussian_ellipsoid([0 0], [1 0.8; 0.8 1], sd); 214 | # end 215 | # 216 | # # Plot three 3-d Gaussians 217 | # figure; 218 | # h1 = plot_gaussian_ellipsoid([1 1 0], [1 0.5 0.2; 0.5 1 0.4; 0.2 0.4 1]); 219 | # h2 = plot_gaussian_ellipsoid([1.5 1 .5], [1 -0.7 0.6; -0.7 1 0; 0.6 0 1]); 220 | # h3 = plot_gaussian_ellipsoid([1 2 2], [0.5 0 0; 0 0.5 0; 0 0 0.5]); 221 | # set(h2,'facealpha',0.6); 222 | # view(129,36); set(gca,'proj','perspective'); grid on; 223 | # grid on; axis equal; axis tight; 224 | # ------------------------------------------- 225 | # 226 | # Gautam Vallabha, Sep-23-2007, Gautam.Vallabha@mathworks.com 227 | 228 | # Revision 1.0, Sep-23-2007 229 | # - File created 230 | # Revision 1.1, 26-Sep-2007 231 | # - NARGOUT==0 check added. 232 | # - Help added on NPTS for ellipsoids 233 | 234 | if axh is None: 235 | axh = plt.gca() 236 | if m.size != len(m): 237 | raise Exception('M must be a vector'); 238 | if (m.size == 2): 239 | h = show2d(m[:], C, sdwidth, npts, axh, color) 240 | elif (m.size == 3): 241 | h = show3d(m[:], C, sdwidth, npts, axh, color) 242 | else: 243 | raise Exception('Unsupported dimensionality'); 244 | return h 245 | 246 | #----------------------------- 247 | def show2d(means, C, sdwidth, npts=None, axh=None, color='r'): 248 | if (npts is None): 249 | npts = 50 250 | # plot the gaussian fits 251 | tt = np.linspace(0, 2 * np.pi, npts).transpose() 252 | x = np.cos(tt); 253 | y = np.sin(tt); 254 | ap = np.vstack((x[:], y[:])).transpose() 255 | v, d = np.linalg.eigvals(C) 256 | d = sdwidth / np.sqrt(d) # convert variance to sdwidth*sd 257 | bp = np.dot(v, np.dot(d, ap)) + means 258 | h = axh.plot(bp[:, 0], bp[:, 1], ls='-', color=color) 259 | return h 260 | 261 | #----------------------------- 262 | def show3d(means, C, sdwidth, npts=None, axh=None): 263 | if (npts is None): 264 | npts = 20 265 | x, y, z = sphere(npts); 266 | ap = np.concatenate((x[:], y[:], z[:])).transpose() 267 | v, d = eigvals(C) 268 | if any(d[:] < 0): 269 | print('warning: negative eigenvalues') 270 | d = np.max(d, 0) 271 | d = sdwidth * np.sqrt(d); # convert variance to sdwidth*sd 272 | bp = (v * d * ap) + repmat(means, 1, size(ap,2)); 273 | xp = reshape(bp[0,:], size(x)); 274 | yp = reshape(bp[1,:], size(y)); 275 | zp = reshape(bp[2,:], size(z)); 276 | h = axh.surf(xp, yp, zp); 277 | return h 278 | 279 | def fit_multivariate_gaussian(X_s): 280 | gmm = GaussianMixture(n_components=1).fit(X_s) 281 | labels = gmm.predict(X_s) 282 | N = 50 283 | X = np.linspace(-2, 10, N) 284 | Y = np.linspace(-4, 4, N) 285 | X, Y = np.meshgrid(X, Y) 286 | # Pack X and Y into a single 3-dimensional array 287 | pos = np.empty(X.shape + (2,)) 288 | pos[:, :, 0] = X 289 | pos[:, :, 1] = Y 290 | norm = distribution.MultivariateNormal(torch.Tensor(gmm.means_[0]), torch.Tensor(gmm.covariances_[0])) 291 | Z = torch.exp(norm.log_prob(torch.Tensor(pos))).numpy() 292 | plt.figure(figsize=(10, 8)); 293 | ax = plt.gca() 294 | cset = ax.contourf(X, Y, Z, cmap='magma') 295 | plt.scatter(X_s[:, 0], X_s[:, 1], c='b', s=60, edgecolor='w', zorder=2.5); plt.grid(True); 296 | return labels 297 | 298 | def fit_gaussian_mixture(X_s): 299 | gmm = GaussianMixture(n_components=4).fit(X_s) 300 | labels = gmm.predict(X_s) 301 | N = 50 302 | X = np.linspace(-2, 10, N) 303 | Y = np.linspace(-4, 4, N) 304 | X, Y = np.meshgrid(X, Y) 305 | # Pack X and Y into a single 3-dimensional array 306 | pos = np.empty(X.shape + (2,)) 307 | pos[:, :, 0] = X 308 | pos[:, :, 1] = Y 309 | Z = np.zeros((pos.shape[0], pos.shape[1])) 310 | for i in range(4): 311 | norm = distribution.MultivariateNormal(torch.Tensor(gmm.means_[i]), torch.Tensor(gmm.covariances_[i])) 312 | Z += torch.exp(norm.log_prob(torch.Tensor(pos))).numpy() 313 | plt.figure(figsize=(10, 8)); 314 | ax = plt.gca() 315 | cset = ax.contourf(X, Y, Z, cmap='magma') 316 | plt.scatter(X_s[:, 0], X_s[:, 1], c='b', s=60, edgecolor='w', zorder=2.5); plt.grid(True); 317 | return labels 318 | 319 | def draw_ellipse(position, covariance, ax=None, **kwargs): 320 | """Draw an ellipse with a given position and covariance""" 321 | ax = ax or plt.gca() 322 | 323 | # Convert covariance to principal axes 324 | if covariance.shape == (2, 2): 325 | U, s, Vt = np.linalg.svd(covariance) 326 | angle = np.degrees(np.arctan2(U[1, 0], U[0, 0])) 327 | width, height = 2 * np.sqrt(s) 328 | else: 329 | angle = 0 330 | width, height = 2 * np.sqrt(covariance) 331 | 332 | # Draw the Ellipse 333 | for nsig in range(1, 4): 334 | ax.add_patch(Ellipse(position, nsig * width, nsig * height, 335 | angle, **kwargs)) 336 | 337 | def plot_gmm(gmm, X, label=True, ax=None): 338 | plt.figure(figsize=(10,8)) 339 | ax = ax or plt.gca() 340 | labels = gmm.fit(X).predict(X) 341 | if label: 342 | ax.scatter(X[:, 0], X[:, 1], c=labels, s=40, cmap='magma', edgecolor='gray', zorder=2) 343 | else: 344 | ax.scatter(X[:, 0], X[:, 1], s=40, zorder=2) 345 | ax.axis('equal') 346 | 347 | w_factor = 0.4 / gmm.weights_.max() 348 | for pos, covar, w in zip(gmm.means_, gmm.covariances_, gmm.weights_): 349 | draw_ellipse(pos, covar, alpha=w * w_factor) -------------------------------------------------------------------------------- /diffusion_13_CLIP guided diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "CLIP-Conditioned CLIP-Guided Diffusion (cc12m_1, 256x256)", 7 | "provenance": [], 8 | "collapsed_sections": [ 9 | "tWgrsecxuFmz" 10 | ], 11 | "include_colab_link": true 12 | }, 13 | "kernelspec": { 14 | "name": "python3", 15 | "display_name": "Python 3" 16 | }, 17 | "language_info": { 18 | "name": "python" 19 | }, 20 | "accelerator": "GPU" 21 | }, 22 | "cells": [ 23 | { 24 | "cell_type": "markdown", 25 | "metadata": { 26 | "id": "view-in-github", 27 | "colab_type": "text" 28 | }, 29 | "source": [ 30 | "\"Open" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "source": [ 36 | "# [CLIP-Conditioned CLIP-Guided Diffusion (cc12m_1, 256x256)](https://github.com/crowsonkb/v-diffusion-pytorch)\n", 37 | "\n", 38 | "By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings)\n", 39 | "\n", 40 | "and JD Pressman (https://twitter.com/jd_pressman).\n", 41 | "\n", 42 | "Notebook by BoneAmputee (https://twitter.com/BoneAmputee)." 43 | ], 44 | "metadata": { 45 | "id": "ISWlIVtq7Oui" 46 | } 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "source": [ 51 | "# Setup" 52 | ], 53 | "metadata": { 54 | "id": "tWgrsecxuFmz" 55 | } 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": { 61 | "id": "H7c8Jo0Yqb_d" 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "!pip install ftfy\n", 66 | "%cd /content/\n", 67 | "!git clone https://github.com/crowsonkb/v-diffusion-pytorch.git\n", 68 | "%cd v-diffusion-pytorch\n", 69 | "!git clone https://github.com/openai/CLIP.git\n", 70 | "%mkdir -p checkpoints\n", 71 | "%mkdir -p frames\n", 72 | "!curl -L \"https://v-diffusion.s3.us-west-2.amazonaws.com/cc12m_1.pth\" > \"checkpoints/cc12m_1.pth\"" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "source": [ 78 | "# Run" 79 | ], 80 | "metadata": { 81 | "id": "a053QXsbuIcy" 82 | } 83 | }, 84 | { 85 | "cell_type": "code", 86 | "source": [ 87 | "#!/usr/bin/env python3\n", 88 | "\n", 89 | "\"\"\"CLIP guided sampling from a diffusion model.\"\"\"\n", 90 | "\n", 91 | "import argparse\n", 92 | "from pathlib import Path\n", 93 | "\n", 94 | "from PIL import Image\n", 95 | "import torch\n", 96 | "from torch import nn\n", 97 | "from torch.nn import functional as F\n", 98 | "from torchvision import transforms\n", 99 | "from torchvision.transforms import functional as TF\n", 100 | "from tqdm.notebook import trange\n", 101 | "from IPython import display\n", 102 | "from shutil import rmtree\n", 103 | "import os\n", 104 | "\n", 105 | "from CLIP import clip\n", 106 | "from diffusion import get_model, get_models, utils\n", 107 | "\n", 108 | "MODULE_DIR = Path(\"/content/v-diffusion-pytorch/\").resolve()\n", 109 | "\n", 110 | "\n", 111 | "@torch.no_grad()\n", 112 | "def sample(model, x, steps, check_in, eta, extra_args):\n", 113 | " \"\"\"Draws samples from a model given starting noise.\"\"\"\n", 114 | " ts = x.new_ones([x.shape[0]])\n", 115 | "\n", 116 | " # Create the noise schedule\n", 117 | " alphas, sigmas = utils.t_to_alpha_sigma(steps)\n", 118 | "\n", 119 | " # The sampling loop\n", 120 | " for i in trange(len(steps)):\n", 121 | "\n", 122 | " # Get the model output (v, the predicted velocity)\n", 123 | " with torch.cuda.amp.autocast():\n", 124 | " v = model(x, ts * steps[i], **extra_args).float()\n", 125 | "\n", 126 | " # Predict the noise and the denoised image\n", 127 | " pred = x * alphas[i] - v * sigmas[i]\n", 128 | " eps = x * sigmas[i] + v * alphas[i]\n", 129 | "\n", 130 | " if i % check_in == 0:\n", 131 | " outfile = f'frames/{str(i).zfill(4)}.png'\n", 132 | " utils.to_pil_image(pred).save(outfile)\n", 133 | " display.display(display.Image(outfile))\n", 134 | "\n", 135 | " # If we are not on the last timestep, compute the noisy image for the\n", 136 | " # next timestep.\n", 137 | " if i < len(steps) - 1:\n", 138 | " # If eta > 0, adjust the scaling factor for the predicted noise\n", 139 | " # downward according to the amount of additional noise to add\n", 140 | " ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \\\n", 141 | " (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()\n", 142 | " adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()\n", 143 | "\n", 144 | " # Recombine the predicted noise and predicted denoised image in the\n", 145 | " # correct proportions for the next step\n", 146 | " x = pred * alphas[i + 1] + eps * adjusted_sigma\n", 147 | "\n", 148 | " # Add the correct amount of fresh noise\n", 149 | " if eta:\n", 150 | " x += torch.randn_like(x) * ddim_sigma\n", 151 | "\n", 152 | " # If we are on the last timestep, output the denoised image\n", 153 | " return pred\n", 154 | "\n", 155 | "\n", 156 | "@torch.no_grad()\n", 157 | "def cond_sample(model, x, steps, check_in, eta, extra_args, cond_fn):\n", 158 | " \"\"\"Draws guided samples from a model given starting noise.\"\"\"\n", 159 | " ts = x.new_ones([x.shape[0]])\n", 160 | "\n", 161 | " # Create the noise schedule\n", 162 | " alphas, sigmas = utils.t_to_alpha_sigma(steps)\n", 163 | "\n", 164 | " # The sampling loop\n", 165 | " for i in trange(len(steps)):\n", 166 | "\n", 167 | " # Get the model output\n", 168 | " with torch.enable_grad():\n", 169 | " x = x.detach().requires_grad_()\n", 170 | " with torch.cuda.amp.autocast():\n", 171 | " v = model(x, ts * steps[i], **extra_args)\n", 172 | "\n", 173 | " if steps[i] < 1:\n", 174 | " pred = x * alphas[i] - v * sigmas[i]\n", 175 | " if i % check_in == 0:\n", 176 | " outfile = f'frames/{str(i).zfill(4)}.png'\n", 177 | " utils.to_pil_image(pred).save(outfile)\n", 178 | " display.display(display.Image(outfile))\n", 179 | " cond_grad = cond_fn(x, ts * steps[i], pred, **extra_args).detach()\n", 180 | " v = v.detach() - cond_grad * (sigmas[i] / alphas[i])\n", 181 | " else:\n", 182 | " v = v.detach()\n", 183 | "\n", 184 | " # Predict the noise and the denoised image\n", 185 | " pred = x * alphas[i] - v * sigmas[i]\n", 186 | " eps = x * sigmas[i] + v * alphas[i]\n", 187 | "\n", 188 | " # If we are not on the last timestep, compute the noisy image for the\n", 189 | " # next timestep.\n", 190 | " if i < len(steps) - 1:\n", 191 | " # If eta > 0, adjust the scaling factor for the predicted noise\n", 192 | " # downward according to the amount of additional noise to add\n", 193 | " ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \\\n", 194 | " (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()\n", 195 | " adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()\n", 196 | "\n", 197 | " # Recombine the predicted noise and predicted denoised image in the\n", 198 | " # correct proportions for the next step\n", 199 | " x = pred * alphas[i + 1] + eps * adjusted_sigma\n", 200 | "\n", 201 | " # Add the correct amount of fresh noise\n", 202 | " if eta:\n", 203 | " x += torch.randn_like(x) * ddim_sigma\n", 204 | "\n", 205 | " # If we are on the last timestep, output the denoised image\n", 206 | " return pred\n", 207 | "\n", 208 | "\n", 209 | "class MakeCutouts(nn.Module):\n", 210 | " def __init__(self, cut_size, cutn, cut_pow=1.):\n", 211 | " super().__init__()\n", 212 | " self.cut_size = cut_size\n", 213 | " self.cutn = cutn\n", 214 | " self.cut_pow = cut_pow\n", 215 | "\n", 216 | " def forward(self, input):\n", 217 | " sideY, sideX = input.shape[2:4]\n", 218 | " max_size = min(sideX, sideY)\n", 219 | " min_size = min(sideX, sideY, self.cut_size)\n", 220 | " cutouts = []\n", 221 | " for _ in range(self.cutn):\n", 222 | " size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)\n", 223 | " offsetx = torch.randint(0, sideX - size + 1, ())\n", 224 | " offsety = torch.randint(0, sideY - size + 1, ())\n", 225 | " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", 226 | " cutout = F.adaptive_avg_pool2d(cutout, self.cut_size)\n", 227 | " cutouts.append(cutout)\n", 228 | " return torch.cat(cutouts)\n", 229 | "\n", 230 | "\n", 231 | "def spherical_dist_loss(x, y):\n", 232 | " x = F.normalize(x, dim=-1)\n", 233 | " y = F.normalize(y, dim=-1)\n", 234 | " return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)\n", 235 | "\n", 236 | "\n", 237 | "def parse_prompt(prompt):\n", 238 | " if prompt.startswith('http://') or prompt.startswith('https://'):\n", 239 | " vals = prompt.rsplit(':', 2)\n", 240 | " vals = [vals[0] + ':' + vals[1], *vals[2:]]\n", 241 | " else:\n", 242 | " vals = prompt.rsplit(':', 1)\n", 243 | " vals = vals + ['', '1'][len(vals):]\n", 244 | " return vals[0], float(vals[1])\n", 245 | "\n", 246 | "\n", 247 | "def main():\n", 248 | "\n", 249 | " #@markdown `prompts`: the text prompts to use. Relative weights for text prompts can be specified by putting the weight after a colon. The vertical bar character can be used to denote multiple prompts.\n", 250 | " prompts = \"an armchair in the shape of an avocado|conceptual art\" #@param {type:\"string\"}\n", 251 | " #@markdown `batch_size`: sample this many images at a time (default 1)\n", 252 | " batch_size = 1 #@param {type:\"integer\"}\n", 253 | " #@markdown `checkpoint`: manually specify the model checkpoint file\n", 254 | " checkpoint = \"\" #@param {type:\"string\"}\n", 255 | " #@markdown `clip_guidance_scale`: how strongly the result should match the text prompt (default 500). If set to 0, the cc12m_1 model will still be CLIP conditioned and sampling will go faster and use less memory.\n", 256 | " clip_guidance_scale = 500 #@param {type:\"number\"}\n", 257 | " #@markdown `device`: the PyTorch device name to use (default autodetects)\n", 258 | " device = \"cuda:0\" #@param {type:\"string\"}\n", 259 | " #@markdown `eta`: set to 0 for deterministic (DDIM) sampling, 1 (the default) for stochastic (DDPM) sampling, and in between to interpolate between the two. DDIM is preferred for low numbers of timesteps.\n", 260 | " eta = 1.0 #@param {type:\"number\"}\n", 261 | " #@markdown `images`: the image prompts to use (local files or HTTP(S) URLs). Relative weights for image prompts can be specified by putting the weight after a colon, for example: `\"image_1.png:0.5\"`.\n", 262 | " images = \"\" #@param {type:\"string\"}\n", 263 | " #@markdown `model`: specify the model to use (default cc12m_1)\n", 264 | " model = \"cc12m_1\" #@param {type:\"string\"}\n", 265 | " #@markdown `n`: sample until this many images are sampled (default 1)\n", 266 | " n = 1 #@param {type:\"integer\"}\n", 267 | " #@markdown `seed`: specify the random seed (default 0)\n", 268 | " seed = 0 #@param {type:\"integer\"}\n", 269 | " #@markdown `steps`: specify the number of diffusion timesteps (default is 1000, can lower for faster but lower quality sampling)\n", 270 | " steps = 1000 #@param {type:\"integer\"}\n", 271 | " #@markdown `check_in`: specify the number of steps between each image update\n", 272 | " check_in = 100 #@param {type:\"integer\"}\n", 273 | " #@markdown `cutn`: specify the number of cuts to observe when guiding\n", 274 | " cutn = 16 #@param {type:\"integer\"}\n", 275 | " #@markdown `cut_pow`: specify the cut power\n", 276 | " cut_pow = 1.0 #@param {type:\"number\"}\n", 277 | " #@markdown `width`: specify the width\n", 278 | " width = 256 #@param {type:\"integer\"}\n", 279 | " #@markdown `height`: specify the height\n", 280 | " height = 256 #@param {type:\"integer\"}\n", 281 | "\n", 282 | " prompts = [x.strip() for x in prompts.split('|')]\n", 283 | " prompts = [x for x in prompts if x != '']\n", 284 | " images = [x.strip() for x in images.split('|')]\n", 285 | " images = [x for x in images if x != '']\n", 286 | "\n", 287 | " args = argparse.Namespace(\n", 288 | " prompts = prompts,\n", 289 | " batch_size = batch_size,\n", 290 | " checkpoint = checkpoint,\n", 291 | " clip_guidance_scale = clip_guidance_scale,\n", 292 | " device = device,\n", 293 | " eta = eta,\n", 294 | " images = images,\n", 295 | " model = model,\n", 296 | " n = n,\n", 297 | " seed = seed,\n", 298 | " steps = steps,\n", 299 | " check_in = check_in,\n", 300 | " cutn = cutn,\n", 301 | " cut_pow = cut_pow,\n", 302 | " width = width,\n", 303 | " height = height\n", 304 | " )\n", 305 | "\n", 306 | " if args.device:\n", 307 | " device = torch.device(args.device)\n", 308 | " else:\n", 309 | " device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 310 | " print('Using device:', device)\n", 311 | "\n", 312 | " model = get_model(args.model)()\n", 313 | " # _, side_y, side_x = model.shape\n", 314 | " side_y, side_x = (args.height//64)*64, (args.width//64)*64\n", 315 | " checkpoint = args.checkpoint\n", 316 | " if not checkpoint:\n", 317 | " checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pth'\n", 318 | " model.load_state_dict(torch.load(checkpoint, map_location='cpu'))\n", 319 | " if device.type == 'cuda':\n", 320 | " model = model.half()\n", 321 | " model = model.to(device).eval().requires_grad_(False)\n", 322 | " clip_model = clip.load(model.clip_model, jit=False, device=device)[0]\n", 323 | " clip_model.eval().requires_grad_(False)\n", 324 | " normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n", 325 | " std=[0.26862954, 0.26130258, 0.27577711])\n", 326 | " cutn = args.cutn\n", 327 | " make_cutouts = MakeCutouts(clip_model.visual.input_resolution, cutn=cutn, cut_pow=args.cut_pow)\n", 328 | "\n", 329 | " target_embeds, weights = [], []\n", 330 | "\n", 331 | " for prompt in args.prompts:\n", 332 | " txt, weight = parse_prompt(prompt)\n", 333 | " target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())\n", 334 | " weights.append(weight)\n", 335 | "\n", 336 | " for prompt in args.images:\n", 337 | " path, weight = parse_prompt(prompt)\n", 338 | " img = Image.open(utils.fetch(path)).convert('RGB')\n", 339 | " img = TF.resize(img, min(side_x, side_y, *img.size),\n", 340 | " transforms.InterpolationMode.LANCZOS)\n", 341 | " batch = make_cutouts(TF.to_tensor(img)[None].to(device))\n", 342 | " embeds = F.normalize(clip_model.encode_image(normalize(batch)).float(), dim=-1)\n", 343 | " target_embeds.append(embeds)\n", 344 | " weights.extend([weight / cutn] * cutn)\n", 345 | "\n", 346 | " if not target_embeds:\n", 347 | " raise RuntimeError('At least one text or image prompt must be specified.')\n", 348 | " target_embeds = torch.cat(target_embeds)\n", 349 | " weights = torch.tensor(weights, device=device)\n", 350 | " if weights.sum().abs() < 1e-3:\n", 351 | " raise RuntimeError('The weights must not sum to 0.')\n", 352 | " weights /= weights.sum().abs()\n", 353 | "\n", 354 | " clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1)\n", 355 | " clip_embed = clip_embed.repeat([args.n, 1])\n", 356 | "\n", 357 | " torch.manual_seed(args.seed)\n", 358 | "\n", 359 | " def cond_fn(x, t, pred, clip_embed):\n", 360 | " clip_in = normalize(make_cutouts((pred + 1) / 2))\n", 361 | " image_embeds = clip_model.encode_image(clip_in).view([cutn, x.shape[0], -1])\n", 362 | " losses = spherical_dist_loss(image_embeds, clip_embed[None])\n", 363 | " loss = losses.mean(0).sum() * args.clip_guidance_scale\n", 364 | " grad = -torch.autograd.grad(loss, x)[0]\n", 365 | " return grad\n", 366 | "\n", 367 | " def run(x, clip_embed):\n", 368 | " t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]\n", 369 | " steps = utils.get_spliced_ddpm_cosine_schedule(t)\n", 370 | " extra_args = {'clip_embed': clip_embed}\n", 371 | " if not args.clip_guidance_scale:\n", 372 | " return sample(model, x, steps, args.check_in, args.eta, extra_args)\n", 373 | " return cond_sample(model, x, steps, args.check_in, args.eta, extra_args, cond_fn)\n", 374 | "\n", 375 | " def run_all(n, batch_size):\n", 376 | " x = torch.randn([args.n, 3, side_y, side_x], device=device)\n", 377 | " for i in trange(0, n, batch_size):\n", 378 | " cur_batch_size = min(n - i, batch_size)\n", 379 | " outs = run(x[i:i+cur_batch_size], clip_embed[i:i+cur_batch_size])\n", 380 | " for j, out in enumerate(outs):\n", 381 | " outfile = f'out_{i + j:05}.png'\n", 382 | " utils.to_pil_image(out).save(outfile)\n", 383 | " display.display(display.Image(outfile))\n", 384 | "\n", 385 | " try:\n", 386 | " run_all(args.n, args.batch_size)\n", 387 | " except KeyboardInterrupt:\n", 388 | " pass\n", 389 | "\n", 390 | "\n", 391 | "if __name__ == '__main__':\n", 392 | " if os.path.exists(\"frames\"):\n", 393 | " rmtree(\"frames\")\n", 394 | " os.makedirs(\"frames\")\n", 395 | " main()\n" 396 | ], 397 | "metadata": { 398 | "id": "h-hZ1MbgvR0Y", 399 | "cellView": "form" 400 | }, 401 | "execution_count": null, 402 | "outputs": [] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "source": [ 407 | "" 408 | ], 409 | "metadata": { 410 | "id": "r0hhpsuv0NIz" 411 | }, 412 | "execution_count": null, 413 | "outputs": [] 414 | } 415 | ] 416 | } -------------------------------------------------------------------------------- /diffusion_14_Augmented CLIP Guided Diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "1YwMUyt9LHG1" 7 | }, 8 | "source": [ 9 | "# Generates images from text prompts with CLIP guided diffusion.\n", 10 | "\n", 11 | "By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses either OpenAI's 256x256 unconditional ImageNet or Katherine Crowson's fine-tuned 512x512 diffusion model (https://github.com/openai/guided-diffusion), together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images.\n", 12 | "\n", 13 | "Modified by Daniel Russell (https://github.com/russelldc, https://twitter.com/danielrussruss) to include (hopefully) optimal params for quick generations in 15-100 timesteps rather than 1000, as well as more robust augmentations.\n", 14 | "\n", 15 | "**Update**: Sep 19th 2021\n", 16 | "\n", 17 | "\n", 18 | "Further improvements from Dango233 and nsheppard helped improve the quality of diffusion in general, and especially so for shorter runs like this notebook aims to achieve.\n", 19 | "\n", 20 | "Katherine's original notebook can be found here:\n", 21 | "https://colab.research.google.com/drive/1QBsaDAZv8np29FPbvjffbE1eytoJcsgA\n", 22 | "\n", 23 | "Vark has added some code to load in multiple Clip models at once, which all prompts are evaluated against, which may greatly improve accuracy.\n", 24 | "\n", 25 | "I, pbaylies, have added some code to augment a CLIP model, which may also improve the results." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": { 32 | "id": "qZ3rNuAWAewx" 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "import torch\n", 37 | "# Check the GPU status\n", 38 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 39 | "print('Using device:', device)\n", 40 | "!nvidia-smi" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "cellView": "form", 48 | "id": "yZsjzwS0YGo6" 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "#@title Choose model here:\n", 53 | "diffusion_model = \"256x256_diffusion_uncond\" #@param [\"256x256_diffusion_uncond\", \"512x512_diffusion_uncond_finetune_008100\"]\n", 54 | "#@markdown If you connect your Google Drive, you can save the final image of each run on your drive.\n", 55 | "\n", 56 | "google_drive = False #@param {type:\"boolean\"}\n", 57 | "\n", 58 | "#@markdown You can use your mounted Google Drive to load the model checkpoint file if you've already got a copy downloaded there. This will save time (and resources!) when you re-visit this notebook in the future.\n", 59 | "\n", 60 | "#@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:\n", 61 | "yes_please = False #@param {type:\"boolean\"}" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "cellView": "form", 69 | "id": "mQE-fIMnYKYK" 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "#@title Download diffusion model\n", 74 | "\n", 75 | "import os\n", 76 | "model_path = os.getcwd() + '/'\n", 77 | "if google_drive:\n", 78 | " from google.colab import drive\n", 79 | " drive.mount('/content/drive')\n", 80 | " if yes_please:\n", 81 | " model_path = '/content/drive/MyDrive/' \n", 82 | "\n", 83 | "if diffusion_model == '256x256_diffusion_uncond':\n", 84 | " !wget --continue 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt' -P {model_path}\n", 85 | "elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n", 86 | " !wget --continue 'https://the-eye.eu/public/AI/models/512x512_diffusion_unconditional_ImageNet/512x512_diffusion_uncond_finetune_008100.pt' -P {model_path}\n", 87 | "\n", 88 | "if google_drive and not yes_please:\n", 89 | " model_path = '/content/drive/MyDrive/' \n", 90 | "try:\n", 91 | " os.mkdir(model_path + 'image_storage')\n", 92 | "except:\n", 93 | " pass" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": { 99 | "id": "4jxCQbtInUCN" 100 | }, 101 | "source": [ 102 | "# Install and import dependencies" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": { 109 | "id": "-_UVMZCIAq_r" 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "!git clone https://github.com/openai/CLIP\n", 114 | "!git clone https://github.com/crowsonkb/guided-diffusion\n", 115 | "!pip3.8 install -e ./CLIP\n", 116 | "!pip3.8 install -e ./guided-diffusion\n", 117 | "!pip3.8 install lpips datetime" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": { 124 | "id": "JmbrcrhpBPC6" 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "import gc\n", 129 | "import io\n", 130 | "import math\n", 131 | "import sys\n", 132 | "from IPython import display\n", 133 | "import lpips\n", 134 | "from PIL import Image, ImageOps\n", 135 | "import requests\n", 136 | "import torch\n", 137 | "from torch import nn\n", 138 | "from torch.nn import functional as F\n", 139 | "import torchvision.transforms as T\n", 140 | "import torchvision.transforms.functional as TF\n", 141 | "from tqdm.notebook import tqdm\n", 142 | "sys.path.append('./CLIP')\n", 143 | "sys.path.append('./guided-diffusion')\n", 144 | "import clip\n", 145 | "from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults\n", 146 | "from datetime import datetime\n", 147 | "import numpy as np\n", 148 | "import matplotlib.pyplot as plt\n", 149 | "import random" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": { 155 | "id": "h4MHBPT1nirT" 156 | }, 157 | "source": [ 158 | "# Define necessary functions" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": { 165 | "id": "FpZczxnOnPIU" 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869\n", 170 | "\n", 171 | "def interp(t):\n", 172 | " return 3 * t**2 - 2 * t ** 3\n", 173 | "\n", 174 | "def perlin(width, height, scale=10, device=None):\n", 175 | " gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)\n", 176 | " xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)\n", 177 | " ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)\n", 178 | " wx = 1 - interp(xs)\n", 179 | " wy = 1 - interp(ys)\n", 180 | " dots = 0\n", 181 | " dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)\n", 182 | " dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)\n", 183 | " dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))\n", 184 | " dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))\n", 185 | " return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)\n", 186 | "\n", 187 | "def perlin_ms(octaves, width, height, grayscale, device=device):\n", 188 | " out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]\n", 189 | " # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]\n", 190 | " for i in range(1 if grayscale else 3):\n", 191 | " scale = 2 ** len(octaves)\n", 192 | " oct_width = width\n", 193 | " oct_height = height\n", 194 | " for oct in octaves:\n", 195 | " p = perlin(oct_width, oct_height, scale, device)\n", 196 | " out_array[i] += p * oct\n", 197 | " scale //= 2\n", 198 | " oct_width *= 2\n", 199 | " oct_height *= 2\n", 200 | " return torch.cat(out_array)\n", 201 | "\n", 202 | "def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):\n", 203 | " out = perlin_ms(octaves, width, height, grayscale)\n", 204 | " if grayscale:\n", 205 | " out = TF.resize(size=(side_x, side_y), img=out.unsqueeze(0))\n", 206 | " out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')\n", 207 | " else:\n", 208 | " out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])\n", 209 | " out = TF.resize(size=(side_x, side_y), img=out)\n", 210 | " out = TF.to_pil_image(out.clamp(0, 1).squeeze())\n", 211 | "\n", 212 | " out = ImageOps.autocontrast(out)\n", 213 | " return out" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": { 220 | "id": "YHOj78Yvx8jP" 221 | }, 222 | "outputs": [], 223 | "source": [ 224 | "def fetch(url_or_path):\n", 225 | " if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n", 226 | " r = requests.get(url_or_path)\n", 227 | " r.raise_for_status()\n", 228 | " fd = io.BytesIO()\n", 229 | " fd.write(r.content)\n", 230 | " fd.seek(0)\n", 231 | " return fd\n", 232 | " return open(url_or_path, 'rb')\n", 233 | "\n", 234 | "\n", 235 | "def parse_prompt(prompt):\n", 236 | " if prompt.startswith('http://') or prompt.startswith('https://'):\n", 237 | " vals = prompt.rsplit(':', 2)\n", 238 | " vals = [vals[0] + ':' + vals[1], *vals[2:]]\n", 239 | " else:\n", 240 | " vals = prompt.rsplit(':', 1)\n", 241 | " vals = vals + ['', '1'][len(vals):]\n", 242 | " return vals[0], float(vals[1])\n", 243 | "\n", 244 | "def sinc(x):\n", 245 | " return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n", 246 | "\n", 247 | "def lanczos(x, a):\n", 248 | " cond = torch.logical_and(-a < x, x < a)\n", 249 | " out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n", 250 | " return out / out.sum()\n", 251 | "\n", 252 | "def ramp(ratio, width):\n", 253 | " n = math.ceil(width / ratio + 1)\n", 254 | " out = torch.empty([n])\n", 255 | " cur = 0\n", 256 | " for i in range(out.shape[0]):\n", 257 | " out[i] = cur\n", 258 | " cur += ratio\n", 259 | " return torch.cat([-out[1:].flip([0]), out])[1:-1]\n", 260 | "\n", 261 | "def resample(input, size, align_corners=True):\n", 262 | " n, c, h, w = input.shape\n", 263 | " dh, dw = size\n", 264 | "\n", 265 | " input = input.reshape([n * c, 1, h, w])\n", 266 | "\n", 267 | " if dh < h:\n", 268 | " kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n", 269 | " pad_h = (kernel_h.shape[0] - 1) // 2\n", 270 | " input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n", 271 | " input = F.conv2d(input, kernel_h[None, None, :, None])\n", 272 | "\n", 273 | " if dw < w:\n", 274 | " kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n", 275 | " pad_w = (kernel_w.shape[0] - 1) // 2\n", 276 | " input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n", 277 | " input = F.conv2d(input, kernel_w[None, None, None, :])\n", 278 | "\n", 279 | " input = input.reshape([n, c, h, w])\n", 280 | " return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n", 281 | "\n", 282 | "class MakeCutouts(nn.Module):\n", 283 | " def __init__(self, cut_size, cutn, skip_augs=False):\n", 284 | " super().__init__()\n", 285 | " self.cut_size = cut_size\n", 286 | " self.cutn = cutn\n", 287 | " self.skip_augs = skip_augs\n", 288 | " self.augs = T.Compose([\n", 289 | " T.RandomHorizontalFlip(p=0.5),\n", 290 | " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", 291 | " T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n", 292 | " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", 293 | " T.RandomPerspective(distortion_scale=0.4, p=0.7),\n", 294 | " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", 295 | " T.RandomGrayscale(p=0.15),\n", 296 | " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", 297 | " # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n", 298 | " ])\n", 299 | "\n", 300 | " def forward(self, input):\n", 301 | " input = T.Pad(input.shape[2]//4, fill=0)(input)\n", 302 | " sideY, sideX = input.shape[2:4]\n", 303 | " max_size = min(sideX, sideY)\n", 304 | "\n", 305 | " cutouts = []\n", 306 | " for ch in range(cutn):\n", 307 | " if ch > cutn - cutn//4:\n", 308 | " cutout = input.clone()\n", 309 | " else:\n", 310 | " size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))\n", 311 | " offsetx = torch.randint(0, abs(sideX - size + 1), ())\n", 312 | " offsety = torch.randint(0, abs(sideY - size + 1), ())\n", 313 | " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", 314 | "\n", 315 | " if not self.skip_augs:\n", 316 | " cutout = self.augs(cutout)\n", 317 | " cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n", 318 | " del cutout\n", 319 | "\n", 320 | " cutouts = torch.cat(cutouts, dim=0)\n", 321 | " return cutouts\n", 322 | "\n", 323 | "\n", 324 | "def spherical_dist_loss(x, y):\n", 325 | " x = F.normalize(x, dim=-1)\n", 326 | " y = F.normalize(y, dim=-1)\n", 327 | " return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)\n", 328 | "\n", 329 | "\n", 330 | "def tv_loss(input):\n", 331 | " \"\"\"L2 total variation loss, as in Mahendran et al.\"\"\"\n", 332 | " input = F.pad(input, (0, 1, 0, 1), 'replicate')\n", 333 | " x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]\n", 334 | " y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]\n", 335 | " return (x_diff**2 + y_diff**2).mean([1, 2, 3])\n", 336 | "\n", 337 | "\n", 338 | "def range_loss(input):\n", 339 | " return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])\n", 340 | "\n", 341 | "def unitwise_norm(x, norm_type=2.0):\n", 342 | " if x.ndim <= 1:\n", 343 | " return x.norm(norm_type)\n", 344 | " else:\n", 345 | " # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor\n", 346 | " # might need special cases for other weights (possibly MHA) where this may not be true\n", 347 | " return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)\n", 348 | "\n", 349 | "def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0, rand=0.02):\n", 350 | " if isinstance(parameters, torch.Tensor):\n", 351 | " parameters = [parameters]\n", 352 | " for p in parameters:\n", 353 | " if p.grad is None:\n", 354 | " continue\n", 355 | " p_data = p.detach()\n", 356 | " g_data = p.grad.detach()\n", 357 | " max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)\n", 358 | " grad_norm = unitwise_norm(g_data, norm_type=norm_type)\n", 359 | " # add stochastic gradient clipping\n", 360 | " #clipped_grad = g_data * (((max_norm * torch.rand_like(g_data)).clamp_(min=1e-6)) / grad_norm.clamp(min=1e-6))\n", 361 | " # add noise\n", 362 | " #clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) + (max_norm / 50.0) * torch.randn_like(g_data)\n", 363 | " # do both?\n", 364 | " #clipped_grad = g_data * (((max_norm * torch.rand_like(g_data)).clamp_(min=1e-6)) / grad_norm.clamp(min=1e-6)) + (max_norm / 100.0) * torch.randn_like(g_data)\n", 365 | " if rand > 0.0:\n", 366 | " clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) + (rand*max_norm) * torch.randn_like(g_data)\n", 367 | " else:\n", 368 | " clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))\n", 369 | " #new_grads = torch.where(grad_norm < max_norm, g_data + (max_norm / 30.0) * torch.randn_like(g_data), clipped_grad)\n", 370 | " new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)\n", 371 | " p.grad.detach().copy_(new_grads)" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "metadata": { 378 | "id": "X5gODNAMEUCR" 379 | }, 380 | "outputs": [], 381 | "source": [ 382 | "def do_run():\n", 383 | " loss_values = []\n", 384 | " final_values = []\n", 385 | " images = []\n", 386 | " \n", 387 | " target_embeds, weights = [], []\n", 388 | " model_stats = []\n", 389 | " first = True\n", 390 | "\n", 391 | " for clip_model in clip_models:\n", 392 | " model_stat = {\"clip_model\":None,\"target_embeds\":[],\"make_cutouts\":None,\"weights\":[]}\n", 393 | " model_stat[\"clip_model\"] = clip_model\n", 394 | " model_stat[\"make_cutouts\"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs)\n", 395 | " for prompt in text_prompts:\n", 396 | " txt, weight = parse_prompt(prompt)\n", 397 | " txt = clip_model.encode_text(clip.tokenize(prompt, truncate=True).to(device)).float()\n", 398 | " \n", 399 | " # normalize and weighted average with augmented embeddings below\n", 400 | " with torch.no_grad():\n", 401 | " orig_txt = txt.clone()\n", 402 | " if first:\n", 403 | " txts = []\n", 404 | " for i in range(0, num_models):\n", 405 | " txt = orig_txt.clone()\n", 406 | " for j in range(0, model_depth):\n", 407 | " first_txt = txt.clone()\n", 408 | " (std1, mean1) = torch.std_mean(txt)\n", 409 | " txt = text_to_images[i](txt)\n", 410 | " (std2, mean2) = torch.std_mean(txt)\n", 411 | " txt = mean1+std1*((txt-mean2)/(std2))\n", 412 | " txt0 = txt.clone()\n", 413 | " (std1, mean1) = torch.std_mean(txt)\n", 414 | " txt = image_to_texts[i](txt)\n", 415 | " (std2, mean2) = torch.std_mean(txt)\n", 416 | " txt = mean1+std1*((txt-mean2)/(std2))\n", 417 | " txt = 0.25*orig_txt+0.25*first_txt+0.25*txt+0.25*txt0\n", 418 | " if (random.randint(0, model_avg_freq) == 0 and len(txts) > 0):\n", 419 | " txt = 0.5*txt + txts[-random.randint(0, j)]\n", 420 | " txts.append(txt.clone())\n", 421 | " else:\n", 422 | " txts = [orig_txt]\n", 423 | " \n", 424 | " if fuzzy_prompt:\n", 425 | " for i in range(25):\n", 426 | " for txt in txts:\n", 427 | " model_stat[\"target_embeds\"].append((txt + torch.randn(txt.shape).cuda() * rand_mag).clamp(0,1))\n", 428 | " model_stat[\"weights\"].append(weight)\n", 429 | " else:\n", 430 | " for txt in txts:\n", 431 | " model_stat[\"target_embeds\"].append(txt)\n", 432 | " model_stat[\"weights\"].append(weight)\n", 433 | " \n", 434 | " for prompt in image_prompts:\n", 435 | " path, weight = parse_prompt(prompt)\n", 436 | " img = Image.open(fetch(path)).convert('RGB')\n", 437 | " img = TF.resize(img, min(side_x, side_y, *img.size), Image.LANCZOS)\n", 438 | " batch = model_stat[\"make_cutouts\"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))\n", 439 | " embed = clip_model.encode_image(normalize(batch)).float()\n", 440 | " if fuzzy_prompt:\n", 441 | " for i in range(25):\n", 442 | " model_stat[\"target_embeds\"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))\n", 443 | " weights.extend([weight / cutn] * cutn)\n", 444 | " else:\n", 445 | " model_stat[\"target_embeds\"].append(embed)\n", 446 | " model_stat[\"weights\"].extend([weight / cutn] * cutn)\n", 447 | " \n", 448 | " model_stat[\"target_embeds\"] = torch.cat(model_stat[\"target_embeds\"])\n", 449 | " model_stat[\"weights\"] = torch.tensor(model_stat[\"weights\"], device=device)\n", 450 | " if model_stat[\"weights\"].sum().abs() < 1e-3:\n", 451 | " raise RuntimeError('The weights must not sum to 0.')\n", 452 | " model_stat[\"weights\"] /= model_stat[\"weights\"].sum().abs()\n", 453 | " model_stats.append(model_stat)\n", 454 | " first = False\n", 455 | "\n", 456 | " \n", 457 | " init = None\n", 458 | " if init_image is not None:\n", 459 | " init = Image.open(fetch(init_image)).convert('RGB')\n", 460 | " init = init.resize((side_x, side_y), Image.LANCZOS)\n", 461 | " init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)\n", 462 | " \n", 463 | " if perlin_init:\n", 464 | " if perlin_mode == 'color':\n", 465 | " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n", 466 | " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n", 467 | " elif perlin_mode == 'gray':\n", 468 | " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n", 469 | " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n", 470 | " else:\n", 471 | " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n", 472 | " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n", 473 | " \n", 474 | " init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n", 475 | " del init2\n", 476 | " \n", 477 | " cur_t = None\n", 478 | " \n", 479 | " def cond_fn(x, t, y=None):\n", 480 | " with torch.enable_grad():\n", 481 | " x = x.detach().requires_grad_()\n", 482 | " n = x.shape[0]\n", 483 | " my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t\n", 484 | " out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})\n", 485 | " fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]\n", 486 | " x_in = out['pred_xstart'] * fac + x * (1 - fac)\n", 487 | " x_in_grad = torch.zeros_like(x_in)\n", 488 | "\n", 489 | " for model_stat in model_stats:\n", 490 | " for i in range(cutn_batches):\n", 491 | " clip_in = normalize(model_stat[\"make_cutouts\"](x_in.add(1).div(2)))\n", 492 | " image_embeds = model_stat[\"clip_model\"].encode_image(clip_in).float()\n", 493 | " dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat[\"target_embeds\"].unsqueeze(0))\n", 494 | " dists = dists.view([cutn, n, -1])\n", 495 | " losses = dists.mul(model_stat[\"weights\"]).sum(2).mean(0)\n", 496 | " loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch\n", 497 | " x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches\n", 498 | " tv_losses = tv_loss(x_in)\n", 499 | " range_losses = range_loss(out['pred_xstart'])\n", 500 | " sat_losses = torch.abs(x_in - x_in.clamp(min=-1,max=1)).mean()\n", 501 | " loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale\n", 502 | " if init is not None and init_scale:\n", 503 | " init_losses = lpips_model(x_in, init)\n", 504 | " loss = loss + init_losses.sum() * init_scale\n", 505 | " x_in_grad += torch.autograd.grad(loss, x_in)[0]\n", 506 | " grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]\n", 507 | " if clamp_grad:\n", 508 | " adaptive_clip_grad([x], rand=1.0)\n", 509 | " magnitude = grad.square().mean().sqrt()\n", 510 | " return grad * magnitude.clamp(max=0.05) / magnitude\n", 511 | " return grad\n", 512 | " \n", 513 | " if model_config['timestep_respacing'].startswith('ddim'):\n", 514 | " sample_fn = diffusion.ddim_sample_loop_progressive\n", 515 | " else:\n", 516 | " sample_fn = diffusion.p_sample_loop_progressive\n", 517 | " \n", 518 | " for i in range(n_batches):\n", 519 | " cur_t = diffusion.num_timesteps - skip_timesteps - 1\n", 520 | " \n", 521 | " if model_config['timestep_respacing'].startswith('ddim'):\n", 522 | " samples = sample_fn(\n", 523 | " model,\n", 524 | " (batch_size, 3, side_y, side_x),\n", 525 | " clip_denoised=clip_denoised,\n", 526 | " model_kwargs={},\n", 527 | " cond_fn=cond_fn,\n", 528 | " progress=True,\n", 529 | " skip_timesteps=skip_timesteps,\n", 530 | " init_image=init,\n", 531 | " randomize_class=randomize_class,\n", 532 | " eta=eta,\n", 533 | " )\n", 534 | " else:\n", 535 | " samples = sample_fn(\n", 536 | " model,\n", 537 | " (batch_size, 3, side_y, side_x),\n", 538 | " clip_denoised=clip_denoised,\n", 539 | " model_kwargs={},\n", 540 | " cond_fn=cond_fn,\n", 541 | " progress=True,\n", 542 | " skip_timesteps=skip_timesteps,\n", 543 | " init_image=init,\n", 544 | " randomize_class=randomize_class,\n", 545 | " )\n", 546 | "\n", 547 | " for j, sample in enumerate(samples):\n", 548 | " cur_t -= 1\n", 549 | " if j % display_rate == 0 or cur_t == -1:\n", 550 | " display.clear_output(wait=True)\n", 551 | " for k, image in enumerate(sample['pred_xstart']):\n", 552 | " tqdm.write(f'Batch {i}, step {j}, output {k}:')\n", 553 | " current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')\n", 554 | " filename = f'progress_batch{i:05}_iteration{j:05}_output{k:05}_{current_time}.png'\n", 555 | " image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))\n", 556 | " #image.save(model_path + filename)\n", 557 | " display.display(image)\n", 558 | " if cur_t == -1:\n", 559 | " fpath = model_path + 'image_storage/' + filename\n", 560 | " image.save(fpath)\n", 561 | " final_values.append(loss_values[-1])\n", 562 | " images.append(fpath)\n", 563 | " \n", 564 | " plt.plot(np.array(loss_values), 'r')\n", 565 | " return (final_values, images)" 566 | ] 567 | }, 568 | { 569 | "cell_type": "markdown", 570 | "metadata": { 571 | "id": "CQVtY1Ixnqx4" 572 | }, 573 | "source": [ 574 | "# Load Diffusion and CLIP models" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": null, 580 | "metadata": { 581 | "id": "Fpbody2NCR7w" 582 | }, 583 | "outputs": [], 584 | "source": [ 585 | "timestep_respacing = 'ddim200' # Modify this value to change the number of timesteps.\n", 586 | "diffusion_steps = 1000\n", 587 | "\n", 588 | "model_config = model_and_diffusion_defaults()\n", 589 | "if diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n", 590 | " model_config.update({\n", 591 | " 'attention_resolutions': '32, 16, 8',\n", 592 | " 'class_cond': False,\n", 593 | " 'diffusion_steps': diffusion_steps,\n", 594 | " 'rescale_timesteps': True,\n", 595 | " 'timestep_respacing': timestep_respacing,\n", 596 | " 'image_size': 512,\n", 597 | " 'learn_sigma': True,\n", 598 | " 'noise_schedule': 'linear',\n", 599 | " 'num_channels': 256,\n", 600 | " 'num_head_channels': 64,\n", 601 | " 'num_res_blocks': 2,\n", 602 | " 'resblock_updown': True,\n", 603 | " 'use_fp16': True,\n", 604 | " 'use_scale_shift_norm': True,\n", 605 | " })\n", 606 | "elif diffusion_model == '256x256_diffusion_uncond':\n", 607 | " model_config.update({\n", 608 | " 'attention_resolutions': '32, 16, 8',\n", 609 | " 'class_cond': False,\n", 610 | " 'diffusion_steps': diffusion_steps,\n", 611 | " 'rescale_timesteps': True,\n", 612 | " 'timestep_respacing': timestep_respacing,\n", 613 | " 'image_size': 256,\n", 614 | " 'learn_sigma': True,\n", 615 | " 'noise_schedule': 'linear',\n", 616 | " 'num_channels': 256,\n", 617 | " 'num_head_channels': 64,\n", 618 | " 'num_res_blocks': 2,\n", 619 | " 'resblock_updown': True,\n", 620 | " 'use_fp16': True,\n", 621 | " 'use_scale_shift_norm': True,\n", 622 | " })\n", 623 | "side_x = side_y = model_config['image_size']\n", 624 | "\n", 625 | "model, diffusion = create_model_and_diffusion(**model_config)\n", 626 | "model.load_state_dict(torch.load(f'{model_path}{diffusion_model}.pt', map_location='cpu'))\n", 627 | "model.requires_grad_(False).eval().to(device)\n", 628 | "for name, param in model.named_parameters():\n", 629 | " if 'qkv' in name or 'norm' in name or 'proj' in name:\n", 630 | " param.requires_grad_()\n", 631 | "if model_config['use_fp16']:\n", 632 | " model.convert_to_fp16()" 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": null, 638 | "metadata": { 639 | "id": "VnQjGugaDZPJ" 640 | }, 641 | "outputs": [], 642 | "source": [ 643 | "#clip_models = [clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device),clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device),clip.load('RN50x4', jit=False)[0].eval().requires_grad_(False).to(device)]\n", 644 | "#clip_models = [clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)]\n", 645 | "clip_models = [clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device),clip.load('ViT-L/14', jit=False)[0].eval().requires_grad_(False).to(device)]\n", 646 | "\n", 647 | "\n", 648 | "normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])\n", 649 | "#lpips_model = lpips.LPIPS(net='vgg').to(device)\n", 650 | "#AlexNet model also works fine for less RAM usage\n", 651 | "lpips_model = lpips.LPIPS(net='alex').to(device)" 652 | ] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "metadata": { 657 | "id": "9zY-8I90LkC6" 658 | }, 659 | "source": [ 660 | "# Settings" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": null, 666 | "metadata": { 667 | "id": "U0PwzFZbLfcy" 668 | }, 669 | "outputs": [], 670 | "source": [ 671 | "_project_name = 'painting' #@param{type:\"string\"}\n", 672 | "_text_prompt = \"Stunning landscape painting in spring, by Claude Monet\"\n", 673 | "_text_prompt1 = \"Photorealistic proportional and well-composed portrait painting of a beautiful young lady, pretty face, by van Gogh\"\n", 674 | "_text_prompt2 = \"An aesthetically beautiful picture of a young lady at the park, trending on Instagram\"\n", 675 | "\n", 676 | "# Feel free to use multiple prompts!\n", 677 | "text_prompts = [\n", 678 | " _text_prompt,\n", 679 | " _text_prompt1,\n", 680 | " _text_prompt2,\n", 681 | "]\n", 682 | "\n", 683 | "image_prompts = [\n", 684 | " #'mona.jpg',\n", 685 | "]\n", 686 | "\n", 687 | "# 350/50/50/32 and 500/0/0/64 have worked well for 25 timesteps on 256px\n", 688 | "# Also, sometimes 1 cutn actually works out fine\n", 689 | "\n", 690 | "clip_guidance_scale = 5000 # 1000 - Controls how much the image should look like the prompt.\n", 691 | "tv_scale = 150 # 150 - Controls the smoothness of the final output.\n", 692 | "range_scale = 150 # 150 - Controls how far out of range RGB values are allowed to be.\n", 693 | "sat_scale = 0 # 0 - Controls how much saturation is allowed. From nshepperd's JAX notebook.\n", 694 | "cutn = 4 # 16 - Controls how many crops to take from the image.\n", 695 | "cutn_batches = 1 # 2 - Accumulate CLIP gradient from multiple batches of cuts [Can help with OOM errors / Low VRAM]\n", 696 | "\n", 697 | "#init_image = 'mona_256.png' # None - URL or local path\n", 698 | "init_image = None\n", 699 | "init_scale = 0 # 0 - This enhances the effect of the init image, a good value is 1000\n", 700 | "skip_timesteps = 0 # 0 - Controls the starting point along the diffusion timesteps\n", 701 | "perlin_init = False # False - Option to start with random perlin noise\n", 702 | "perlin_mode = 'gray' # 'mixed' ('gray', 'color')\n", 703 | "\n", 704 | "skip_augs = False # False - Controls whether to skip torchvision augmentations\n", 705 | "randomize_class = True # True - Controls whether the imagenet class is randomly changed each iteration\n", 706 | "clip_denoised = False # False - Determines whether CLIP discriminates a noisy or denoised image\n", 707 | "clamp_grad = True # True - Experimental: Using adaptive clip grad in the cond_fn\n", 708 | "\n", 709 | "fuzzy_prompt = False # False - Controls whether to add multiple noisy prompts to the prompt losses\n", 710 | "rand_mag = 1.0 # 0.1 - Controls the magnitude of the random noise\n", 711 | "eta = 1.0 # 0.0 - DDIM hyperparameter\n", 712 | "\n", 713 | "seed = random.randint(0, 2**32) # Choose a random seed and print it at end of run for reproduction\n", 714 | "#seed = 2547079922\n", 715 | "if seed is not None:\n", 716 | " np.random.seed(seed)\n", 717 | " random.seed(seed)\n", 718 | " torch.manual_seed(seed)\n", 719 | " torch.cuda.manual_seed_all(seed)\n", 720 | " torch.backends.cudnn.deterministic = True\n" 721 | ] 722 | }, 723 | { 724 | "cell_type": "markdown", 725 | "metadata": { 726 | "id": "Nf9hTc8YLoLx" 727 | }, 728 | "source": [ 729 | "## Diffuse!" 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "execution_count": null, 735 | "metadata": { 736 | "id": "4sjiHy2ygOJX", 737 | "scrolled": false 738 | }, 739 | "outputs": [], 740 | "source": [ 741 | "display_rate = 20\n", 742 | "n_batches = 4 # 1 - Controls how many consecutive batches of images are generated\n", 743 | "batch_size = 1 # 1 - Controls how many images are generated in parallel in a batch\n", 744 | "gc.collect()\n", 745 | "torch.cuda.empty_cache()\n", 746 | "best_losses = []\n", 747 | "best_images = []\n", 748 | "best_images_log = []\n", 749 | "ls_run = []\n", 750 | "im_run = []\n", 751 | "model_used_log = []\n", 752 | "aug_model_path = 'checkpoints/'\n", 753 | "aug_models = [\"_media\",\"_sampo\",\"_nothing2\",\"s1\",\"s2\",\"s3\"]\n", 754 | "num_models = 6\n", 755 | "model_depth = 50\n", 756 | "model_avg_freq = 5\n", 757 | "start = True\n", 758 | "try:\n", 759 | " while True:\n", 760 | " if start or random.randint(0, len(aug_models)-1) == 0:\n", 761 | " text_to_images = []\n", 762 | " image_to_texts = []\n", 763 | " pick_models = np.random.choice(aug_models, size=[num_models], replace=False)\n", 764 | " print(pick_models)\n", 765 | " model_used_log.append(pick_models)\n", 766 | " for pm in pick_models:\n", 767 | " text_to_image = torch.load(aug_model_path + \"t2i\" + pm + \".pt\")\n", 768 | " text_to_image.requires_grad_(False).eval().to(device)\n", 769 | " image_to_text = torch.load(aug_model_path + \"i2t\" + pm + \".pt\")\n", 770 | " image_to_text.requires_grad_(False).eval().to(device)\n", 771 | " text_to_images.append(text_to_image)\n", 772 | " image_to_texts.append(image_to_text)\n", 773 | " start = False\n", 774 | " (ls_run, im_run) = do_run()\n", 775 | " best_losses = best_losses + ls_run\n", 776 | " best_images = best_images + im_run\n", 777 | " init_scale = 0\n", 778 | " skip_timesteps = random.randint(40, 100)\n", 779 | " seed = random.randint(0, 2**32)\n", 780 | " best_indexes = np.argpartition(np.array(best_losses), kth=2)[0:3]\n", 781 | " random.shuffle(best_indexes)\n", 782 | " best_losses = list(np.array(best_losses)[best_indexes][0:3])\n", 783 | " best_images = list(np.array(best_images)[best_indexes][0:3])\n", 784 | " init_image = best_images[random.randint(0, 2)]\n", 785 | " best_images_log.append(init_image)\n", 786 | "except KeyboardInterrupt:\n", 787 | " pass\n", 788 | "finally:\n", 789 | " print('seed', seed)\n", 790 | " gc.collect()\n", 791 | " torch.cuda.empty_cache()" 792 | ] 793 | } 794 | ], 795 | "metadata": { 796 | "accelerator": "GPU", 797 | "colab": { 798 | "collapsed_sections": [], 799 | "machine_shape": "hm", 800 | "name": "Multi-Perceptor CLIP Guided Diffusion HQ 256x256 and 512x512.ipynb", 801 | "private_outputs": true, 802 | "provenance": [] 803 | }, 804 | "kernelspec": { 805 | "display_name": "Python 3 (ipykernel)", 806 | "language": "python", 807 | "name": "python3" 808 | }, 809 | "language_info": { 810 | "codemirror_mode": { 811 | "name": "ipython", 812 | "version": 3 813 | }, 814 | "file_extension": ".py", 815 | "mimetype": "text/x-python", 816 | "name": "python", 817 | "nbconvert_exporter": "python", 818 | "pygments_lexer": "ipython3", 819 | "version": "3.9.7" 820 | } 821 | }, 822 | "nbformat": 4, 823 | "nbformat_minor": 1 824 | } 825 | --------------------------------------------------------------------------------