├── .gitignore
├── LICENSE
├── README.md
├── augs.py
├── ddim_sampler.py
├── diffuse.py
├── old_diffusion_notebook.ipynb
├── perlin.py
└── secondary_diffusion.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # cloned repos
2 | ResizeRight/
3 | guided-diffusion/
4 | latent-diffusion/
5 | out_diffusion/
6 | taming-transformers/
7 | CLIP/
8 |
9 | # files
10 | *.jpg
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | pip-wheel-metadata/
35 | share/python-wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .nox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | *.py,cover
62 | .hypothesis/
63 | .pytest_cache/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | Pipfile.lock
104 | Pipfile
105 |
106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107 | __pypackages__/
108 |
109 | # Celery stuff
110 | celerybeat-schedule
111 | celerybeat.pid
112 |
113 | # SageMath parsed files
114 | *.sage.py
115 |
116 | # Environments
117 | .env
118 | .venv
119 | env/
120 | venv/
121 | ENV/
122 | env.bak/
123 | venv.bak/
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 | .dmypy.json
138 | dmypy.json
139 |
140 | # Pyre type checker
141 | .pyre/
142 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Katherine Crowson
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in
11 | # all copies or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | # THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Diffusion Gen
2 |
3 | An adapation of DiscoDiffusion (https://colab.research.google.com/drive/1sHfRn5Y0YKYKi1k-ifUSBFRNJ8_1sa39#scrollTo=BGBzhk3dpcGO) to run locally, to improve code quality and to speed it up. So far the code was just cleaned up a bit and the lpips network initialization was removed when only an input text is used.
4 |
5 | Around 11GB GPU VRAM are needed for the current default settings of `--width` 1280 and `--height` 768. Decreasing the image size is the easiest way to make it fit in smaler GPUs.
6 | With defaults settings it takes 07:46 minutes on an RTX 2080TI, 19:01 minutes on a GTX 1080 TI, and 17:01 minutes on a Titan XP to generate images like these:
7 |
8 |
9 | **The meaning of life**
10 | 
11 |
12 |
13 | **The meaning of life by Picasso**
14 | 
15 |
16 |
17 | **The meaning of life by Greg Rutkowski**
18 | 
19 |
20 | **Consciousness**
21 | 
22 |
23 | *forgot the prompt but it was about pikachu staring at a tumultous sea of blood, adapted from the DiscoDiffusion original notebook*
24 | 
25 |
26 |
27 | ## Setup
28 | If you're using Windows, please also refer to the section below called `Setup for Windows`!
29 |
30 | First run `ipython3 diffuse.py` to set everything up and to clone the repositories. IMPORTANT: you need to use ipython instead of python because I was lazy and all git clone etc are run via ipython
31 |
32 | At the moment you can only set a single text as a target but this should be improved in the future. Only runs with GPU support atm.
33 |
34 | Use it like this:
35 |
36 | ```
37 | python3 diffuse.py --text "The meaning of life --gpu [Optional: device number of GPU to run this on] --root_path [Optional: path to output folder, default is "out_diffusion" in local dir]
38 | ```
39 | If you only have 8 GB VRAM on your GPU, the highest resolution you can use run is 832x512, or 896x448. Set it by adding `--width 832 --height 512` for example. Thanks @Jotunblood for testing!
40 |
41 | you can also set: `--out_name [Optional: set naming in your root_path according to this for better overview]` and `--sharpen_preset [Optional: set it to any of ('Off', 'Faster', 'Fast', 'Slow', 'Very Slow') to modify the sharpening process at the end. Default: Off]`
42 |
43 |
44 | ## Setup for Windows
45 | See https://github.com/NotNANtoN/diffusion_gen/issues/1, the instructions from @JotunBlood are adopted here.
46 |
47 | Instructions:
48 | - Install Anaconda
49 | - Create and activate a new environment (don't use base)
50 | - Install pytorch via their web code, using pip (not conda)
51 | - Install iPython
52 | - Add the forge channel to anaconda
53 | - conda config --add channels conda-forge
54 | - Install dependency packages using conda (for those available), otherwise use pip. Packages of relevance: OpenCV, pandas, timm, lpips, requests, pytorch-lightning, and omegaconf. There might be one or two others.
55 | - Run ipython diffuse.py
56 | - If it goes all the way, congrats. If you hit the SSL errors, open diffuse.py and add the following lines to the top of diffuse.py to the top (I did it around line 7.):
57 | ```
58 | import ssl
59 | ssl._create_default_https_context = ssl._create_unverified_context
60 | ```
61 | - If you get Frame Prompt: [''] and a failed output, make sure you're using python3 to run diffuse.py and not iPython :)
62 | - If you get a CUDA out of memory warning, pass a lower res like --width 720 --height 480 when you run
63 |
64 | ## Tutorial (copypasta from old colab notebook)
65 |
66 | ### **Diffusion settings**
67 | ---
68 |
69 | This section is outdated as of v2
70 |
71 | Setting | Description | Default
72 | --- | --- | ---
73 | **Your vision:**
74 | `text_prompts` | A description of what you'd like the machine to generate. Think of it like writing the caption below your image on a website. | N/A
75 | `image_prompts` | Think of these images more as a description of their contents. | N/A
76 | **Image quality:**
77 | `clip_guidance_scale` | Controls how much the image should look like the prompt. | 1000
78 | `tv_scale` | Controls the smoothness of the final output. | 150
79 | `range_scale` | Controls how far out of range RGB values are allowed to be. | 150
80 | `sat_scale` | Controls how much saturation is allowed. From nshepperd's JAX notebook. | 0
81 | `cutn` | Controls how many crops to take from the image. | 16
82 | `cutn_batches` | Accumulate CLIP gradient from multiple batches of cuts | 2
83 | **Init settings:**
84 | `init_image` | URL or local path | None
85 | `init_scale` | This enhances the effect of the init image, a good value is 1000 | 0
86 | `skip_steps Controls the starting point along the diffusion timesteps | 0
87 | `perlin_init` | Option to start with random perlin noise | False
88 | `perlin_mode` | ('gray', 'color') | 'mixed'
89 | **Advanced:**
90 | `skip_augs` |Controls whether to skip torchvision augmentations | False
91 | `randomize_class` |Controls whether the imagenet class is randomly changed each iteration | True
92 | `clip_denoised` |Determines whether CLIP discriminates a noisy or denoised image | False
93 | `clamp_grad` |Experimental: Using adaptive clip grad in the cond_fn | True
94 | `seed` | Choose a random seed and print it at end of run for reproduction | random_seed
95 | `fuzzy_prompt` | Controls whether to add multiple noisy prompts to the prompt losses | False
96 | `rand_mag` |Controls the magnitude of the random noise | 0.1
97 | `eta` | DDIM hyperparameter | 0.5
98 |
--------------------------------------------------------------------------------
/augs.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch.nn as nn
3 | import torch
4 |
5 | import torchvision.transforms as T
6 | import torchvision.transforms.functional as TF
7 | from torch.nn import functional as F
8 |
9 | sys.path.append("ResizeRight")
10 | from resize_right import resize
11 |
12 |
13 | class MakeCutouts(nn.Module):
14 | def __init__(self, cut_size, cutn, skip_augs=False):
15 | super().__init__()
16 | self.cut_size = cut_size
17 | self.cutn = cutn
18 | self.skip_augs = skip_augs
19 | self.augs = T.Compose([
20 | T.RandomHorizontalFlip(p=0.5),
21 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
22 | T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
23 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
24 | T.RandomPerspective(distortion_scale=0.4, p=0.7),
25 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
26 | T.RandomGrayscale(p=0.15),
27 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
28 | # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
29 | ])
30 |
31 | def forward(self, input):
32 | input = T.Pad(input.shape[2]//4, fill=0)(input)
33 | sideY, sideX = input.shape[2:4]
34 | max_size = min(sideX, sideY)
35 |
36 | cutouts = []
37 | for ch in range(self.cutn):
38 | if ch > self.cutn - self.cutn//4:
39 | cutout = input.clone()
40 | else:
41 | size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
42 | offsetx = torch.randint(0, abs(sideX - size + 1), ())
43 | offsety = torch.randint(0, abs(sideY - size + 1), ())
44 | cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
45 |
46 | if not self.skip_augs:
47 | cutout = self.augs(cutout)
48 | cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
49 | del cutout
50 |
51 | cutouts = torch.cat(cutouts, dim=0)
52 | return cutouts
53 |
54 |
55 | class MakeCutoutsDango(nn.Module):
56 | def __init__(self, cut_size,
57 | Overview=4,
58 | InnerCrop = 0,
59 | IC_Size_Pow=0.5,
60 | IC_Grey_P = 0.2,
61 | animation_mode='None',
62 | ):
63 | super().__init__()
64 | self.cut_size = cut_size
65 | self.Overview = Overview
66 | self.InnerCrop = InnerCrop
67 | self.IC_Size_Pow = IC_Size_Pow
68 | self.IC_Grey_P = IC_Grey_P
69 | if animation_mode == 'None':
70 | self.augs = T.Compose([
71 | T.RandomHorizontalFlip(p=0.5),
72 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
73 | T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),
74 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
75 | T.RandomGrayscale(p=0.1),
76 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
77 | T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
78 | ])
79 | elif animation_mode == 'Video Input':
80 | self.augs = T.Compose([
81 | T.RandomHorizontalFlip(p=0.5),
82 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
83 | T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
84 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
85 | T.RandomPerspective(distortion_scale=0.4, p=0.7),
86 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
87 | T.RandomGrayscale(p=0.15),
88 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
89 | # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
90 | ])
91 | elif animation_mode == '2D':
92 | self.augs = T.Compose([
93 | T.RandomHorizontalFlip(p=0.4),
94 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
95 | T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),
96 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
97 | T.RandomGrayscale(p=0.1),
98 | T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
99 | T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3),
100 | ])
101 |
102 |
103 | def forward(self, input):
104 | cutouts = []
105 | gray = T.Grayscale(3)
106 | sideY, sideX = input.shape[2:4]
107 | max_size = min(sideX, sideY)
108 | min_size = min(sideX, sideY, self.cut_size)
109 | l_size = max(sideX, sideY)
110 | output_shape = [1,3,self.cut_size,self.cut_size]
111 | output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2]
112 | pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2))
113 | cutout = resize(pad_input, out_shape=output_shape)
114 |
115 | if self.Overview>0:
116 | if self.Overview <= 4:
117 | if self.Overview >= 1:
118 | cutouts.append(cutout)
119 | if self.Overview >= 2:
120 | cutouts.append(gray(cutout))
121 | if self.Overview >= 3:
122 | cutouts.append(TF.hflip(cutout))
123 | if self.Overview == 4:
124 | cutouts.append(gray(TF.hflip(cutout)))
125 | else:
126 | cutout = resize(pad_input, out_shape=output_shape)
127 | for _ in range(self.Overview):
128 | cutouts.append(cutout)
129 |
130 | #if cutout_debug:
131 | # TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save("/content/cutout_overview0.jpg",quality=99)
132 |
133 | if self.InnerCrop >0:
134 | for i in range(self.InnerCrop):
135 | size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size)
136 | offsetx = torch.randint(0, sideX - size + 1, ())
137 | offsety = torch.randint(0, sideY - size + 1, ())
138 | cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
139 | if i <= int(self.IC_Grey_P * self.InnerCrop):
140 | cutout = gray(cutout)
141 | cutout = resize(cutout, out_shape=output_shape)
142 | cutouts.append(cutout)
143 | #if cutout_debug:
144 | # TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save("/content/cutout_InnerCrop.jpg",quality=99)
145 | cutouts = torch.cat(cutouts)
146 | #if skip_augs is not True: cutouts=self.augs(cutouts)
147 | return cutouts
148 |
149 |
150 | def sinc(x):
151 | return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
152 |
153 |
154 | def lanczos(x, a):
155 | cond = torch.logical_and(-a < x, x < a)
156 | out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
157 | return out / out.sum()
158 |
159 |
160 | def ramp(ratio, width):
161 | n = math.ceil(width / ratio + 1)
162 | out = torch.empty([n])
163 | cur = 0
164 | for i in range(out.shape[0]):
165 | out[i] = cur
166 | cur += ratio
167 | return torch.cat([-out[1:].flip([0]), out])[1:-1]
168 |
169 |
170 | def resample(input, size, align_corners=True):
171 | n, c, h, w = input.shape
172 | dh, dw = size
173 |
174 | input = input.reshape([n * c, 1, h, w])
175 |
176 | if dh < h:
177 | kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
178 | pad_h = (kernel_h.shape[0] - 1) // 2
179 | input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
180 | input = F.conv2d(input, kernel_h[None, None, :, None])
181 |
182 | if dw < w:
183 | kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
184 | pad_w = (kernel_w.shape[0] - 1) // 2
185 | input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
186 | input = F.conv2d(input, kernel_w[None, None, None, :])
187 |
188 | input = input.reshape([n, c, h, w])
189 | return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
190 |
--------------------------------------------------------------------------------
/ddim_sampler.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | sys.path.append("latent-diffusion")
8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9 |
10 |
11 |
12 | class DDIMSampler:
13 | def __init__(self, model, schedule="linear", **kwargs):
14 | super().__init__()
15 | self.model = model
16 | self.ddpm_num_timesteps = model.num_timesteps
17 | self.schedule = schedule
18 |
19 | def register_buffer(self, name, attr):
20 | if type(attr) == torch.Tensor:
21 | if attr.device != torch.device("cuda"):
22 | attr = attr.to(torch.device("cuda"))
23 | setattr(self, name, attr)
24 |
25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28 | alphas_cumprod = self.model.alphas_cumprod
29 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31 |
32 | self.register_buffer('betas', to_torch(self.model.betas))
33 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35 |
36 | # calculations for diffusion q(x_t | x_{t-1}) and others
37 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42 |
43 | # ddim sampling parameters
44 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45 | ddim_timesteps=self.ddim_timesteps,
46 | eta=ddim_eta,verbose=verbose)
47 | self.register_buffer('ddim_sigmas', ddim_sigmas)
48 | self.register_buffer('ddim_alphas', ddim_alphas)
49 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55 |
56 | @torch.no_grad()
57 | def sample(self,
58 | S,
59 | batch_size,
60 | shape,
61 | conditioning=None,
62 | callback=None,
63 | normals_sequence=None,
64 | img_callback=None,
65 | quantize_x0=False,
66 | eta=0.,
67 | mask=None,
68 | x0=None,
69 | temperature=1.,
70 | noise_dropout=0.,
71 | score_corrector=None,
72 | corrector_kwargs=None,
73 | verbose=True,
74 | x_T=None,
75 | log_every_t=100,
76 | **kwargs
77 | ):
78 | if conditioning is not None:
79 | if isinstance(conditioning, dict):
80 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
81 | if cbs != batch_size:
82 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
83 | else:
84 | if conditioning.shape[0] != batch_size:
85 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
86 |
87 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
88 | # sampling
89 | C, H, W = shape
90 | size = (batch_size, C, H, W)
91 | # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
92 |
93 | samples, intermediates = self.ddim_sampling(conditioning, size,
94 | callback=callback,
95 | img_callback=img_callback,
96 | quantize_denoised=quantize_x0,
97 | mask=mask, x0=x0,
98 | ddim_use_original_steps=False,
99 | noise_dropout=noise_dropout,
100 | temperature=temperature,
101 | score_corrector=score_corrector,
102 | corrector_kwargs=corrector_kwargs,
103 | x_T=x_T,
104 | log_every_t=log_every_t
105 | )
106 | return samples, intermediates
107 |
108 | @torch.no_grad()
109 | def ddim_sampling(self, cond, shape,
110 | x_T=None, ddim_use_original_steps=False,
111 | callback=None, timesteps=None, quantize_denoised=False,
112 | mask=None, x0=None, img_callback=None, log_every_t=100,
113 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
114 | device = self.model.betas.device
115 | b = shape[0]
116 | if x_T is None:
117 | img = torch.randn(shape, device=device)
118 | else:
119 | img = x_T
120 |
121 | if timesteps is None:
122 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
123 | elif timesteps is not None and not ddim_use_original_steps:
124 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
125 | timesteps = self.ddim_timesteps[:subset_end]
126 |
127 | intermediates = {'x_inter': [img], 'pred_x0': [img]}
128 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
129 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
130 | print(f"Running DDIM Sharpening with {total_steps} timesteps")
131 |
132 | iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)
133 |
134 | for i, step in enumerate(iterator):
135 | index = total_steps - i - 1
136 | ts = torch.full((b,), step, device=device, dtype=torch.long)
137 |
138 | if mask is not None:
139 | assert x0 is not None
140 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
141 | img = img_orig * mask + (1. - mask) * img
142 |
143 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
144 | quantize_denoised=quantize_denoised, temperature=temperature,
145 | noise_dropout=noise_dropout, score_corrector=score_corrector,
146 | corrector_kwargs=corrector_kwargs)
147 | img, pred_x0 = outs
148 | if callback: callback(i)
149 | if img_callback: img_callback(pred_x0, i)
150 |
151 | if index % log_every_t == 0 or index == total_steps - 1:
152 | intermediates['x_inter'].append(img)
153 | intermediates['pred_x0'].append(pred_x0)
154 |
155 | return img, intermediates
156 |
157 | @torch.no_grad()
158 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
159 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
160 | b, *_, device = *x.shape, x.device
161 | e_t = self.model.apply_model(x, t, c)
162 | if score_corrector is not None:
163 | assert self.model.parameterization == "eps"
164 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
165 |
166 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
167 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
168 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
169 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
170 | # select parameters corresponding to the currently considered timestep
171 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
172 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
173 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
174 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
175 |
176 | # current prediction for x_0
177 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
178 | if quantize_denoised:
179 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
180 | # direction pointing to x_t
181 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
182 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
183 | if noise_dropout > 0.:
184 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
185 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
186 | return x_prev, pred_x0
187 |
--------------------------------------------------------------------------------
/diffuse.py:
--------------------------------------------------------------------------------
1 | # # 1. Set Up
2 | import os
3 | from os import path
4 | import sys
5 | from argparse import ArgumentParser, Namespace
6 |
7 | def run_from_ipython():
8 | try:
9 | __IPYTHON__
10 | return True
11 | except NameError:
12 | return False
13 |
14 | parser = ArgumentParser()
15 | parser.add_argument("--gpu", default="0", type=str)
16 | parser.add_argument("--text", default="", type=str)
17 | parser.add_argument("--root_path", default="out_diffusion")
18 | parser.add_argument("--setup", default=False, type=bool)
19 | parser.add_argument("--out_name", default="out_image", type=str)
20 | parser.add_argument("--sharpen_preset", default="Off", type=str, choices=['Off', 'Faster', 'Fast', 'Slow', 'Very Slow'])
21 | parser.add_argument("--width", default=1280, type=int)
22 | parser.add_argument("--height", default=768, type=int)
23 | parser.add_argument("--init_image", default=None, type=str)
24 | parser.add_argument("--steps", default=250, type=int)
25 | parser.add_argument("--skip_steps", default=None, type=int)
26 | parser.add_argument("--inter_saves", default=3, type=int)
27 |
28 |
29 | if run_from_ipython():
30 | argparse_args = parser.parse_args({})
31 | argparse_args.setup = 1
32 | else:
33 | argparse_args = parser.parse_args()
34 | # read args to determine GPU before importing torch and befor importing other files (torch is also imported in other files)
35 | os.environ["CUDA_VISIBLE_DEVICES"] = argparse_args.gpu
36 |
37 | if argparse_args.setup:
38 | get_ipython().system('python3 -m pip install torch lpips datetime timm ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb')
39 | get_ipython().system('git clone https://github.com/CompVis/latent-diffusion.git')
40 | get_ipython().system('git clone https://github.com/openai/CLIP')
41 | get_ipython().system('pip3 install -e ./CLIP')
42 | get_ipython().system('git clone https://github.com/assafshocher/ResizeRight.git')
43 | get_ipython().system('git clone https://github.com/crowsonkb/guided-diffusion')
44 | get_ipython().system('python3 -m pip install -e ./guided-diffusion')
45 |
46 | get_ipython().system('apt install imagemagick')
47 |
48 | #SuperRes
49 | get_ipython().system('git clone https://github.com/CompVis/latent-diffusion.git')
50 | get_ipython().system('git clone https://github.com/CompVis/taming-transformers')
51 | get_ipython().system('pip install -e ./taming-transformers')
52 |
53 |
54 |
55 |
56 | # sys.path.append('./SLIP')
57 | sys.path.append('./ResizeRight')
58 | sys.path.append("latent-diffusion")
59 |
60 | from secondary_diffusion import SecondaryDiffusionImageNet, SecondaryDiffusionImageNet2
61 | from ddim_sampler import DDIMSampler
62 | from augs import MakeCutouts, MakeCutoutsDango
63 | from perlin import create_perlin_noise, regen_perlin
64 |
65 |
66 | def alpha_sigma_to_t(alpha, sigma):
67 | return torch.atan2(sigma, alpha) * 2 / math.pi
68 |
69 |
70 |
71 | root_path = argparse_args.root_path
72 | initDirPath = f'{root_path}/init_images'
73 | os.makedirs(initDirPath, exist_ok=1)
74 | outDirPath = f'{root_path}/images_out'
75 | os.makedirs(outDirPath, exist_ok=1)
76 | model_path = f'{root_path}/models'
77 | os.makedirs(model_path, exist_ok=1)
78 |
79 |
80 |
81 | #@title ### 2.1 Install and import dependencies
82 | model_256_downloaded = False
83 | model_512_downloaded = False
84 | model_secondary_downloaded = False
85 |
86 | from functools import partial
87 | import cv2
88 | import pandas as pd
89 | import gc
90 | import io
91 | import math
92 | import timm
93 | from IPython import display
94 | import lpips
95 | from PIL import Image, ImageOps
96 | import requests
97 | from glob import glob
98 | import json
99 | from types import SimpleNamespace
100 | import torch
101 | from torch import nn
102 | from torch.nn import functional as F
103 | import torchvision.transforms as T
104 | import torchvision.transforms.functional as TF
105 | from tqdm.notebook import tqdm
106 | #sys.path.append('./CLIP')
107 | sys.path.append('./guided-diffusion')
108 | import clip
109 | from resize_right import resize
110 | # from models import SLIP_VITB16, SLIP, SLIP_VITL16
111 | from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
112 | from datetime import datetime
113 | import numpy as np
114 | import matplotlib.pyplot as plt
115 | import random
116 | from ipywidgets import Output
117 | import hashlib
118 |
119 |
120 |
121 |
122 | #SuperRes
123 | import ipywidgets as widgets
124 | import os
125 | sys.path.append(".")
126 | sys.path.append('./taming-transformers')
127 | from taming.models import vqgan # checking correct import from taming
128 | from torchvision.datasets.utils import download_url
129 |
130 | sys.path.append("./latent-diffusion")
131 | if argparse_args.setup:
132 | get_ipython().run_line_magic('cd', "latent-diffusion")
133 | from ldm.util import instantiate_from_config
134 | # from ldm.models.diffusion.ddim import DDIMSampler
135 | from ldm.util import ismap
136 | if argparse_args.setup:
137 | get_ipython().run_line_magic('cd', '..')
138 | #from google.colab import files
139 | from IPython.display import Image as ipyimg
140 | from numpy import asarray
141 | from einops import rearrange, repeat
142 | import torch, torchvision
143 | import time
144 | from omegaconf import OmegaConf
145 | import warnings
146 | warnings.filterwarnings("ignore", category=UserWarning)
147 |
148 |
149 | import torch
150 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
151 | print('Using device:', device)
152 |
153 | if torch.cuda.get_device_capability(device) == (8,0): ## A100 fix thanks to Emad
154 | print('Disabling CUDNN for A100 gpu', file=sys.stderr)
155 | torch.backends.cudnn.enabled = False
156 |
157 |
158 | #@title 2.2 Define necessary functions
159 |
160 | def fetch(url_or_path):
161 | if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
162 | r = requests.get(url_or_path)
163 | r.raise_for_status()
164 | fd = io.BytesIO()
165 | fd.write(r.content)
166 | fd.seek(0)
167 | return fd
168 | return open(url_or_path, 'rb')
169 |
170 |
171 | def parse_prompt(prompt):
172 | if prompt.startswith('http://') or prompt.startswith('https://'):
173 | vals = prompt.rsplit(':', 2)
174 | vals = [vals[0] + ':' + vals[1], *vals[2:]]
175 | else:
176 | vals = prompt.rsplit(':', 1)
177 | vals = vals + ['', '1'][len(vals):]
178 | return vals[0], float(vals[1])
179 |
180 |
181 | def spherical_dist_loss(x, y):
182 | x = F.normalize(x, dim=-1)
183 | y = F.normalize(y, dim=-1)
184 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
185 |
186 | def tv_loss(input):
187 | """L2 total variation loss, as in Mahendran et al."""
188 | input = F.pad(input, (0, 1, 0, 1), 'replicate')
189 | x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
190 | y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
191 | return (x_diff**2 + y_diff**2).mean([1, 2, 3])
192 |
193 |
194 | def range_loss(input):
195 | return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
196 |
197 | stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete
198 |
199 | def do_run():
200 | seed = args.seed
201 | print(range(args.start_frame, args.max_frames))
202 | for frame_num in range(args.start_frame, args.max_frames):
203 | if stop_on_next_loop:
204 | break
205 |
206 | display.clear_output(wait=True)
207 |
208 | # Print Frame progress if animation mode is on
209 | if args.animation_mode != "None":
210 | batchBar = tqdm(range(args.max_frames), desc ="Frames")
211 | batchBar.n = frame_num
212 | batchBar.refresh()
213 |
214 |
215 | # Inits if not video frames
216 | if args.animation_mode != "Video Input":
217 | if args.init_image == '':
218 | init_image = None
219 | else:
220 | init_image = args.init_image
221 | init_scale = args.init_scale
222 | skip_steps = args.skip_steps
223 |
224 | if args.animation_mode == "2D":
225 | if args.key_frames:
226 | angle = args.angle_series[frame_num]
227 | zoom = args.zoom_series[frame_num]
228 | translation_x = args.translation_x_series[frame_num]
229 | translation_y = args.translation_y_series[frame_num]
230 | print(
231 | f'angle: {angle}',
232 | f'zoom: {zoom}',
233 | f'translation_x: {translation_x}',
234 | f'translation_y: {translation_y}',
235 | )
236 |
237 | if frame_num > 0:
238 | seed = seed + 1
239 | if resume_run and frame_num == start_frame:
240 | img_0 = cv2.imread(batchFolder+f"/{batch_name}({batchNum})_{start_frame-1:04}.png")
241 | else:
242 | img_0 = cv2.imread('prevFrame.png')
243 | center = (1*img_0.shape[1]//2, 1*img_0.shape[0]//2)
244 | trans_mat = np.float32(
245 | [[1, 0, translation_x],
246 | [0, 1, translation_y]]
247 | )
248 | rot_mat = cv2.getRotationMatrix2D( center, angle, zoom )
249 | trans_mat = np.vstack([trans_mat, [0,0,1]])
250 | rot_mat = np.vstack([rot_mat, [0,0,1]])
251 | transformation_matrix = np.matmul(rot_mat, trans_mat)
252 | img_0 = cv2.warpPerspective(
253 | img_0,
254 | transformation_matrix,
255 | (img_0.shape[1], img_0.shape[0]),
256 | borderMode=cv2.BORDER_WRAP
257 | )
258 | cv2.imwrite('prevFrameScaled.png', img_0)
259 | init_image = 'prevFrameScaled.png'
260 | init_scale = args.frames_scale
261 | skip_steps = args.calc_frames_skip_steps
262 |
263 | if args.animation_mode == "Video Input":
264 | seed = seed + 1
265 | init_image = f'{videoFramesFolder}/{frame_num+1:04}.jpg'
266 | init_scale = args.frames_scale
267 | skip_steps = args.calc_frames_skip_steps
268 |
269 | loss_values = []
270 |
271 | if seed is not None:
272 | np.random.seed(seed)
273 | random.seed(seed)
274 | torch.manual_seed(seed)
275 | torch.cuda.manual_seed_all(seed)
276 | torch.backends.cudnn.deterministic = True
277 |
278 | target_embeds, weights = [], []
279 |
280 | if args.prompts_series is not None and frame_num >= len(args.prompts_series):
281 | frame_prompt = args.prompts_series[-1]
282 | elif args.prompts_series is not None:
283 | frame_prompt = args.prompts_series[frame_num]
284 | else:
285 | frame_prompt = []
286 |
287 | print(args.image_prompts_series)
288 | if args.image_prompts_series is not None and frame_num >= len(args.image_prompts_series):
289 | image_prompt = args.image_prompts_series[-1]
290 | elif args.image_prompts_series is not None:
291 | image_prompt = args.image_prompts_series[frame_num]
292 | else:
293 | image_prompt = []
294 |
295 | print(f'Frame Prompt: {frame_prompt}')
296 |
297 | model_stats = []
298 | for clip_model in clip_models:
299 | cutn = 16
300 | model_stat = {"clip_model":None,"target_embeds":[],"make_cutouts":None,"weights":[]}
301 | model_stat["clip_model"] = clip_model
302 |
303 |
304 | for prompt in frame_prompt:
305 | txt, weight = parse_prompt(prompt)
306 | txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()
307 |
308 | if args.fuzzy_prompt:
309 | for i in range(25):
310 | model_stat["target_embeds"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0,1))
311 | model_stat["weights"].append(weight)
312 | else:
313 | model_stat["target_embeds"].append(txt)
314 | model_stat["weights"].append(weight)
315 |
316 | if image_prompt:
317 | model_stat["make_cutouts"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs)
318 | for prompt in image_prompt:
319 | path, weight = parse_prompt(prompt)
320 | img = Image.open(fetch(path)).convert('RGB')
321 | img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)
322 | batch = model_stat["make_cutouts"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))
323 | embed = clip_model.encode_image(normalize(batch)).float()
324 | if fuzzy_prompt:
325 | for i in range(25):
326 | model_stat["target_embeds"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))
327 | weights.extend([weight / cutn] * cutn)
328 | else:
329 | model_stat["target_embeds"].append(embed)
330 | model_stat["weights"].extend([weight / cutn] * cutn)
331 |
332 | model_stat["target_embeds"] = torch.cat(model_stat["target_embeds"])
333 | model_stat["weights"] = torch.tensor(model_stat["weights"], device=device)
334 | if model_stat["weights"].sum().abs() < 1e-3:
335 | raise RuntimeError('The weights must not sum to 0.')
336 | model_stat["weights"] /= model_stat["weights"].sum().abs()
337 | model_stats.append(model_stat)
338 |
339 | init = None
340 | if init_image is not None:
341 | init = Image.open(fetch(init_image)).convert('RGB')
342 | init = init.resize((args.side_x, args.side_y), Image.LANCZOS)
343 | init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
344 |
345 | if args.perlin_init:
346 | if args.perlin_mode == 'color':
347 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
348 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)
349 | elif args.perlin_mode == 'gray':
350 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)
351 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
352 | else:
353 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
354 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
355 | # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device)
356 | init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)
357 | del init2
358 |
359 | cur_t = None
360 | if init is not None and args.init_scale:
361 | lpips_model = lpips.LPIPS(net='vgg').to(device)
362 |
363 | def cond_fn(x, t, y=None):
364 | with torch.enable_grad():
365 | x_is_NaN = False
366 | x = x.detach().requires_grad_()
367 | n = x.shape[0]
368 | if use_secondary_model is True:
369 | alpha = torch.tensor(diffusion.sqrt_alphas_cumprod[cur_t], device=device, dtype=torch.float32)
370 | sigma = torch.tensor(diffusion.sqrt_one_minus_alphas_cumprod[cur_t], device=device, dtype=torch.float32)
371 | cosine_t = alpha_sigma_to_t(alpha, sigma)
372 | out = secondary_model(x, cosine_t[None].repeat([n])).pred
373 | fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
374 | x_in = out * fac + x * (1 - fac)
375 | x_in_grad = torch.zeros_like(x_in)
376 | else:
377 | my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
378 | out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
379 | fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
380 | x_in = out['pred_xstart'] * fac + x * (1 - fac)
381 | x_in_grad = torch.zeros_like(x_in)
382 | for model_stat in model_stats:
383 | for i in range(args.cutn_batches):
384 | t_int = int(t.item())+1 #errors on last step without +1, need to find source
385 | #when using SLIP Base model the dimensions need to be hard coded to avoid AttributeError: 'VisionTransformer' object has no attribute 'input_resolution'
386 | try:
387 | input_resolution=model_stat["clip_model"].visual.input_resolution
388 | except:
389 | input_resolution=224
390 |
391 | cuts = MakeCutoutsDango(input_resolution,
392 | Overview= args.cut_overview[1000-t_int],
393 | InnerCrop = args.cut_innercut[1000-t_int],
394 | IC_Size_Pow=args.cut_ic_pow,
395 | IC_Grey_P = args.cut_icgray_p[1000-t_int],
396 | animation_mode=args.animation_mode,
397 | )
398 | clip_in = normalize(cuts(x_in.add(1).div(2)))
399 | image_embeds = model_stat["clip_model"].encode_image(clip_in).float()
400 | dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat["target_embeds"].unsqueeze(0))
401 | dists = dists.view([args.cut_overview[1000-t_int]+args.cut_innercut[1000-t_int], n, -1])
402 | losses = dists.mul(model_stat["weights"]).sum(2).mean(0)
403 | loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch
404 | x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches
405 | tv_losses = tv_loss(x_in)
406 | if use_secondary_model is True:
407 | range_losses = range_loss(out)
408 | else:
409 | range_losses = range_loss(out['pred_xstart'])
410 | sat_losses = torch.abs(x_in - x_in.clamp(min=-1,max=1)).mean()
411 | loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale
412 |
413 | if init is not None and args.init_scale:
414 | init_losses = lpips_model(x_in, init)
415 | loss = loss + init_losses.sum() * args.init_scale
416 | x_in_grad += torch.autograd.grad(loss, x_in)[0]
417 | if torch.isnan(x_in_grad).any() == False:
418 | grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]
419 | else:
420 | # print("NaN'd")
421 | x_is_NaN = True
422 | grad = torch.zeros_like(x)
423 | if args.clamp_grad and x_is_NaN == False:
424 | magnitude = grad.square().mean().sqrt()
425 | return grad * magnitude.clamp(max=args.clamp_max) / magnitude #min=-0.02, min=-clamp_max,
426 | return grad
427 |
428 | if model_config['timestep_respacing'].startswith('ddim'):
429 | sample_fn = diffusion.ddim_sample_loop_progressive
430 | else:
431 | sample_fn = diffusion.p_sample_loop_progressive
432 |
433 |
434 | image_display = Output()
435 | for i in range(args.n_batches):
436 | if args.animation_mode == 'None':
437 | display.clear_output(wait=True)
438 | batchBar = tqdm(range(args.n_batches), desc ="Batches")
439 | batchBar.n = i
440 | batchBar.refresh()
441 | print('')
442 | display.display(image_display)
443 | gc.collect()
444 | torch.cuda.empty_cache()
445 | cur_t = diffusion.num_timesteps - skip_steps - 1
446 | total_steps = cur_t
447 |
448 | if perlin_init:
449 | init = regen_perlin()
450 |
451 | if model_config['timestep_respacing'].startswith('ddim'):
452 | samples = sample_fn(
453 | model,
454 | (batch_size, 3, args.side_y, args.side_x),
455 | clip_denoised=clip_denoised,
456 | model_kwargs={},
457 | cond_fn=cond_fn,
458 | progress=True,
459 | skip_timesteps=skip_steps,
460 | init_image=init,
461 | randomize_class=randomize_class,
462 | eta=eta,
463 | )
464 | else:
465 | samples = sample_fn(
466 | model,
467 | (batch_size, 3, args.side_y, args.side_x),
468 | clip_denoised=clip_denoised,
469 | model_kwargs={},
470 | cond_fn=cond_fn,
471 | progress=True,
472 | skip_timesteps=skip_steps,
473 | init_image=init,
474 | randomize_class=randomize_class,
475 | )
476 |
477 |
478 | # with run_display:
479 | # display.clear_output(wait=True)
480 | imgToSharpen = None
481 | for j, sample in enumerate(samples):
482 | cur_t -= 1
483 | intermediateStep = False
484 | if args.steps_per_checkpoint is not None:
485 | if j % steps_per_checkpoint == 0 and j > 0:
486 | intermediateStep = True
487 | elif j in args.intermediate_saves:
488 | intermediateStep = True
489 | with image_display:
490 | if j % args.display_rate == 0 or cur_t == -1 or intermediateStep == True:
491 | for k, image in enumerate(sample['pred_xstart']):
492 | # tqdm.write(f'Batch {i}, step {j}, output {k}:')
493 | current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')
494 | percent = math.ceil(j/total_steps*100)
495 | if args.n_batches > 0:
496 | #if intermediates are saved to the subfolder, don't append a step or percentage to the name
497 | if cur_t == -1 and args.intermediates_in_subfolder is True:
498 | save_num = f'{frame_num:04}' if animation_mode != "None" else i
499 | filename = f'{args.batch_name}({args.batchNum})_{save_num}.png'
500 | else:
501 | #If we're working with percentages, append it
502 | if args.steps_per_checkpoint is not None:
503 | filename = f'{args.batch_name}({args.batchNum})_{i:04}-{percent:02}%.png'
504 | # Or else, iIf we're working with specific steps, append those
505 | else:
506 | filename = f'{args.batch_name}({args.batchNum})_{i:04}-{j:03}.png'
507 | image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
508 | if j % args.display_rate == 0 or cur_t == -1:
509 | image.save('progress.jpg', subsampling=0, quality=95)
510 | #display.clear_output(wait=True)
511 | #display.display(display.Image('progress.png'))
512 | if args.steps_per_checkpoint is not None:
513 | if j % args.steps_per_checkpoint == 0 and j > 0:
514 | if args.intermediates_in_subfolder is True:
515 | image.save(f'{partialFolder}/{filename}')
516 | else:
517 | image.save(f'{batchFolder}/{filename}')
518 | else:
519 | if j in args.intermediate_saves:
520 | if args.intermediates_in_subfolder is True:
521 | image.save(f'{partialFolder}/{filename}')
522 | else:
523 | image.save(f'{batchFolder}/{filename}')
524 | if cur_t == -1:
525 | if frame_num == 0:
526 | save_settings()
527 | if args.animation_mode != "None":
528 | image.save('prevFrame.jpg', subsampling=0, quality=95)
529 | if args.sharpen_preset != "Off" and animation_mode == "None":
530 | imgToSharpen = image
531 | if args.keep_unsharp is True:
532 | image.save(f'{unsharpenFolder}/{filename}')
533 | else:
534 | image.save(f'{batchFolder}/{filename}')
535 | # if frame_num != args.max_frames-1:
536 | # display.clear_output()
537 |
538 | with image_display:
539 | if args.sharpen_preset != "Off" and animation_mode == "None":
540 | print('Starting Diffusion Sharpening...')
541 | do_superres(imgToSharpen, f'{batchFolder}/{filename}')
542 | display.clear_output()
543 |
544 | plt.plot(np.array(loss_values), 'r')
545 |
546 | def save_settings():
547 | setting_list = {
548 | 'text_prompts': text_prompts,
549 | 'image_prompts': image_prompts,
550 | 'clip_guidance_scale': clip_guidance_scale,
551 | 'tv_scale': tv_scale,
552 | 'range_scale': range_scale,
553 | 'sat_scale': sat_scale,
554 | # 'cutn': cutn,
555 | 'cutn_batches': cutn_batches,
556 | 'max_frames': max_frames,
557 | 'interp_spline': interp_spline,
558 | # 'rotation_per_frame': rotation_per_frame,
559 | 'init_image': init_image,
560 | 'init_scale': init_scale,
561 | 'skip_steps': skip_steps,
562 | # 'zoom_per_frame': zoom_per_frame,
563 | 'frames_scale': frames_scale,
564 | 'frames_skip_steps': frames_skip_steps,
565 | 'perlin_init': perlin_init,
566 | 'perlin_mode': perlin_mode,
567 | 'skip_augs': skip_augs,
568 | 'randomize_class': randomize_class,
569 | 'clip_denoised': clip_denoised,
570 | 'clamp_grad': clamp_grad,
571 | 'clamp_max': clamp_max,
572 | 'seed': seed,
573 | 'fuzzy_prompt': fuzzy_prompt,
574 | 'rand_mag': rand_mag,
575 | 'eta': eta,
576 | 'width': width_height[0],
577 | 'height': width_height[1],
578 | 'diffusion_model': diffusion_model,
579 | 'use_secondary_model': use_secondary_model,
580 | 'steps': steps,
581 | 'diffusion_steps': diffusion_steps,
582 | 'ViTB32': ViTB32,
583 | 'ViTB16': ViTB16,
584 | 'RN101': RN101,
585 | 'RN50': RN50,
586 | 'RN50x4': RN50x4,
587 | 'RN50x16': RN50x16,
588 | 'cut_overview': str(cut_overview),
589 | 'cut_innercut': str(cut_innercut),
590 | 'cut_ic_pow': cut_ic_pow,
591 | 'cut_icgray_p': str(cut_icgray_p),
592 | 'key_frames': key_frames,
593 | 'max_frames': max_frames,
594 | 'angle': angle,
595 | 'zoom': zoom,
596 | 'translation_x': translation_x,
597 | 'translation_y': translation_y,
598 | 'video_init_path':video_init_path,
599 | 'extract_nth_frame':extract_nth_frame,
600 | }
601 | # print('Settings:', setting_list)
602 | with open(f"{batchFolder}/{batch_name}({batchNum})_settings.txt", "w+") as f: #save settings
603 | json.dump(setting_list, f, ensure_ascii=False, indent=4)
604 |
605 | #@title 2.4 SuperRes Define
606 |
607 | def download_models():
608 | # this is the small bsr light model
609 | url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'
610 | url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'
611 |
612 | path_conf = f'{model_path}/superres/'
613 | path_ckpt = f'{model_path}/superres/'
614 |
615 | download_url(url_conf, path_conf, 'project.yaml')
616 | download_url(url_ckpt, path_ckpt, 'last.ckpt')
617 |
618 | path_conf = path_conf + 'project.yaml' # fix it
619 | path_ckpt = path_ckpt + 'last.ckpt' # fix it
620 | return path_conf, path_ckpt
621 |
622 |
623 | def load_model_from_config(config, ckpt):
624 | print(f"Loading model from {ckpt}")
625 | pl_sd = torch.load(ckpt, map_location="cpu")
626 | global_step = pl_sd["global_step"]
627 | sd = pl_sd["state_dict"]
628 | model = instantiate_from_config(config.model)
629 | m, u = model.load_state_dict(sd, strict=False)
630 | model.cuda()
631 | model.eval()
632 | return {"model": model}, global_step
633 |
634 |
635 | def get_model(mode):
636 | path_conf, path_ckpt = download_models()
637 | config = OmegaConf.load(path_conf)
638 | model, step = load_model_from_config(config, path_ckpt)
639 | return model
640 |
641 |
642 | def get_custom_cond(mode):
643 | dest = "data/example_conditioning"
644 |
645 | if mode == "superresolution":
646 | uploaded_img = files.upload()
647 | filename = next(iter(uploaded_img))
648 | name, filetype = filename.split(".") # todo assumes just one dot in name !
649 | os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}")
650 |
651 | elif mode == "text_conditional":
652 | w = widgets.Text(value='A cake with cream!', disabled=True)
653 | display.display(w)
654 |
655 | with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f:
656 | f.write(w.value)
657 |
658 | elif mode == "class_conditional":
659 | w = widgets.IntSlider(min=0, max=1000)
660 | display.display(w)
661 | with open(f"{dest}/{mode}/custom.txt", 'w') as f:
662 | f.write(w.value)
663 |
664 | else:
665 | raise NotImplementedError(f"cond not implemented for mode{mode}")
666 |
667 |
668 | def get_cond_options(mode):
669 | path = "data/example_conditioning"
670 | path = os.path.join(path, mode)
671 | onlyfiles = [f for f in sorted(os.listdir(path))]
672 | return path, onlyfiles
673 |
674 |
675 | def select_cond_path(mode):
676 | path = "data/example_conditioning" # todo
677 | path = os.path.join(path, mode)
678 | onlyfiles = [f for f in sorted(os.listdir(path))]
679 |
680 | selected = widgets.RadioButtons(
681 | options=onlyfiles,
682 | description='Select conditioning:',
683 | disabled=False
684 | )
685 | display.display(selected)
686 | selected_path = os.path.join(path, selected.value)
687 | return selected_path
688 |
689 |
690 | def get_cond(mode, img):
691 | example = dict()
692 | if mode == "superresolution":
693 | up_f = 4
694 | # visualize_cond_img(selected_path)
695 |
696 | c = img
697 | c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
698 | c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
699 | c_up = rearrange(c_up, '1 c h w -> 1 h w c')
700 | c = rearrange(c, '1 c h w -> 1 h w c')
701 | c = 2. * c - 1.
702 |
703 | c = c.to(torch.device("cuda"))
704 | example["LR_image"] = c
705 | example["image"] = c_up
706 |
707 | return example
708 |
709 |
710 | def visualize_cond_img(path):
711 | display.display(ipyimg(filename=path))
712 |
713 |
714 | def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None):
715 | # global stride
716 |
717 | example = get_cond(task, img)
718 |
719 | save_intermediate_vid = False
720 | n_runs = 1
721 | masked = False
722 | guider = None
723 | ckwargs = None
724 | mode = 'ddim'
725 | ddim_use_x0_pred = False
726 | temperature = 1.
727 | eta = eta
728 | make_progrow = True
729 | custom_shape = None
730 |
731 | height, width = example["image"].shape[1:3]
732 | split_input = height >= 128 and width >= 128
733 |
734 | if split_input:
735 | ks = 128
736 | stride = 64
737 | vqf = 4 #
738 | model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
739 | "vqf": vqf,
740 | "patch_distributed_vq": True,
741 | "tie_braker": False,
742 | "clip_max_weight": 0.5,
743 | "clip_min_weight": 0.01,
744 | "clip_max_tie_weight": 0.5,
745 | "clip_min_tie_weight": 0.01}
746 | else:
747 | if hasattr(model, "split_input_params"):
748 | delattr(model, "split_input_params")
749 |
750 | invert_mask = False
751 |
752 | x_T = None
753 | for n in range(n_runs):
754 | if custom_shape is not None:
755 | x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
756 | x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])
757 |
758 | logs = make_convolutional_sample(example, model,
759 | mode=mode, custom_steps=custom_steps,
760 | eta=eta, swap_mode=False , masked=masked,
761 | invert_mask=invert_mask, quantize_x0=False,
762 | custom_schedule=None, decode_interval=10,
763 | resize_enabled=resize_enabled, custom_shape=custom_shape,
764 | temperature=temperature, noise_dropout=0.,
765 | corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,
766 | make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred
767 | )
768 | return logs
769 |
770 |
771 | @torch.no_grad()
772 | def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
773 | mask=None, x0=None, quantize_x0=False, img_callback=None,
774 | temperature=1., noise_dropout=0., score_corrector=None,
775 | corrector_kwargs=None, x_T=None, log_every_t=None
776 | ):
777 |
778 | ddim = DDIMSampler(model)
779 | bs = shape[0] # dont know where this comes from but wayne
780 | shape = shape[1:] # cut batch dim
781 | # print(f"Sampling with eta = {eta}; steps: {steps}")
782 | samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
783 | normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
784 | mask=mask, x0=x0, temperature=temperature, verbose=False,
785 | score_corrector=score_corrector,
786 | corrector_kwargs=corrector_kwargs, x_T=x_T)
787 |
788 | return samples, intermediates
789 |
790 |
791 | @torch.no_grad()
792 | def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False,
793 | invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,
794 | resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
795 | corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):
796 | log = dict()
797 |
798 | z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
799 | return_first_stage_outputs=True,
800 | force_c_encode=not (hasattr(model, 'split_input_params')
801 | and model.cond_stage_key == 'coordinates_bbox'),
802 | return_original_cond=True)
803 |
804 | log_every_t = 1 if save_intermediate_vid else None
805 |
806 | if custom_shape is not None:
807 | z = torch.randn(custom_shape)
808 | # print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
809 |
810 | z0 = None
811 |
812 | log["input"] = x
813 | log["reconstruction"] = xrec
814 |
815 | if ismap(xc):
816 | log["original_conditioning"] = model.to_rgb(xc)
817 | if hasattr(model, 'cond_stage_key'):
818 | log[model.cond_stage_key] = model.to_rgb(xc)
819 |
820 | else:
821 | log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
822 | if model.cond_stage_model:
823 | log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
824 | if model.cond_stage_key =='class_label':
825 | log[model.cond_stage_key] = xc[model.cond_stage_key]
826 |
827 | with model.ema_scope("Plotting"):
828 | t0 = time.time()
829 | img_cb = None
830 |
831 | sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
832 | eta=eta,
833 | quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,
834 | temperature=temperature, noise_dropout=noise_dropout,
835 | score_corrector=corrector, corrector_kwargs=corrector_kwargs,
836 | x_T=x_T, log_every_t=log_every_t)
837 | t1 = time.time()
838 |
839 | if ddim_use_x0_pred:
840 | sample = intermediates['pred_x0'][-1]
841 |
842 | x_sample = model.decode_first_stage(sample)
843 |
844 | try:
845 | x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
846 | log["sample_noquant"] = x_sample_noquant
847 | log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
848 | except:
849 | pass
850 |
851 | log["sample"] = x_sample
852 | log["time"] = t1 - t0
853 |
854 | return log
855 |
856 | sr_diffMode = 'superresolution'
857 | sr_model = get_model('superresolution') if argparse_args.sharpen_preset != "Off" else None
858 |
859 |
860 | def do_superres(img, filepath):
861 |
862 | if args.sharpen_preset == 'Faster':
863 | sr_diffusion_steps = "25"
864 | sr_pre_downsample = '1/2'
865 | if args.sharpen_preset == 'Fast':
866 | sr_diffusion_steps = "100"
867 | sr_pre_downsample = '1/2'
868 | if args.sharpen_preset == 'Slow':
869 | sr_diffusion_steps = "25"
870 | sr_pre_downsample = 'None'
871 | if args.sharpen_preset == 'Very Slow':
872 | sr_diffusion_steps = "100"
873 | sr_pre_downsample = 'None'
874 |
875 |
876 | sr_post_downsample = 'Original Size'
877 | sr_diffusion_steps = int(sr_diffusion_steps)
878 | sr_eta = 1.0
879 | sr_downsample_method = 'Lanczos'
880 |
881 | gc.collect()
882 | torch.cuda.empty_cache()
883 |
884 | im_og = img
885 | width_og, height_og = im_og.size
886 |
887 | #Downsample Pre
888 | if sr_pre_downsample == '1/2':
889 | downsample_rate = 2
890 | elif sr_pre_downsample == '1/4':
891 | downsample_rate = 4
892 | else:
893 | downsample_rate = 1
894 |
895 | width_downsampled_pre = width_og//downsample_rate
896 | height_downsampled_pre = height_og//downsample_rate
897 |
898 | if downsample_rate != 1:
899 | # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
900 | im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
901 | # im_og.save('/content/temp.png')
902 | # filepath = '/content/temp.png'
903 |
904 | logs = sr_run(sr_model["model"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta)
905 |
906 | sample = logs["sample"]
907 | sample = sample.detach().cpu()
908 | sample = torch.clamp(sample, -1., 1.)
909 | sample = (sample + 1.) / 2. * 255
910 | sample = sample.numpy().astype(np.uint8)
911 | sample = np.transpose(sample, (0, 2, 3, 1))
912 | a = Image.fromarray(sample[0])
913 |
914 | #Downsample Post
915 | if sr_post_downsample == '1/2':
916 | downsample_rate = 2
917 | elif sr_post_downsample == '1/4':
918 | downsample_rate = 4
919 | else:
920 | downsample_rate = 1
921 |
922 | width, height = a.size
923 | width_downsampled_post = width//downsample_rate
924 | height_downsampled_post = height//downsample_rate
925 |
926 | if sr_downsample_method == 'Lanczos':
927 | aliasing = Image.LANCZOS
928 | else:
929 | aliasing = Image.NEAREST
930 |
931 | if downsample_rate != 1:
932 | # print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]')
933 | a = a.resize((width_downsampled_post, height_downsampled_post), aliasing)
934 | elif sr_post_downsample == 'Original Size':
935 | # print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]')
936 | a = a.resize((width_og, height_og), aliasing)
937 |
938 | display.display(a)
939 | a.save(filepath)
940 | return
941 | print(f'Processing finished!')
942 |
943 |
944 | # # 3. Diffusion and CLIP model settings
945 |
946 | #@markdown ####**Models Settings:**
947 | diffusion_model = "512x512_diffusion_uncond_finetune_008100" #@param ["256x256_diffusion_uncond", "512x512_diffusion_uncond_finetune_008100"]
948 | use_secondary_model = True #@param {type: 'boolean'}
949 |
950 | timestep_respacing = '50' # param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000']
951 | diffusion_steps = 1000 # param {type: 'number'}
952 | use_checkpoint = True #@param {type: 'boolean'}
953 | ViTB32 = True #@param{type:"boolean"}
954 | ViTB16 = True #@param{type:"boolean"}
955 | RN101 = False #@param{type:"boolean"}
956 | RN50 = True #@param{type:"boolean"}
957 | RN50x4 = False #@param{type:"boolean"}
958 | RN50x16 = False #@param{type:"boolean"}
959 | SLIPB16 = False # param{type:"boolean"}
960 | SLIPL16 = False # param{type:"boolean"}
961 |
962 | #@markdown If you're having issues with model downloads, check this to compare SHA's:
963 | check_model_SHA = False #@param{type:"boolean"}
964 |
965 | model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'
966 | model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'
967 | model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'
968 |
969 | model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'
970 | model_512_link = 'http://batbot.tv/ai/models/guided-diffusion/512x512_diffusion_uncond_finetune_008100.pt'
971 | model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'
972 |
973 | model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'
974 | model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'
975 | model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'
976 |
977 | # Download the diffusion model
978 | if diffusion_model == '256x256_diffusion_uncond':
979 | if os.path.exists(model_256_path) and check_model_SHA:
980 | print('Checking 256 Diffusion File')
981 | with open(model_256_path,"rb") as f:
982 | bytes = f.read()
983 | hash = hashlib.sha256(bytes).hexdigest()
984 | if hash == model_256_SHA:
985 | print('256 Model SHA matches')
986 | model_256_downloaded = True
987 | else:
988 | print("256 Model SHA doesn't match, redownloading...")
989 | if argparse_args.setup:
990 | download_url(model_256_link, model_path, '256x256_diffusion_uncond.pt')
991 | model_256_downloaded = True
992 | elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:
993 | print('256 Model already downloaded, check check_model_SHA if the file is corrupt')
994 | else:
995 | if argparse_args.setup:
996 | download_url(model_256_link, model_path, '256x256_diffusion_uncond.pt')
997 | model_256_downloaded = True
998 | elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':
999 | if os.path.exists(model_512_path) and check_model_SHA:
1000 | print('Checking 512 Diffusion File')
1001 | with open(model_512_path,"rb") as f:
1002 | bytes = f.read()
1003 | hash = hashlib.sha256(bytes).hexdigest()
1004 | if hash == model_512_SHA:
1005 | print('512 Model SHA matches')
1006 | model_512_downloaded = True
1007 | else:
1008 | print("512 Model SHA doesn't match, redownloading...")
1009 | if argparse_args.setup:
1010 | download_url(model_512_link, model_path, '512x512_diffusion_uncond_finetune_008100.pt')
1011 | model_512_downloaded = True
1012 | elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:
1013 | print('512 Model already downloaded, check check_model_SHA if the file is corrupt')
1014 | else:
1015 | if argparse_args.setup:
1016 | download_url(model_512_link, model_path, '512x512_diffusion_uncond_finetune_008100.pt')
1017 | model_512_downloaded = True
1018 |
1019 |
1020 | # Download the secondary diffusion model v2
1021 | if use_secondary_model == True:
1022 | if os.path.exists(model_secondary_path) and check_model_SHA:
1023 | print('Checking Secondary Diffusion File')
1024 | with open(model_secondary_path,"rb") as f:
1025 | bytes = f.read()
1026 | hash = hashlib.sha256(bytes).hexdigest()
1027 | if hash == model_secondary_SHA:
1028 | print('Secondary Model SHA matches')
1029 | model_secondary_downloaded = True
1030 | else:
1031 | print("Secondary Model SHA doesn't match, redownloading...")
1032 | if argparse_args.setup:
1033 | download_url(model_secondary_link, model_path, 'secondary_model_imagenet_2.pth')
1034 | model_secondary_downloaded = True
1035 | elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:
1036 | print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')
1037 | else:
1038 | if argparse_args.setup:
1039 | download_url(model_secondary_link, model_path, 'secondary_model_imagenet_2.pth')
1040 | model_secondary_downloaded = True
1041 |
1042 | model_config = model_and_diffusion_defaults()
1043 | if diffusion_model == '512x512_diffusion_uncond_finetune_008100':
1044 | model_config.update({
1045 | 'attention_resolutions': '32, 16, 8',
1046 | 'class_cond': False,
1047 | 'diffusion_steps': diffusion_steps,
1048 | 'rescale_timesteps': True,
1049 | 'timestep_respacing': timestep_respacing,
1050 | 'image_size': 512,
1051 | 'learn_sigma': True,
1052 | 'noise_schedule': 'linear',
1053 | 'num_channels': 256,
1054 | 'num_head_channels': 64,
1055 | 'num_res_blocks': 2,
1056 | 'resblock_updown': True,
1057 | 'use_checkpoint': use_checkpoint,
1058 | 'use_fp16': True,
1059 | 'use_scale_shift_norm': True,
1060 | })
1061 | elif diffusion_model == '256x256_diffusion_uncond':
1062 | model_config.update({
1063 | 'attention_resolutions': '32, 16, 8',
1064 | 'class_cond': False,
1065 | 'diffusion_steps': diffusion_steps,
1066 | 'rescale_timesteps': True,
1067 | 'timestep_respacing': timestep_respacing,
1068 | 'image_size': 256,
1069 | 'learn_sigma': True,
1070 | 'noise_schedule': 'linear',
1071 | 'num_channels': 256,
1072 | 'num_head_channels': 64,
1073 | 'num_res_blocks': 2,
1074 | 'resblock_updown': True,
1075 | 'use_checkpoint': use_checkpoint,
1076 | 'use_fp16': True,
1077 | 'use_scale_shift_norm': True,
1078 | })
1079 |
1080 | model_default = model_config['image_size']
1081 |
1082 | secondary_model = SecondaryDiffusionImageNet2()
1083 | secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu'))
1084 | secondary_model.eval().requires_grad_(False).to(device)
1085 |
1086 |
1087 | clip_dict = {'ViT-B/32': ViTB32, 'ViT-B/16': ViTB16, 'RN50': RN50, 'RN50x4': RN50x4, 'RN50x16': RN50x16, 'RN101': RN101}
1088 | clip_models = [clip.load(clip_name, jit=False)[0].eval().requires_grad_(False).to(device) for clip_name in clip_dict if clip_dict[clip_name]]
1089 |
1090 | if SLIPB16:
1091 | SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)
1092 | if argparse_args.setup:
1093 | if not os.path.exists(f'{model_path}/slip_base_100ep.pt'):
1094 | get_ipython().system('wget https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt -P {model_path}')
1095 | sd = torch.load(f'{model_path}/slip_base_100ep.pt')
1096 | real_sd = {}
1097 | for k, v in sd['state_dict'].items():
1098 | real_sd['.'.join(k.split('.')[1:])] = v
1099 | del sd
1100 | SLIPB16model.load_state_dict(real_sd)
1101 | SLIPB16model.requires_grad_(False).eval().to(device)
1102 |
1103 | clip_models.append(SLIPB16model)
1104 |
1105 | if SLIPL16:
1106 | SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)
1107 | if argparse_args.setup:
1108 | if not os.path.exists(f'{model_path}/slip_large_100ep.pt'):
1109 | get_ipython().system('wget https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt -P {model_path}')
1110 | sd = torch.load(f'{model_path}/slip_large_100ep.pt')
1111 | real_sd = {}
1112 | for k, v in sd['state_dict'].items():
1113 | real_sd['.'.join(k.split('.')[1:])] = v
1114 | del sd
1115 | SLIPL16model.load_state_dict(real_sd)
1116 | SLIPL16model.requires_grad_(False).eval().to(device)
1117 |
1118 | clip_models.append(SLIPL16model)
1119 |
1120 | normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
1121 |
1122 | # # 4. Settings
1123 |
1124 |
1125 | #@markdown ####**Basic Settings:**
1126 | batch_name = argparse_args.out_name #@param{type: 'string'}
1127 | steps = argparse_args.steps #@param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true}
1128 | width_height = [argparse_args.width, argparse_args.height]#@param{type: 'raw'}
1129 | clip_guidance_scale = 5000 #@param{type: 'number'}
1130 | tv_scale = 0#@param{type: 'number'}
1131 | range_scale = 150#@param{type: 'number'}
1132 | sat_scale = 0#@param{type: 'number'}
1133 | cutn_batches = 4 #@param{type: 'number'}
1134 | skip_augs = False#@param{type: 'boolean'}
1135 |
1136 | #@markdown ---
1137 |
1138 | #@markdown ####**Init Settings:**
1139 | init_image = argparse_args.init_image #@param{type: 'string'}
1140 | init_scale = 1000 #@param{type: 'integer'}
1141 | skip_steps = argparse_args.skip_steps if argparse_args.skip_steps is not None else (steps // 2 if init_image is not None else 0)
1142 |
1143 | #Get corrected sizes
1144 | side_x = (width_height[0]//64)*64;
1145 | side_y = (width_height[1]//64)*64;
1146 | if side_x != width_height[0] or side_y != width_height[1]:
1147 | print(f'Changing output size to {side_x}x{side_y}. Dimensions must by multiples of 64.')
1148 |
1149 | #Update Model Settings
1150 | timestep_respacing = f'ddim{steps}'
1151 | diffusion_steps = (1000//steps)*steps if steps < 1000 else steps
1152 | model_config.update({
1153 | 'timestep_respacing': timestep_respacing,
1154 | 'diffusion_steps': diffusion_steps,
1155 | })
1156 |
1157 | #Make folder for batch
1158 | batchFolder = f'{outDirPath}/{batch_name}'
1159 | os.makedirs(batchFolder, exist_ok=1)
1160 |
1161 |
1162 | # ###Animation Settings
1163 |
1164 | # In[13]:
1165 |
1166 |
1167 | #@markdown ####**Animation Mode:**
1168 | animation_mode = "None" #@param['None', '2D', 'Video Input']
1169 | #@markdown *For animation, you probably want to turn `cutn_batches` to 1 to make it quicker.*
1170 |
1171 |
1172 | #@markdown ---
1173 |
1174 | #@markdown ####**Video Input Settings:**
1175 | video_init_path = "/content/training.mp4" #@param {type: 'string'}
1176 | extract_nth_frame = 2 #@param {type:"number"}
1177 |
1178 | if animation_mode == "Video Input":
1179 | videoFramesFolder = f'/content/videoFrames'
1180 | os.makedirs(videoFramesFolder, exist_ok=True)
1181 | print(f"Exporting Video Frames (1 every {extract_nth_frame})...")
1182 | if argparse_args.setup:
1183 | try:
1184 | get_ipython().system('rm {videoFramesFolder}/*.jpg')
1185 | except:
1186 | print('')
1187 | vf = f'"select=not(mod(n\,{extract_nth_frame}))"'
1188 | if argparse_args.setup:
1189 | get_ipython().system('ffmpeg -i {video_init_path} -vf {vf} -vsync vfr -q:v 2 -loglevel error -stats {videoFramesFolder}/%04d.jpg')
1190 |
1191 |
1192 | #@markdown ---
1193 |
1194 | #@markdown ####**2D Animation Settings:**
1195 | #@markdown `zoom` is a multiplier of dimensions, 1 is no zoom.
1196 |
1197 | key_frames = True #@param {type:"boolean"}
1198 | max_frames = 10000#@param {type:"number"}
1199 |
1200 | if animation_mode == "Video Input":
1201 | max_frames = len(glob(f'{videoFramesFolder}/*.jpg'))
1202 |
1203 | interp_spline = 'Linear' #Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:"string"}
1204 | angle = "0:(0)"#@param {type:"string"}
1205 | zoom = "0: (1), 10: (1.05)"#@param {type:"string"}
1206 | translation_x = "0: (0)"#@param {type:"string"}
1207 | translation_y = "0: (0)"#@param {type:"string"}
1208 |
1209 | #@markdown ---
1210 |
1211 | #@markdown ####**Coherency Settings:**
1212 | #@markdown `frame_scale` tries to guide the new frame to looking like the old one. A good default is 1500.
1213 | frames_scale = 1500 #@param{type: 'integer'}
1214 | #@markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into.
1215 | frames_skip_steps = '60%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}
1216 |
1217 |
1218 | def parse_key_frames(string, prompt_parser=None):
1219 | """Given a string representing frame numbers paired with parameter values at that frame,
1220 | return a dictionary with the frame numbers as keys and the parameter values as the values.
1221 |
1222 | Parameters
1223 | ----------
1224 | string: string
1225 | Frame numbers paired with parameter values at that frame number, in the format
1226 | 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'
1227 | prompt_parser: function or None, optional
1228 | If provided, prompt_parser will be applied to each string of parameter values.
1229 |
1230 | Returns
1231 | -------
1232 | dict
1233 | Frame numbers as keys, parameter values at that frame number as values
1234 |
1235 | Raises
1236 | ------
1237 | RuntimeError
1238 | If the input string does not match the expected format.
1239 |
1240 | Examples
1241 | --------
1242 | >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)")
1243 | {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}
1244 |
1245 | >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)", prompt_parser=lambda x: x.lower()))
1246 | {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}
1247 | """
1248 | import re
1249 | pattern = r'((?P[0-9]+):[\s]*[\(](?P[\S\s]*?)[\)])'
1250 | frames = dict()
1251 | for match_object in re.finditer(pattern, string):
1252 | frame = int(match_object.groupdict()['frame'])
1253 | param = match_object.groupdict()['param']
1254 | if prompt_parser:
1255 | frames[frame] = prompt_parser(param)
1256 | else:
1257 | frames[frame] = param
1258 |
1259 | if frames == {} and len(string) != 0:
1260 | raise RuntimeError('Key Frame string not correctly formatted')
1261 | return frames
1262 |
1263 | def get_inbetweens(key_frames, integer=False):
1264 | """Given a dict with frame numbers as keys and a parameter value as values,
1265 | return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.
1266 | Any values not provided in the input dict are calculated by linear interpolation between
1267 | the values of the previous and next provided frames. If there is no previous provided frame, then
1268 | the value is equal to the value of the next provided frame, or if there is no next provided frame,
1269 | then the value is equal to the value of the previous provided frame. If no frames are provided,
1270 | all frame values are NaN.
1271 |
1272 | Parameters
1273 | ----------
1274 | key_frames: dict
1275 | A dict with integer frame numbers as keys and numerical values of a particular parameter as values.
1276 | integer: Bool, optional
1277 | If True, the values of the output series are converted to integers.
1278 | Otherwise, the values are floats.
1279 |
1280 | Returns
1281 | -------
1282 | pd.Series
1283 | A Series with length max_frames representing the parameter values for each frame.
1284 |
1285 | Examples
1286 | --------
1287 | >>> max_frames = 5
1288 | >>> get_inbetweens({1: 5, 3: 6})
1289 | 0 5.0
1290 | 1 5.0
1291 | 2 5.5
1292 | 3 6.0
1293 | 4 6.0
1294 | dtype: float64
1295 |
1296 | >>> get_inbetweens({1: 5, 3: 6}, integer=True)
1297 | 0 5
1298 | 1 5
1299 | 2 5
1300 | 3 6
1301 | 4 6
1302 | dtype: int64
1303 | """
1304 | key_frame_series = pd.Series([np.nan for a in range(max_frames)])
1305 |
1306 | for i, value in key_frames.items():
1307 | key_frame_series[i] = value
1308 | key_frame_series = key_frame_series.astype(float)
1309 |
1310 | interp_method = interp_spline
1311 |
1312 | if interp_method == 'Cubic' and len(key_frames.items()) <=3:
1313 | interp_method = 'Quadratic'
1314 |
1315 | if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:
1316 | interp_method = 'Linear'
1317 |
1318 |
1319 | key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]
1320 | key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]
1321 | # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both')
1322 | key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')
1323 | if integer:
1324 | return key_frame_series.astype(int)
1325 | return key_frame_series
1326 |
1327 |
1328 | def split_prompts(prompts):
1329 | prompt_series = pd.Series([np.nan for a in range(max_frames)])
1330 | for i, prompt in prompts.items():
1331 | prompt_series[i] = prompt
1332 | # prompt_series = prompt_series.astype(str)
1333 | prompt_series = prompt_series.ffill().bfill()
1334 | return prompt_series
1335 |
1336 |
1337 | if key_frames:
1338 | try:
1339 | angle_series = get_inbetweens(parse_key_frames(angle))
1340 | except RuntimeError as e:
1341 | print(
1342 | "WARNING: You have selected to use key frames, but you have not "
1343 | "formatted `angle` correctly for key frames.\n"
1344 | "Attempting to interpret `angle` as "
1345 | f'"0: ({angle})"\n'
1346 | "Please read the instructions to find out how to use key frames "
1347 | "correctly.\n"
1348 | )
1349 | angle = f"0: ({angle})"
1350 | angle_series = get_inbetweens(parse_key_frames(angle))
1351 |
1352 | try:
1353 | zoom_series = get_inbetweens(parse_key_frames(zoom))
1354 | except RuntimeError as e:
1355 | print(
1356 | "WARNING: You have selected to use key frames, but you have not "
1357 | "formatted `zoom` correctly for key frames.\n"
1358 | "Attempting to interpret `zoom` as "
1359 | f'"0: ({zoom})"\n'
1360 | "Please read the instructions to find out how to use key frames "
1361 | "correctly.\n"
1362 | )
1363 | zoom = f"0: ({zoom})"
1364 | zoom_series = get_inbetweens(parse_key_frames(zoom))
1365 |
1366 | try:
1367 | translation_x_series = get_inbetweens(parse_key_frames(translation_x))
1368 | except RuntimeError as e:
1369 | print(
1370 | "WARNING: You have selected to use key frames, but you have not "
1371 | "formatted `translation_x` correctly for key frames.\n"
1372 | "Attempting to interpret `translation_x` as "
1373 | f'"0: ({translation_x})"\n'
1374 | "Please read the instructions to find out how to use key frames "
1375 | "correctly.\n"
1376 | )
1377 | translation_x = f"0: ({translation_x})"
1378 | translation_x_series = get_inbetweens(parse_key_frames(translation_x))
1379 |
1380 | try:
1381 | translation_y_series = get_inbetweens(parse_key_frames(translation_y))
1382 | except RuntimeError as e:
1383 | print(
1384 | "WARNING: You have selected to use key frames, but you have not "
1385 | "formatted `translation_y` correctly for key frames.\n"
1386 | "Attempting to interpret `translation_y` as "
1387 | f'"0: ({translation_y})"\n'
1388 | "Please read the instructions to find out how to use key frames "
1389 | "correctly.\n"
1390 | )
1391 | translation_y = f"0: ({translation_y})"
1392 | translation_y_series = get_inbetweens(parse_key_frames(translation_y))
1393 |
1394 | else:
1395 | angle = float(angle)
1396 | zoom = float(zoom)
1397 | translation_x = float(translation_x)
1398 | translation_y = float(translation_y)
1399 |
1400 |
1401 | # ### Extra Settings
1402 | # Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling
1403 |
1404 | # In[14]:
1405 |
1406 |
1407 | #@markdown ####**Saving:**
1408 |
1409 | intermediate_saves = argparse_args.inter_saves #@param{type: 'raw'}
1410 | intermediates_in_subfolder = True #@param{type: 'boolean'}
1411 | #@markdown Intermediate steps will save a copy at your specified intervals. You can either format it as a single integer or a list of specific steps
1412 |
1413 | #@markdown A value of `2` will save a copy at 33% and 66%. 0 will save none.
1414 |
1415 | #@markdown A value of `[5, 9, 34, 45]` will save at steps 5, 9, 34, and 45. (Make sure to include the brackets)
1416 |
1417 |
1418 | if type(intermediate_saves) is not list:
1419 | if intermediate_saves:
1420 | steps_per_checkpoint = math.floor((steps - skip_steps - 1) // (intermediate_saves+1))
1421 | steps_per_checkpoint = steps_per_checkpoint if steps_per_checkpoint > 0 else 1
1422 | print(f'Will save every {steps_per_checkpoint} steps')
1423 | else:
1424 | steps_per_checkpoint = steps+10
1425 | else:
1426 | steps_per_checkpoint = None
1427 |
1428 | if intermediate_saves and intermediates_in_subfolder is True:
1429 | partialFolder = f'{batchFolder}/partials'
1430 | os.makedirs(partialFolder, exist_ok=True)
1431 |
1432 | #@markdown ---
1433 |
1434 | #@markdown ####**SuperRes Sharpening:**
1435 | #@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.*
1436 | sharpen_preset = argparse_args.sharpen_preset #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']
1437 | keep_unsharp = True #@param{type: 'boolean'}
1438 |
1439 | if sharpen_preset != 'Off' and keep_unsharp is True:
1440 | unsharpenFolder = f'{batchFolder}/unsharpened'
1441 | os.makedirs(unsharpenFolder, exist_ok=True)
1442 |
1443 |
1444 | #@markdown ---
1445 |
1446 | #@markdown ####**Advanced Settings:**
1447 | #@markdown *There are a few extra advanced settings available if you double click this cell.*
1448 |
1449 | #@markdown *Perlin init will replace your init, so uncheck if using one.*
1450 |
1451 | perlin_init = False #@param{type: 'boolean'}
1452 | perlin_mode = 'mixed' #@param ['mixed', 'color', 'gray']
1453 | set_seed = 'random_seed' #@param{type: 'string'}
1454 | eta = 0.8#@param{type: 'number'}
1455 | clamp_grad = True #@param{type: 'boolean'}
1456 | clamp_max = 0.05 #@param{type: 'number'}
1457 |
1458 |
1459 | ### EXTRA ADVANCED SETTINGS:
1460 | randomize_class = True
1461 | clip_denoised = False
1462 | fuzzy_prompt = False
1463 | rand_mag = 0.05
1464 |
1465 |
1466 | #@markdown ---
1467 |
1468 | #@markdown ####**Cutn Scheduling:**
1469 | #@markdown Format: `[40]*400+[20]*600` = 40 cuts for the first 400 /1000 steps, then 20 for the last 600/1000
1470 |
1471 | #@markdown cut_overview and cut_innercut are cumulative for total cutn on any given step. Overview cuts see the entire image and are good for early structure, innercuts are your standard cutn.
1472 |
1473 | cut_overview = "[12]*400+[4]*600" #@param {type: 'string'}
1474 | cut_innercut ="[4]*400+[12]*600"#@param {type: 'string'}
1475 | cut_ic_pow = 1#@param {type: 'number'}
1476 | cut_icgray_p = "[0.2]*400+[0]*600"#@param {type: 'string'}
1477 |
1478 |
1479 | # ###Prompts
1480 | # `animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one.
1481 |
1482 | # In[15]:
1483 |
1484 |
1485 | text_prompts = {
1486 | 0: [argparse_args.text]
1487 | #"A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation."],
1488 | # 100: ["This set of prompts start at frame 100", "This prompt has weight five:5"],
1489 | }
1490 |
1491 | image_prompts = {
1492 |
1493 | #0:['base_images/tree_of_life/tree_new_1.jpeg',],
1494 | }
1495 |
1496 |
1497 | # # 5. Diffuse!
1498 | #@title Do the Run!
1499 | #@markdown `n_batches` ignored with animation modes.
1500 | display_rate = 10 #@param{type: 'number'}
1501 | n_batches = 1 #@param{type: 'number'}
1502 | batch_size = 1
1503 |
1504 |
1505 | resume_run = False #@param{type: 'boolean'}
1506 | run_to_resume = 'latest' #@param{type: 'string'}
1507 | resume_from_frame = 'latest' #@param{type: 'string'}
1508 | retain_overwritten_frames = False #@param{type: 'boolean'}
1509 | if retain_overwritten_frames is True:
1510 | retainFolder = f'{batchFolder}/retained'
1511 | os.makedirs(retainFolder, exist_ok=True)
1512 |
1513 |
1514 | skip_step_ratio = int(frames_skip_steps.rstrip("%")) / 100
1515 | calc_frames_skip_steps = math.floor(steps * skip_step_ratio)
1516 |
1517 | if steps <= calc_frames_skip_steps:
1518 | sys.exit("ERROR: You can't skip more steps than your total steps")
1519 |
1520 | if resume_run:
1521 | if run_to_resume == 'latest':
1522 | try:
1523 | batchNum
1524 | except:
1525 | batchNum = len(glob(f"{batchFolder}/{batch_name}(*)_settings.txt"))-1
1526 | else:
1527 | batchNum = int(run_to_resume)
1528 | if resume_from_frame == 'latest':
1529 | start_frame = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png"))
1530 | else:
1531 | start_frame = int(resume_from_frame)+1
1532 | if retain_overwritten_frames is True:
1533 | existing_frames = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png"))
1534 | frames_to_save = existing_frames - start_frame
1535 | print(f'Moving {frames_to_save} frames to the Retained folder')
1536 | move_files(start_frame, existing_frames, batchFolder, retainFolder)
1537 | else:
1538 | start_frame = 0
1539 | batchNum = len(glob(batchFolder+"/*.txt"))
1540 | while path.isfile(f"{batchFolder}/{batch_name}({batchNum})_settings.txt") is True or path.isfile(f"{batchFolder}/{batch_name}-{batchNum}_settings.txt") is True:
1541 | batchNum += 1
1542 |
1543 | print(f'Starting Run: {batch_name}({batchNum}) at frame {start_frame}')
1544 |
1545 | if set_seed == 'random_seed':
1546 | random.seed()
1547 | seed = random.randint(0, 2**32)
1548 | # print(f'Using seed: {seed}')
1549 | else:
1550 | seed = int(set_seed)
1551 |
1552 | args = {
1553 | 'batchNum': batchNum,
1554 | 'prompts_series':split_prompts(text_prompts) if text_prompts else None,
1555 | 'image_prompts_series':split_prompts(image_prompts) if image_prompts else None,
1556 | 'seed': seed,
1557 | 'display_rate':display_rate,
1558 | 'n_batches':n_batches if animation_mode == 'None' else 1,
1559 | 'batch_size':batch_size,
1560 | 'batch_name': batch_name,
1561 | 'steps': steps,
1562 | 'width_height': width_height,
1563 | 'clip_guidance_scale': clip_guidance_scale,
1564 | 'tv_scale': tv_scale,
1565 | 'range_scale': range_scale,
1566 | 'sat_scale': sat_scale,
1567 | 'cutn_batches': cutn_batches,
1568 | 'init_image': init_image,
1569 | 'init_scale': init_scale,
1570 | 'skip_steps': skip_steps,
1571 | 'sharpen_preset': sharpen_preset,
1572 | 'keep_unsharp': keep_unsharp,
1573 | 'side_x': side_x,
1574 | 'side_y': side_y,
1575 | 'timestep_respacing': timestep_respacing,
1576 | 'diffusion_steps': diffusion_steps,
1577 | 'animation_mode': animation_mode,
1578 | 'video_init_path': video_init_path,
1579 | 'extract_nth_frame': extract_nth_frame,
1580 | 'key_frames': key_frames,
1581 | 'max_frames': max_frames if animation_mode != "None" else 1,
1582 | 'interp_spline': interp_spline,
1583 | 'start_frame': start_frame,
1584 | 'angle': angle,
1585 | 'zoom': zoom,
1586 | 'translation_x': translation_x,
1587 | 'translation_y': translation_y,
1588 | 'angle_series':angle_series,
1589 | 'zoom_series':zoom_series,
1590 | 'translation_x_series':translation_x_series,
1591 | 'translation_y_series':translation_y_series,
1592 | 'frames_scale': frames_scale,
1593 | 'calc_frames_skip_steps': calc_frames_skip_steps,
1594 | 'skip_step_ratio': skip_step_ratio,
1595 | 'calc_frames_skip_steps': calc_frames_skip_steps,
1596 | 'text_prompts': text_prompts,
1597 | 'image_prompts': image_prompts,
1598 | 'cut_overview': eval(cut_overview),
1599 | 'cut_innercut': eval(cut_innercut),
1600 | 'cut_ic_pow': cut_ic_pow,
1601 | 'cut_icgray_p': eval(cut_icgray_p),
1602 | 'intermediate_saves': intermediate_saves,
1603 | 'intermediates_in_subfolder': intermediates_in_subfolder,
1604 | 'steps_per_checkpoint': steps_per_checkpoint,
1605 | 'perlin_init': perlin_init,
1606 | 'perlin_mode': perlin_mode,
1607 | 'set_seed': set_seed,
1608 | 'eta': eta,
1609 | 'clamp_grad': clamp_grad,
1610 | 'clamp_max': clamp_max,
1611 | 'skip_augs': skip_augs,
1612 | 'randomize_class': randomize_class,
1613 | 'clip_denoised': clip_denoised,
1614 | 'fuzzy_prompt': fuzzy_prompt,
1615 | 'rand_mag': rand_mag,
1616 | }
1617 |
1618 | args = SimpleNamespace(**args)
1619 |
1620 | print('Prepping model...')
1621 | model, diffusion = create_model_and_diffusion(**model_config)
1622 | model.load_state_dict(torch.load(f'{model_path}/{diffusion_model}.pt', map_location='cpu'))
1623 | model.requires_grad_(False).eval().to(device)
1624 | for name, param in model.named_parameters():
1625 | if 'qkv' in name or 'norm' in name or 'proj' in name:
1626 | param.requires_grad_()
1627 | if model_config['use_fp16']:
1628 | model.convert_to_fp16()
1629 |
1630 | gc.collect()
1631 | torch.cuda.empty_cache()
1632 | try:
1633 | do_run()
1634 | except KeyboardInterrupt:
1635 | pass
1636 | finally:
1637 | print('Seed used:', seed)
1638 | gc.collect()
1639 | torch.cuda.empty_cache()
1640 |
1641 |
1642 | # # 6. Create the video
1643 |
1644 | # @title ### **Create video**
1645 | #@markdown Video file will save in the same folder as your images.
1646 |
1647 | skip_video_for_run_all = True #@param {type: 'boolean'}
1648 |
1649 | if skip_video_for_run_all == False:
1650 | # import subprocess in case this cell is run without the above cells
1651 | import subprocess
1652 | from base64 import b64encode
1653 |
1654 | latest_run = batchNum
1655 |
1656 | folder = batch_name #@param
1657 | run = latest_run #@param
1658 | final_frame = 'final_frame'
1659 |
1660 |
1661 | init_frame = 1#@param {type:"number"} This is the frame where the video will start
1662 | last_frame = final_frame#@param {type:"number"} You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.
1663 | fps = 12#@param {type:"number"}
1664 | view_video_in_cell = False #@param {type: 'boolean'}
1665 |
1666 | frames = []
1667 | # tqdm.write('Generating video...')
1668 |
1669 | if last_frame == 'final_frame':
1670 | last_frame = len(glob(batchFolder+f"/{folder}({run})_*.png"))
1671 | print(f'Total frames: {last_frame}')
1672 |
1673 | image_path = f"{outDirPath}/{folder}/{folder}({run})_%04d.png"
1674 | filepath = f"{outDirPath}/{folder}/{folder}({run}).mp4"
1675 |
1676 |
1677 | cmd = [
1678 | 'ffmpeg',
1679 | '-y',
1680 | '-vcodec',
1681 | 'png',
1682 | '-r',
1683 | str(fps),
1684 | '-start_number',
1685 | str(init_frame),
1686 | '-i',
1687 | image_path,
1688 | '-frames:v',
1689 | str(last_frame+1),
1690 | '-c:v',
1691 | 'libx264',
1692 | '-vf',
1693 | f'fps={fps}',
1694 | '-pix_fmt',
1695 | 'yuv420p',
1696 | '-crf',
1697 | '17',
1698 | '-preset',
1699 | 'veryslow',
1700 | filepath
1701 | ]
1702 |
1703 | process = subprocess.Popen(cmd, cwd=f'{batchFolder}', stdout=subprocess.PIPE, stderr=subprocess.PIPE)
1704 | stdout, stderr = process.communicate()
1705 | if process.returncode != 0:
1706 | print(stderr)
1707 | raise RuntimeError(stderr)
1708 | else:
1709 | print("The video is ready")
1710 |
1711 | if view_video_in_cell:
1712 | mp4 = open(filepath,'rb').read()
1713 | data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
1714 | display.HTML("""
1715 |
1718 | """ % data_url)
1719 |
--------------------------------------------------------------------------------
/perlin.py:
--------------------------------------------------------------------------------
1 | # from https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869
2 |
3 |
4 | import torch
5 | import torchvision.transforms as TF
6 |
7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8 |
9 | def interp(t):
10 | return 3 * t**2 - 2 * t ** 3
11 |
12 |
13 | def perlin(width, height, scale=10, device=None):
14 | gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)
15 | xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)
16 | ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)
17 | wx = 1 - interp(xs)
18 | wy = 1 - interp(ys)
19 | dots = 0
20 | dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
21 | dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
22 | dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
23 | dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))
24 | return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)
25 |
26 |
27 | def perlin_ms(octaves, width, height, grayscale, device=None):
28 | out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]
29 | # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
30 | for i in range(1 if grayscale else 3):
31 | scale = 2 ** len(octaves)
32 | oct_width = width
33 | oct_height = height
34 | for oct in octaves:
35 | p = perlin(oct_width, oct_height, scale, device)
36 | out_array[i] += p * oct
37 | scale //= 2
38 | oct_width *= 2
39 | oct_height *= 2
40 | return torch.cat(out_array)
41 |
42 |
43 | def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):
44 | out = perlin_ms(octaves, width, height, grayscale)
45 | if grayscale:
46 | out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))
47 | out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')
48 | else:
49 | out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])
50 | out = TF.resize(size=(side_y, side_x), img=out)
51 | out = TF.to_pil_image(out.clamp(0, 1).squeeze())
52 |
53 | out = ImageOps.autocontrast(out)
54 | return out
55 |
56 |
57 | def regen_perlin():
58 | if perlin_mode == 'color':
59 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
60 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)
61 | elif perlin_mode == 'gray':
62 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)
63 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
64 | else:
65 | init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
66 | init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
67 |
68 | init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)
69 | del init2
70 | return init.expand(batch_size, -1, -1, -1)
71 |
--------------------------------------------------------------------------------
/secondary_diffusion.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import math
3 | from functools import partial
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | @dataclass
10 | class DiffusionOutput:
11 | v: torch.Tensor
12 | pred: torch.Tensor
13 | eps: torch.Tensor
14 |
15 |
16 | class ConvBlock(nn.Sequential):
17 | def __init__(self, c_in, c_out):
18 | super().__init__(
19 | nn.Conv2d(c_in, c_out, 3, padding=1),
20 | nn.ReLU(inplace=True),
21 | )
22 |
23 |
24 | class SkipBlock(nn.Module):
25 | def __init__(self, main, skip=None):
26 | super().__init__()
27 | self.main = nn.Sequential(*main)
28 | self.skip = skip if skip else nn.Identity()
29 |
30 | def forward(self, input):
31 | return torch.cat([self.main(input), self.skip(input)], dim=1)
32 |
33 |
34 | class FourierFeatures(nn.Module):
35 | def __init__(self, in_features, out_features, std=1.):
36 | super().__init__()
37 | assert out_features % 2 == 0
38 | self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)
39 |
40 | def forward(self, input):
41 | f = 2 * math.pi * input @ self.weight.T
42 | return torch.cat([f.cos(), f.sin()], dim=-1)
43 |
44 |
45 | class SecondaryDiffusionImageNet(nn.Module):
46 | def __init__(self):
47 | super().__init__()
48 | c = 64 # The base channel count
49 |
50 | self.timestep_embed = FourierFeatures(1, 16)
51 |
52 | self.net = nn.Sequential(
53 | ConvBlock(3 + 16, c),
54 | ConvBlock(c, c),
55 | SkipBlock([
56 | nn.AvgPool2d(2),
57 | ConvBlock(c, c * 2),
58 | ConvBlock(c * 2, c * 2),
59 | SkipBlock([
60 | nn.AvgPool2d(2),
61 | ConvBlock(c * 2, c * 4),
62 | ConvBlock(c * 4, c * 4),
63 | SkipBlock([
64 | nn.AvgPool2d(2),
65 | ConvBlock(c * 4, c * 8),
66 | ConvBlock(c * 8, c * 4),
67 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
68 | ]),
69 | ConvBlock(c * 8, c * 4),
70 | ConvBlock(c * 4, c * 2),
71 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
72 | ]),
73 | ConvBlock(c * 4, c * 2),
74 | ConvBlock(c * 2, c),
75 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
76 | ]),
77 | ConvBlock(c * 2, c),
78 | nn.Conv2d(c, 3, 3, padding=1),
79 | )
80 |
81 | def forward(self, input, t):
82 | timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)
83 | v = self.net(torch.cat([input, timestep_embed], dim=1))
84 | alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))
85 | pred = input * alphas - v * sigmas
86 | eps = input * sigmas + v * alphas
87 | return DiffusionOutput(v, pred, eps)
88 |
89 |
90 | class SecondaryDiffusionImageNet2(nn.Module):
91 | def __init__(self):
92 | super().__init__()
93 | c = 64 # The base channel count
94 | cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]
95 |
96 | self.timestep_embed = FourierFeatures(1, 16)
97 | self.down = nn.AvgPool2d(2)
98 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
99 |
100 | self.net = nn.Sequential(
101 | ConvBlock(3 + 16, cs[0]),
102 | ConvBlock(cs[0], cs[0]),
103 | SkipBlock([
104 | self.down,
105 | ConvBlock(cs[0], cs[1]),
106 | ConvBlock(cs[1], cs[1]),
107 | SkipBlock([
108 | self.down,
109 | ConvBlock(cs[1], cs[2]),
110 | ConvBlock(cs[2], cs[2]),
111 | SkipBlock([
112 | self.down,
113 | ConvBlock(cs[2], cs[3]),
114 | ConvBlock(cs[3], cs[3]),
115 | SkipBlock([
116 | self.down,
117 | ConvBlock(cs[3], cs[4]),
118 | ConvBlock(cs[4], cs[4]),
119 | SkipBlock([
120 | self.down,
121 | ConvBlock(cs[4], cs[5]),
122 | ConvBlock(cs[5], cs[5]),
123 | ConvBlock(cs[5], cs[5]),
124 | ConvBlock(cs[5], cs[4]),
125 | self.up,
126 | ]),
127 | ConvBlock(cs[4] * 2, cs[4]),
128 | ConvBlock(cs[4], cs[3]),
129 | self.up,
130 | ]),
131 | ConvBlock(cs[3] * 2, cs[3]),
132 | ConvBlock(cs[3], cs[2]),
133 | self.up,
134 | ]),
135 | ConvBlock(cs[2] * 2, cs[2]),
136 | ConvBlock(cs[2], cs[1]),
137 | self.up,
138 | ]),
139 | ConvBlock(cs[1] * 2, cs[1]),
140 | ConvBlock(cs[1], cs[0]),
141 | self.up,
142 | ]),
143 | ConvBlock(cs[0] * 2, cs[0]),
144 | nn.Conv2d(cs[0], 3, 3, padding=1),
145 | )
146 |
147 | def forward(self, input, t):
148 | timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)
149 | v = self.net(torch.cat([input, timestep_embed], dim=1))
150 | alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))
151 | pred = input * alphas - v * sigmas
152 | eps = input * sigmas + v * alphas
153 | return DiffusionOutput(v, pred, eps)
154 |
155 |
156 | def append_dims(x, n):
157 | return x[(Ellipsis, *(None,) * (n - x.ndim))]
158 |
159 |
160 | def expand_to_planes(x, shape):
161 | return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])
162 |
163 |
164 | def t_to_alpha_sigma(t):
165 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
166 |
--------------------------------------------------------------------------------