├── .gitignore ├── LICENSE ├── README.md ├── augs.py ├── ddim_sampler.py ├── diffuse.py ├── old_diffusion_notebook.ipynb ├── perlin.py └── secondary_diffusion.py /.gitignore: -------------------------------------------------------------------------------- 1 | # cloned repos 2 | ResizeRight/ 3 | guided-diffusion/ 4 | latent-diffusion/ 5 | out_diffusion/ 6 | taming-transformers/ 7 | CLIP/ 8 | 9 | # files 10 | *.jpg 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | Pipfile.lock 104 | Pipfile 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Katherine Crowson 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in 11 | # all copies or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | # THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Gen 2 | 3 | An adapation of DiscoDiffusion (https://colab.research.google.com/drive/1sHfRn5Y0YKYKi1k-ifUSBFRNJ8_1sa39#scrollTo=BGBzhk3dpcGO) to run locally, to improve code quality and to speed it up. So far the code was just cleaned up a bit and the lpips network initialization was removed when only an input text is used. 4 | 5 | Around 11GB GPU VRAM are needed for the current default settings of `--width` 1280 and `--height` 768. Decreasing the image size is the easiest way to make it fit in smaler GPUs. 6 | With defaults settings it takes 07:46 minutes on an RTX 2080TI, 19:01 minutes on a GTX 1080 TI, and 17:01 minutes on a Titan XP to generate images like these: 7 | 8 | 9 | **The meaning of life** 10 | ![meaning(0)_0](https://user-images.githubusercontent.com/19983153/150617587-0b1396bd-339f-4867-8a4a-c15bb75fd71a.png) 11 | 12 | 13 | **The meaning of life by Picasso** 14 | ![meaning(2)_0](https://user-images.githubusercontent.com/19983153/150617599-4ceb2896-9aa1-4497-b7ad-80c488f68938.png) 15 | 16 | 17 | **The meaning of life by Greg Rutkowski** 18 | ![meaning_rutkowski](https://user-images.githubusercontent.com/19983153/150616859-0630e090-d737-4ced-9893-4a2c9937a949.png) 19 | 20 | **Consciousness** 21 | ![out_image(0)_0](https://user-images.githubusercontent.com/19983153/150617545-1048b160-084c-4854-adc3-6afb13731fdf.png) 22 | 23 | *forgot the prompt but it was about pikachu staring at a tumultous sea of blood, adapted from the DiscoDiffusion original notebook* 24 | ![Screenshot from 2022-01-21 15-35-09](https://user-images.githubusercontent.com/19983153/150616643-54436dbc-1e38-4127-b0dd-f0097470ae0f.png) 25 | 26 | 27 | ## Setup 28 | If you're using Windows, please also refer to the section below called `Setup for Windows`! 29 | 30 | First run `ipython3 diffuse.py` to set everything up and to clone the repositories. IMPORTANT: you need to use ipython instead of python because I was lazy and all git clone etc are run via ipython 31 | 32 | At the moment you can only set a single text as a target but this should be improved in the future. Only runs with GPU support atm. 33 | 34 | Use it like this: 35 | 36 | ``` 37 | python3 diffuse.py --text "The meaning of life --gpu [Optional: device number of GPU to run this on] --root_path [Optional: path to output folder, default is "out_diffusion" in local dir] 38 | ``` 39 | If you only have 8 GB VRAM on your GPU, the highest resolution you can use run is 832x512, or 896x448. Set it by adding `--width 832 --height 512` for example. Thanks @Jotunblood for testing! 40 | 41 | you can also set: `--out_name [Optional: set naming in your root_path according to this for better overview]` and `--sharpen_preset [Optional: set it to any of ('Off', 'Faster', 'Fast', 'Slow', 'Very Slow') to modify the sharpening process at the end. Default: Off]` 42 | 43 | 44 | ## Setup for Windows 45 | See https://github.com/NotNANtoN/diffusion_gen/issues/1, the instructions from @JotunBlood are adopted here. 46 | 47 | Instructions: 48 | - Install Anaconda 49 | - Create and activate a new environment (don't use base) 50 | - Install pytorch via their web code, using pip (not conda) 51 | - Install iPython 52 | - Add the forge channel to anaconda 53 | - conda config --add channels conda-forge 54 | - Install dependency packages using conda (for those available), otherwise use pip. Packages of relevance: OpenCV, pandas, timm, lpips, requests, pytorch-lightning, and omegaconf. There might be one or two others. 55 | - Run ipython diffuse.py 56 | - If it goes all the way, congrats. If you hit the SSL errors, open diffuse.py and add the following lines to the top of diffuse.py to the top (I did it around line 7.): 57 | ``` 58 | import ssl 59 | ssl._create_default_https_context = ssl._create_unverified_context 60 | ``` 61 | - If you get Frame Prompt: [''] and a failed output, make sure you're using python3 to run diffuse.py and not iPython :) 62 | - If you get a CUDA out of memory warning, pass a lower res like --width 720 --height 480 when you run 63 | 64 | ## Tutorial (copypasta from old colab notebook) 65 | 66 | ### **Diffusion settings** 67 | --- 68 | 69 | This section is outdated as of v2 70 | 71 | Setting | Description | Default 72 | --- | --- | --- 73 | **Your vision:** 74 | `text_prompts` | A description of what you'd like the machine to generate. Think of it like writing the caption below your image on a website. | N/A 75 | `image_prompts` | Think of these images more as a description of their contents. | N/A 76 | **Image quality:** 77 | `clip_guidance_scale` | Controls how much the image should look like the prompt. | 1000 78 | `tv_scale` | Controls the smoothness of the final output. | 150 79 | `range_scale` | Controls how far out of range RGB values are allowed to be. | 150 80 | `sat_scale` | Controls how much saturation is allowed. From nshepperd's JAX notebook. | 0 81 | `cutn` | Controls how many crops to take from the image. | 16 82 | `cutn_batches` | Accumulate CLIP gradient from multiple batches of cuts | 2 83 | **Init settings:** 84 | `init_image` | URL or local path | None 85 | `init_scale` | This enhances the effect of the init image, a good value is 1000 | 0 86 | `skip_steps Controls the starting point along the diffusion timesteps | 0 87 | `perlin_init` | Option to start with random perlin noise | False 88 | `perlin_mode` | ('gray', 'color') | 'mixed' 89 | **Advanced:** 90 | `skip_augs` |Controls whether to skip torchvision augmentations | False 91 | `randomize_class` |Controls whether the imagenet class is randomly changed each iteration | True 92 | `clip_denoised` |Determines whether CLIP discriminates a noisy or denoised image | False 93 | `clamp_grad` |Experimental: Using adaptive clip grad in the cond_fn | True 94 | `seed` | Choose a random seed and print it at end of run for reproduction | random_seed 95 | `fuzzy_prompt` | Controls whether to add multiple noisy prompts to the prompt losses | False 96 | `rand_mag` |Controls the magnitude of the random noise | 0.1 97 | `eta` | DDIM hyperparameter | 0.5 98 | -------------------------------------------------------------------------------- /augs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch.nn as nn 3 | import torch 4 | 5 | import torchvision.transforms as T 6 | import torchvision.transforms.functional as TF 7 | from torch.nn import functional as F 8 | 9 | sys.path.append("ResizeRight") 10 | from resize_right import resize 11 | 12 | 13 | class MakeCutouts(nn.Module): 14 | def __init__(self, cut_size, cutn, skip_augs=False): 15 | super().__init__() 16 | self.cut_size = cut_size 17 | self.cutn = cutn 18 | self.skip_augs = skip_augs 19 | self.augs = T.Compose([ 20 | T.RandomHorizontalFlip(p=0.5), 21 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 22 | T.RandomAffine(degrees=15, translate=(0.1, 0.1)), 23 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 24 | T.RandomPerspective(distortion_scale=0.4, p=0.7), 25 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 26 | T.RandomGrayscale(p=0.15), 27 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 28 | # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 29 | ]) 30 | 31 | def forward(self, input): 32 | input = T.Pad(input.shape[2]//4, fill=0)(input) 33 | sideY, sideX = input.shape[2:4] 34 | max_size = min(sideX, sideY) 35 | 36 | cutouts = [] 37 | for ch in range(self.cutn): 38 | if ch > self.cutn - self.cutn//4: 39 | cutout = input.clone() 40 | else: 41 | size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.)) 42 | offsetx = torch.randint(0, abs(sideX - size + 1), ()) 43 | offsety = torch.randint(0, abs(sideY - size + 1), ()) 44 | cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] 45 | 46 | if not self.skip_augs: 47 | cutout = self.augs(cutout) 48 | cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) 49 | del cutout 50 | 51 | cutouts = torch.cat(cutouts, dim=0) 52 | return cutouts 53 | 54 | 55 | class MakeCutoutsDango(nn.Module): 56 | def __init__(self, cut_size, 57 | Overview=4, 58 | InnerCrop = 0, 59 | IC_Size_Pow=0.5, 60 | IC_Grey_P = 0.2, 61 | animation_mode='None', 62 | ): 63 | super().__init__() 64 | self.cut_size = cut_size 65 | self.Overview = Overview 66 | self.InnerCrop = InnerCrop 67 | self.IC_Size_Pow = IC_Size_Pow 68 | self.IC_Grey_P = IC_Grey_P 69 | if animation_mode == 'None': 70 | self.augs = T.Compose([ 71 | T.RandomHorizontalFlip(p=0.5), 72 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 73 | T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR), 74 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 75 | T.RandomGrayscale(p=0.1), 76 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 77 | T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 78 | ]) 79 | elif animation_mode == 'Video Input': 80 | self.augs = T.Compose([ 81 | T.RandomHorizontalFlip(p=0.5), 82 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 83 | T.RandomAffine(degrees=15, translate=(0.1, 0.1)), 84 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 85 | T.RandomPerspective(distortion_scale=0.4, p=0.7), 86 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 87 | T.RandomGrayscale(p=0.15), 88 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 89 | # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), 90 | ]) 91 | elif animation_mode == '2D': 92 | self.augs = T.Compose([ 93 | T.RandomHorizontalFlip(p=0.4), 94 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 95 | T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR), 96 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 97 | T.RandomGrayscale(p=0.1), 98 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), 99 | T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3), 100 | ]) 101 | 102 | 103 | def forward(self, input): 104 | cutouts = [] 105 | gray = T.Grayscale(3) 106 | sideY, sideX = input.shape[2:4] 107 | max_size = min(sideX, sideY) 108 | min_size = min(sideX, sideY, self.cut_size) 109 | l_size = max(sideX, sideY) 110 | output_shape = [1,3,self.cut_size,self.cut_size] 111 | output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2] 112 | pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2)) 113 | cutout = resize(pad_input, out_shape=output_shape) 114 | 115 | if self.Overview>0: 116 | if self.Overview <= 4: 117 | if self.Overview >= 1: 118 | cutouts.append(cutout) 119 | if self.Overview >= 2: 120 | cutouts.append(gray(cutout)) 121 | if self.Overview >= 3: 122 | cutouts.append(TF.hflip(cutout)) 123 | if self.Overview == 4: 124 | cutouts.append(gray(TF.hflip(cutout))) 125 | else: 126 | cutout = resize(pad_input, out_shape=output_shape) 127 | for _ in range(self.Overview): 128 | cutouts.append(cutout) 129 | 130 | #if cutout_debug: 131 | # TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save("/content/cutout_overview0.jpg",quality=99) 132 | 133 | if self.InnerCrop >0: 134 | for i in range(self.InnerCrop): 135 | size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size) 136 | offsetx = torch.randint(0, sideX - size + 1, ()) 137 | offsety = torch.randint(0, sideY - size + 1, ()) 138 | cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] 139 | if i <= int(self.IC_Grey_P * self.InnerCrop): 140 | cutout = gray(cutout) 141 | cutout = resize(cutout, out_shape=output_shape) 142 | cutouts.append(cutout) 143 | #if cutout_debug: 144 | # TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save("/content/cutout_InnerCrop.jpg",quality=99) 145 | cutouts = torch.cat(cutouts) 146 | #if skip_augs is not True: cutouts=self.augs(cutouts) 147 | return cutouts 148 | 149 | 150 | def sinc(x): 151 | return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) 152 | 153 | 154 | def lanczos(x, a): 155 | cond = torch.logical_and(-a < x, x < a) 156 | out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([])) 157 | return out / out.sum() 158 | 159 | 160 | def ramp(ratio, width): 161 | n = math.ceil(width / ratio + 1) 162 | out = torch.empty([n]) 163 | cur = 0 164 | for i in range(out.shape[0]): 165 | out[i] = cur 166 | cur += ratio 167 | return torch.cat([-out[1:].flip([0]), out])[1:-1] 168 | 169 | 170 | def resample(input, size, align_corners=True): 171 | n, c, h, w = input.shape 172 | dh, dw = size 173 | 174 | input = input.reshape([n * c, 1, h, w]) 175 | 176 | if dh < h: 177 | kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) 178 | pad_h = (kernel_h.shape[0] - 1) // 2 179 | input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect') 180 | input = F.conv2d(input, kernel_h[None, None, :, None]) 181 | 182 | if dw < w: 183 | kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) 184 | pad_w = (kernel_w.shape[0] - 1) // 2 185 | input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect') 186 | input = F.conv2d(input, kernel_w[None, None, None, :]) 187 | 188 | input = input.reshape([n, c, h, w]) 189 | return F.interpolate(input, size, mode='bicubic', align_corners=align_corners) 190 | -------------------------------------------------------------------------------- /ddim_sampler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | sys.path.append("latent-diffusion") 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | 10 | 11 | 12 | class DDIMSampler: 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 27 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 28 | alphas_cumprod = self.model.alphas_cumprod 29 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 30 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 31 | 32 | self.register_buffer('betas', to_torch(self.model.betas)) 33 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 34 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 35 | 36 | # calculations for diffusion q(x_t | x_{t-1}) and others 37 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 38 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 39 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 42 | 43 | # ddim sampling parameters 44 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 45 | ddim_timesteps=self.ddim_timesteps, 46 | eta=ddim_eta,verbose=verbose) 47 | self.register_buffer('ddim_sigmas', ddim_sigmas) 48 | self.register_buffer('ddim_alphas', ddim_alphas) 49 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 50 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 51 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 52 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 53 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 54 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 55 | 56 | @torch.no_grad() 57 | def sample(self, 58 | S, 59 | batch_size, 60 | shape, 61 | conditioning=None, 62 | callback=None, 63 | normals_sequence=None, 64 | img_callback=None, 65 | quantize_x0=False, 66 | eta=0., 67 | mask=None, 68 | x0=None, 69 | temperature=1., 70 | noise_dropout=0., 71 | score_corrector=None, 72 | corrector_kwargs=None, 73 | verbose=True, 74 | x_T=None, 75 | log_every_t=100, 76 | **kwargs 77 | ): 78 | if conditioning is not None: 79 | if isinstance(conditioning, dict): 80 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 81 | if cbs != batch_size: 82 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 83 | else: 84 | if conditioning.shape[0] != batch_size: 85 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 86 | 87 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 88 | # sampling 89 | C, H, W = shape 90 | size = (batch_size, C, H, W) 91 | # print(f'Data shape for DDIM sampling is {size}, eta {eta}') 92 | 93 | samples, intermediates = self.ddim_sampling(conditioning, size, 94 | callback=callback, 95 | img_callback=img_callback, 96 | quantize_denoised=quantize_x0, 97 | mask=mask, x0=x0, 98 | ddim_use_original_steps=False, 99 | noise_dropout=noise_dropout, 100 | temperature=temperature, 101 | score_corrector=score_corrector, 102 | corrector_kwargs=corrector_kwargs, 103 | x_T=x_T, 104 | log_every_t=log_every_t 105 | ) 106 | return samples, intermediates 107 | 108 | @torch.no_grad() 109 | def ddim_sampling(self, cond, shape, 110 | x_T=None, ddim_use_original_steps=False, 111 | callback=None, timesteps=None, quantize_denoised=False, 112 | mask=None, x0=None, img_callback=None, log_every_t=100, 113 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): 114 | device = self.model.betas.device 115 | b = shape[0] 116 | if x_T is None: 117 | img = torch.randn(shape, device=device) 118 | else: 119 | img = x_T 120 | 121 | if timesteps is None: 122 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 123 | elif timesteps is not None and not ddim_use_original_steps: 124 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 125 | timesteps = self.ddim_timesteps[:subset_end] 126 | 127 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 128 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 129 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 130 | print(f"Running DDIM Sharpening with {total_steps} timesteps") 131 | 132 | iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps) 133 | 134 | for i, step in enumerate(iterator): 135 | index = total_steps - i - 1 136 | ts = torch.full((b,), step, device=device, dtype=torch.long) 137 | 138 | if mask is not None: 139 | assert x0 is not None 140 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 141 | img = img_orig * mask + (1. - mask) * img 142 | 143 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 144 | quantize_denoised=quantize_denoised, temperature=temperature, 145 | noise_dropout=noise_dropout, score_corrector=score_corrector, 146 | corrector_kwargs=corrector_kwargs) 147 | img, pred_x0 = outs 148 | if callback: callback(i) 149 | if img_callback: img_callback(pred_x0, i) 150 | 151 | if index % log_every_t == 0 or index == total_steps - 1: 152 | intermediates['x_inter'].append(img) 153 | intermediates['pred_x0'].append(pred_x0) 154 | 155 | return img, intermediates 156 | 157 | @torch.no_grad() 158 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 159 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): 160 | b, *_, device = *x.shape, x.device 161 | e_t = self.model.apply_model(x, t, c) 162 | if score_corrector is not None: 163 | assert self.model.parameterization == "eps" 164 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 165 | 166 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 167 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 168 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 169 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 170 | # select parameters corresponding to the currently considered timestep 171 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 172 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 173 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 174 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 175 | 176 | # current prediction for x_0 177 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 178 | if quantize_denoised: 179 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 180 | # direction pointing to x_t 181 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 182 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 183 | if noise_dropout > 0.: 184 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 185 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 186 | return x_prev, pred_x0 187 | -------------------------------------------------------------------------------- /diffuse.py: -------------------------------------------------------------------------------- 1 | # # 1. Set Up 2 | import os 3 | from os import path 4 | import sys 5 | from argparse import ArgumentParser, Namespace 6 | 7 | def run_from_ipython(): 8 | try: 9 | __IPYTHON__ 10 | return True 11 | except NameError: 12 | return False 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument("--gpu", default="0", type=str) 16 | parser.add_argument("--text", default="", type=str) 17 | parser.add_argument("--root_path", default="out_diffusion") 18 | parser.add_argument("--setup", default=False, type=bool) 19 | parser.add_argument("--out_name", default="out_image", type=str) 20 | parser.add_argument("--sharpen_preset", default="Off", type=str, choices=['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']) 21 | parser.add_argument("--width", default=1280, type=int) 22 | parser.add_argument("--height", default=768, type=int) 23 | parser.add_argument("--init_image", default=None, type=str) 24 | parser.add_argument("--steps", default=250, type=int) 25 | parser.add_argument("--skip_steps", default=None, type=int) 26 | parser.add_argument("--inter_saves", default=3, type=int) 27 | 28 | 29 | if run_from_ipython(): 30 | argparse_args = parser.parse_args({}) 31 | argparse_args.setup = 1 32 | else: 33 | argparse_args = parser.parse_args() 34 | # read args to determine GPU before importing torch and befor importing other files (torch is also imported in other files) 35 | os.environ["CUDA_VISIBLE_DEVICES"] = argparse_args.gpu 36 | 37 | if argparse_args.setup: 38 | get_ipython().system('python3 -m pip install torch lpips datetime timm ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb') 39 | get_ipython().system('git clone https://github.com/CompVis/latent-diffusion.git') 40 | get_ipython().system('git clone https://github.com/openai/CLIP') 41 | get_ipython().system('pip3 install -e ./CLIP') 42 | get_ipython().system('git clone https://github.com/assafshocher/ResizeRight.git') 43 | get_ipython().system('git clone https://github.com/crowsonkb/guided-diffusion') 44 | get_ipython().system('python3 -m pip install -e ./guided-diffusion') 45 | 46 | get_ipython().system('apt install imagemagick') 47 | 48 | #SuperRes 49 | get_ipython().system('git clone https://github.com/CompVis/latent-diffusion.git') 50 | get_ipython().system('git clone https://github.com/CompVis/taming-transformers') 51 | get_ipython().system('pip install -e ./taming-transformers') 52 | 53 | 54 | 55 | 56 | # sys.path.append('./SLIP') 57 | sys.path.append('./ResizeRight') 58 | sys.path.append("latent-diffusion") 59 | 60 | from secondary_diffusion import SecondaryDiffusionImageNet, SecondaryDiffusionImageNet2 61 | from ddim_sampler import DDIMSampler 62 | from augs import MakeCutouts, MakeCutoutsDango 63 | from perlin import create_perlin_noise, regen_perlin 64 | 65 | 66 | def alpha_sigma_to_t(alpha, sigma): 67 | return torch.atan2(sigma, alpha) * 2 / math.pi 68 | 69 | 70 | 71 | root_path = argparse_args.root_path 72 | initDirPath = f'{root_path}/init_images' 73 | os.makedirs(initDirPath, exist_ok=1) 74 | outDirPath = f'{root_path}/images_out' 75 | os.makedirs(outDirPath, exist_ok=1) 76 | model_path = f'{root_path}/models' 77 | os.makedirs(model_path, exist_ok=1) 78 | 79 | 80 | 81 | #@title ### 2.1 Install and import dependencies 82 | model_256_downloaded = False 83 | model_512_downloaded = False 84 | model_secondary_downloaded = False 85 | 86 | from functools import partial 87 | import cv2 88 | import pandas as pd 89 | import gc 90 | import io 91 | import math 92 | import timm 93 | from IPython import display 94 | import lpips 95 | from PIL import Image, ImageOps 96 | import requests 97 | from glob import glob 98 | import json 99 | from types import SimpleNamespace 100 | import torch 101 | from torch import nn 102 | from torch.nn import functional as F 103 | import torchvision.transforms as T 104 | import torchvision.transforms.functional as TF 105 | from tqdm.notebook import tqdm 106 | #sys.path.append('./CLIP') 107 | sys.path.append('./guided-diffusion') 108 | import clip 109 | from resize_right import resize 110 | # from models import SLIP_VITB16, SLIP, SLIP_VITL16 111 | from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults 112 | from datetime import datetime 113 | import numpy as np 114 | import matplotlib.pyplot as plt 115 | import random 116 | from ipywidgets import Output 117 | import hashlib 118 | 119 | 120 | 121 | 122 | #SuperRes 123 | import ipywidgets as widgets 124 | import os 125 | sys.path.append(".") 126 | sys.path.append('./taming-transformers') 127 | from taming.models import vqgan # checking correct import from taming 128 | from torchvision.datasets.utils import download_url 129 | 130 | sys.path.append("./latent-diffusion") 131 | if argparse_args.setup: 132 | get_ipython().run_line_magic('cd', "latent-diffusion") 133 | from ldm.util import instantiate_from_config 134 | # from ldm.models.diffusion.ddim import DDIMSampler 135 | from ldm.util import ismap 136 | if argparse_args.setup: 137 | get_ipython().run_line_magic('cd', '..') 138 | #from google.colab import files 139 | from IPython.display import Image as ipyimg 140 | from numpy import asarray 141 | from einops import rearrange, repeat 142 | import torch, torchvision 143 | import time 144 | from omegaconf import OmegaConf 145 | import warnings 146 | warnings.filterwarnings("ignore", category=UserWarning) 147 | 148 | 149 | import torch 150 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 151 | print('Using device:', device) 152 | 153 | if torch.cuda.get_device_capability(device) == (8,0): ## A100 fix thanks to Emad 154 | print('Disabling CUDNN for A100 gpu', file=sys.stderr) 155 | torch.backends.cudnn.enabled = False 156 | 157 | 158 | #@title 2.2 Define necessary functions 159 | 160 | def fetch(url_or_path): 161 | if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): 162 | r = requests.get(url_or_path) 163 | r.raise_for_status() 164 | fd = io.BytesIO() 165 | fd.write(r.content) 166 | fd.seek(0) 167 | return fd 168 | return open(url_or_path, 'rb') 169 | 170 | 171 | def parse_prompt(prompt): 172 | if prompt.startswith('http://') or prompt.startswith('https://'): 173 | vals = prompt.rsplit(':', 2) 174 | vals = [vals[0] + ':' + vals[1], *vals[2:]] 175 | else: 176 | vals = prompt.rsplit(':', 1) 177 | vals = vals + ['', '1'][len(vals):] 178 | return vals[0], float(vals[1]) 179 | 180 | 181 | def spherical_dist_loss(x, y): 182 | x = F.normalize(x, dim=-1) 183 | y = F.normalize(y, dim=-1) 184 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) 185 | 186 | def tv_loss(input): 187 | """L2 total variation loss, as in Mahendran et al.""" 188 | input = F.pad(input, (0, 1, 0, 1), 'replicate') 189 | x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] 190 | y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] 191 | return (x_diff**2 + y_diff**2).mean([1, 2, 3]) 192 | 193 | 194 | def range_loss(input): 195 | return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) 196 | 197 | stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete 198 | 199 | def do_run(): 200 | seed = args.seed 201 | print(range(args.start_frame, args.max_frames)) 202 | for frame_num in range(args.start_frame, args.max_frames): 203 | if stop_on_next_loop: 204 | break 205 | 206 | display.clear_output(wait=True) 207 | 208 | # Print Frame progress if animation mode is on 209 | if args.animation_mode != "None": 210 | batchBar = tqdm(range(args.max_frames), desc ="Frames") 211 | batchBar.n = frame_num 212 | batchBar.refresh() 213 | 214 | 215 | # Inits if not video frames 216 | if args.animation_mode != "Video Input": 217 | if args.init_image == '': 218 | init_image = None 219 | else: 220 | init_image = args.init_image 221 | init_scale = args.init_scale 222 | skip_steps = args.skip_steps 223 | 224 | if args.animation_mode == "2D": 225 | if args.key_frames: 226 | angle = args.angle_series[frame_num] 227 | zoom = args.zoom_series[frame_num] 228 | translation_x = args.translation_x_series[frame_num] 229 | translation_y = args.translation_y_series[frame_num] 230 | print( 231 | f'angle: {angle}', 232 | f'zoom: {zoom}', 233 | f'translation_x: {translation_x}', 234 | f'translation_y: {translation_y}', 235 | ) 236 | 237 | if frame_num > 0: 238 | seed = seed + 1 239 | if resume_run and frame_num == start_frame: 240 | img_0 = cv2.imread(batchFolder+f"/{batch_name}({batchNum})_{start_frame-1:04}.png") 241 | else: 242 | img_0 = cv2.imread('prevFrame.png') 243 | center = (1*img_0.shape[1]//2, 1*img_0.shape[0]//2) 244 | trans_mat = np.float32( 245 | [[1, 0, translation_x], 246 | [0, 1, translation_y]] 247 | ) 248 | rot_mat = cv2.getRotationMatrix2D( center, angle, zoom ) 249 | trans_mat = np.vstack([trans_mat, [0,0,1]]) 250 | rot_mat = np.vstack([rot_mat, [0,0,1]]) 251 | transformation_matrix = np.matmul(rot_mat, trans_mat) 252 | img_0 = cv2.warpPerspective( 253 | img_0, 254 | transformation_matrix, 255 | (img_0.shape[1], img_0.shape[0]), 256 | borderMode=cv2.BORDER_WRAP 257 | ) 258 | cv2.imwrite('prevFrameScaled.png', img_0) 259 | init_image = 'prevFrameScaled.png' 260 | init_scale = args.frames_scale 261 | skip_steps = args.calc_frames_skip_steps 262 | 263 | if args.animation_mode == "Video Input": 264 | seed = seed + 1 265 | init_image = f'{videoFramesFolder}/{frame_num+1:04}.jpg' 266 | init_scale = args.frames_scale 267 | skip_steps = args.calc_frames_skip_steps 268 | 269 | loss_values = [] 270 | 271 | if seed is not None: 272 | np.random.seed(seed) 273 | random.seed(seed) 274 | torch.manual_seed(seed) 275 | torch.cuda.manual_seed_all(seed) 276 | torch.backends.cudnn.deterministic = True 277 | 278 | target_embeds, weights = [], [] 279 | 280 | if args.prompts_series is not None and frame_num >= len(args.prompts_series): 281 | frame_prompt = args.prompts_series[-1] 282 | elif args.prompts_series is not None: 283 | frame_prompt = args.prompts_series[frame_num] 284 | else: 285 | frame_prompt = [] 286 | 287 | print(args.image_prompts_series) 288 | if args.image_prompts_series is not None and frame_num >= len(args.image_prompts_series): 289 | image_prompt = args.image_prompts_series[-1] 290 | elif args.image_prompts_series is not None: 291 | image_prompt = args.image_prompts_series[frame_num] 292 | else: 293 | image_prompt = [] 294 | 295 | print(f'Frame Prompt: {frame_prompt}') 296 | 297 | model_stats = [] 298 | for clip_model in clip_models: 299 | cutn = 16 300 | model_stat = {"clip_model":None,"target_embeds":[],"make_cutouts":None,"weights":[]} 301 | model_stat["clip_model"] = clip_model 302 | 303 | 304 | for prompt in frame_prompt: 305 | txt, weight = parse_prompt(prompt) 306 | txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float() 307 | 308 | if args.fuzzy_prompt: 309 | for i in range(25): 310 | model_stat["target_embeds"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0,1)) 311 | model_stat["weights"].append(weight) 312 | else: 313 | model_stat["target_embeds"].append(txt) 314 | model_stat["weights"].append(weight) 315 | 316 | if image_prompt: 317 | model_stat["make_cutouts"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs) 318 | for prompt in image_prompt: 319 | path, weight = parse_prompt(prompt) 320 | img = Image.open(fetch(path)).convert('RGB') 321 | img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS) 322 | batch = model_stat["make_cutouts"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1)) 323 | embed = clip_model.encode_image(normalize(batch)).float() 324 | if fuzzy_prompt: 325 | for i in range(25): 326 | model_stat["target_embeds"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1)) 327 | weights.extend([weight / cutn] * cutn) 328 | else: 329 | model_stat["target_embeds"].append(embed) 330 | model_stat["weights"].extend([weight / cutn] * cutn) 331 | 332 | model_stat["target_embeds"] = torch.cat(model_stat["target_embeds"]) 333 | model_stat["weights"] = torch.tensor(model_stat["weights"], device=device) 334 | if model_stat["weights"].sum().abs() < 1e-3: 335 | raise RuntimeError('The weights must not sum to 0.') 336 | model_stat["weights"] /= model_stat["weights"].sum().abs() 337 | model_stats.append(model_stat) 338 | 339 | init = None 340 | if init_image is not None: 341 | init = Image.open(fetch(init_image)).convert('RGB') 342 | init = init.resize((args.side_x, args.side_y), Image.LANCZOS) 343 | init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1) 344 | 345 | if args.perlin_init: 346 | if args.perlin_mode == 'color': 347 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False) 348 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False) 349 | elif args.perlin_mode == 'gray': 350 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True) 351 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True) 352 | else: 353 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False) 354 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True) 355 | # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device) 356 | init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1) 357 | del init2 358 | 359 | cur_t = None 360 | if init is not None and args.init_scale: 361 | lpips_model = lpips.LPIPS(net='vgg').to(device) 362 | 363 | def cond_fn(x, t, y=None): 364 | with torch.enable_grad(): 365 | x_is_NaN = False 366 | x = x.detach().requires_grad_() 367 | n = x.shape[0] 368 | if use_secondary_model is True: 369 | alpha = torch.tensor(diffusion.sqrt_alphas_cumprod[cur_t], device=device, dtype=torch.float32) 370 | sigma = torch.tensor(diffusion.sqrt_one_minus_alphas_cumprod[cur_t], device=device, dtype=torch.float32) 371 | cosine_t = alpha_sigma_to_t(alpha, sigma) 372 | out = secondary_model(x, cosine_t[None].repeat([n])).pred 373 | fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t] 374 | x_in = out * fac + x * (1 - fac) 375 | x_in_grad = torch.zeros_like(x_in) 376 | else: 377 | my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t 378 | out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y}) 379 | fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t] 380 | x_in = out['pred_xstart'] * fac + x * (1 - fac) 381 | x_in_grad = torch.zeros_like(x_in) 382 | for model_stat in model_stats: 383 | for i in range(args.cutn_batches): 384 | t_int = int(t.item())+1 #errors on last step without +1, need to find source 385 | #when using SLIP Base model the dimensions need to be hard coded to avoid AttributeError: 'VisionTransformer' object has no attribute 'input_resolution' 386 | try: 387 | input_resolution=model_stat["clip_model"].visual.input_resolution 388 | except: 389 | input_resolution=224 390 | 391 | cuts = MakeCutoutsDango(input_resolution, 392 | Overview= args.cut_overview[1000-t_int], 393 | InnerCrop = args.cut_innercut[1000-t_int], 394 | IC_Size_Pow=args.cut_ic_pow, 395 | IC_Grey_P = args.cut_icgray_p[1000-t_int], 396 | animation_mode=args.animation_mode, 397 | ) 398 | clip_in = normalize(cuts(x_in.add(1).div(2))) 399 | image_embeds = model_stat["clip_model"].encode_image(clip_in).float() 400 | dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat["target_embeds"].unsqueeze(0)) 401 | dists = dists.view([args.cut_overview[1000-t_int]+args.cut_innercut[1000-t_int], n, -1]) 402 | losses = dists.mul(model_stat["weights"]).sum(2).mean(0) 403 | loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch 404 | x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches 405 | tv_losses = tv_loss(x_in) 406 | if use_secondary_model is True: 407 | range_losses = range_loss(out) 408 | else: 409 | range_losses = range_loss(out['pred_xstart']) 410 | sat_losses = torch.abs(x_in - x_in.clamp(min=-1,max=1)).mean() 411 | loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale 412 | 413 | if init is not None and args.init_scale: 414 | init_losses = lpips_model(x_in, init) 415 | loss = loss + init_losses.sum() * args.init_scale 416 | x_in_grad += torch.autograd.grad(loss, x_in)[0] 417 | if torch.isnan(x_in_grad).any() == False: 418 | grad = -torch.autograd.grad(x_in, x, x_in_grad)[0] 419 | else: 420 | # print("NaN'd") 421 | x_is_NaN = True 422 | grad = torch.zeros_like(x) 423 | if args.clamp_grad and x_is_NaN == False: 424 | magnitude = grad.square().mean().sqrt() 425 | return grad * magnitude.clamp(max=args.clamp_max) / magnitude #min=-0.02, min=-clamp_max, 426 | return grad 427 | 428 | if model_config['timestep_respacing'].startswith('ddim'): 429 | sample_fn = diffusion.ddim_sample_loop_progressive 430 | else: 431 | sample_fn = diffusion.p_sample_loop_progressive 432 | 433 | 434 | image_display = Output() 435 | for i in range(args.n_batches): 436 | if args.animation_mode == 'None': 437 | display.clear_output(wait=True) 438 | batchBar = tqdm(range(args.n_batches), desc ="Batches") 439 | batchBar.n = i 440 | batchBar.refresh() 441 | print('') 442 | display.display(image_display) 443 | gc.collect() 444 | torch.cuda.empty_cache() 445 | cur_t = diffusion.num_timesteps - skip_steps - 1 446 | total_steps = cur_t 447 | 448 | if perlin_init: 449 | init = regen_perlin() 450 | 451 | if model_config['timestep_respacing'].startswith('ddim'): 452 | samples = sample_fn( 453 | model, 454 | (batch_size, 3, args.side_y, args.side_x), 455 | clip_denoised=clip_denoised, 456 | model_kwargs={}, 457 | cond_fn=cond_fn, 458 | progress=True, 459 | skip_timesteps=skip_steps, 460 | init_image=init, 461 | randomize_class=randomize_class, 462 | eta=eta, 463 | ) 464 | else: 465 | samples = sample_fn( 466 | model, 467 | (batch_size, 3, args.side_y, args.side_x), 468 | clip_denoised=clip_denoised, 469 | model_kwargs={}, 470 | cond_fn=cond_fn, 471 | progress=True, 472 | skip_timesteps=skip_steps, 473 | init_image=init, 474 | randomize_class=randomize_class, 475 | ) 476 | 477 | 478 | # with run_display: 479 | # display.clear_output(wait=True) 480 | imgToSharpen = None 481 | for j, sample in enumerate(samples): 482 | cur_t -= 1 483 | intermediateStep = False 484 | if args.steps_per_checkpoint is not None: 485 | if j % steps_per_checkpoint == 0 and j > 0: 486 | intermediateStep = True 487 | elif j in args.intermediate_saves: 488 | intermediateStep = True 489 | with image_display: 490 | if j % args.display_rate == 0 or cur_t == -1 or intermediateStep == True: 491 | for k, image in enumerate(sample['pred_xstart']): 492 | # tqdm.write(f'Batch {i}, step {j}, output {k}:') 493 | current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f') 494 | percent = math.ceil(j/total_steps*100) 495 | if args.n_batches > 0: 496 | #if intermediates are saved to the subfolder, don't append a step or percentage to the name 497 | if cur_t == -1 and args.intermediates_in_subfolder is True: 498 | save_num = f'{frame_num:04}' if animation_mode != "None" else i 499 | filename = f'{args.batch_name}({args.batchNum})_{save_num}.png' 500 | else: 501 | #If we're working with percentages, append it 502 | if args.steps_per_checkpoint is not None: 503 | filename = f'{args.batch_name}({args.batchNum})_{i:04}-{percent:02}%.png' 504 | # Or else, iIf we're working with specific steps, append those 505 | else: 506 | filename = f'{args.batch_name}({args.batchNum})_{i:04}-{j:03}.png' 507 | image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1)) 508 | if j % args.display_rate == 0 or cur_t == -1: 509 | image.save('progress.jpg', subsampling=0, quality=95) 510 | #display.clear_output(wait=True) 511 | #display.display(display.Image('progress.png')) 512 | if args.steps_per_checkpoint is not None: 513 | if j % args.steps_per_checkpoint == 0 and j > 0: 514 | if args.intermediates_in_subfolder is True: 515 | image.save(f'{partialFolder}/{filename}') 516 | else: 517 | image.save(f'{batchFolder}/{filename}') 518 | else: 519 | if j in args.intermediate_saves: 520 | if args.intermediates_in_subfolder is True: 521 | image.save(f'{partialFolder}/{filename}') 522 | else: 523 | image.save(f'{batchFolder}/{filename}') 524 | if cur_t == -1: 525 | if frame_num == 0: 526 | save_settings() 527 | if args.animation_mode != "None": 528 | image.save('prevFrame.jpg', subsampling=0, quality=95) 529 | if args.sharpen_preset != "Off" and animation_mode == "None": 530 | imgToSharpen = image 531 | if args.keep_unsharp is True: 532 | image.save(f'{unsharpenFolder}/{filename}') 533 | else: 534 | image.save(f'{batchFolder}/{filename}') 535 | # if frame_num != args.max_frames-1: 536 | # display.clear_output() 537 | 538 | with image_display: 539 | if args.sharpen_preset != "Off" and animation_mode == "None": 540 | print('Starting Diffusion Sharpening...') 541 | do_superres(imgToSharpen, f'{batchFolder}/{filename}') 542 | display.clear_output() 543 | 544 | plt.plot(np.array(loss_values), 'r') 545 | 546 | def save_settings(): 547 | setting_list = { 548 | 'text_prompts': text_prompts, 549 | 'image_prompts': image_prompts, 550 | 'clip_guidance_scale': clip_guidance_scale, 551 | 'tv_scale': tv_scale, 552 | 'range_scale': range_scale, 553 | 'sat_scale': sat_scale, 554 | # 'cutn': cutn, 555 | 'cutn_batches': cutn_batches, 556 | 'max_frames': max_frames, 557 | 'interp_spline': interp_spline, 558 | # 'rotation_per_frame': rotation_per_frame, 559 | 'init_image': init_image, 560 | 'init_scale': init_scale, 561 | 'skip_steps': skip_steps, 562 | # 'zoom_per_frame': zoom_per_frame, 563 | 'frames_scale': frames_scale, 564 | 'frames_skip_steps': frames_skip_steps, 565 | 'perlin_init': perlin_init, 566 | 'perlin_mode': perlin_mode, 567 | 'skip_augs': skip_augs, 568 | 'randomize_class': randomize_class, 569 | 'clip_denoised': clip_denoised, 570 | 'clamp_grad': clamp_grad, 571 | 'clamp_max': clamp_max, 572 | 'seed': seed, 573 | 'fuzzy_prompt': fuzzy_prompt, 574 | 'rand_mag': rand_mag, 575 | 'eta': eta, 576 | 'width': width_height[0], 577 | 'height': width_height[1], 578 | 'diffusion_model': diffusion_model, 579 | 'use_secondary_model': use_secondary_model, 580 | 'steps': steps, 581 | 'diffusion_steps': diffusion_steps, 582 | 'ViTB32': ViTB32, 583 | 'ViTB16': ViTB16, 584 | 'RN101': RN101, 585 | 'RN50': RN50, 586 | 'RN50x4': RN50x4, 587 | 'RN50x16': RN50x16, 588 | 'cut_overview': str(cut_overview), 589 | 'cut_innercut': str(cut_innercut), 590 | 'cut_ic_pow': cut_ic_pow, 591 | 'cut_icgray_p': str(cut_icgray_p), 592 | 'key_frames': key_frames, 593 | 'max_frames': max_frames, 594 | 'angle': angle, 595 | 'zoom': zoom, 596 | 'translation_x': translation_x, 597 | 'translation_y': translation_y, 598 | 'video_init_path':video_init_path, 599 | 'extract_nth_frame':extract_nth_frame, 600 | } 601 | # print('Settings:', setting_list) 602 | with open(f"{batchFolder}/{batch_name}({batchNum})_settings.txt", "w+") as f: #save settings 603 | json.dump(setting_list, f, ensure_ascii=False, indent=4) 604 | 605 | #@title 2.4 SuperRes Define 606 | 607 | def download_models(): 608 | # this is the small bsr light model 609 | url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1' 610 | url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1' 611 | 612 | path_conf = f'{model_path}/superres/' 613 | path_ckpt = f'{model_path}/superres/' 614 | 615 | download_url(url_conf, path_conf, 'project.yaml') 616 | download_url(url_ckpt, path_ckpt, 'last.ckpt') 617 | 618 | path_conf = path_conf + 'project.yaml' # fix it 619 | path_ckpt = path_ckpt + 'last.ckpt' # fix it 620 | return path_conf, path_ckpt 621 | 622 | 623 | def load_model_from_config(config, ckpt): 624 | print(f"Loading model from {ckpt}") 625 | pl_sd = torch.load(ckpt, map_location="cpu") 626 | global_step = pl_sd["global_step"] 627 | sd = pl_sd["state_dict"] 628 | model = instantiate_from_config(config.model) 629 | m, u = model.load_state_dict(sd, strict=False) 630 | model.cuda() 631 | model.eval() 632 | return {"model": model}, global_step 633 | 634 | 635 | def get_model(mode): 636 | path_conf, path_ckpt = download_models() 637 | config = OmegaConf.load(path_conf) 638 | model, step = load_model_from_config(config, path_ckpt) 639 | return model 640 | 641 | 642 | def get_custom_cond(mode): 643 | dest = "data/example_conditioning" 644 | 645 | if mode == "superresolution": 646 | uploaded_img = files.upload() 647 | filename = next(iter(uploaded_img)) 648 | name, filetype = filename.split(".") # todo assumes just one dot in name ! 649 | os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}") 650 | 651 | elif mode == "text_conditional": 652 | w = widgets.Text(value='A cake with cream!', disabled=True) 653 | display.display(w) 654 | 655 | with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f: 656 | f.write(w.value) 657 | 658 | elif mode == "class_conditional": 659 | w = widgets.IntSlider(min=0, max=1000) 660 | display.display(w) 661 | with open(f"{dest}/{mode}/custom.txt", 'w') as f: 662 | f.write(w.value) 663 | 664 | else: 665 | raise NotImplementedError(f"cond not implemented for mode{mode}") 666 | 667 | 668 | def get_cond_options(mode): 669 | path = "data/example_conditioning" 670 | path = os.path.join(path, mode) 671 | onlyfiles = [f for f in sorted(os.listdir(path))] 672 | return path, onlyfiles 673 | 674 | 675 | def select_cond_path(mode): 676 | path = "data/example_conditioning" # todo 677 | path = os.path.join(path, mode) 678 | onlyfiles = [f for f in sorted(os.listdir(path))] 679 | 680 | selected = widgets.RadioButtons( 681 | options=onlyfiles, 682 | description='Select conditioning:', 683 | disabled=False 684 | ) 685 | display.display(selected) 686 | selected_path = os.path.join(path, selected.value) 687 | return selected_path 688 | 689 | 690 | def get_cond(mode, img): 691 | example = dict() 692 | if mode == "superresolution": 693 | up_f = 4 694 | # visualize_cond_img(selected_path) 695 | 696 | c = img 697 | c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) 698 | c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True) 699 | c_up = rearrange(c_up, '1 c h w -> 1 h w c') 700 | c = rearrange(c, '1 c h w -> 1 h w c') 701 | c = 2. * c - 1. 702 | 703 | c = c.to(torch.device("cuda")) 704 | example["LR_image"] = c 705 | example["image"] = c_up 706 | 707 | return example 708 | 709 | 710 | def visualize_cond_img(path): 711 | display.display(ipyimg(filename=path)) 712 | 713 | 714 | def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None): 715 | # global stride 716 | 717 | example = get_cond(task, img) 718 | 719 | save_intermediate_vid = False 720 | n_runs = 1 721 | masked = False 722 | guider = None 723 | ckwargs = None 724 | mode = 'ddim' 725 | ddim_use_x0_pred = False 726 | temperature = 1. 727 | eta = eta 728 | make_progrow = True 729 | custom_shape = None 730 | 731 | height, width = example["image"].shape[1:3] 732 | split_input = height >= 128 and width >= 128 733 | 734 | if split_input: 735 | ks = 128 736 | stride = 64 737 | vqf = 4 # 738 | model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), 739 | "vqf": vqf, 740 | "patch_distributed_vq": True, 741 | "tie_braker": False, 742 | "clip_max_weight": 0.5, 743 | "clip_min_weight": 0.01, 744 | "clip_max_tie_weight": 0.5, 745 | "clip_min_tie_weight": 0.01} 746 | else: 747 | if hasattr(model, "split_input_params"): 748 | delattr(model, "split_input_params") 749 | 750 | invert_mask = False 751 | 752 | x_T = None 753 | for n in range(n_runs): 754 | if custom_shape is not None: 755 | x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) 756 | x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0]) 757 | 758 | logs = make_convolutional_sample(example, model, 759 | mode=mode, custom_steps=custom_steps, 760 | eta=eta, swap_mode=False , masked=masked, 761 | invert_mask=invert_mask, quantize_x0=False, 762 | custom_schedule=None, decode_interval=10, 763 | resize_enabled=resize_enabled, custom_shape=custom_shape, 764 | temperature=temperature, noise_dropout=0., 765 | corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid, 766 | make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred 767 | ) 768 | return logs 769 | 770 | 771 | @torch.no_grad() 772 | def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, 773 | mask=None, x0=None, quantize_x0=False, img_callback=None, 774 | temperature=1., noise_dropout=0., score_corrector=None, 775 | corrector_kwargs=None, x_T=None, log_every_t=None 776 | ): 777 | 778 | ddim = DDIMSampler(model) 779 | bs = shape[0] # dont know where this comes from but wayne 780 | shape = shape[1:] # cut batch dim 781 | # print(f"Sampling with eta = {eta}; steps: {steps}") 782 | samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, 783 | normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, 784 | mask=mask, x0=x0, temperature=temperature, verbose=False, 785 | score_corrector=score_corrector, 786 | corrector_kwargs=corrector_kwargs, x_T=x_T) 787 | 788 | return samples, intermediates 789 | 790 | 791 | @torch.no_grad() 792 | def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False, 793 | invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000, 794 | resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, 795 | corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False): 796 | log = dict() 797 | 798 | z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, 799 | return_first_stage_outputs=True, 800 | force_c_encode=not (hasattr(model, 'split_input_params') 801 | and model.cond_stage_key == 'coordinates_bbox'), 802 | return_original_cond=True) 803 | 804 | log_every_t = 1 if save_intermediate_vid else None 805 | 806 | if custom_shape is not None: 807 | z = torch.randn(custom_shape) 808 | # print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") 809 | 810 | z0 = None 811 | 812 | log["input"] = x 813 | log["reconstruction"] = xrec 814 | 815 | if ismap(xc): 816 | log["original_conditioning"] = model.to_rgb(xc) 817 | if hasattr(model, 'cond_stage_key'): 818 | log[model.cond_stage_key] = model.to_rgb(xc) 819 | 820 | else: 821 | log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) 822 | if model.cond_stage_model: 823 | log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) 824 | if model.cond_stage_key =='class_label': 825 | log[model.cond_stage_key] = xc[model.cond_stage_key] 826 | 827 | with model.ema_scope("Plotting"): 828 | t0 = time.time() 829 | img_cb = None 830 | 831 | sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, 832 | eta=eta, 833 | quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0, 834 | temperature=temperature, noise_dropout=noise_dropout, 835 | score_corrector=corrector, corrector_kwargs=corrector_kwargs, 836 | x_T=x_T, log_every_t=log_every_t) 837 | t1 = time.time() 838 | 839 | if ddim_use_x0_pred: 840 | sample = intermediates['pred_x0'][-1] 841 | 842 | x_sample = model.decode_first_stage(sample) 843 | 844 | try: 845 | x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) 846 | log["sample_noquant"] = x_sample_noquant 847 | log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) 848 | except: 849 | pass 850 | 851 | log["sample"] = x_sample 852 | log["time"] = t1 - t0 853 | 854 | return log 855 | 856 | sr_diffMode = 'superresolution' 857 | sr_model = get_model('superresolution') if argparse_args.sharpen_preset != "Off" else None 858 | 859 | 860 | def do_superres(img, filepath): 861 | 862 | if args.sharpen_preset == 'Faster': 863 | sr_diffusion_steps = "25" 864 | sr_pre_downsample = '1/2' 865 | if args.sharpen_preset == 'Fast': 866 | sr_diffusion_steps = "100" 867 | sr_pre_downsample = '1/2' 868 | if args.sharpen_preset == 'Slow': 869 | sr_diffusion_steps = "25" 870 | sr_pre_downsample = 'None' 871 | if args.sharpen_preset == 'Very Slow': 872 | sr_diffusion_steps = "100" 873 | sr_pre_downsample = 'None' 874 | 875 | 876 | sr_post_downsample = 'Original Size' 877 | sr_diffusion_steps = int(sr_diffusion_steps) 878 | sr_eta = 1.0 879 | sr_downsample_method = 'Lanczos' 880 | 881 | gc.collect() 882 | torch.cuda.empty_cache() 883 | 884 | im_og = img 885 | width_og, height_og = im_og.size 886 | 887 | #Downsample Pre 888 | if sr_pre_downsample == '1/2': 889 | downsample_rate = 2 890 | elif sr_pre_downsample == '1/4': 891 | downsample_rate = 4 892 | else: 893 | downsample_rate = 1 894 | 895 | width_downsampled_pre = width_og//downsample_rate 896 | height_downsampled_pre = height_og//downsample_rate 897 | 898 | if downsample_rate != 1: 899 | # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') 900 | im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) 901 | # im_og.save('/content/temp.png') 902 | # filepath = '/content/temp.png' 903 | 904 | logs = sr_run(sr_model["model"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta) 905 | 906 | sample = logs["sample"] 907 | sample = sample.detach().cpu() 908 | sample = torch.clamp(sample, -1., 1.) 909 | sample = (sample + 1.) / 2. * 255 910 | sample = sample.numpy().astype(np.uint8) 911 | sample = np.transpose(sample, (0, 2, 3, 1)) 912 | a = Image.fromarray(sample[0]) 913 | 914 | #Downsample Post 915 | if sr_post_downsample == '1/2': 916 | downsample_rate = 2 917 | elif sr_post_downsample == '1/4': 918 | downsample_rate = 4 919 | else: 920 | downsample_rate = 1 921 | 922 | width, height = a.size 923 | width_downsampled_post = width//downsample_rate 924 | height_downsampled_post = height//downsample_rate 925 | 926 | if sr_downsample_method == 'Lanczos': 927 | aliasing = Image.LANCZOS 928 | else: 929 | aliasing = Image.NEAREST 930 | 931 | if downsample_rate != 1: 932 | # print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]') 933 | a = a.resize((width_downsampled_post, height_downsampled_post), aliasing) 934 | elif sr_post_downsample == 'Original Size': 935 | # print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]') 936 | a = a.resize((width_og, height_og), aliasing) 937 | 938 | display.display(a) 939 | a.save(filepath) 940 | return 941 | print(f'Processing finished!') 942 | 943 | 944 | # # 3. Diffusion and CLIP model settings 945 | 946 | #@markdown ####**Models Settings:** 947 | diffusion_model = "512x512_diffusion_uncond_finetune_008100" #@param ["256x256_diffusion_uncond", "512x512_diffusion_uncond_finetune_008100"] 948 | use_secondary_model = True #@param {type: 'boolean'} 949 | 950 | timestep_respacing = '50' # param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000'] 951 | diffusion_steps = 1000 # param {type: 'number'} 952 | use_checkpoint = True #@param {type: 'boolean'} 953 | ViTB32 = True #@param{type:"boolean"} 954 | ViTB16 = True #@param{type:"boolean"} 955 | RN101 = False #@param{type:"boolean"} 956 | RN50 = True #@param{type:"boolean"} 957 | RN50x4 = False #@param{type:"boolean"} 958 | RN50x16 = False #@param{type:"boolean"} 959 | SLIPB16 = False # param{type:"boolean"} 960 | SLIPL16 = False # param{type:"boolean"} 961 | 962 | #@markdown If you're having issues with model downloads, check this to compare SHA's: 963 | check_model_SHA = False #@param{type:"boolean"} 964 | 965 | model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a' 966 | model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648' 967 | model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a' 968 | 969 | model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt' 970 | model_512_link = 'http://batbot.tv/ai/models/guided-diffusion/512x512_diffusion_uncond_finetune_008100.pt' 971 | model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth' 972 | 973 | model_256_path = f'{model_path}/256x256_diffusion_uncond.pt' 974 | model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt' 975 | model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth' 976 | 977 | # Download the diffusion model 978 | if diffusion_model == '256x256_diffusion_uncond': 979 | if os.path.exists(model_256_path) and check_model_SHA: 980 | print('Checking 256 Diffusion File') 981 | with open(model_256_path,"rb") as f: 982 | bytes = f.read() 983 | hash = hashlib.sha256(bytes).hexdigest() 984 | if hash == model_256_SHA: 985 | print('256 Model SHA matches') 986 | model_256_downloaded = True 987 | else: 988 | print("256 Model SHA doesn't match, redownloading...") 989 | if argparse_args.setup: 990 | download_url(model_256_link, model_path, '256x256_diffusion_uncond.pt') 991 | model_256_downloaded = True 992 | elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True: 993 | print('256 Model already downloaded, check check_model_SHA if the file is corrupt') 994 | else: 995 | if argparse_args.setup: 996 | download_url(model_256_link, model_path, '256x256_diffusion_uncond.pt') 997 | model_256_downloaded = True 998 | elif diffusion_model == '512x512_diffusion_uncond_finetune_008100': 999 | if os.path.exists(model_512_path) and check_model_SHA: 1000 | print('Checking 512 Diffusion File') 1001 | with open(model_512_path,"rb") as f: 1002 | bytes = f.read() 1003 | hash = hashlib.sha256(bytes).hexdigest() 1004 | if hash == model_512_SHA: 1005 | print('512 Model SHA matches') 1006 | model_512_downloaded = True 1007 | else: 1008 | print("512 Model SHA doesn't match, redownloading...") 1009 | if argparse_args.setup: 1010 | download_url(model_512_link, model_path, '512x512_diffusion_uncond_finetune_008100.pt') 1011 | model_512_downloaded = True 1012 | elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True: 1013 | print('512 Model already downloaded, check check_model_SHA if the file is corrupt') 1014 | else: 1015 | if argparse_args.setup: 1016 | download_url(model_512_link, model_path, '512x512_diffusion_uncond_finetune_008100.pt') 1017 | model_512_downloaded = True 1018 | 1019 | 1020 | # Download the secondary diffusion model v2 1021 | if use_secondary_model == True: 1022 | if os.path.exists(model_secondary_path) and check_model_SHA: 1023 | print('Checking Secondary Diffusion File') 1024 | with open(model_secondary_path,"rb") as f: 1025 | bytes = f.read() 1026 | hash = hashlib.sha256(bytes).hexdigest() 1027 | if hash == model_secondary_SHA: 1028 | print('Secondary Model SHA matches') 1029 | model_secondary_downloaded = True 1030 | else: 1031 | print("Secondary Model SHA doesn't match, redownloading...") 1032 | if argparse_args.setup: 1033 | download_url(model_secondary_link, model_path, 'secondary_model_imagenet_2.pth') 1034 | model_secondary_downloaded = True 1035 | elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True: 1036 | print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt') 1037 | else: 1038 | if argparse_args.setup: 1039 | download_url(model_secondary_link, model_path, 'secondary_model_imagenet_2.pth') 1040 | model_secondary_downloaded = True 1041 | 1042 | model_config = model_and_diffusion_defaults() 1043 | if diffusion_model == '512x512_diffusion_uncond_finetune_008100': 1044 | model_config.update({ 1045 | 'attention_resolutions': '32, 16, 8', 1046 | 'class_cond': False, 1047 | 'diffusion_steps': diffusion_steps, 1048 | 'rescale_timesteps': True, 1049 | 'timestep_respacing': timestep_respacing, 1050 | 'image_size': 512, 1051 | 'learn_sigma': True, 1052 | 'noise_schedule': 'linear', 1053 | 'num_channels': 256, 1054 | 'num_head_channels': 64, 1055 | 'num_res_blocks': 2, 1056 | 'resblock_updown': True, 1057 | 'use_checkpoint': use_checkpoint, 1058 | 'use_fp16': True, 1059 | 'use_scale_shift_norm': True, 1060 | }) 1061 | elif diffusion_model == '256x256_diffusion_uncond': 1062 | model_config.update({ 1063 | 'attention_resolutions': '32, 16, 8', 1064 | 'class_cond': False, 1065 | 'diffusion_steps': diffusion_steps, 1066 | 'rescale_timesteps': True, 1067 | 'timestep_respacing': timestep_respacing, 1068 | 'image_size': 256, 1069 | 'learn_sigma': True, 1070 | 'noise_schedule': 'linear', 1071 | 'num_channels': 256, 1072 | 'num_head_channels': 64, 1073 | 'num_res_blocks': 2, 1074 | 'resblock_updown': True, 1075 | 'use_checkpoint': use_checkpoint, 1076 | 'use_fp16': True, 1077 | 'use_scale_shift_norm': True, 1078 | }) 1079 | 1080 | model_default = model_config['image_size'] 1081 | 1082 | secondary_model = SecondaryDiffusionImageNet2() 1083 | secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu')) 1084 | secondary_model.eval().requires_grad_(False).to(device) 1085 | 1086 | 1087 | clip_dict = {'ViT-B/32': ViTB32, 'ViT-B/16': ViTB16, 'RN50': RN50, 'RN50x4': RN50x4, 'RN50x16': RN50x16, 'RN101': RN101} 1088 | clip_models = [clip.load(clip_name, jit=False)[0].eval().requires_grad_(False).to(device) for clip_name in clip_dict if clip_dict[clip_name]] 1089 | 1090 | if SLIPB16: 1091 | SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256) 1092 | if argparse_args.setup: 1093 | if not os.path.exists(f'{model_path}/slip_base_100ep.pt'): 1094 | get_ipython().system('wget https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt -P {model_path}') 1095 | sd = torch.load(f'{model_path}/slip_base_100ep.pt') 1096 | real_sd = {} 1097 | for k, v in sd['state_dict'].items(): 1098 | real_sd['.'.join(k.split('.')[1:])] = v 1099 | del sd 1100 | SLIPB16model.load_state_dict(real_sd) 1101 | SLIPB16model.requires_grad_(False).eval().to(device) 1102 | 1103 | clip_models.append(SLIPB16model) 1104 | 1105 | if SLIPL16: 1106 | SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256) 1107 | if argparse_args.setup: 1108 | if not os.path.exists(f'{model_path}/slip_large_100ep.pt'): 1109 | get_ipython().system('wget https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt -P {model_path}') 1110 | sd = torch.load(f'{model_path}/slip_large_100ep.pt') 1111 | real_sd = {} 1112 | for k, v in sd['state_dict'].items(): 1113 | real_sd['.'.join(k.split('.')[1:])] = v 1114 | del sd 1115 | SLIPL16model.load_state_dict(real_sd) 1116 | SLIPL16model.requires_grad_(False).eval().to(device) 1117 | 1118 | clip_models.append(SLIPL16model) 1119 | 1120 | normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) 1121 | 1122 | # # 4. Settings 1123 | 1124 | 1125 | #@markdown ####**Basic Settings:** 1126 | batch_name = argparse_args.out_name #@param{type: 'string'} 1127 | steps = argparse_args.steps #@param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true} 1128 | width_height = [argparse_args.width, argparse_args.height]#@param{type: 'raw'} 1129 | clip_guidance_scale = 5000 #@param{type: 'number'} 1130 | tv_scale = 0#@param{type: 'number'} 1131 | range_scale = 150#@param{type: 'number'} 1132 | sat_scale = 0#@param{type: 'number'} 1133 | cutn_batches = 4 #@param{type: 'number'} 1134 | skip_augs = False#@param{type: 'boolean'} 1135 | 1136 | #@markdown --- 1137 | 1138 | #@markdown ####**Init Settings:** 1139 | init_image = argparse_args.init_image #@param{type: 'string'} 1140 | init_scale = 1000 #@param{type: 'integer'} 1141 | skip_steps = argparse_args.skip_steps if argparse_args.skip_steps is not None else (steps // 2 if init_image is not None else 0) 1142 | 1143 | #Get corrected sizes 1144 | side_x = (width_height[0]//64)*64; 1145 | side_y = (width_height[1]//64)*64; 1146 | if side_x != width_height[0] or side_y != width_height[1]: 1147 | print(f'Changing output size to {side_x}x{side_y}. Dimensions must by multiples of 64.') 1148 | 1149 | #Update Model Settings 1150 | timestep_respacing = f'ddim{steps}' 1151 | diffusion_steps = (1000//steps)*steps if steps < 1000 else steps 1152 | model_config.update({ 1153 | 'timestep_respacing': timestep_respacing, 1154 | 'diffusion_steps': diffusion_steps, 1155 | }) 1156 | 1157 | #Make folder for batch 1158 | batchFolder = f'{outDirPath}/{batch_name}' 1159 | os.makedirs(batchFolder, exist_ok=1) 1160 | 1161 | 1162 | # ###Animation Settings 1163 | 1164 | # In[13]: 1165 | 1166 | 1167 | #@markdown ####**Animation Mode:** 1168 | animation_mode = "None" #@param['None', '2D', 'Video Input'] 1169 | #@markdown *For animation, you probably want to turn `cutn_batches` to 1 to make it quicker.* 1170 | 1171 | 1172 | #@markdown --- 1173 | 1174 | #@markdown ####**Video Input Settings:** 1175 | video_init_path = "/content/training.mp4" #@param {type: 'string'} 1176 | extract_nth_frame = 2 #@param {type:"number"} 1177 | 1178 | if animation_mode == "Video Input": 1179 | videoFramesFolder = f'/content/videoFrames' 1180 | os.makedirs(videoFramesFolder, exist_ok=True) 1181 | print(f"Exporting Video Frames (1 every {extract_nth_frame})...") 1182 | if argparse_args.setup: 1183 | try: 1184 | get_ipython().system('rm {videoFramesFolder}/*.jpg') 1185 | except: 1186 | print('') 1187 | vf = f'"select=not(mod(n\,{extract_nth_frame}))"' 1188 | if argparse_args.setup: 1189 | get_ipython().system('ffmpeg -i {video_init_path} -vf {vf} -vsync vfr -q:v 2 -loglevel error -stats {videoFramesFolder}/%04d.jpg') 1190 | 1191 | 1192 | #@markdown --- 1193 | 1194 | #@markdown ####**2D Animation Settings:** 1195 | #@markdown `zoom` is a multiplier of dimensions, 1 is no zoom. 1196 | 1197 | key_frames = True #@param {type:"boolean"} 1198 | max_frames = 10000#@param {type:"number"} 1199 | 1200 | if animation_mode == "Video Input": 1201 | max_frames = len(glob(f'{videoFramesFolder}/*.jpg')) 1202 | 1203 | interp_spline = 'Linear' #Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:"string"} 1204 | angle = "0:(0)"#@param {type:"string"} 1205 | zoom = "0: (1), 10: (1.05)"#@param {type:"string"} 1206 | translation_x = "0: (0)"#@param {type:"string"} 1207 | translation_y = "0: (0)"#@param {type:"string"} 1208 | 1209 | #@markdown --- 1210 | 1211 | #@markdown ####**Coherency Settings:** 1212 | #@markdown `frame_scale` tries to guide the new frame to looking like the old one. A good default is 1500. 1213 | frames_scale = 1500 #@param{type: 'integer'} 1214 | #@markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into. 1215 | frames_skip_steps = '60%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'} 1216 | 1217 | 1218 | def parse_key_frames(string, prompt_parser=None): 1219 | """Given a string representing frame numbers paired with parameter values at that frame, 1220 | return a dictionary with the frame numbers as keys and the parameter values as the values. 1221 | 1222 | Parameters 1223 | ---------- 1224 | string: string 1225 | Frame numbers paired with parameter values at that frame number, in the format 1226 | 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...' 1227 | prompt_parser: function or None, optional 1228 | If provided, prompt_parser will be applied to each string of parameter values. 1229 | 1230 | Returns 1231 | ------- 1232 | dict 1233 | Frame numbers as keys, parameter values at that frame number as values 1234 | 1235 | Raises 1236 | ------ 1237 | RuntimeError 1238 | If the input string does not match the expected format. 1239 | 1240 | Examples 1241 | -------- 1242 | >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)") 1243 | {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'} 1244 | 1245 | >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)", prompt_parser=lambda x: x.lower())) 1246 | {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'} 1247 | """ 1248 | import re 1249 | pattern = r'((?P[0-9]+):[\s]*[\(](?P[\S\s]*?)[\)])' 1250 | frames = dict() 1251 | for match_object in re.finditer(pattern, string): 1252 | frame = int(match_object.groupdict()['frame']) 1253 | param = match_object.groupdict()['param'] 1254 | if prompt_parser: 1255 | frames[frame] = prompt_parser(param) 1256 | else: 1257 | frames[frame] = param 1258 | 1259 | if frames == {} and len(string) != 0: 1260 | raise RuntimeError('Key Frame string not correctly formatted') 1261 | return frames 1262 | 1263 | def get_inbetweens(key_frames, integer=False): 1264 | """Given a dict with frame numbers as keys and a parameter value as values, 1265 | return a pandas Series containing the value of the parameter at every frame from 0 to max_frames. 1266 | Any values not provided in the input dict are calculated by linear interpolation between 1267 | the values of the previous and next provided frames. If there is no previous provided frame, then 1268 | the value is equal to the value of the next provided frame, or if there is no next provided frame, 1269 | then the value is equal to the value of the previous provided frame. If no frames are provided, 1270 | all frame values are NaN. 1271 | 1272 | Parameters 1273 | ---------- 1274 | key_frames: dict 1275 | A dict with integer frame numbers as keys and numerical values of a particular parameter as values. 1276 | integer: Bool, optional 1277 | If True, the values of the output series are converted to integers. 1278 | Otherwise, the values are floats. 1279 | 1280 | Returns 1281 | ------- 1282 | pd.Series 1283 | A Series with length max_frames representing the parameter values for each frame. 1284 | 1285 | Examples 1286 | -------- 1287 | >>> max_frames = 5 1288 | >>> get_inbetweens({1: 5, 3: 6}) 1289 | 0 5.0 1290 | 1 5.0 1291 | 2 5.5 1292 | 3 6.0 1293 | 4 6.0 1294 | dtype: float64 1295 | 1296 | >>> get_inbetweens({1: 5, 3: 6}, integer=True) 1297 | 0 5 1298 | 1 5 1299 | 2 5 1300 | 3 6 1301 | 4 6 1302 | dtype: int64 1303 | """ 1304 | key_frame_series = pd.Series([np.nan for a in range(max_frames)]) 1305 | 1306 | for i, value in key_frames.items(): 1307 | key_frame_series[i] = value 1308 | key_frame_series = key_frame_series.astype(float) 1309 | 1310 | interp_method = interp_spline 1311 | 1312 | if interp_method == 'Cubic' and len(key_frames.items()) <=3: 1313 | interp_method = 'Quadratic' 1314 | 1315 | if interp_method == 'Quadratic' and len(key_frames.items()) <= 2: 1316 | interp_method = 'Linear' 1317 | 1318 | 1319 | key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] 1320 | key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()] 1321 | # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both') 1322 | key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both') 1323 | if integer: 1324 | return key_frame_series.astype(int) 1325 | return key_frame_series 1326 | 1327 | 1328 | def split_prompts(prompts): 1329 | prompt_series = pd.Series([np.nan for a in range(max_frames)]) 1330 | for i, prompt in prompts.items(): 1331 | prompt_series[i] = prompt 1332 | # prompt_series = prompt_series.astype(str) 1333 | prompt_series = prompt_series.ffill().bfill() 1334 | return prompt_series 1335 | 1336 | 1337 | if key_frames: 1338 | try: 1339 | angle_series = get_inbetweens(parse_key_frames(angle)) 1340 | except RuntimeError as e: 1341 | print( 1342 | "WARNING: You have selected to use key frames, but you have not " 1343 | "formatted `angle` correctly for key frames.\n" 1344 | "Attempting to interpret `angle` as " 1345 | f'"0: ({angle})"\n' 1346 | "Please read the instructions to find out how to use key frames " 1347 | "correctly.\n" 1348 | ) 1349 | angle = f"0: ({angle})" 1350 | angle_series = get_inbetweens(parse_key_frames(angle)) 1351 | 1352 | try: 1353 | zoom_series = get_inbetweens(parse_key_frames(zoom)) 1354 | except RuntimeError as e: 1355 | print( 1356 | "WARNING: You have selected to use key frames, but you have not " 1357 | "formatted `zoom` correctly for key frames.\n" 1358 | "Attempting to interpret `zoom` as " 1359 | f'"0: ({zoom})"\n' 1360 | "Please read the instructions to find out how to use key frames " 1361 | "correctly.\n" 1362 | ) 1363 | zoom = f"0: ({zoom})" 1364 | zoom_series = get_inbetweens(parse_key_frames(zoom)) 1365 | 1366 | try: 1367 | translation_x_series = get_inbetweens(parse_key_frames(translation_x)) 1368 | except RuntimeError as e: 1369 | print( 1370 | "WARNING: You have selected to use key frames, but you have not " 1371 | "formatted `translation_x` correctly for key frames.\n" 1372 | "Attempting to interpret `translation_x` as " 1373 | f'"0: ({translation_x})"\n' 1374 | "Please read the instructions to find out how to use key frames " 1375 | "correctly.\n" 1376 | ) 1377 | translation_x = f"0: ({translation_x})" 1378 | translation_x_series = get_inbetweens(parse_key_frames(translation_x)) 1379 | 1380 | try: 1381 | translation_y_series = get_inbetweens(parse_key_frames(translation_y)) 1382 | except RuntimeError as e: 1383 | print( 1384 | "WARNING: You have selected to use key frames, but you have not " 1385 | "formatted `translation_y` correctly for key frames.\n" 1386 | "Attempting to interpret `translation_y` as " 1387 | f'"0: ({translation_y})"\n' 1388 | "Please read the instructions to find out how to use key frames " 1389 | "correctly.\n" 1390 | ) 1391 | translation_y = f"0: ({translation_y})" 1392 | translation_y_series = get_inbetweens(parse_key_frames(translation_y)) 1393 | 1394 | else: 1395 | angle = float(angle) 1396 | zoom = float(zoom) 1397 | translation_x = float(translation_x) 1398 | translation_y = float(translation_y) 1399 | 1400 | 1401 | # ### Extra Settings 1402 | # Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling 1403 | 1404 | # In[14]: 1405 | 1406 | 1407 | #@markdown ####**Saving:** 1408 | 1409 | intermediate_saves = argparse_args.inter_saves #@param{type: 'raw'} 1410 | intermediates_in_subfolder = True #@param{type: 'boolean'} 1411 | #@markdown Intermediate steps will save a copy at your specified intervals. You can either format it as a single integer or a list of specific steps 1412 | 1413 | #@markdown A value of `2` will save a copy at 33% and 66%. 0 will save none. 1414 | 1415 | #@markdown A value of `[5, 9, 34, 45]` will save at steps 5, 9, 34, and 45. (Make sure to include the brackets) 1416 | 1417 | 1418 | if type(intermediate_saves) is not list: 1419 | if intermediate_saves: 1420 | steps_per_checkpoint = math.floor((steps - skip_steps - 1) // (intermediate_saves+1)) 1421 | steps_per_checkpoint = steps_per_checkpoint if steps_per_checkpoint > 0 else 1 1422 | print(f'Will save every {steps_per_checkpoint} steps') 1423 | else: 1424 | steps_per_checkpoint = steps+10 1425 | else: 1426 | steps_per_checkpoint = None 1427 | 1428 | if intermediate_saves and intermediates_in_subfolder is True: 1429 | partialFolder = f'{batchFolder}/partials' 1430 | os.makedirs(partialFolder, exist_ok=True) 1431 | 1432 | #@markdown --- 1433 | 1434 | #@markdown ####**SuperRes Sharpening:** 1435 | #@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.* 1436 | sharpen_preset = argparse_args.sharpen_preset #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow'] 1437 | keep_unsharp = True #@param{type: 'boolean'} 1438 | 1439 | if sharpen_preset != 'Off' and keep_unsharp is True: 1440 | unsharpenFolder = f'{batchFolder}/unsharpened' 1441 | os.makedirs(unsharpenFolder, exist_ok=True) 1442 | 1443 | 1444 | #@markdown --- 1445 | 1446 | #@markdown ####**Advanced Settings:** 1447 | #@markdown *There are a few extra advanced settings available if you double click this cell.* 1448 | 1449 | #@markdown *Perlin init will replace your init, so uncheck if using one.* 1450 | 1451 | perlin_init = False #@param{type: 'boolean'} 1452 | perlin_mode = 'mixed' #@param ['mixed', 'color', 'gray'] 1453 | set_seed = 'random_seed' #@param{type: 'string'} 1454 | eta = 0.8#@param{type: 'number'} 1455 | clamp_grad = True #@param{type: 'boolean'} 1456 | clamp_max = 0.05 #@param{type: 'number'} 1457 | 1458 | 1459 | ### EXTRA ADVANCED SETTINGS: 1460 | randomize_class = True 1461 | clip_denoised = False 1462 | fuzzy_prompt = False 1463 | rand_mag = 0.05 1464 | 1465 | 1466 | #@markdown --- 1467 | 1468 | #@markdown ####**Cutn Scheduling:** 1469 | #@markdown Format: `[40]*400+[20]*600` = 40 cuts for the first 400 /1000 steps, then 20 for the last 600/1000 1470 | 1471 | #@markdown cut_overview and cut_innercut are cumulative for total cutn on any given step. Overview cuts see the entire image and are good for early structure, innercuts are your standard cutn. 1472 | 1473 | cut_overview = "[12]*400+[4]*600" #@param {type: 'string'} 1474 | cut_innercut ="[4]*400+[12]*600"#@param {type: 'string'} 1475 | cut_ic_pow = 1#@param {type: 'number'} 1476 | cut_icgray_p = "[0.2]*400+[0]*600"#@param {type: 'string'} 1477 | 1478 | 1479 | # ###Prompts 1480 | # `animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one. 1481 | 1482 | # In[15]: 1483 | 1484 | 1485 | text_prompts = { 1486 | 0: [argparse_args.text] 1487 | #"A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation."], 1488 | # 100: ["This set of prompts start at frame 100", "This prompt has weight five:5"], 1489 | } 1490 | 1491 | image_prompts = { 1492 | 1493 | #0:['base_images/tree_of_life/tree_new_1.jpeg',], 1494 | } 1495 | 1496 | 1497 | # # 5. Diffuse! 1498 | #@title Do the Run! 1499 | #@markdown `n_batches` ignored with animation modes. 1500 | display_rate = 10 #@param{type: 'number'} 1501 | n_batches = 1 #@param{type: 'number'} 1502 | batch_size = 1 1503 | 1504 | 1505 | resume_run = False #@param{type: 'boolean'} 1506 | run_to_resume = 'latest' #@param{type: 'string'} 1507 | resume_from_frame = 'latest' #@param{type: 'string'} 1508 | retain_overwritten_frames = False #@param{type: 'boolean'} 1509 | if retain_overwritten_frames is True: 1510 | retainFolder = f'{batchFolder}/retained' 1511 | os.makedirs(retainFolder, exist_ok=True) 1512 | 1513 | 1514 | skip_step_ratio = int(frames_skip_steps.rstrip("%")) / 100 1515 | calc_frames_skip_steps = math.floor(steps * skip_step_ratio) 1516 | 1517 | if steps <= calc_frames_skip_steps: 1518 | sys.exit("ERROR: You can't skip more steps than your total steps") 1519 | 1520 | if resume_run: 1521 | if run_to_resume == 'latest': 1522 | try: 1523 | batchNum 1524 | except: 1525 | batchNum = len(glob(f"{batchFolder}/{batch_name}(*)_settings.txt"))-1 1526 | else: 1527 | batchNum = int(run_to_resume) 1528 | if resume_from_frame == 'latest': 1529 | start_frame = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png")) 1530 | else: 1531 | start_frame = int(resume_from_frame)+1 1532 | if retain_overwritten_frames is True: 1533 | existing_frames = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png")) 1534 | frames_to_save = existing_frames - start_frame 1535 | print(f'Moving {frames_to_save} frames to the Retained folder') 1536 | move_files(start_frame, existing_frames, batchFolder, retainFolder) 1537 | else: 1538 | start_frame = 0 1539 | batchNum = len(glob(batchFolder+"/*.txt")) 1540 | while path.isfile(f"{batchFolder}/{batch_name}({batchNum})_settings.txt") is True or path.isfile(f"{batchFolder}/{batch_name}-{batchNum}_settings.txt") is True: 1541 | batchNum += 1 1542 | 1543 | print(f'Starting Run: {batch_name}({batchNum}) at frame {start_frame}') 1544 | 1545 | if set_seed == 'random_seed': 1546 | random.seed() 1547 | seed = random.randint(0, 2**32) 1548 | # print(f'Using seed: {seed}') 1549 | else: 1550 | seed = int(set_seed) 1551 | 1552 | args = { 1553 | 'batchNum': batchNum, 1554 | 'prompts_series':split_prompts(text_prompts) if text_prompts else None, 1555 | 'image_prompts_series':split_prompts(image_prompts) if image_prompts else None, 1556 | 'seed': seed, 1557 | 'display_rate':display_rate, 1558 | 'n_batches':n_batches if animation_mode == 'None' else 1, 1559 | 'batch_size':batch_size, 1560 | 'batch_name': batch_name, 1561 | 'steps': steps, 1562 | 'width_height': width_height, 1563 | 'clip_guidance_scale': clip_guidance_scale, 1564 | 'tv_scale': tv_scale, 1565 | 'range_scale': range_scale, 1566 | 'sat_scale': sat_scale, 1567 | 'cutn_batches': cutn_batches, 1568 | 'init_image': init_image, 1569 | 'init_scale': init_scale, 1570 | 'skip_steps': skip_steps, 1571 | 'sharpen_preset': sharpen_preset, 1572 | 'keep_unsharp': keep_unsharp, 1573 | 'side_x': side_x, 1574 | 'side_y': side_y, 1575 | 'timestep_respacing': timestep_respacing, 1576 | 'diffusion_steps': diffusion_steps, 1577 | 'animation_mode': animation_mode, 1578 | 'video_init_path': video_init_path, 1579 | 'extract_nth_frame': extract_nth_frame, 1580 | 'key_frames': key_frames, 1581 | 'max_frames': max_frames if animation_mode != "None" else 1, 1582 | 'interp_spline': interp_spline, 1583 | 'start_frame': start_frame, 1584 | 'angle': angle, 1585 | 'zoom': zoom, 1586 | 'translation_x': translation_x, 1587 | 'translation_y': translation_y, 1588 | 'angle_series':angle_series, 1589 | 'zoom_series':zoom_series, 1590 | 'translation_x_series':translation_x_series, 1591 | 'translation_y_series':translation_y_series, 1592 | 'frames_scale': frames_scale, 1593 | 'calc_frames_skip_steps': calc_frames_skip_steps, 1594 | 'skip_step_ratio': skip_step_ratio, 1595 | 'calc_frames_skip_steps': calc_frames_skip_steps, 1596 | 'text_prompts': text_prompts, 1597 | 'image_prompts': image_prompts, 1598 | 'cut_overview': eval(cut_overview), 1599 | 'cut_innercut': eval(cut_innercut), 1600 | 'cut_ic_pow': cut_ic_pow, 1601 | 'cut_icgray_p': eval(cut_icgray_p), 1602 | 'intermediate_saves': intermediate_saves, 1603 | 'intermediates_in_subfolder': intermediates_in_subfolder, 1604 | 'steps_per_checkpoint': steps_per_checkpoint, 1605 | 'perlin_init': perlin_init, 1606 | 'perlin_mode': perlin_mode, 1607 | 'set_seed': set_seed, 1608 | 'eta': eta, 1609 | 'clamp_grad': clamp_grad, 1610 | 'clamp_max': clamp_max, 1611 | 'skip_augs': skip_augs, 1612 | 'randomize_class': randomize_class, 1613 | 'clip_denoised': clip_denoised, 1614 | 'fuzzy_prompt': fuzzy_prompt, 1615 | 'rand_mag': rand_mag, 1616 | } 1617 | 1618 | args = SimpleNamespace(**args) 1619 | 1620 | print('Prepping model...') 1621 | model, diffusion = create_model_and_diffusion(**model_config) 1622 | model.load_state_dict(torch.load(f'{model_path}/{diffusion_model}.pt', map_location='cpu')) 1623 | model.requires_grad_(False).eval().to(device) 1624 | for name, param in model.named_parameters(): 1625 | if 'qkv' in name or 'norm' in name or 'proj' in name: 1626 | param.requires_grad_() 1627 | if model_config['use_fp16']: 1628 | model.convert_to_fp16() 1629 | 1630 | gc.collect() 1631 | torch.cuda.empty_cache() 1632 | try: 1633 | do_run() 1634 | except KeyboardInterrupt: 1635 | pass 1636 | finally: 1637 | print('Seed used:', seed) 1638 | gc.collect() 1639 | torch.cuda.empty_cache() 1640 | 1641 | 1642 | # # 6. Create the video 1643 | 1644 | # @title ### **Create video** 1645 | #@markdown Video file will save in the same folder as your images. 1646 | 1647 | skip_video_for_run_all = True #@param {type: 'boolean'} 1648 | 1649 | if skip_video_for_run_all == False: 1650 | # import subprocess in case this cell is run without the above cells 1651 | import subprocess 1652 | from base64 import b64encode 1653 | 1654 | latest_run = batchNum 1655 | 1656 | folder = batch_name #@param 1657 | run = latest_run #@param 1658 | final_frame = 'final_frame' 1659 | 1660 | 1661 | init_frame = 1#@param {type:"number"} This is the frame where the video will start 1662 | last_frame = final_frame#@param {type:"number"} You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist. 1663 | fps = 12#@param {type:"number"} 1664 | view_video_in_cell = False #@param {type: 'boolean'} 1665 | 1666 | frames = [] 1667 | # tqdm.write('Generating video...') 1668 | 1669 | if last_frame == 'final_frame': 1670 | last_frame = len(glob(batchFolder+f"/{folder}({run})_*.png")) 1671 | print(f'Total frames: {last_frame}') 1672 | 1673 | image_path = f"{outDirPath}/{folder}/{folder}({run})_%04d.png" 1674 | filepath = f"{outDirPath}/{folder}/{folder}({run}).mp4" 1675 | 1676 | 1677 | cmd = [ 1678 | 'ffmpeg', 1679 | '-y', 1680 | '-vcodec', 1681 | 'png', 1682 | '-r', 1683 | str(fps), 1684 | '-start_number', 1685 | str(init_frame), 1686 | '-i', 1687 | image_path, 1688 | '-frames:v', 1689 | str(last_frame+1), 1690 | '-c:v', 1691 | 'libx264', 1692 | '-vf', 1693 | f'fps={fps}', 1694 | '-pix_fmt', 1695 | 'yuv420p', 1696 | '-crf', 1697 | '17', 1698 | '-preset', 1699 | 'veryslow', 1700 | filepath 1701 | ] 1702 | 1703 | process = subprocess.Popen(cmd, cwd=f'{batchFolder}', stdout=subprocess.PIPE, stderr=subprocess.PIPE) 1704 | stdout, stderr = process.communicate() 1705 | if process.returncode != 0: 1706 | print(stderr) 1707 | raise RuntimeError(stderr) 1708 | else: 1709 | print("The video is ready") 1710 | 1711 | if view_video_in_cell: 1712 | mp4 = open(filepath,'rb').read() 1713 | data_url = "data:video/mp4;base64," + b64encode(mp4).decode() 1714 | display.HTML(""" 1715 | 1718 | """ % data_url) 1719 | -------------------------------------------------------------------------------- /perlin.py: -------------------------------------------------------------------------------- 1 | # from https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869 2 | 3 | 4 | import torch 5 | import torchvision.transforms as TF 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | def interp(t): 10 | return 3 * t**2 - 2 * t ** 3 11 | 12 | 13 | def perlin(width, height, scale=10, device=None): 14 | gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device) 15 | xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device) 16 | ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device) 17 | wx = 1 - interp(xs) 18 | wy = 1 - interp(ys) 19 | dots = 0 20 | dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys) 21 | dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys) 22 | dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys)) 23 | dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys)) 24 | return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale) 25 | 26 | 27 | def perlin_ms(octaves, width, height, grayscale, device=None): 28 | out_array = [0.5] if grayscale else [0.5, 0.5, 0.5] 29 | # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0] 30 | for i in range(1 if grayscale else 3): 31 | scale = 2 ** len(octaves) 32 | oct_width = width 33 | oct_height = height 34 | for oct in octaves: 35 | p = perlin(oct_width, oct_height, scale, device) 36 | out_array[i] += p * oct 37 | scale //= 2 38 | oct_width *= 2 39 | oct_height *= 2 40 | return torch.cat(out_array) 41 | 42 | 43 | def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True): 44 | out = perlin_ms(octaves, width, height, grayscale) 45 | if grayscale: 46 | out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0)) 47 | out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB') 48 | else: 49 | out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1]) 50 | out = TF.resize(size=(side_y, side_x), img=out) 51 | out = TF.to_pil_image(out.clamp(0, 1).squeeze()) 52 | 53 | out = ImageOps.autocontrast(out) 54 | return out 55 | 56 | 57 | def regen_perlin(): 58 | if perlin_mode == 'color': 59 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False) 60 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False) 61 | elif perlin_mode == 'gray': 62 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True) 63 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True) 64 | else: 65 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False) 66 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True) 67 | 68 | init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1) 69 | del init2 70 | return init.expand(batch_size, -1, -1, -1) 71 | -------------------------------------------------------------------------------- /secondary_diffusion.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import math 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | @dataclass 10 | class DiffusionOutput: 11 | v: torch.Tensor 12 | pred: torch.Tensor 13 | eps: torch.Tensor 14 | 15 | 16 | class ConvBlock(nn.Sequential): 17 | def __init__(self, c_in, c_out): 18 | super().__init__( 19 | nn.Conv2d(c_in, c_out, 3, padding=1), 20 | nn.ReLU(inplace=True), 21 | ) 22 | 23 | 24 | class SkipBlock(nn.Module): 25 | def __init__(self, main, skip=None): 26 | super().__init__() 27 | self.main = nn.Sequential(*main) 28 | self.skip = skip if skip else nn.Identity() 29 | 30 | def forward(self, input): 31 | return torch.cat([self.main(input), self.skip(input)], dim=1) 32 | 33 | 34 | class FourierFeatures(nn.Module): 35 | def __init__(self, in_features, out_features, std=1.): 36 | super().__init__() 37 | assert out_features % 2 == 0 38 | self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std) 39 | 40 | def forward(self, input): 41 | f = 2 * math.pi * input @ self.weight.T 42 | return torch.cat([f.cos(), f.sin()], dim=-1) 43 | 44 | 45 | class SecondaryDiffusionImageNet(nn.Module): 46 | def __init__(self): 47 | super().__init__() 48 | c = 64 # The base channel count 49 | 50 | self.timestep_embed = FourierFeatures(1, 16) 51 | 52 | self.net = nn.Sequential( 53 | ConvBlock(3 + 16, c), 54 | ConvBlock(c, c), 55 | SkipBlock([ 56 | nn.AvgPool2d(2), 57 | ConvBlock(c, c * 2), 58 | ConvBlock(c * 2, c * 2), 59 | SkipBlock([ 60 | nn.AvgPool2d(2), 61 | ConvBlock(c * 2, c * 4), 62 | ConvBlock(c * 4, c * 4), 63 | SkipBlock([ 64 | nn.AvgPool2d(2), 65 | ConvBlock(c * 4, c * 8), 66 | ConvBlock(c * 8, c * 4), 67 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 68 | ]), 69 | ConvBlock(c * 8, c * 4), 70 | ConvBlock(c * 4, c * 2), 71 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 72 | ]), 73 | ConvBlock(c * 4, c * 2), 74 | ConvBlock(c * 2, c), 75 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 76 | ]), 77 | ConvBlock(c * 2, c), 78 | nn.Conv2d(c, 3, 3, padding=1), 79 | ) 80 | 81 | def forward(self, input, t): 82 | timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape) 83 | v = self.net(torch.cat([input, timestep_embed], dim=1)) 84 | alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t)) 85 | pred = input * alphas - v * sigmas 86 | eps = input * sigmas + v * alphas 87 | return DiffusionOutput(v, pred, eps) 88 | 89 | 90 | class SecondaryDiffusionImageNet2(nn.Module): 91 | def __init__(self): 92 | super().__init__() 93 | c = 64 # The base channel count 94 | cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8] 95 | 96 | self.timestep_embed = FourierFeatures(1, 16) 97 | self.down = nn.AvgPool2d(2) 98 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 99 | 100 | self.net = nn.Sequential( 101 | ConvBlock(3 + 16, cs[0]), 102 | ConvBlock(cs[0], cs[0]), 103 | SkipBlock([ 104 | self.down, 105 | ConvBlock(cs[0], cs[1]), 106 | ConvBlock(cs[1], cs[1]), 107 | SkipBlock([ 108 | self.down, 109 | ConvBlock(cs[1], cs[2]), 110 | ConvBlock(cs[2], cs[2]), 111 | SkipBlock([ 112 | self.down, 113 | ConvBlock(cs[2], cs[3]), 114 | ConvBlock(cs[3], cs[3]), 115 | SkipBlock([ 116 | self.down, 117 | ConvBlock(cs[3], cs[4]), 118 | ConvBlock(cs[4], cs[4]), 119 | SkipBlock([ 120 | self.down, 121 | ConvBlock(cs[4], cs[5]), 122 | ConvBlock(cs[5], cs[5]), 123 | ConvBlock(cs[5], cs[5]), 124 | ConvBlock(cs[5], cs[4]), 125 | self.up, 126 | ]), 127 | ConvBlock(cs[4] * 2, cs[4]), 128 | ConvBlock(cs[4], cs[3]), 129 | self.up, 130 | ]), 131 | ConvBlock(cs[3] * 2, cs[3]), 132 | ConvBlock(cs[3], cs[2]), 133 | self.up, 134 | ]), 135 | ConvBlock(cs[2] * 2, cs[2]), 136 | ConvBlock(cs[2], cs[1]), 137 | self.up, 138 | ]), 139 | ConvBlock(cs[1] * 2, cs[1]), 140 | ConvBlock(cs[1], cs[0]), 141 | self.up, 142 | ]), 143 | ConvBlock(cs[0] * 2, cs[0]), 144 | nn.Conv2d(cs[0], 3, 3, padding=1), 145 | ) 146 | 147 | def forward(self, input, t): 148 | timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape) 149 | v = self.net(torch.cat([input, timestep_embed], dim=1)) 150 | alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t)) 151 | pred = input * alphas - v * sigmas 152 | eps = input * sigmas + v * alphas 153 | return DiffusionOutput(v, pred, eps) 154 | 155 | 156 | def append_dims(x, n): 157 | return x[(Ellipsis, *(None,) * (n - x.ndim))] 158 | 159 | 160 | def expand_to_planes(x, shape): 161 | return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]]) 162 | 163 | 164 | def t_to_alpha_sigma(t): 165 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) 166 | --------------------------------------------------------------------------------