4 |
5 | This is a collection of text-to-image tools, evolved from the [artwork] of the same name.
6 | Based on [CLIP] model and [Lucent] library, with FFT/DWT/RGB parameterizers (no-GAN generation).
7 | *Updated: Old depth estimation method is replaced with [Depth Anything 2].*
8 | Tested on Python 3.7-3.11 with PyTorch from 1.7.1 to 2.3.1.
9 |
10 | *[Aphantasia] is the inability to visualize mental images, the deprivation of visual dreams.
11 | The image in the header is generated by the tool from this word.*
12 |
13 | **Please be kind to mention this project, if you employ it for your masterpieces**
14 |
15 | ## Features
16 | * generating massive detailed textures, a la deepdream
17 | * fullHD/4K resolutions and above
18 | * various CLIP models
19 | * continuous mode to process phrase lists (e.g. illustrating lyrics)
20 | * pan/zoom motion with smooth interpolation
21 | * direct RGB pixels optimization (very stable)
22 | * 3D look, based on [Depth Anything 2]
23 | * complex queries:
24 | * text and/or image as main prompts
25 | * separate text prompts for style and to subtract (avoid) topics
26 | * starting/resuming process from saved parameters or from an image
27 |
28 | Setup [CLIP] et cetera:
29 | ```
30 | pip install -r requirements.txt
31 | pip install git+https://github.com/openai/CLIP.git
32 | ```
33 |
34 | ## Operations
35 |
36 | [](https://colab.research.google.com/github/eps696/aphantasia/blob/master/Aphantasia.ipynb)
37 |
38 | * Generate an image from the text prompt (set the size as you wish):
39 | ```
40 | python clip_fft.py -t "the text" --size 1280-720
41 | ```
42 | * Reproduce an image:
43 | ```
44 | python clip_fft.py -i theimage.jpg --sync 0.4
45 | ```
46 | If `--sync X` argument > 0, [LPIPS] loss is added to keep the composition similar to the original image.
47 |
48 | You can combine both text and image prompts.
49 | For non-English languages use `--translate` (Google translation).
50 |
51 | * Set more specific query like this:
52 | ```
53 | python clip_fft.py -t "topic sentence" -t2 "style description" -t0 "avoid this" --size 1280-720
54 | ```
55 | * Other options:
56 | Text inputs understand syntax with weights, like `good prompt :1 | also good prompt :1 | bad prompt :-0.5`.
57 | `--model M` selects one of the released CLIP visual models: `ViT-B/32` (default), `ViT-B/16`, `RN50`, `RN50x4`, `RN50x16`, `RN101`.
58 | One can also set `--dualmod` to use `ViT-B/32` and `ViT-B/16` at once (preferrable).
59 | `--dwt` switches to DWT (wavelets) generator instead of FFT. There are few methods, chosen by `--wave X`, e.g. `db2`, `db3`, `coif1`, `coif2`, etc.
60 | `--align XX` option is about composition (or sampling distribution, to be more precise): `uniform` is maybe the most adequate; `overscan` can make semi-seamless tileable textures.
61 | `--steps N` sets iterations count. 100-200 is enough for a starter; 500-1000 would elaborate it more thoroughly.
62 | `--samples N` sets amount of the image cuts (samples), processed at one step. With more samples you can set fewer iterations for similar result (and vice versa). 200/200 is a good guess. NB: GPU memory is mostly eaten by this count (not resolution)!
63 | `--aest X` enforces overall cuteness by employing [aesthetic loss](https://github.com/LAION-AI/aesthetic-predictor). try various values (may be negative).
64 | `--decay X` (compositional softness), `--colors X` (saturation) and `--contrast X` may be useful, especially for ResNet models (they tend to burn the colors).
65 | `--sharp X` may be useful to increase sharpness, if the image becomes "myopic" after increasing `decay`. it affects the other color parameters, better tweak them all together!
66 | Current defaults are `--decay 1.5 --colors 1.8 --contrast 1.1 --sharp 0`.
67 | `--transform X` applies some augmentations, usually enhancing result (but slower). there are few choices; `fast` seems optimal.
68 | `--optimizer` can be `adam`, `adamw`, `adam_custom` or `adamw_custom`. Custom options are noiser but stable; pure `adam` is softer, but may tend to colored blurring.
69 | `--invert` negates the whole criteria, if you fancy checking "totally opposite".
70 | `--save_pt myfile.pt` will save FFT/DWT parameters, to resume for next query with `--resume myfile.pt`. One can also start/resume directly from an image file.
71 | `--opt_step N` tells to save every Nth frame (useful with high iterations, default is 1).
72 | `--verbose` ('on' by default) enables some printouts and realtime image preview.
73 | * Some experimental tricks with less definite effects:
74 | `--enforce X` adds more details by boosting similarity between two parallel samples. good start is ~0.1.
75 | `--expand X` boosts diversity by enforcing difference between prev/next samples. good start is ~0.3.
76 | `--noise X` adds some noise to the parameters, possibly making composition less clogged (in a degree).
77 | `--macro X` (from 0 to 1) shifts generation to bigger forms and less disperse composition. should not be too close to 1, since the quality depends on the variety of samples.
78 | `--prog` sets progressive learning rate (from 0.1x to 2x of the one, set by `lrate`). it may boost macro forms creation in some cases (see more [here](https://github.com/eps696/aphantasia/issues/2)).
79 | `--lrate` controls learning rate. The range is quite wide (tested at least within 0.001 to 10).
80 |
81 | ## Text-to-video [continuous mode]
82 |
83 | Here is two ways of making video from the text file(s), processing it line by line in one shot.
84 |
85 | ### Illustrip
86 |
87 | New method, interpolating topics as a constant flow with permanent pan/zoom motion and optional 3D look.
88 |
89 | [](https://colab.research.google.com/github/eps696/aphantasia/blob/master/IllusTrip3D.ipynb)
90 |
91 | * Make video from two text files, processing them line by line, rendering 100 frames per line:
92 | ```
93 | python illustrip.py --in_txt mycontent.txt --in_txt2 mystyles.txt --size 1280-720 --steps 100
94 | ```
95 | * Make video from two phrases, with total length 500 frames:
96 | ```
97 | python illustrip.py --in_txt "my super content" --in_txt2 "my super style" --size 1280-720 --steps 500
98 | ```
99 | Prefixes (`-pre`), postfixes (`-post`) and "stop words" (`--in_txt0`) may be loaded as phrases or text files as well.
100 | All text inputs understand syntax with weights, like `good prompt :1 | also good prompt :1 | bad prompt :-0.5` (within one line).
101 | One can also use image(s) as references with `--in_img` argument. Explore other arguments for more explicit control.
102 | This method works best with direct RGB pixels optimization, but can also be used with FFT parameterization:
103 | ```
104 | python illustrip.py ... --gen FFT --smooth --align uniform --colors 1.8 --contrast 1.1
105 | ```
106 |
107 | To add 3D look, add `--depth 0.01` to the command.
108 |
109 | ### Illustra
110 |
111 | Generates separate images for every text line (with sequences and training videos, as in single-image mode above), then renders final video from those (mixing images in FFT space) of the `length` duration in seconds.
112 |
113 | [](https://colab.research.google.com/github/eps696/aphantasia/blob/master/Illustra.ipynb)
114 |
115 | * Make video from a text file, processing it line by line:
116 | ```
117 | python illustra.py -t mysong.txt --size 1280-720 --length 155
118 | ```
119 | There is `--keep X` parameter, controlling how well the next line/image generation follows the previous. 0 means it's randomly initiated, the higher - the stricter it will keep the original composition. Safe values are 1~2 (much higher numbers may cause the imagery getting stuck).
120 |
121 | * Make video from a directory with saved *.pt snapshots (just interpolate them):
122 | ```
123 | python interpol.py -i mydir --length 155
124 | ```
125 |
126 | ## Other generators
127 |
128 | * VQGAN from [Taming Transformers](https://github.com/CompVis/taming-transformers)
129 | One of the best methods for colors/tones/details (especially with new Gumbel-F8 model); has quite limited resolution though (~800x600 max on Colab).
130 | [](https://colab.research.google.com/github/eps696/aphantasia/blob/master/CLIP_VQGAN.ipynb)
131 |
132 |
133 | * CPPN + [export to HLSL shaders](https://github.com/wxs/cppn-to-glsl)
134 | One of the very first methods, with exports for TouchDesigner, vvvv, Shadertoy, etc.
135 | [](https://colab.research.google.com/drive/1Kbbbwoet3igHPJ4KpNh8z3V-RxtstAcz)
136 | ```
137 | python cppn.py -v -t "the text" --aest 0.5
138 | ```
139 |
140 | * SIREN + [Fourier feature modulation](https://github.com/tancik/fourier-feature-networks)
141 | Another early method, not so interesting on its own.
142 | [](https://colab.research.google.com/drive/1L14q4To5rMK8q2E6whOibQBnPnVbRJ_7)
143 |
144 |
145 | ## Credits
146 |
147 | Based on [CLIP] model by OpenAI ([paper]).
148 | FFT encoding is taken from [Lucent] library, 3D depth processing made by [deKxi].
149 |
150 | Thanks to [Ryan Murdock], [Jonathan Fly], [Hannu Toyryla], [@eduwatch2], [torridgristle] for ideas.
151 |
152 |
153 |
154 | [artwork]:
155 | [Aphantasia]:
156 | [CLIP]:
157 | [SBERT]:
158 | [Lucent]:
159 | [Depth Anything 2]:
160 | [LPIPS]:
161 | [Taming Transformers]:
162 | [Ryan Murdock]:
163 | [Jonathan Fly]:
164 | [Hannu Toyryla]:
165 | [@eduwatch2]:
166 | [torridgristle]:
167 | [deKxi]:
168 | [paper]:
169 |
--------------------------------------------------------------------------------
/_out/Aphantasia.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/Aphantasia.jpg
--------------------------------------------------------------------------------
/_out/Aphantasia2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/Aphantasia2.jpg
--------------------------------------------------------------------------------
/_out/Aphantasia3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/Aphantasia3.jpg
--------------------------------------------------------------------------------
/_out/Aphantasia4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/Aphantasia4.jpg
--------------------------------------------------------------------------------
/_out/some_cute_image-FFT.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/some_cute_image-FFT.jpg
--------------------------------------------------------------------------------
/_out/some_cute_image-SIREN.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/some_cute_image-SIREN.jpg
--------------------------------------------------------------------------------
/_out/some_cute_image-VQGAN.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/_out/some_cute_image-VQGAN.jpg
--------------------------------------------------------------------------------
/aphantasia/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/aphantasia/__init__.py
--------------------------------------------------------------------------------
/aphantasia/image.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from imageio import imread
4 |
5 | import pywt
6 | from pytorch_wavelets import DWTForward, DWTInverse
7 | # from pytorch_wavelets import DTCWTForward, DTCWTInverse
8 |
9 | import torch
10 |
11 | from aphantasia.utils import slice_imgs, derivat, sim_func, basename, img_list, img_read, plot_text, old_torch
12 | from aphantasia.transforms import normalize
13 |
14 | def to_valid_rgb(image_f, colors=1., decorrelate=True):
15 | color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]])
16 | color_correlation_svd_sqrt /= torch.tensor([colors, 1., 1.]) # saturate, empirical
17 | max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max()
18 | color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt
19 | colcorr_t = color_correlation_normalized.T.cuda()
20 |
21 | def _linear_decorrelate_color(image):
22 | return torch.einsum('nchw,cd->ndhw', image, colcorr_t) # edit by katherine crowson
23 |
24 | def inner(*args, **kwargs):
25 | image = image_f(*args, **kwargs)
26 | if decorrelate:
27 | image = _linear_decorrelate_color(image)
28 | return torch.sigmoid(image)
29 | return inner
30 |
31 | ### DWT [wavelets]
32 |
33 | def init_dwt(resume=None, shape=None, wave=None, colors=None):
34 | size = None
35 | wp_fake = pywt.WaveletPacket2D(data=np.zeros(shape[2:]), wavelet='db1', mode='symmetric')
36 | xfm = DWTForward(J=wp_fake.maxlevel, wave=wave, mode='symmetric').cuda()
37 | # xfm = DTCWTForward(J=lvl, biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b']
38 | ifm = DWTInverse(wave=wave, mode='symmetric').cuda() # symmetric zero periodization
39 | # ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda() # 4x more params, biort ['antonini','legall','near_sym_a','near_sym_b']
40 | if resume is None: # random init
41 | Yl_in, Yh_in = xfm(torch.zeros(shape).cuda())
42 | Ys = [torch.randn(*Y.shape).cuda() for Y in [Yl_in, *Yh_in]]
43 | elif isinstance(resume, str):
44 | if os.path.isfile(resume):
45 | if os.path.splitext(resume)[1].lower()[1:] in ['jpg','png','tif','bmp']:
46 | img_in = imread(resume)
47 | Ys = img2dwt(img_in, wave=wave, colors=colors)
48 | print(' loaded image', resume, img_in.shape, 'level', len(Ys)-1)
49 | size = img_in.shape[:2]
50 | wp_fake = pywt.WaveletPacket2D(data=np.zeros(size), wavelet='db1', mode='symmetric')
51 | xfm = DWTForward(J=wp_fake.maxlevel, wave=wave, mode='symmetric').cuda()
52 | else:
53 | Ys = torch.load(resume)
54 | Ys = [y.detach().cuda() for y in Ys]
55 | else: print(' Snapshot not found:', resume); exit()
56 | else:
57 | Ys = [y.cuda() for y in resume]
58 | # print('level', len(Ys)-1, 'low freq', Ys[0].cpu().numpy().shape)
59 | return Ys, xfm, ifm, size
60 |
61 | def dwt_image(shape, wave='coif2', sharp=0.3, colors=1., resume=None):
62 | Ys, _, ifm, size = init_dwt(resume, shape, wave, colors)
63 | Ys = [y.requires_grad_(True) for y in Ys]
64 | scale = dwt_scale(Ys, sharp)
65 |
66 | def inner(shift=None, contrast=1.):
67 | image = ifm((Ys[0], [Ys[i+1] * float(scale[i]) for i in range(len(Ys)-1)]))
68 | image = image * contrast / image.std() # keep contrast, empirical *1.33
69 | return image
70 |
71 | return Ys, inner, size
72 |
73 | def dwt_scale(Ys, sharp):
74 | scale = []
75 | [h0,w0] = Ys[1].shape[3:5]
76 | for i in range(len(Ys)-1):
77 | [h,w] = Ys[i+1].shape[3:5]
78 | scale.append( ((h0*w0)/(h*w)) ** (1.-sharp) )
79 | # print(i+1, Ys[i+1].shape)
80 | return scale
81 |
82 | def img2dwt(img_in, wave='coif2', sharp=0.3, colors=1.):
83 | image_t = un_rgb(img_in, colors=colors)
84 | with torch.no_grad():
85 | wp_fake = pywt.WaveletPacket2D(data=np.zeros(image_t.shape[2:]), wavelet='db1', mode='zero')
86 | lvl = wp_fake.maxlevel
87 | # print(image_t.shape, lvl)
88 | xfm = DWTForward(J=lvl, wave=wave, mode='symmetric').cuda()
89 | Yl_in, Yh_in = xfm(image_t.cuda())
90 | Ys = [Yl_in, *Yh_in]
91 | scale = dwt_scale(Ys, sharp)
92 | for i in range(len(Ys)-1):
93 | Ys[i+1] /= scale[i]
94 | return Ys
95 |
96 | ### FFT/RGB from Lucent library ### https://github.com/greentfrapp/lucent
97 |
98 | def pixel_image(shape, resume=None, sd=1., *noargs, **nokwargs):
99 | size = None
100 | if resume is None:
101 | image_t = torch.randn(*shape) * sd
102 | elif isinstance(resume, str):
103 | if os.path.isfile(resume):
104 | img_in = img_read(resume)
105 | image_t = 3.3 * un_rgb(img_in, colors=2.)
106 | size = img_in.shape[:2]
107 | print(resume, size)
108 | else: print(' Image not found:', resume); exit()
109 | else:
110 | if isinstance(resume, list): resume = resume[0]
111 | image_t = resume
112 | image_t = image_t.cuda().requires_grad_(True)
113 |
114 | def inner(shift=None, contrast=1., fixcontrast=False): # *noargs, **nokwargs
115 | if fixcontrast is True: # for resuming from image
116 | return image_t * contrast / 3.3
117 | else:
118 | return image_t * contrast / image_t.std()
119 | return [image_t], inner, size # lambda: image_t
120 |
121 | # From https://github.com/tensorflow/lucid/blob/master/lucid/optvis/param/spatial.py
122 | def rfft2d_freqs(h, w):
123 | """Computes 2D spectrum frequencies."""
124 | fy = np.fft.fftfreq(h)[:, None]
125 | # when we have an odd input dimension we need to keep one additional frequency and later cut off 1 pixel
126 | w2 = (w+1)//2 if w%2 == 1 else w//2+1
127 | fx = np.fft.fftfreq(w)[:w2]
128 | return np.sqrt(fx * fx + fy * fy)
129 |
130 | def resume_fft(resume=None, shape=None, decay=None, colors=1.6, sd=0.01):
131 | size = None
132 | if resume is None: # random init
133 | params_shape = [*shape[:3], shape[3]//2+1, 2] # [1,3,512,257,2] for 512x512 (2 for imaginary and real components)
134 | params = 0.01 * torch.randn(*params_shape).cuda()
135 | elif isinstance(resume, str):
136 | if os.path.isfile(resume):
137 | if os.path.splitext(resume)[1].lower()[1:] in ['jpg','png','tif','bmp']:
138 | img_in = img_read(resume)
139 | params = img2fft(img_in, decay, colors)
140 | size = img_in.shape[:2]
141 | else:
142 | params = torch.load(resume)
143 | if isinstance(params, list): params = params[0]
144 | params = params.detach().cuda()
145 | params *= sd
146 | else: print(' Snapshot not found:', resume); exit()
147 | else:
148 | if isinstance(resume, list): resume = resume[0]
149 | params = resume.cuda()
150 | return params, size
151 |
152 | def fft_image(shape, sd=0.01, decay_power=1.0, resume=None): # decay ~ blur
153 |
154 | params, size = resume_fft(resume, shape, decay_power, sd=sd)
155 | spectrum_real_imag_t = params.requires_grad_(True)
156 | if size is not None: shape[2:] = size
157 | [h,w] = list(shape[2:])
158 |
159 | freqs = rfft2d_freqs(h,w)
160 | scale = 1. / np.maximum(freqs, 4./max(h,w)) ** decay_power
161 | scale *= np.sqrt(h*w)
162 | scale = torch.tensor(scale).float()[None, None, ..., None].cuda()
163 |
164 | def inner(shift=None, contrast=1., *noargs, **nokwargs):
165 | scaled_spectrum_t = scale * spectrum_real_imag_t
166 | if shift is not None:
167 | scaled_spectrum_t += scale * shift
168 | if old_torch():
169 | image = torch.irfft(scaled_spectrum_t, 2, normalized=True, signal_sizes=(h, w))
170 | else:
171 | if type(scaled_spectrum_t) is not torch.complex64:
172 | scaled_spectrum_t = torch.view_as_complex(scaled_spectrum_t)
173 | image = torch.fft.irfftn(scaled_spectrum_t, s=(h, w), norm='ortho')
174 | image = image * contrast / image.std() # keep contrast, empirical
175 | return image
176 |
177 | return [spectrum_real_imag_t], inner, size
178 |
179 | def inv_sigmoid(x):
180 | eps = 1.e-12
181 | x = torch.clamp(x.double(), eps, 1-eps)
182 | y = torch.log(x/(1-x))
183 | return y.float()
184 |
185 | def un_rgb(image, colors=1.):
186 | color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]])
187 | color_correlation_svd_sqrt /= torch.tensor([colors, 1., 1.]) # saturate, empirical
188 | max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max()
189 | color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt
190 | colcorr_t = color_correlation_normalized.T.cuda()
191 | colcorr_t_inv = torch.linalg.inv(colcorr_t)
192 |
193 | if not isinstance(image, torch.Tensor): # numpy int array [0..255]
194 | image = torch.Tensor(image).cuda().permute(2,0,1).unsqueeze(0) / 255.
195 | # image = inv_sigmoid(image)
196 | image = normalize()(image) # experimental
197 | return torch.einsum('nchw,cd->ndhw', image, colcorr_t_inv) # edit by katherine crowson
198 |
199 | def un_spectrum(spectrum, decay_power):
200 | h = spectrum.shape[2]
201 | w = (spectrum.shape[3]-1)*2
202 | freqs = rfft2d_freqs(h, w)
203 | scale = 1.0 / np.maximum(freqs, 1.0 / max(w, h)) ** decay_power
204 | scale *= np.sqrt(w*h)
205 | scale = torch.tensor(scale).float()[None, None, ..., None].cuda()
206 | return spectrum / scale
207 |
208 | def img2fft(img_in, decay=1., colors=1.):
209 | image_t = un_rgb(img_in, colors=colors)
210 | h, w = image_t.shape[2], image_t.shape[3]
211 |
212 | with torch.no_grad():
213 | if old_torch():
214 | spectrum = torch.rfft(image_t, 2, normalized=True) # 1.7
215 | else:
216 | spectrum = torch.fft.rfftn(image_t, s=(h, w), dim=[2,3], norm='ortho') # 1.8
217 | spectrum = torch.view_as_real(spectrum)
218 | spectrum = un_spectrum(spectrum, decay_power=decay)
219 | spectrum *= 500000. # [sic!!!]
220 | return spectrum
221 |
--------------------------------------------------------------------------------
/aphantasia/interpol.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | warnings.filterwarnings("ignore")
4 | import argparse
5 | import math
6 | import numpy as np
7 |
8 | import torch
9 |
10 | from clip_fft import to_valid_rgb, fft_image
11 | from aphantasia.utils import basename, file_list, checkout
12 | try: # progress bar for notebooks
13 | get_ipython().__class__.__name__
14 | from aphantasia.progress_bar import ProgressIPy as ProgressBar
15 | except: # normal console
16 | from aphantasia.progress_bar import ProgressBar
17 |
18 | def get_args():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('-i', '--in_dir', default='pt')
21 | parser.add_argument('-o', '--out_dir', default='_out')
22 | parser.add_argument('-l', '--length', default=None, type=int, help='Total length in sec')
23 | parser.add_argument('-s', '--steps', default=25, type=int, help='Override length')
24 | parser.add_argument( '--fps', default=25, type=int)
25 | parser.add_argument( '--contrast', default=1.1, type=float)
26 | parser.add_argument( '--colors', default=1.8, type=float)
27 | parser.add_argument('-d', '--decay', default=1.5, type=float)
28 | parser.add_argument('-v', '--verbose', default=True, type=bool)
29 | a = parser.parse_args()
30 | return a
31 |
32 | def read_pt(file):
33 | return torch.load(file)[0].cuda()
34 |
35 | def main():
36 | a = get_args()
37 | tempdir = os.path.join(a.out_dir, 'a')
38 | os.makedirs(tempdir, exist_ok=True)
39 |
40 | ptfiles = file_list(a.in_dir, 'pt')
41 |
42 | ptest = torch.load(ptfiles[0])
43 | if isinstance(ptest, list): ptest = ptest[0]
44 | shape = [*ptest.shape[:3], (ptest.shape[3]-1)*2]
45 |
46 | vsteps = a.lsteps if a.length is None else int(a.length * a.fps / count)
47 | pbar = ProgressBar(vsteps * len(ptfiles))
48 | for px in range(len(ptfiles)):
49 | params1 = read_pt(ptfiles[px])
50 | params2 = read_pt(ptfiles[(px+1) % len(ptfiles)])
51 |
52 | params, image_f, _ = fft_image(shape, resume=params1, sd=1., decay_power=a.decay)
53 | image_f = to_valid_rgb(image_f, colors = a.colors)
54 |
55 | for i in range(vsteps):
56 | with torch.no_grad():
57 | x = i/vsteps # math.sin(1.5708 * i/vsteps)
58 | img = image_f((params2 - params1) * x, contrast=a.contrast).cpu().numpy()[0]
59 | checkout(img, os.path.join(tempdir, '%05d.jpg' % (px * vsteps + i)), verbose=a.verbose)
60 | pbar.upd()
61 |
62 | os.system('ffmpeg -v warning -y -i %s/\%%05d.jpg "%s-pts.mp4"' % (tempdir, a.in_dir))
63 |
64 |
65 | if __name__ == '__main__':
66 | main()
67 |
--------------------------------------------------------------------------------
/aphantasia/progress_bar.py:
--------------------------------------------------------------------------------
1 | """
2 | from progress_bar import ProgressBar
3 |
4 | pbar = ProgressBar(steps)
5 | pbar.upd()
6 | """
7 |
8 | import os
9 | import sys
10 | import math
11 | os.system('') #enable VT100 Escape Sequence for WINDOWS 10 Ver. 1607
12 |
13 | from shutil import get_terminal_size
14 | import time
15 |
16 | import ipywidgets as ipy
17 | import IPython
18 | class ProgressIPy(object):
19 | def __init__(self, task_num=10):
20 | self.pbar = ipy.IntProgress(min=0, max=task_num, bar_style='') # (value=0, min=0, max=max, step=1, description=description, bar_style='')
21 | self.labl = ipy.Label()
22 | IPython.display.display(ipy.HBox([self.pbar, self.labl]))
23 | self.task_num = task_num
24 | self.completed = 0
25 | self.start()
26 |
27 | def start(self, task_num=None):
28 | if task_num is not None:
29 | self.task_num = task_num
30 | if self.task_num > 0:
31 | self.labl.value = '0/{}'.format(self.task_num)
32 | else:
33 | self.labl.value = 'completed: 0, elapsed: 0s'
34 | self.start_time = time.time()
35 |
36 | def upd(self, *p, **kw):
37 | self.completed += 1
38 | elapsed = time.time() - self.start_time + 0.0000000000001
39 | fps = self.completed / elapsed if elapsed>0 else 0
40 | if self.task_num > 0:
41 | finaltime = time.asctime(time.localtime(self.start_time + self.task_num * elapsed / float(self.completed)))
42 | fin = ' end %s' % finaltime[11:16]
43 | percentage = self.completed / float(self.task_num)
44 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
45 | self.labl.value = '{}/{}, rate {:.3g}s, time {}s, left {}s, {}'.format(self.completed, self.task_num, 1./fps, shortime(elapsed), shortime(eta), fin)
46 | else:
47 | self.labl.value = 'completed {}, time {}s, {:.1f} steps/s'.format(self.completed, int(elapsed + 0.5), fps)
48 | self.pbar.value += 1
49 | if self.completed == self.task_num: self.pbar.bar_style = 'success'
50 | return self.completed
51 |
52 |
53 | class ProgressBar(object):
54 | '''A progress bar which can print the progress
55 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
56 | '''
57 | def __init__(self, task_num=0, bar_width=50, start=True):
58 | self.task_num = task_num
59 | max_bar_width = self._get_max_bar_width()
60 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
61 | self.completed = 0
62 | if start:
63 | self.start()
64 |
65 | def _get_max_bar_width(self):
66 | terminal_width, _ = get_terminal_size()
67 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
68 | if max_bar_width < 10:
69 | print('terminal is small ({}), make it bigger for proper visualization'.format(terminal_width))
70 | max_bar_width = 10
71 | return max_bar_width
72 |
73 | def start(self, task_num=None):
74 | if task_num is not None:
75 | self.task_num = task_num
76 | if self.task_num > 0:
77 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(' ' * self.bar_width, self.task_num, 'Start...'))
78 | else:
79 | sys.stdout.write('completed: 0, elapsed: 0s')
80 | sys.stdout.flush()
81 | self.start_time = time.time()
82 |
83 | def upd(self, msg=None):
84 | self.completed += 1
85 | elapsed = time.time() - self.start_time + 0.0000000000001
86 | fps = self.completed / elapsed if elapsed>0 else 0
87 | if self.task_num > 0:
88 | percentage = self.completed / float(self.task_num)
89 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
90 | finaltime = time.asctime(time.localtime(self.start_time + self.task_num * elapsed / float(self.completed)))
91 | fin_msg = ' %ss left, end %s' % (shortime(eta), finaltime[11:16])
92 | if msg is not None: fin_msg += ' ' + str(msg)
93 | mark_width = int(self.bar_width * percentage)
94 | bar_chars = 'X' * mark_width + '-' * (self.bar_width - mark_width) # ▒ ▓ █
95 | sys.stdout.write('\033[2A') # cursor up 2 lines
96 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
97 | try:
98 | sys.stdout.write('[{}] {}/{}, rate {:.3g}s, time {}s, left {}s \n{}\n'.format(
99 | bar_chars, self.completed, self.task_num, 1./fps, shortime(elapsed), shortime(eta), fin_msg))
100 | except:
101 | sys.stdout.write('[{}] {}/{}, rate {:.3g}s, time {}s, left {}s \n{}\n'.format(
102 | bar_chars, self.completed, self.task_num, 1./fps, shortime(elapsed), shortime(eta), '<< unprintable >>'))
103 | else:
104 | sys.stdout.write('completed {}, time {}s, {:.1f} steps/s'.format(self.completed, int(elapsed + 0.5), fps))
105 | sys.stdout.flush()
106 |
107 | def reset(self, count=None, newline=False):
108 | self.start_time = time.time()
109 | if count is not None:
110 | self.task_num = count
111 | if newline is True:
112 | sys.stdout.write('\n\n')
113 |
114 | def time_days(sec):
115 | return '%dd %d:%02d:%02d' % (sec/86400, (sec/3600)%24, (sec/60)%60, sec%60)
116 | def time_hrs(sec):
117 | return '%d:%02d:%02d' % (sec/3600, (sec/60)%60, sec%60)
118 | def shortime(sec):
119 | if sec < 60:
120 | time_short = '%d' % (sec)
121 | elif sec < 3600:
122 | time_short = '%d:%02d' % ((sec/60)%60, sec%60)
123 | elif sec < 86400:
124 | time_short = time_hrs(sec)
125 | else:
126 | time_short = time_days(sec)
127 | return time_short
128 |
129 |
--------------------------------------------------------------------------------
/aphantasia/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The Lucent Authors. All Rights Reserved.
2 | # http://www.apache.org/licenses/LICENSE-2.0
3 |
4 | import numpy as np
5 | import PIL
6 | import kornia
7 | import kornia.geometry.transform as K
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torchvision import transforms as T
12 |
13 | from .utils import old_torch
14 |
15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16 |
17 | def random_elastic():
18 | def inner(x):
19 | a = np.random.rand(2)
20 | k = np.random.randint(8,64) * 2 + 1 # 63
21 | s = k / (np.random.rand()+2.) # 2-3 times less than k
22 | # s = float(np.random.randint(8,64)) # 32
23 | noise = torch.zeros([1, 2, x.shape[2], x.shape[3]]).cuda()
24 | return K.elastic_transform2d(x, noise, (k,k), (s,s), tuple(a))
25 | return inner
26 |
27 | def jitter(d):
28 | assert d > 1, "Jitter parameter d must be more than 1, currently {}".format(d)
29 | def inner(image_t):
30 | dx = np.random.choice(d)
31 | dy = np.random.choice(d)
32 | return K.translate(image_t, torch.tensor([[dx, dy]]).float().to(device))
33 | return inner
34 |
35 | def pad(w, mode="reflect", constant_value=0.5):
36 | if mode != "constant":
37 | constant_value = 0
38 | def inner(image_t):
39 | return F.pad(image_t, [w] * 4, mode=mode, value=constant_value,)
40 | return inner
41 |
42 | def random_scale(scales):
43 | def inner(image_t):
44 | scale = np.random.choice(scales)
45 | shp = image_t.shape[2:]
46 | scale_shape = [_roundup(scale * d) for d in shp]
47 | pad_x = max(0, _roundup((shp[1] - scale_shape[1]) / 2))
48 | pad_y = max(0, _roundup((shp[0] - scale_shape[0]) / 2))
49 | upsample = torch.nn.Upsample(size=scale_shape, mode="bilinear", align_corners=True)
50 | return F.pad(upsample(image_t), [pad_y, pad_x] * 2)
51 | return inner
52 |
53 | def random_rotate(angles, units="degrees"):
54 | def inner(image_t):
55 | b, _, h, w = image_t.shape
56 | # kornia takes degrees
57 | alpha = _rads2angle(np.random.choice(angles), units)
58 | angle = torch.ones(b) * alpha
59 | # scale = torch.ones(b)
60 | scale = torch.ones(b, 2)
61 | center = torch.ones(b, 2)
62 | center[..., 0] = (image_t.shape[3] - 1) / 2
63 | center[..., 1] = (image_t.shape[2] - 1) / 2
64 | try:
65 | M = kornia.geometry.transform.get_rotation_matrix2d(center, angle, scale).to(device)
66 | rotated_image = kornia.geometry.transform.warp_affine(image_t.float(), M, dsize=(h, w))
67 | except:
68 | M = kornia.get_rotation_matrix2d(center, angle, scale).to(device)
69 | rotated_image = kornia.warp_affine(image_t.float(), M, dsize=(h, w))
70 | return rotated_image
71 | return inner
72 |
73 | def random_rotate_fast(angles):
74 | def inner(img):
75 | angle = float(np.random.choice(angles))
76 | size = img.shape[-2:]
77 | if old_torch(): # 1.7.1
78 | img = T.functional.affine(img, angle, [0,0], 1, 0, fillcolor=0, resample=PIL.Image.BILINEAR)
79 | else: # 1.8+
80 | img = T.functional.affine(img, angle, [0,0], 1, 0, fill=0, interpolation=T.InterpolationMode.BILINEAR)
81 | img = T.functional.center_crop(img, size) # on 1.8+ also pads
82 | return img
83 | return inner
84 |
85 | def compose(transforms):
86 | def inner(x):
87 | for transform in transforms:
88 | x = transform(x)
89 | return x
90 | return inner
91 |
92 | def _roundup(value):
93 | return np.ceil(value).astype(int)
94 |
95 | def _rads2angle(angle, units):
96 | if units.lower() == "degrees":
97 | return angle
98 | if units.lower() in ["radians", "rads", "rad"]:
99 | angle = angle * 180.0 / np.pi
100 | return angle
101 |
102 | def normalize():
103 | # ImageNet normalization for torchvision models
104 | # see https://pytorch.org/docs/stable/torchvision/models.html
105 | # normal = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
106 | normal = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
107 | def inner(image_t):
108 | return torch.stack([normal(t) for t in image_t])
109 | return inner
110 |
111 | def preprocess_inceptionv1():
112 | # Original Tensorflow's InceptionV1 model takes in [-117, 138]
113 | # See https://github.com/tensorflow/lucid/blob/master/lucid/modelzoo/other_models/InceptionV1.py#L56
114 | # Thanks to ProGamerGov for this!
115 | return lambda x: x * 255 - 117
116 |
117 | # from lucent
118 | transforms_lucent = compose([
119 | pad(12, mode="constant", constant_value=0.5),
120 | jitter(8),
121 | random_scale([1 + (i - 5) / 50.0 for i in range(11)]),
122 | random_rotate(list(range(-10, 11)) + 5 * [0]),
123 | jitter(4),
124 | ])
125 |
126 | # from openai
127 | transforms_openai = compose([
128 | pad(2, mode='constant', constant_value=.5),
129 | jitter(4),
130 | jitter(4),
131 | jitter(4),
132 | jitter(4),
133 | jitter(4),
134 | jitter(4),
135 | jitter(4),
136 | jitter(4),
137 | jitter(4),
138 | jitter(4),
139 | # random_scale([0.995**n for n in range(-5,80)] + [0.998**n for n in 2*list(range(20,40))]),
140 | random_rotate(list(range(-20,20))+list(range(-10,10))+list(range(-5,5))+5*[0]),
141 | jitter(2),
142 | # crop_or_pad_to(resolution, resolution)
143 | ])
144 |
145 | # my compos
146 |
147 | transforms_elastic = compose([
148 | pad(4, mode="constant", constant_value=0.5),
149 | T.RandomErasing(0.2),
150 | random_rotate(list(range(-30, 30)) + 20 * [0]),
151 | random_elastic(),
152 | jitter(8),
153 | normalize()
154 | ])
155 |
156 | transforms_custom = compose([
157 | pad(4, mode="constant", constant_value=0.5),
158 | # T.RandomPerspective(0.33, 0.2),
159 | # T.RandomErasing(0.2),
160 | random_rotate(list(range(-30, 30)) + 20 * [0]),
161 | jitter(8),
162 | normalize()
163 | ])
164 |
165 | transforms_fast = compose([
166 | T.RandomPerspective(0.33, 0.2),
167 | T.RandomErasing(0.2),
168 | random_rotate_fast(list(range(-30, 30)) + 20 * [0]),
169 | normalize()
170 | ])
171 |
172 |
--------------------------------------------------------------------------------
/aphantasia/utils.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import os
3 | import math
4 | import time
5 | from imageio import imread, imsave
6 | import cv2
7 | import numpy as np
8 | import collections
9 | import scipy
10 | from scipy.ndimage import gaussian_filter
11 | from scipy.interpolate import CubicSpline as CubSpline
12 | import matplotlib.pyplot as plt
13 | from kornia.filters.sobel import spatial_gradient
14 |
15 | import torch
16 | import torch.nn.functional as F
17 |
18 | def plot_text(txt, size=224):
19 | fig = plt.figure(figsize=(1,1), dpi=size)
20 | fontsize = size//len(txt) if len(txt) < 15 else 8
21 | plt.text(0.5, 0.5, txt, fontsize=fontsize, ha='center', va='center', wrap=True)
22 | plt.axis('off')
23 | fig.tight_layout(pad=0)
24 | fig.canvas.draw()
25 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
26 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
27 | return img
28 |
29 | def txt_clean(txt):
30 | return txt.translate(str.maketrans(dict.fromkeys(list("\n',.—|!?/:;\\"), ""))).replace(' ', '_').replace('"', '')
31 |
32 | def intrl(a, b, step=2):
33 | assert len(a) == len(b), ' diff lengths: %d %d' % (len(a), len(b))
34 | assert step > 1
35 | nums = list(range(len(a)))[step::step]
36 | for num in nums:
37 | a[num] = b[num]
38 | return a
39 |
40 | def old_torch():
41 | ver = [int(i) for i in torch.__version__.split('.')[:2]]
42 | return True if (ver[0] < 2 and ver[1] < 8) else False
43 |
44 | def basename(file):
45 | return os.path.splitext(os.path.basename(file))[0]
46 |
47 | def file_list(path, ext=None, subdir=None):
48 | if subdir is True:
49 | files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn]
50 | else:
51 | files = [os.path.join(path, f) for f in os.listdir(path)]
52 | if ext is not None:
53 | if isinstance(ext, list):
54 | files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ext]
55 | elif isinstance(ext, str):
56 | files = [f for f in files if f.endswith(ext)]
57 | else:
58 | print(' Unknown extension/type for file list!')
59 | return sorted([f for f in files if os.path.isfile(f)])
60 |
61 | def img_list(path, subdir=None):
62 | if subdir is True:
63 | files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn]
64 | else:
65 | files = [os.path.join(path, f) for f in os.listdir(path)]
66 | files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ['jpg', 'jpeg', 'png', 'ppm', 'tif']]
67 | return sorted([f for f in files if os.path.isfile(f)])
68 |
69 | def img_read(path):
70 | img = imread(path)
71 | # 8bit to 256bit
72 | if (img.ndim == 2) or (img.shape[2] == 1):
73 | img = np.dstack((img,img,img))
74 | # rgba to rgb
75 | if img.shape[2] == 4:
76 | img = img[:,:,:3]
77 | return img
78 |
79 | def img_save(path, img, norm=True):
80 | if norm == True and not np.issubdtype(img.dtype.kind, np.integer):
81 | img = (img*255).astype(np.uint8)
82 | imsave(path, img)
83 |
84 | def cvshow(img):
85 | img = np.array(img)
86 | if img.shape[0] > 720 or img.shape[1] > 1280:
87 | x_ = 1280 / img.shape[1]
88 | y_ = 720 / img.shape[0]
89 | psize = tuple([int(s * min(x_, y_)) for s in img.shape[:2][::-1]])
90 | img = cv2.resize(img, psize)
91 | cv2.imshow('t', img[:,:,::-1])
92 | cv2.waitKey(1)
93 |
94 | def checkout(img, fname=None, verbose=False):
95 | img = np.transpose(np.array(img)[:,:,:], (1,2,0))
96 | if verbose is True:
97 | cvshow(img)
98 | if fname is not None:
99 | img = np.clip(img*255, 0, 255).astype(np.uint8)
100 | imsave(fname, img)
101 |
102 | def save_cfg(args, dir='./', file='config.txt'):
103 | if dir != '':
104 | os.makedirs(dir, exist_ok=True)
105 | try: args = vars(args)
106 | except: pass
107 | if file is None:
108 | print_dict(args)
109 | else:
110 | with open(os.path.join(dir, file), 'w') as cfg_file:
111 | print_dict(args, cfg_file)
112 |
113 | def print_dict(dict, file=None, path="", indent=''):
114 | for k in sorted(dict.keys()):
115 | if isinstance(dict[k], collections.abc.Mapping):
116 | if file is None:
117 | print(indent + str(k))
118 | else:
119 | file.write(indent + str(k) + ' \n')
120 | path = k if path=="" else path + "->" + k
121 | print_dict(dict[k], file, path, indent + ' ')
122 | else:
123 | if file is None:
124 | print('%s%s: %s' % (indent, str(k), str(dict[k])))
125 | else:
126 | file.write('%s%s: %s \n' % (indent, str(k), str(dict[k])))
127 |
128 | def minmax(x, torch=True):
129 | if torch:
130 | mn = torch.min(x).detach().cpu().numpy()
131 | mx = torch.max(x).detach().cpu().numpy()
132 | else:
133 | mn = np.min(x.detach().cpu().numpy())
134 | mx = np.max(x.detach().cpu().numpy())
135 | return (mn, mx)
136 |
137 | def triangle_blur(x, kernel_size=3, pow=1.0):
138 | padding = (kernel_size-1) // 2
139 | b,c,h,w = x.shape
140 | kernel = torch.linspace(-1,1,kernel_size+2)[1:-1].abs().neg().add(1).reshape(1,1,1,kernel_size).pow(pow).cuda()
141 | kernel = kernel / kernel.sum()
142 | x = x.reshape(b*c,1,h,w)
143 | x = F.pad(x, (padding,padding,padding,padding), mode='reflect')
144 | x = F.conv2d(x, kernel)
145 | x = F.conv2d(x, kernel.permute(0,1,3,2))
146 | x = x.reshape(b,c,h,w)
147 | return x
148 |
149 | # Tiles an array around two points, allowing for pad lengths greater than the input length
150 | # NB: if symm=True, every second tile is mirrored = messed up in GAN
151 | # adapted from https://discuss.pytorch.org/t/symmetric-padding/19866/3
152 | def tile_pad(xt, padding, symm=False):
153 | h, w = xt.shape[-2:]
154 | left, right, top, bottom = padding
155 |
156 | def tile(x, minx, maxx):
157 | rng = maxx - minx
158 | if symm is True: # triangular reflection
159 | double_rng = 2*rng
160 | mod = np.fmod(x - minx, double_rng)
161 | normed_mod = np.where(mod < 0, mod+double_rng, mod)
162 | out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx
163 | else: # repeating tiles
164 | mod = np.remainder(x - minx, rng)
165 | out = mod + minx
166 | return np.array(out, dtype=x.dtype)
167 |
168 | x_idx = np.arange(-left, w+right)
169 | y_idx = np.arange(-top, h+bottom)
170 | x_pad = tile(x_idx, -0.5, w-0.5)
171 | y_pad = tile(y_idx, -0.5, h-0.5)
172 | xx, yy = np.meshgrid(x_pad, y_pad)
173 | return xt[..., yy, xx]
174 |
175 | def pad_up_to(x, size, type='centr'):
176 | sh = x.shape[2:][::-1]
177 | if list(x.shape[2:]) == list(size): return x
178 | padding = []
179 | for i, s in enumerate(size[::-1]):
180 | if 'side' in type.lower():
181 | padding = padding + [0, s-sh[i]]
182 | else: # centr
183 | p0 = (s-sh[i]) // 2
184 | p1 = s-sh[i] - p0
185 | padding = padding + [p0,p1]
186 | y = tile_pad(x, padding, symm = ('symm' in type.lower()))
187 | return y
188 |
189 | def smoothstep(x, NN=1, xmin=0., xmax=1.):
190 | N = math.ceil(NN)
191 | x = np.clip((x - xmin) / (xmax - xmin), 0, 1)
192 | result = 0
193 | for n in range(0, N+1):
194 | result += scipy.special.comb(N+n, n) * scipy.special.comb(2*N+1, N-n) * (-x)**n
195 | result *= x**(N+1)
196 | if NN != N: result = (x + result) / 2
197 | return result
198 |
199 | def slerp(z1, z2, num_steps=None, x=None, smooth=0.5):
200 | z1_norm = z1.norm()
201 | z2_norm = z2.norm()
202 | z2_normal = z2 * (z1_norm / z2_norm)
203 | vectors = []
204 | if num_steps is not None:
205 | xs = [step / (num_steps - 1) for step in range(num_steps)]
206 | else:
207 | xs = [x]
208 | if smooth > 0: xs = [smoothstep(x, smooth) for x in xs]
209 | for x in xs:
210 | interplain = z1 + (z2 - z1) * x
211 | interp = z1 + (z2_normal - z1) * x
212 | interp_norm = interp.norm()
213 | if interp_norm != 0:
214 | interpol_normal = interplain * (z1_norm / interp_norm)
215 | vectors.append(interpol_normal)
216 | return torch.cat(vectors)
217 |
218 | def slice_imgs(imgs, count, size=224, transform=None, align='uniform', macro=0.):
219 | def map(x, a, b):
220 | return x * (b-a) + a
221 |
222 | rnd_size = torch.rand(count)
223 | if align == 'central': # normal around center
224 | rnd_offx = torch.clip(torch.randn(count) * 0.2 + 0.5, 0., 1.)
225 | rnd_offy = torch.clip(torch.randn(count) * 0.2 + 0.5, 0., 1.)
226 | else: # uniform
227 | rnd_offx = torch.rand(count)
228 | rnd_offy = torch.rand(count)
229 |
230 | sz = [img.shape[2:] for img in imgs]
231 | sz_max = [torch.min(torch.tensor(s)) for s in sz]
232 | if 'over' in align: # expand frame to sample outside
233 | if align == 'overmax':
234 | sz = [[2*s[0], 2*s[1]] for s in list(sz)]
235 | else:
236 | sz = [[int(1.5*s[0]), int(1.5*s[1])] for s in list(sz)]
237 | imgs = [pad_up_to(imgs[i], sz[i], type='centr') for i in range(len(imgs))]
238 |
239 | sliced = []
240 | for i, img in enumerate(imgs):
241 | cuts = []
242 | sz_max_i = sz_max[i]
243 | for c in range(count):
244 | sz_min_i = 0.9*sz_max[i] if torch.rand(1) < macro else size
245 | csize = map(rnd_size[c], sz_min_i, sz_max_i).int()
246 | offsetx = map(rnd_offx[c], 0, sz[i][1] - csize).int()
247 | offsety = map(rnd_offy[c], 0, sz[i][0] - csize).int()
248 | cut = img[:, :, offsety:offsety + csize, offsetx:offsetx + csize]
249 | cut = F.interpolate(cut, (size,size), mode='bicubic', align_corners=True) # bilinear
250 | if transform is not None:
251 | cut = transform(cut)
252 | cuts.append(cut)
253 | sliced.append(torch.cat(cuts, 0))
254 | return sliced
255 |
256 | def derivat(img, mode='sobel'):
257 | if mode == 'scharr':
258 | # https://en.wikipedia.org/wiki/Sobel_operator#Alternative_operators
259 | k_scharr = torch.Tensor([[[-0.183,0.,0.183], [-0.634,0.,0.634], [-0.183,0.,0.183]], [[-0.183,-0.634,-0.183], [0.,0.,0.], [0.183,0.634,0.183]]])
260 | k_scharr = k_scharr.unsqueeze(1).tile((1,3,1,1)).cuda()
261 | return 0.2 * torch.mean(torch.abs(F.conv2d(img, k_scharr)))
262 | elif mode == 'sobel':
263 | # https://kornia.readthedocs.io/en/latest/filters.html#edge-detection
264 | return torch.mean(torch.abs(spatial_gradient(img)))
265 | else: # trivial hack
266 | dx = torch.mean(torch.abs(img[:,:,:,1:] - img[:,:,:,:-1]))
267 | dy = torch.mean(torch.abs(img[:,:,1:,:] - img[:,:,:-1,:]))
268 | return 0.5 * (dx+dy)
269 |
270 | def dot_compare(v1, v2, cossim_pow=0):
271 | dot = (v1 * v2).sum()
272 | mag = torch.sqrt(torch.sum(v2**2))
273 | cossim = dot/(1e-6 + mag)
274 | return dot * cossim ** cossim_pow
275 |
276 | def sim_func(v1, v2, type=None):
277 | if type is not None and 'mix' in type: # mixed
278 | coss = torch.cosine_similarity(v1, v2, dim=-1).mean()
279 | v1 = F.normalize(v1, dim=-1)
280 | v2 = F.normalize(v2, dim=-1)
281 | spher = torch.abs((v1 - v2).norm(dim=-1).div(2).arcsin().pow(2).mul(2)).mean()
282 | return coss - 0.25 * spher
283 | elif type is not None and 'spher' in type: # spherical
284 | # from https://colab.research.google.com/drive/1ED6_MYVXTApBHzQObUPaaMolgf9hZOOF
285 | v1 = F.normalize(v1, dim=-1)
286 | v2 = F.normalize(v2, dim=-1)
287 | # return 1 - torch.abs((v1 - v2).norm(dim=-1).div(2).arcsin().pow(2).mul(2)).mean()
288 | return (v1 - v2).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
289 | elif type is not None and 'ang' in type: # angular
290 | # return 1 - torch.acos(torch.cosine_similarity(v1, v2, dim=-1).mean()) / np.pi
291 | return 1 - torch.acos(torch.cosine_similarity(v1, v2, dim=-1)).mean() / np.pi
292 | elif type is not None and 'dot' in type: # dot compare cossim from lucent inversion
293 | return dot_compare(v1, v2, cossim_pow=1) # decrease pow if nan (black output)
294 | else:
295 | return torch.cosine_similarity(v1, v2, dim=-1).mean()
296 |
297 | # = = = = = = = = = = = = = = = = = = = = = = = = = = =
298 |
299 | def get_z(shape, rnd, uniform=False):
300 | if uniform:
301 | return rnd.uniform(0., 1., shape)
302 | else:
303 | return rnd.randn(*shape) # *x unpacks tuple/list to sequence
304 |
305 | def smoothstep(x, NN=1., xmin=0., xmax=1.):
306 | N = math.ceil(NN)
307 | x = np.clip((x - xmin) / (xmax - xmin), 0, 1)
308 | result = 0
309 | for n in range(0, N+1):
310 | result += scipy.special.comb(N+n, n) * scipy.special.comb(2*N+1, N-n) * (-x)**n
311 | result *= x**(N+1)
312 | if NN != N: result = (x + result) / 2
313 | return result
314 |
315 | def lerp(z1, z2, num_steps, smooth=0.):
316 | vectors = []
317 | xs = [step / (num_steps - 1) for step in range(num_steps)]
318 | if smooth > 0: xs = [smoothstep(x, smooth) for x in xs]
319 | for x in xs:
320 | interpol = z1 + (z2 - z1) * x
321 | vectors.append(interpol)
322 | return np.array(vectors)
323 |
324 | # interpolate on hypersphere
325 | def slerp_np(z1, z2, num_steps, smooth=0.):
326 | z1_norm = np.linalg.norm(z1)
327 | z2_norm = np.linalg.norm(z2)
328 | z2_normal = z2 * (z1_norm / z2_norm)
329 | vectors = []
330 | xs = [step / (num_steps - 1) for step in range(num_steps)]
331 | if smooth > 0: xs = [smoothstep(x, smooth) for x in xs]
332 | for x in xs:
333 | interplain = z1 + (z2 - z1) * x
334 | interp = z1 + (z2_normal - z1) * x
335 | interp_norm = np.linalg.norm(interp)
336 | interpol_normal = interplain * (z1_norm / interp_norm)
337 | # interpol_normal = interp * (z1_norm / interp_norm)
338 | vectors.append(interpol_normal)
339 | return np.array(vectors)
340 |
341 | def cublerp(points, steps, fstep, looped=True):
342 | keys = np.array([i*fstep for i in range(steps)] + [steps*fstep])
343 | last_pt_num = 0 if looped is True else -1
344 | points = np.concatenate((points, np.expand_dims(points[last_pt_num], 0)))
345 | cspline = CubSpline(keys, points)
346 | return cspline(range(steps*fstep+1))
347 |
348 | # = = = = = = = = = = = = = = = = = = = = = = = = = = =
349 |
350 | def latent_anima(shape, frames, transit, key_latents=None, smooth=0.5, uniform=False, cubic=False, gauss=False, start_lat=None, seed=None, looped=True, verbose=False):
351 | if key_latents is None:
352 | transit = int(max(1, min(frames//2, transit)))
353 | steps = max(1, math.ceil(frames / transit))
354 | log = ' timeline: %d steps by %d' % (steps, transit)
355 |
356 | if seed is None:
357 | seed = np.random.seed(int((time.time()%1) * 9999))
358 | rnd = np.random.RandomState(seed)
359 |
360 | # make key points
361 | if key_latents is None:
362 | key_latents = np.array([get_z(shape, rnd, uniform) for i in range(steps)])
363 | if start_lat is not None:
364 | key_latents[0] = start_lat
365 |
366 | latents = np.expand_dims(key_latents[0], 0)
367 |
368 | # populate lerp between key points
369 | if transit == 1:
370 | latents = key_latents
371 | else:
372 | if cubic:
373 | latents = cublerp(key_latents, steps, transit, looped)
374 | log += ', cubic'
375 | else:
376 | for i in range(steps):
377 | zA = key_latents[i]
378 | lat_num = (i+1)%steps if looped is True else min(i+1, steps-1)
379 | zB = key_latents[lat_num]
380 | if uniform is True:
381 | interps_z = lerp(zA, zB, transit, smooth=smooth)
382 | else:
383 | interps_z = slerp_np(zA, zB, transit, smooth=smooth)
384 | latents = np.concatenate((latents, interps_z))
385 | latents = np.array(latents)
386 |
387 | if gauss:
388 | lats_post = gaussian_filter(latents, [transit, 0, 0], mode="wrap")
389 | lats_post = (lats_post / np.linalg.norm(lats_post, axis=-1, keepdims=True)) * math.sqrt(np.prod(shape))
390 | log += ', gauss'
391 | latents = lats_post
392 |
393 | if verbose: print(log)
394 | if latents.shape[0] > frames: # extra frame
395 | latents = latents[1:]
396 | return latents
397 |
398 | # = = = = = = = = = = = = = = = = = = = = = = = = = = =
399 |
400 | # from https://github.com/LAION-AI/aesthetic-predictor
401 | from urllib.request import urlretrieve # pylint: disable=import-outside-toplevel
402 | def aesthetic_model(clip_model='ViT-B/32'):
403 | nf = 768 if clip_model == "ViT-L/14" else 512 if clip_model in ['ViT-B/16', 'ViT-B/32'] else None
404 | clip_model = clip_model.replace('/','_').replace('-','_').lower()
405 | path_to_model = 'sa_0_4_%s_linear.pth' % clip_model
406 | if not os.path.isfile(path_to_model):
407 | url_model = "https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_" + clip_model + "_linear.pth?raw=true"
408 | urlretrieve(url_model, path_to_model)
409 | if nf is None or not os.path.isfile(path_to_model): return None
410 | m = torch.nn.Linear(nf, 1)
411 | m.load_state_dict(torch.load(path_to_model))
412 | m.eval().half()
413 | return m
414 |
415 |
--------------------------------------------------------------------------------
/clip_fft.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | warnings.filterwarnings("ignore")
4 | import argparse
5 | import numpy as np
6 | from imageio import imread, imsave
7 | import shutil
8 |
9 | try:
10 | from googletrans import Translator
11 | googletrans_ok = True
12 | except:
13 | googletrans_ok = False
14 |
15 | import torch
16 | import torchvision
17 | import torch.nn.functional as F
18 |
19 | import clip
20 | os.environ['KMP_DUPLICATE_LIB_OK']='True'
21 | # from sentence_transformers import SentenceTransformer
22 | import lpips
23 |
24 | from aphantasia.image import to_valid_rgb, fft_image, dwt_image
25 | from aphantasia.utils import slice_imgs, derivat, sim_func, aesthetic_model, basename, img_list, img_read, plot_text, txt_clean, checkout, old_torch
26 | from aphantasia import transforms
27 | try: # progress bar for notebooks
28 | get_ipython().__class__.__name__
29 | from aphantasia.progress_bar import ProgressIPy as ProgressBar
30 | except: # normal console
31 | from aphantasia.progress_bar import ProgressBar
32 |
33 | clip_models = ['ViT-B/16', 'ViT-B/32', 'RN101', 'RN50x16', 'RN50x4', 'RN50']
34 |
35 | def get_args():
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument('-t', '--in_txt', default=None, help='input text')
38 | parser.add_argument('-t2', '--in_txt2', default=None, help='input text - style')
39 | parser.add_argument('-t0', '--in_txt0', default=None, help='input text to subtract')
40 | parser.add_argument('-i', '--in_img', default=None, help='input image')
41 | parser.add_argument('-wi', '--weight_img', default=0.5, type=float, help='weight for images')
42 | parser.add_argument( '--out_dir', default='_out')
43 | parser.add_argument('-s', '--size', default='1280-720', help='Output resolution')
44 | parser.add_argument('-r', '--resume', default=None, help='Path to saved FFT snapshots, to resume from')
45 | parser.add_argument('-ops', '--opt_step', default=1, type=int, help='How many optimizing steps per save step')
46 | parser.add_argument('-tr', '--translate', action='store_true', help='Translate text with Google Translate')
47 | # parser.add_argument('-ml', '--multilang', action='store_true', help='Use SBERT multilanguage model for text')
48 | parser.add_argument( '--save_pt', action='store_true', help='Save FFT snapshots for further use')
49 | parser.add_argument('-v', '--verbose', dest='verbose', action='store_true')
50 | parser.add_argument('-nv', '--no-verbose', dest='verbose', action='store_false')
51 | parser.set_defaults(verbose=True)
52 | # training
53 | parser.add_argument('-m', '--model', default='ViT-B/32', choices=clip_models, help='Select CLIP model to use')
54 | parser.add_argument( '--steps', default=200, type=int, help='Total iterations')
55 | parser.add_argument( '--samples', default=200, type=int, help='Samples to evaluate')
56 | parser.add_argument('-lr', '--lrate', default=0.05, type=float, help='Learning rate')
57 | parser.add_argument('-p', '--prog', action='store_true', help='Enable progressive lrate growth (up to double a.lrate)')
58 | parser.add_argument('-dm', '--dualmod', default=None, type=int, help='Every this step use another CLIP ViT model')
59 | # wavelet
60 | parser.add_argument( '--dwt', action='store_true', help='Use DWT instead of FFT')
61 | parser.add_argument('-w', '--wave', default='coif2', help='wavelets: db[1..], coif[1..], haar, dmey')
62 | # tweaks
63 | parser.add_argument('-a', '--align', default='uniform', choices=['central', 'uniform', 'overscan', 'overmax'], help='Sampling distribution')
64 | parser.add_argument('-tf', '--transform', default='fast', choices=['none', 'fast', 'custom', 'elastic'], help='augmenting transforms')
65 | parser.add_argument('-opt', '--optimizer', default='adam_custom', choices=['adam', 'adamw', 'adam_custom', 'adamw_custom'], help='Optimizer')
66 | parser.add_argument( '--contrast', default=1.1, type=float)
67 | parser.add_argument( '--colors', default=1.8, type=float)
68 | parser.add_argument( '--decay', default=1.5, type=float)
69 | parser.add_argument('-sh', '--sharp', default=0., type=float)
70 | parser.add_argument('-mm', '--macro', default=0.4, type=float, help='Endorse macro forms 0..1 ')
71 | parser.add_argument( '--aest', default=0., type=float, help='Enhance aesthetics')
72 | parser.add_argument('-e', '--enforce', default=0, type=float, help='Enforce details (by boosting similarity between two parallel samples)')
73 | parser.add_argument('-x', '--expand', default=0, type=float, help='Boosts diversity (by enforcing difference between prev/next samples)')
74 | parser.add_argument('-n', '--noise', default=0, type=float, help='Add noise to suppress accumulation') # < 0.05 ?
75 | parser.add_argument('-c', '--sync', default=0, type=float, help='Sync output to input image')
76 | parser.add_argument( '--invert', action='store_true', help='Invert criteria')
77 | parser.add_argument( '--sim', default='mix', help='Similarity function (dot/angular/spherical/mixed; None = cossim)')
78 | a = parser.parse_args()
79 |
80 | if a.size is not None: a.size = [int(s) for s in a.size.split('-')][::-1]
81 | if len(a.size)==1: a.size = a.size * 2
82 | if (a.in_img is not None and a.sync != 0) or a.resume is not None: a.align = 'overscan'
83 | # if a.multilang is True: a.model = 'ViT-B/32' # sbert model is trained with ViT
84 | if a.translate is True and googletrans_ok is not True:
85 | print('\n Install googletrans module to enable translation!'); exit()
86 | if a.dualmod is not None:
87 | a.model = 'ViT-B/32'
88 | a.sim = 'cossim'
89 |
90 | return a
91 |
92 | def main():
93 | a = get_args()
94 |
95 | shape = [1, 3, *a.size]
96 | if a.dwt is True:
97 | params, image_f, sz = dwt_image(shape, a.wave, 0.3, a.colors, a.resume)
98 | else:
99 | params, image_f, sz = fft_image(shape, 0.07, a.decay, a.resume)
100 | if sz is not None: a.size = sz
101 | image_f = to_valid_rgb(image_f, colors = a.colors)
102 |
103 | if a.prog is True:
104 | lr1 = a.lrate * 2
105 | lr0 = lr1 * 0.01
106 | else:
107 | lr0 = a.lrate
108 | if a.optimizer.lower() == 'adamw':
109 | optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01)
110 | elif a.optimizer.lower() == 'adamw_custom':
111 | optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01, betas=(.0,.999), amsgrad=True)
112 | elif a.optimizer.lower() == 'adam':
113 | optimizer = torch.optim.Adam(params, lr0)
114 | else: # adam_custom
115 | optimizer = torch.optim.Adam(params, lr0, betas=(.0,.999))
116 | sign = 1. if a.invert is True else -1.
117 |
118 | # Load CLIP models
119 | model_clip, _ = clip.load(a.model, jit=old_torch())
120 | try:
121 | a.modsize = model_clip.visual.input_resolution
122 | except:
123 | a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224
124 | if a.verbose is True: print(' using model', a.model)
125 | xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33}
126 | if a.model in xmem.keys():
127 | a.samples = int(a.samples * xmem[a.model])
128 |
129 | # if a.multilang is True:
130 | # model_lang = SentenceTransformer('clip-ViT-B-32-multilingual-v1').cuda()
131 |
132 | if a.dualmod is not None: # second is vit-16
133 | model_clip2, _ = clip.load('ViT-B/16', jit=old_torch())
134 | a.samples = int(a.samples * 0.23)
135 | dualmod_nums = list(range(a.steps))[a.dualmod::a.dualmod]
136 | print(' dual model every %d step' % a.dualmod)
137 |
138 | if a.aest != 0 and a.model in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14']:
139 | aest = aesthetic_model(a.model).cuda()
140 | if a.dualmod is not None:
141 | aest2 = aesthetic_model('ViT-B/16').cuda()
142 |
143 | def enc_text(txt, model_clip=model_clip):
144 | embs = []
145 | for subtxt in txt.split('|'):
146 | if ':' in subtxt:
147 | [subtxt, wt] = subtxt.split(':')
148 | wt = float(wt)
149 | else: wt = 1.
150 | emb = model_clip.encode_text(clip.tokenize(subtxt).cuda())
151 | # if a.multilang is True:
152 | # emb = model_lang.encode([subtxt], convert_to_tensor=True, show_progress_bar=False)
153 | embs.append([emb.detach().clone(), wt])
154 | return embs
155 |
156 | if a.enforce != 0:
157 | a.samples = int(a.samples * 0.5)
158 | if a.sync > 0:
159 | a.samples = int(a.samples * 0.5)
160 |
161 | if 'elastic' in a.transform:
162 | trform_f = transforms.transforms_elastic
163 | a.samples = int(a.samples * 0.95)
164 | elif 'custom' in a.transform:
165 | trform_f = transforms.transforms_custom
166 | a.samples = int(a.samples * 0.95)
167 | elif 'fast' in a.transform:
168 | trform_f = transforms.transforms_fast
169 | a.samples = int(a.samples * 0.95)
170 | else:
171 | trform_f = transforms.normalize()
172 |
173 | out_name = []
174 | if a.in_txt is not None:
175 | if a.verbose is True: print(' topic text: ', a.in_txt)
176 | if a.translate:
177 | translator = Translator()
178 | a.in_txt = translator.translate(a.in_txt, dest='en').text
179 | if a.verbose is True: print(' translated to:', a.in_txt)
180 | txt_enc = enc_text(a.in_txt)
181 | out_name.append(txt_clean(a.in_txt).lower()[:40])
182 | if a.dualmod is not None:
183 | txt_enc2 = enc_text(a.in_txt, model_clip2)
184 |
185 | if a.in_txt2 is not None:
186 | if a.verbose is True: print(' style text:', a.in_txt2)
187 | a.samples = int(a.samples * 0.75)
188 | if a.translate:
189 | translator = Translator()
190 | a.in_txt2 = translator.translate(a.in_txt2, dest='en').text
191 | if a.verbose is True: print(' translated to:', a.in_txt2)
192 | style_enc = enc_text(a.in_txt2)
193 | out_name.append(txt_clean(a.in_txt2).lower()[:40])
194 | if a.dualmod is not None:
195 | style_enc2 = enc_text(a.in_txt2, model_clip2)
196 |
197 | if a.in_txt0 is not None:
198 | if a.verbose is True: print(' subtract text:', a.in_txt0)
199 | a.samples = int(a.samples * 0.75)
200 | if a.translate:
201 | translator = Translator()
202 | a.in_txt0 = translator.translate(a.in_txt0, dest='en').text
203 | if a.verbose is True: print(' translated to:', a.in_txt0)
204 | not_enc = enc_text(a.in_txt0)
205 | out_name.append('off-' + txt_clean(a.in_txt0).lower()[:40])
206 | if a.dualmod is not None:
207 | not_enc2 = enc_text(a.in_txt0, model_clip2)
208 |
209 | # if a.multilang is True: del model_lang
210 |
211 | if a.in_img is not None and os.path.isfile(a.in_img):
212 | if a.verbose is True: print(' ref image:', basename(a.in_img))
213 | img_in = torch.from_numpy(img_read(a.in_img)/255.).unsqueeze(0).permute(0,3,1,2).cuda()
214 | img_in = img_in[:,:3,:,:] # fix rgb channels
215 | in_sliced = slice_imgs([img_in], a.samples, a.modsize, transforms.normalize(), a.align)[0]
216 | img_enc = model_clip.encode_image(in_sliced).detach().clone()
217 | if a.dualmod is not None:
218 | img_enc2 = model_clip2.encode_image(in_sliced).detach().clone()
219 | if a.sync > 0:
220 | sim_loss = lpips.LPIPS(net='vgg', verbose=False).cuda()
221 | sim_size = [s//2 for s in a.size]
222 | img_in = F.interpolate(img_in, sim_size, mode='bicubic', align_corners=True).float()
223 | else:
224 | del img_in
225 | del in_sliced; torch.cuda.empty_cache()
226 | out_name.append(basename(a.in_img).replace(' ', '_'))
227 |
228 | if a.verbose is True: print(' samples:', a.samples)
229 | out_name = '-'.join(out_name)
230 | out_name += '-%s' % a.model.replace('/','').replace('-','') if a.dualmod is None else '-dm%d' % a.dualmod
231 | tempdir = os.path.join(a.out_dir, out_name)
232 | os.makedirs(tempdir, exist_ok=True)
233 |
234 | prev_enc = 0
235 | def train(i):
236 | loss = 0
237 |
238 | noise = a.noise * torch.rand(1, 1, *params[0].shape[2:4], 1).cuda() if a.noise > 0 else None
239 | img_out = image_f(noise)
240 | img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
241 |
242 | if a.in_txt is not None: # input text
243 | txt_enc_ = txt_enc2 if a.dualmod is not None and i in dualmod_nums else txt_enc
244 | if a.in_txt2 is not None:
245 | style_enc_ = style_enc2 if a.dualmod is not None and i in dualmod_nums else style_enc
246 | if a.in_img is not None and os.path.isfile(a.in_img):
247 | img_enc_ = img_enc2 if a.dualmod is not None and i in dualmod_nums else img_enc
248 | if a.in_txt0 is not None:
249 | not_enc_ = not_enc2 if a.dualmod is not None and i in dualmod_nums else not_enc
250 | model_clip_ = model_clip2 if a.dualmod is not None and i in dualmod_nums else model_clip
251 | if a.aest != 0:
252 | aest_ = aest2 if a.dualmod is not None and i in dualmod_nums else aest
253 |
254 | out_enc = model_clip_.encode_image(img_sliced)
255 | if a.aest != 0 and aest_ is not None:
256 | loss -= 0.001 * a.aest * aest_(out_enc).mean()
257 | if a.in_txt is not None: # input text
258 | for enc, wt in txt_enc_:
259 | loss += sign * wt * sim_func(enc, out_enc, a.sim)
260 | if a.in_txt2 is not None: # input text - style
261 | for enc, wt in style_enc_:
262 | loss += sign * wt * sim_func(enc, out_enc, a.sim)
263 | if a.in_txt0 is not None: # subtract text
264 | for enc, wt in not_enc_:
265 | loss += -sign * wt * sim_func(enc, out_enc, a.sim)
266 | if a.in_img is not None and os.path.isfile(a.in_img): # input image
267 | loss += sign * a.weight_img * sim_func(img_enc_, out_enc, a.sim)
268 | if a.sync > 0 and a.in_img is not None and os.path.isfile(a.in_img): # image composition
269 | prog_sync = (a.steps // a.opt_step - i) / (a.steps // a.opt_step)
270 | loss += prog_sync * a.sync * sim_loss(F.interpolate(img_out, sim_size, mode='bicubic', align_corners=True).float(), img_in, normalize=True).squeeze()
271 | if a.sharp != 0 and a.dwt is not True: # scharr|sobel|default
272 | loss -= a.sharp * derivat(img_out, mode='naiv')
273 | # loss -= a.sharp * derivat(img_sliced, mode='scharr')
274 | if a.enforce != 0:
275 | img_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
276 | out_enc2 = model_clip_.encode_image(img_sliced)
277 | loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim)
278 | del out_enc2; torch.cuda.empty_cache()
279 | if a.expand > 0:
280 | global prev_enc
281 | if i > 0:
282 | loss += a.expand * sim_func(out_enc, prev_enc, a.sim)
283 | prev_enc = out_enc.detach() # .clone()
284 |
285 | del img_out, img_sliced, out_enc; torch.cuda.empty_cache()
286 | assert not isinstance(loss, int), ' Loss not defined, check the inputs'
287 |
288 | if a.prog is True:
289 | lr_cur = lr0 + (i / a.steps) * (lr1 - lr0)
290 | for g in optimizer.param_groups:
291 | g['lr'] = lr_cur
292 |
293 | optimizer.zero_grad()
294 | loss.backward()
295 | optimizer.step()
296 |
297 | if i % a.opt_step == 0:
298 | with torch.no_grad():
299 | img = image_f(contrast=a.contrast).cpu().numpy()[0]
300 | # empirical tone mapping
301 | if (a.sync > 0 and a.in_img is not None):
302 | img = img **1.3
303 | elif a.sharp != 0:
304 | img = img ** (1 + a.sharp/2.)
305 | checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.opt_step)), verbose=a.verbose)
306 | pbar.upd()
307 |
308 | pbar = ProgressBar(a.steps // a.opt_step)
309 | for i in range(a.steps):
310 | train(i)
311 |
312 | os.system('ffmpeg -v warning -y -i %s/\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(a.out_dir, out_name)))
313 | shutil.copy(img_list(tempdir)[-1], os.path.join(a.out_dir, '%s-%d.jpg' % (out_name, a.steps)))
314 | if a.save_pt is True:
315 | torch.save(params, '%s.pt' % os.path.join(a.out_dir, out_name))
316 |
317 | if __name__ == '__main__':
318 | main()
319 |
--------------------------------------------------------------------------------
/cppn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | warnings.filterwarnings("ignore")
4 | import argparse
5 | import numpy as np
6 | import shutil
7 | import math
8 | from collections import OrderedDict
9 |
10 | try:
11 | from googletrans import Translator
12 | googletrans_ok = True
13 | except:
14 | googletrans_ok = False
15 |
16 | import torch
17 | import torchvision
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 | import clip
22 | os.environ['KMP_DUPLICATE_LIB_OK']='True'
23 |
24 | from aphantasia.utils import slice_imgs, derivat, aesthetic_model, txt_clean, checkout, old_torch
25 | from aphantasia import transforms
26 | from shader_expo import cppn_to_shader
27 |
28 | from eps.progress_bar import ProgressBar
29 | from eps.data_load import basename, img_list, img_read, file_list, save_cfg
30 |
31 | clip_models = ['ViT-B/16', 'ViT-B/32', 'ViT-L/14', 'RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101']
32 |
33 | def get_args():
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument('-i', '--in_img', default=None, help='input image')
36 | parser.add_argument('-t', '--in_txt', default=None, help='input text')
37 | parser.add_argument('-t0', '--in_txt0', default=None, help='input text to subtract')
38 | parser.add_argument( '--out_dir', default='_out')
39 | parser.add_argument('-r', '--resume', default=None, help='Input CPPN model (NPY file) to resume from')
40 | parser.add_argument('-s', '--size', default='512-512', help='Output resolution')
41 | parser.add_argument( '--fstep', default=1, type=int, help='Saving step')
42 | parser.add_argument('-tr', '--translate', action='store_true')
43 | parser.add_argument('-v', '--verbose', action='store_true')
44 | parser.add_argument('-ex', '--export', action='store_true', help="Only export shaders from resumed snapshot")
45 | # networks
46 | parser.add_argument('-l', '--layers', default=10, type=int, help='CPPN layers')
47 | parser.add_argument('-nf', '--nf', default=24, type=int, help='num features') # 256
48 | parser.add_argument('-act', '--actfn', default='unbias', choices=['unbias', 'comp', 'relu'], help='activation function')
49 | parser.add_argument('-dec', '--decim', default=3, type=int, help='Decimal precision for export')
50 | # training
51 | parser.add_argument('-m', '--model', default='ViT-B/32', choices=clip_models, help='Select CLIP model to use')
52 | parser.add_argument('-dm', '--dualmod', default=None, type=int, help='Every this step use another CLIP ViT model')
53 | parser.add_argument( '--steps', default=200, type=int, help='Total iterations')
54 | parser.add_argument( '--samples', default=50, type=int, help='Samples to evaluate')
55 | parser.add_argument('-lr', '--lrate', default=0.003, type=float, help='Learning rate')
56 | parser.add_argument('-a', '--align', default='overscan', choices=['central', 'uniform', 'overscan'], help='Sampling distribution')
57 | parser.add_argument('-sh', '--sharp', default=0, type=float)
58 | parser.add_argument('-tf', '--transform', action='store_true', help='use augmenting transforms?')
59 | parser.add_argument('-mc', '--macro', default=0.4, type=float, help='Endorse macro forms 0..1; -1 = normal big')
60 | parser.add_argument( '--aest', default=0., type=float)
61 | a = parser.parse_args()
62 | if a.size is not None: a.size = [int(s) for s in a.size.split('-')][::-1]
63 | if len(a.size)==1: a.size = a.size * 2
64 | if a.translate is True and googletrans_ok is not True:
65 | print('\n Install googletrans module to enable translation!'); exit()
66 | if a.dualmod is not None:
67 | a.model = 'ViT-B/32'
68 | return a
69 |
70 |
71 | class ConvLayer(nn.Module):
72 | def __init__(self, nf_in, nf_out, act_fn='relu'):
73 | super().__init__()
74 | self.nf_in = nf_in
75 | self.conv = nn.Conv2d(nf_in, nf_out, 1, 1)
76 | if act_fn == 'comp':
77 | self.act_fn = self.composite_activation
78 | elif act_fn == 'unbias':
79 | self.act_fn = self.composite_activation_unbiased
80 | elif act_fn == 'relu':
81 | self.act_fn = self.relu_normalized
82 | else: # last layer (output)
83 | self.act_fn = torch.sigmoid
84 | with torch.no_grad(): # init
85 | self.conv.weight.normal_(0., math.sqrt(1./self.nf_in))
86 | self.conv.bias.uniform_(-.5, .5)
87 |
88 | def composite_activation(self, x):
89 | x = torch.atan(x)
90 | return torch.cat([x/0.67, (x*x)/0.6], 1)
91 | def composite_activation_unbiased(self, x):
92 | x = torch.atan(x)
93 | return torch.cat([x/0.67, (x*x-0.45)/0.396], 1)
94 | def relu_normalized(self, x):
95 | x = F.relu(x)
96 | return (x-0.40)/0.58
97 | # https://colab.research.google.com/drive/1F1c2ouulmqys-GJBVBHn04I1UVWeexiB
98 |
99 | def forward(self, input):
100 | return self.act_fn(self.conv(input))
101 |
102 | class CPPN(nn.Module):
103 | def __init__(self, nf_in=2, nf_hid=16, num_layers=9, nf_out=3, act_fn='unbias'): # unbias relu
104 | super().__init__()
105 | nf_hid_in = nf_hid if act_fn == 'relu' else nf_hid*2
106 | self.net = []
107 | self.net.append(ConvLayer(nf_in, nf_hid, act_fn))
108 | for i in range(num_layers-1):
109 | self.net.append(ConvLayer(nf_hid_in, nf_hid, act_fn))
110 | self.net.append(ConvLayer(nf_hid_in, nf_out, 'sigmoid'))
111 | self.net = nn.Sequential(*self.net)
112 |
113 | def forward(self, coords):
114 | coords = coords.clone().detach().requires_grad_(True) # [1,3,h,w]
115 | output = self.net(coords.cuda())
116 | return output
117 |
118 | def load_cppn(file, verbose=True): # actfn='unbias'
119 | params = np.load(file, allow_pickle=True)
120 | nf = params[0].shape[-1]
121 | num_layers = len(params) // 2 - 1
122 | act_fn = 'relu' if params[0].shape[-1] == params[2].shape[-2] else 'unbias'
123 | snet = CPPN(2, nf, num_layers, 3, act_fn=act_fn).cuda()
124 | if verbose is True: print(' loaded:', file)
125 | if verbose is True: print(' .. %d vars, %d layers, %d nf, act %s' % (len(params), num_layers, nf, act_fn))
126 | keys = list(snet.state_dict().keys())
127 | assert len(keys) == len(params)
128 | cppn_dict = OrderedDict({})
129 | for lnum in range(0, len(keys), 2):
130 | cppn_dict[keys[lnum]] = np.transpose(torch.from_numpy(params[lnum]), (3,2,1,0))
131 | cppn_dict[keys[lnum+1]] = torch.from_numpy(params[lnum+1])
132 | snet.load_state_dict(cppn_dict)
133 | return snet
134 |
135 | def get_mgrid(sideX, sideY):
136 | tensors = [np.linspace(-1, 1, num=sideY), np.linspace(-1, 1, num=sideX)]
137 | mgrid = np.stack(np.meshgrid(*tensors), axis=-1)
138 | mgrid = np.transpose(mgrid, (2,0,1))[np.newaxis]
139 | return mgrid
140 |
141 | def export_gfx(model, out_name, mode, precision, size):
142 | shader = cppn_to_shader(model, mode=mode, verbose=False, fix_aspect=True, size=size, precision=precision)
143 | if mode == 'vvvv': out_path = out_name + '.tfx'
144 | elif mode == 'buffer': out_path = out_name + '.txt'
145 | else: out_path = out_name + '-%s.glsl' % mode
146 | with open(out_path, 'wt') as f:
147 | f.write(shader)
148 | return out_path
149 |
150 | def export_data(cppn_dict, out_name, size, decim=3, actfn='unbias', shaders=False, npy=True):
151 | if npy is True: arrays = []
152 | if shaders is True: params = []
153 | keys = list(cppn_dict.keys())
154 |
155 | for lnum in range(0, len(keys), 2):
156 | w = cppn_dict[keys[lnum]].permute((3,2,1,0)).cpu().numpy()
157 | b = cppn_dict[keys[lnum+1]].cpu().numpy()
158 | if shaders is True: params.append({'weights': w, 'bias': b, 'activation': actfn})
159 | if npy is True: arrays += [w,b]
160 |
161 | if npy is True:
162 | np.save(out_name + '.npy', np.array(arrays, object))
163 | if shaders is True:
164 | export_gfx(params, out_name, 'td', decim, size)
165 | export_gfx(params, out_name, 'vvvv', decim, size)
166 | export_gfx(params, out_name, 'buffer', decim, size)
167 | export_gfx(params, out_name, 'bookofshaders', decim, size)
168 | export_gfx(params, out_name, 'shadertoy', decim, size)
169 |
170 |
171 | def main():
172 | a = get_args()
173 | bx = 1.
174 |
175 | mgrid = get_mgrid(*a.size)
176 | mgrid = torch.from_numpy(mgrid.astype(np.float32)).cuda()
177 |
178 | # Load models
179 | if a.resume is not None and os.path.isfile(a.resume):
180 | snet = load_cppn(a.resume)
181 | else:
182 | snet = CPPN(mgrid.shape[1], a.nf, a.layers, 3, act_fn=a.actfn).cuda()
183 | print(' .. %d vars, %d layers, %d nf, act %s' % (len(snet.state_dict().keys()), a.layers, a.nf, a.actfn))
184 |
185 | if a.export is True:
186 | print('exporting')
187 | export_data(snet.state_dict(), a.resume.replace('.npy', ''), a.size, a.decim, a.actfn, shaders=True, npy=False)
188 | img = snet(mgrid).detach().cpu().numpy()[0]
189 | checkout(img, a.resume.replace('.npy', '.jpg'), verbose=False)
190 | exit(0)
191 |
192 | model_clip, _ = clip.load(a.model, jit=old_torch())
193 | try:
194 | a.modsize = model_clip.visual.input_resolution
195 | except:
196 | a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 448 if a.model == 'RN50x64' else 224
197 | xmem = {'ViT-B/16':0.25, 'ViT-L/14':0.11, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN50x64':0.04, 'RN101':0.33}
198 | if a.model in xmem.keys():
199 | a.samples = int(a.samples * xmem[a.model])
200 |
201 | if a.dualmod is not None:
202 | model_clip2, _ = clip.load('ViT-B/16', jit=old_torch())
203 | a.samples = int(a.samples * 0.69) # second is vit-16
204 | dualmod_nums = list(range(a.steps))[a.dualmod::a.dualmod]
205 | print(' dual model every %d step' % a.dualmod)
206 |
207 | if a.aest != 0 and a.model in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14']:
208 | aest = aesthetic_model(a.model).cuda()
209 | if a.dualmod is not None:
210 | aest2 = aesthetic_model('ViT-B/16').cuda()
211 |
212 | def enc_text(txt, model_clip=model_clip):
213 | if txt is None or len(txt)==0: return None
214 | emb = model_clip.encode_text(clip.tokenize(txt).cuda()[:,:77])
215 | return emb.detach().clone()
216 |
217 | optimizer = torch.optim.Adam(snet.parameters(), a.lrate) # orig .00001, better 0.0001
218 |
219 | if a.transform is True:
220 | trform_f = transforms.trfm_fast
221 | a.samples = int(a.samples * 0.95)
222 | else:
223 | trform_f = transforms.normalize()
224 |
225 | out_name = []
226 | if a.in_txt is not None:
227 | print(' ref text: ', basename(a.in_txt))
228 | if a.translate:
229 | translator = Translator()
230 | a.in_txt = translator.translate(a.in_txt, dest='en').text
231 | print(' translated to:', a.in_txt)
232 | txt_enc = enc_text(a.in_txt)
233 | if a.dualmod is not None:
234 | txt_enc2 = enc_text(a.in_txt, model_clip2)
235 | out_name.append(txt_clean(a.in_txt))
236 |
237 | if a.in_txt0 is not None:
238 | print(' no text: ', basename(a.in_txt0))
239 | if a.translate:
240 | translator = Translator()
241 | a.in_txt0 = translator.translate(a.in_txt0, dest='en').text
242 | print(' translated to:', a.in_txt0)
243 | not_enc = enc_text(a.in_txt0)
244 | if a.dualmod is not None:
245 | not_enc2 = enc_text(a.in_txt0, model_clip2)
246 |
247 | img_enc = None
248 | if a.in_img is not None and os.path.isfile(a.in_img):
249 | print(' ref image:', basename(a.in_img))
250 | img_in = torch.from_numpy(img_read(a.in_img)/255.).unsqueeze(0).permute(0,3,1,2).cuda()
251 | in_sliced = slice_imgs([img_in], a.samples, a.modsize, transforms.normalize(), a.align)[0]
252 | img_enc = model_clip.encode_image(in_sliced).detach().clone()
253 | if a.dualmod is not None:
254 | img_enc2 = model_clip2.encode_image(in_sliced).detach().clone()
255 | del img_in, in_sliced; torch.cuda.empty_cache()
256 | out_name.append(basename(a.in_img).replace(' ', '_'))
257 |
258 | # Prepare dirs
259 | sfx = '-l%d-n%d' % (a.layers, a.nf)
260 | if a.dualmod is not None: sfx += '-dm%d' % a.dualmod
261 | if a.aest != 0: sfx += '-ae%.2g' % a.aest
262 | workdir = os.path.join(a.out_dir, 'cppn')
263 | out_name = os.path.join(workdir, '-'.join(out_name) + sfx)
264 | tempdir = out_name
265 | os.makedirs(out_name, exist_ok=True)
266 | print(a.samples)
267 |
268 | def train(i, img_enc=None):
269 | loss = 0
270 | img_out = snet(mgrid)
271 |
272 | txt_enc_ = txt_enc2 if a.dualmod is not None and i in dualmod_nums else txt_enc
273 | if a.in_img is not None and os.path.isfile(a.in_img):
274 | img_enc_ = img_enc2 if a.dualmod is not None and i in dualmod_nums else img_enc
275 | if a.in_txt0 is not None:
276 | not_enc_ = not_enc2 if a.dualmod is not None and i in dualmod_nums else not_enc
277 | model_clip_ = model_clip2 if a.dualmod is not None and i in dualmod_nums else model_clip
278 | if a.aest != 0:
279 | aest_ = aest2 if a.dualmod is not None and i in dualmod_nums else aest
280 |
281 | imgs_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)
282 | out_enc = model_clip_.encode_image(imgs_sliced[-1])
283 | if a.aest != 0 and aest_ is not None:
284 | loss -= 0.001 * a.aest * aest_(out_enc).mean()
285 | if a.in_txt is not None:
286 | loss -= torch.cosine_similarity(txt_enc_, out_enc, dim=-1).mean()
287 | if a.in_txt0 is not None:
288 | loss += 0.5 * torch.cosine_similarity(not_enc_, out_enc, dim=-1).mean()
289 | if a.in_img is not None and os.path.isfile(a.in_img):
290 | loss -= torch.cosine_similarity(img_enc_, out_enc, dim=-1).mean()
291 | if a.sharp != 0: # mode = scharr|sobel|default
292 | loss -= a.sharp * derivat(img_out, mode='sobel')
293 | del img_out, imgs_sliced, out_enc; torch.cuda.empty_cache()
294 |
295 | optimizer.zero_grad()
296 | loss.backward()
297 | optimizer.step()
298 |
299 | if i % a.fstep == 0:
300 | with torch.no_grad():
301 | img = snet(mgrid).cpu().numpy()[0]
302 | fname = os.path.join(tempdir, '%04d' % (i // a.fstep))
303 | checkout(img, fname + '.jpg', verbose=a.verbose)
304 | export_data(snet.state_dict(), fname, a.size, a.decim)
305 | return
306 |
307 | pbar = ProgressBar(a.steps)
308 | for i in range(a.steps):
309 | log = train(i, img_enc)
310 | pbar.upd(log)
311 |
312 | export_data(snet.state_dict(), out_name, a.size, a.decim, shaders=True)
313 | os.system('ffmpeg -v warning -y -i %s\%%04d.jpg -c:v mjpeg -pix_fmt yuvj444p -dst_range 1 -q:v 2 "%s.avi"' % (tempdir, out_name))
314 | shutil.copy(img_list(tempdir)[-1], out_name + '-%d.jpg' % a.steps)
315 | # shutil.rmtree(tempdir)
316 |
317 |
318 | if __name__ == '__main__':
319 | main()
320 |
--------------------------------------------------------------------------------
/depth/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eps696/aphantasia/8a415286d2891e92d865150d6e0e59fdfd32fb01/depth/__init__.py
--------------------------------------------------------------------------------
/depth/any2/dinov2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9 |
10 | from functools import partial
11 | import math
12 | import logging
13 | from typing import Sequence, Tuple, Union, Callable
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.utils.checkpoint
18 | from torch.nn.init import trunc_normal_
19 |
20 | from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21 |
22 |
23 | logger = logging.getLogger("dinov2")
24 |
25 |
26 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27 | if not depth_first and include_root:
28 | fn(module=module, name=name)
29 | for child_name, child_module in module.named_children():
30 | child_name = ".".join((name, child_name)) if name else child_name
31 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32 | if depth_first and include_root:
33 | fn(module=module, name=name)
34 | return module
35 |
36 |
37 | class BlockChunk(nn.ModuleList):
38 | def forward(self, x):
39 | for b in self:
40 | x = b(x)
41 | return x
42 |
43 |
44 | class DinoVisionTransformer(nn.Module):
45 | def __init__(
46 | self,
47 | img_size=224,
48 | patch_size=16,
49 | in_chans=3,
50 | embed_dim=768,
51 | depth=12,
52 | num_heads=12,
53 | mlp_ratio=4.0,
54 | qkv_bias=True,
55 | ffn_bias=True,
56 | proj_bias=True,
57 | drop_path_rate=0.0,
58 | drop_path_uniform=False,
59 | init_values=None, # for layerscale: None or 0 => no layerscale
60 | embed_layer=PatchEmbed,
61 | act_layer=nn.GELU,
62 | block_fn=Block,
63 | ffn_layer="mlp",
64 | block_chunks=1,
65 | num_register_tokens=0,
66 | interpolate_antialias=False,
67 | interpolate_offset=0.1,
68 | ):
69 | """
70 | Args:
71 | img_size (int, tuple): input image size
72 | patch_size (int, tuple): patch size
73 | in_chans (int): number of input channels
74 | embed_dim (int): embedding dimension
75 | depth (int): depth of transformer
76 | num_heads (int): number of attention heads
77 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78 | qkv_bias (bool): enable bias for qkv if True
79 | proj_bias (bool): enable bias for proj in attn if True
80 | ffn_bias (bool): enable bias for ffn if True
81 | drop_path_rate (float): stochastic depth rate
82 | drop_path_uniform (bool): apply uniform drop rate across blocks
83 | weight_init (str): weight init scheme
84 | init_values (float): layer-scale init values
85 | embed_layer (nn.Module): patch embedding layer
86 | act_layer (nn.Module): MLP activation layer
87 | block_fn (nn.Module): transformer block class
88 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90 | num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91 | interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92 | interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93 | """
94 | super().__init__()
95 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
96 |
97 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98 | self.num_tokens = 1
99 | self.n_blocks = depth
100 | self.num_heads = num_heads
101 | self.patch_size = patch_size
102 | self.num_register_tokens = num_register_tokens
103 | self.interpolate_antialias = interpolate_antialias
104 | self.interpolate_offset = interpolate_offset
105 |
106 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107 | num_patches = self.patch_embed.num_patches
108 |
109 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111 | assert num_register_tokens >= 0
112 | self.register_tokens = (
113 | nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114 | )
115 |
116 | if drop_path_uniform is True:
117 | dpr = [drop_path_rate] * depth
118 | else:
119 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120 |
121 | if ffn_layer == "mlp":
122 | logger.info("using MLP layer as FFN")
123 | ffn_layer = Mlp
124 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125 | logger.info("using SwiGLU layer as FFN")
126 | ffn_layer = SwiGLUFFNFused
127 | elif ffn_layer == "identity":
128 | logger.info("using Identity layer as FFN")
129 |
130 | def f(*args, **kwargs):
131 | return nn.Identity()
132 |
133 | ffn_layer = f
134 | else:
135 | raise NotImplementedError
136 |
137 | blocks_list = [
138 | block_fn(
139 | dim=embed_dim,
140 | num_heads=num_heads,
141 | mlp_ratio=mlp_ratio,
142 | qkv_bias=qkv_bias,
143 | proj_bias=proj_bias,
144 | ffn_bias=ffn_bias,
145 | drop_path=dpr[i],
146 | norm_layer=norm_layer,
147 | act_layer=act_layer,
148 | ffn_layer=ffn_layer,
149 | init_values=init_values,
150 | )
151 | for i in range(depth)
152 | ]
153 | if block_chunks > 0:
154 | self.chunked_blocks = True
155 | chunked_blocks = []
156 | chunksize = depth // block_chunks
157 | for i in range(0, depth, chunksize):
158 | # this is to keep the block index consistent if we chunk the block list
159 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161 | else:
162 | self.chunked_blocks = False
163 | self.blocks = nn.ModuleList(blocks_list)
164 |
165 | self.norm = norm_layer(embed_dim)
166 | self.head = nn.Identity()
167 |
168 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169 |
170 | self.init_weights()
171 |
172 | def init_weights(self):
173 | trunc_normal_(self.pos_embed, std=0.02)
174 | nn.init.normal_(self.cls_token, std=1e-6)
175 | if self.register_tokens is not None:
176 | nn.init.normal_(self.register_tokens, std=1e-6)
177 | named_apply(init_weights_vit_timm, self)
178 |
179 | def interpolate_pos_encoding(self, x, w, h):
180 | previous_dtype = x.dtype
181 | npatch = x.shape[1] - 1
182 | N = self.pos_embed.shape[1] - 1
183 | if npatch == N and w == h:
184 | return self.pos_embed
185 | pos_embed = self.pos_embed.float()
186 | class_pos_embed = pos_embed[:, 0]
187 | patch_pos_embed = pos_embed[:, 1:]
188 | dim = x.shape[-1]
189 | w0 = w // self.patch_size
190 | h0 = h // self.patch_size
191 | # we add a small number to avoid floating point error in the interpolation
192 | # see discussion at https://github.com/facebookresearch/dino/issues/8
193 | # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
194 | w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
195 | # w0, h0 = w0 + 0.1, h0 + 0.1
196 |
197 | sqrt_N = math.sqrt(N)
198 | sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
199 | patch_pos_embed = nn.functional.interpolate(
200 | patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
201 | scale_factor=(sx, sy),
202 | # (int(w0), int(h0)), # to solve the upsampling shape issue
203 | mode="bicubic",
204 | antialias=self.interpolate_antialias
205 | )
206 |
207 | assert int(w0) == patch_pos_embed.shape[-2]
208 | assert int(h0) == patch_pos_embed.shape[-1]
209 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
211 |
212 | def prepare_tokens_with_masks(self, x, masks=None):
213 | B, nc, w, h = x.shape
214 | x = self.patch_embed(x)
215 | if masks is not None:
216 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
217 |
218 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
219 | x = x + self.interpolate_pos_encoding(x, w, h)
220 |
221 | if self.register_tokens is not None:
222 | x = torch.cat(
223 | (
224 | x[:, :1],
225 | self.register_tokens.expand(x.shape[0], -1, -1),
226 | x[:, 1:],
227 | ),
228 | dim=1,
229 | )
230 |
231 | return x
232 |
233 | def forward_features_list(self, x_list, masks_list):
234 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
235 | for blk in self.blocks:
236 | x = blk(x)
237 |
238 | all_x = x
239 | output = []
240 | for x, masks in zip(all_x, masks_list):
241 | x_norm = self.norm(x)
242 | output.append(
243 | {
244 | "x_norm_clstoken": x_norm[:, 0],
245 | "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247 | "x_prenorm": x,
248 | "masks": masks,
249 | }
250 | )
251 | return output
252 |
253 | def forward_features(self, x, masks=None):
254 | if isinstance(x, list):
255 | return self.forward_features_list(x, masks)
256 |
257 | x = self.prepare_tokens_with_masks(x, masks)
258 |
259 | for blk in self.blocks:
260 | x = blk(x)
261 |
262 | x_norm = self.norm(x)
263 | return {
264 | "x_norm_clstoken": x_norm[:, 0],
265 | "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
266 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
267 | "x_prenorm": x,
268 | "masks": masks,
269 | }
270 |
271 | def _get_intermediate_layers_not_chunked(self, x, n=1):
272 | x = self.prepare_tokens_with_masks(x)
273 | # If n is an int, take the n last blocks. If it's a list, take them
274 | output, total_block_len = [], len(self.blocks)
275 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
276 | for i, blk in enumerate(self.blocks):
277 | x = blk(x)
278 | if i in blocks_to_take:
279 | output.append(x)
280 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
281 | return output
282 |
283 | def _get_intermediate_layers_chunked(self, x, n=1):
284 | x = self.prepare_tokens_with_masks(x)
285 | output, i, total_block_len = [], 0, len(self.blocks[-1])
286 | # If n is an int, take the n last blocks. If it's a list, take them
287 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288 | for block_chunk in self.blocks:
289 | for blk in block_chunk[i:]: # Passing the nn.Identity()
290 | x = blk(x)
291 | if i in blocks_to_take:
292 | output.append(x)
293 | i += 1
294 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
295 | return output
296 |
297 | def get_intermediate_layers(
298 | self,
299 | x: torch.Tensor,
300 | n: Union[int, Sequence] = 1, # Layers or n last layers to take
301 | reshape: bool = False,
302 | return_class_token: bool = False,
303 | norm=True
304 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
305 | if self.chunked_blocks:
306 | outputs = self._get_intermediate_layers_chunked(x, n)
307 | else:
308 | outputs = self._get_intermediate_layers_not_chunked(x, n)
309 | if norm:
310 | outputs = [self.norm(out) for out in outputs]
311 | class_tokens = [out[:, 0] for out in outputs]
312 | outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
313 | if reshape:
314 | B, _, w, h = x.shape
315 | outputs = [
316 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
317 | for out in outputs
318 | ]
319 | if return_class_token:
320 | return tuple(zip(outputs, class_tokens))
321 | return tuple(outputs)
322 |
323 | def forward(self, *args, is_training=False, **kwargs):
324 | ret = self.forward_features(*args, **kwargs)
325 | if is_training:
326 | return ret
327 | else:
328 | return self.head(ret["x_norm_clstoken"])
329 |
330 |
331 | def init_weights_vit_timm(module: nn.Module, name: str = ""):
332 | """ViT weight initialization, original timm impl (for reproducibility)"""
333 | if isinstance(module, nn.Linear):
334 | trunc_normal_(module.weight, std=0.02)
335 | if module.bias is not None:
336 | nn.init.zeros_(module.bias)
337 |
338 |
339 | def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
340 | model = DinoVisionTransformer(
341 | patch_size=patch_size,
342 | embed_dim=384,
343 | depth=12,
344 | num_heads=6,
345 | mlp_ratio=4,
346 | block_fn=partial(Block, attn_class=MemEffAttention),
347 | num_register_tokens=num_register_tokens,
348 | **kwargs,
349 | )
350 | return model
351 |
352 |
353 | def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
354 | model = DinoVisionTransformer(
355 | patch_size=patch_size,
356 | embed_dim=768,
357 | depth=12,
358 | num_heads=12,
359 | mlp_ratio=4,
360 | block_fn=partial(Block, attn_class=MemEffAttention),
361 | num_register_tokens=num_register_tokens,
362 | **kwargs,
363 | )
364 | return model
365 |
366 |
367 | def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
368 | model = DinoVisionTransformer(
369 | patch_size=patch_size,
370 | embed_dim=1024,
371 | depth=24,
372 | num_heads=16,
373 | mlp_ratio=4,
374 | block_fn=partial(Block, attn_class=MemEffAttention),
375 | num_register_tokens=num_register_tokens,
376 | **kwargs,
377 | )
378 | return model
379 |
380 |
381 | def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
382 | """
383 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
384 | """
385 | model = DinoVisionTransformer(
386 | patch_size=patch_size,
387 | embed_dim=1536,
388 | depth=40,
389 | num_heads=24,
390 | mlp_ratio=4,
391 | block_fn=partial(Block, attn_class=MemEffAttention),
392 | num_register_tokens=num_register_tokens,
393 | **kwargs,
394 | )
395 | return model
396 |
397 |
398 | def DINOv2(model_name):
399 | model_zoo = {
400 | "vits": vit_small,
401 | "vitb": vit_base,
402 | "vitl": vit_large,
403 | "vitg": vit_giant2
404 | }
405 |
406 | return model_zoo[model_name](
407 | img_size=518,
408 | patch_size=14,
409 | init_values=1.0,
410 | ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
411 | block_chunks=0,
412 | num_register_tokens=0,
413 | interpolate_antialias=False,
414 | interpolate_offset=0.1
415 | )
416 |
--------------------------------------------------------------------------------
/depth/any2/dinov2_layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .mlp import Mlp
8 | from .patch_embed import PatchEmbed
9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10 | from .block import NestedTensorBlock
11 | from .attention import MemEffAttention
12 |
--------------------------------------------------------------------------------
/depth/any2/dinov2_layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | import logging
12 |
13 | from torch import Tensor
14 | from torch import nn
15 |
16 |
17 | logger = logging.getLogger("dinov2")
18 |
19 |
20 | try:
21 | from xformers.ops import memory_efficient_attention, unbind, fmha
22 |
23 | XFORMERS_AVAILABLE = True
24 | except ImportError:
25 | logger.warning("xFormers not available")
26 | XFORMERS_AVAILABLE = False
27 |
28 |
29 | class Attention(nn.Module):
30 | def __init__(
31 | self,
32 | dim: int,
33 | num_heads: int = 8,
34 | qkv_bias: bool = False,
35 | proj_bias: bool = True,
36 | attn_drop: float = 0.0,
37 | proj_drop: float = 0.0,
38 | ) -> None:
39 | super().__init__()
40 | self.num_heads = num_heads
41 | head_dim = dim // num_heads
42 | self.scale = head_dim**-0.5
43 |
44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45 | self.attn_drop = nn.Dropout(attn_drop)
46 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
47 | self.proj_drop = nn.Dropout(proj_drop)
48 |
49 | def forward(self, x: Tensor) -> Tensor:
50 | B, N, C = x.shape
51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52 |
53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54 | attn = q @ k.transpose(-2, -1)
55 |
56 | attn = attn.softmax(dim=-1)
57 | attn = self.attn_drop(attn)
58 |
59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60 | x = self.proj(x)
61 | x = self.proj_drop(x)
62 | return x
63 |
64 |
65 | class MemEffAttention(Attention):
66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67 | if not XFORMERS_AVAILABLE:
68 | assert attn_bias is None, "xFormers is required for nested tensors usage"
69 | return super().forward(x)
70 |
71 | B, N, C = x.shape
72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73 |
74 | q, k, v = unbind(qkv, 2)
75 |
76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77 | x = x.reshape([B, N, C])
78 |
79 | x = self.proj(x)
80 | x = self.proj_drop(x)
81 | return x
82 |
83 |
--------------------------------------------------------------------------------
/depth/any2/dinov2_layers/block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | import logging
12 | from typing import Callable, List, Any, Tuple, Dict
13 |
14 | import torch
15 | from torch import nn, Tensor
16 |
17 | from .attention import Attention, MemEffAttention
18 | from .drop_path import DropPath
19 | from .layer_scale import LayerScale
20 | from .mlp import Mlp
21 |
22 |
23 | logger = logging.getLogger("dinov2")
24 |
25 |
26 | try:
27 | from xformers.ops import fmha
28 | from xformers.ops import scaled_index_add, index_select_cat
29 |
30 | XFORMERS_AVAILABLE = True
31 | except ImportError:
32 | logger.warning("xFormers not available")
33 | XFORMERS_AVAILABLE = False
34 |
35 |
36 | class Block(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | num_heads: int,
41 | mlp_ratio: float = 4.0,
42 | qkv_bias: bool = False,
43 | proj_bias: bool = True,
44 | ffn_bias: bool = True,
45 | drop: float = 0.0,
46 | attn_drop: float = 0.0,
47 | init_values=None,
48 | drop_path: float = 0.0,
49 | act_layer: Callable[..., nn.Module] = nn.GELU,
50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51 | attn_class: Callable[..., nn.Module] = Attention,
52 | ffn_layer: Callable[..., nn.Module] = Mlp,
53 | ) -> None:
54 | super().__init__()
55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56 | self.norm1 = norm_layer(dim)
57 | self.attn = attn_class(
58 | dim,
59 | num_heads=num_heads,
60 | qkv_bias=qkv_bias,
61 | proj_bias=proj_bias,
62 | attn_drop=attn_drop,
63 | proj_drop=drop,
64 | )
65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67 |
68 | self.norm2 = norm_layer(dim)
69 | mlp_hidden_dim = int(dim * mlp_ratio)
70 | self.mlp = ffn_layer(
71 | in_features=dim,
72 | hidden_features=mlp_hidden_dim,
73 | act_layer=act_layer,
74 | drop=drop,
75 | bias=ffn_bias,
76 | )
77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79 |
80 | self.sample_drop_ratio = drop_path
81 |
82 | def forward(self, x: Tensor) -> Tensor:
83 | def attn_residual_func(x: Tensor) -> Tensor:
84 | return self.ls1(self.attn(self.norm1(x)))
85 |
86 | def ffn_residual_func(x: Tensor) -> Tensor:
87 | return self.ls2(self.mlp(self.norm2(x)))
88 |
89 | if self.training and self.sample_drop_ratio > 0.1:
90 | # the overhead is compensated only for a drop path rate larger than 0.1
91 | x = drop_add_residual_stochastic_depth(
92 | x,
93 | residual_func=attn_residual_func,
94 | sample_drop_ratio=self.sample_drop_ratio,
95 | )
96 | x = drop_add_residual_stochastic_depth(
97 | x,
98 | residual_func=ffn_residual_func,
99 | sample_drop_ratio=self.sample_drop_ratio,
100 | )
101 | elif self.training and self.sample_drop_ratio > 0.0:
102 | x = x + self.drop_path1(attn_residual_func(x))
103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104 | else:
105 | x = x + attn_residual_func(x)
106 | x = x + ffn_residual_func(x)
107 | return x
108 |
109 |
110 | def drop_add_residual_stochastic_depth(
111 | x: Tensor,
112 | residual_func: Callable[[Tensor], Tensor],
113 | sample_drop_ratio: float = 0.0,
114 | ) -> Tensor:
115 | # 1) extract subset using permutation
116 | b, n, d = x.shape
117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119 | x_subset = x[brange]
120 |
121 | # 2) apply residual_func to get residual
122 | residual = residual_func(x_subset)
123 |
124 | x_flat = x.flatten(1)
125 | residual = residual.flatten(1)
126 |
127 | residual_scale_factor = b / sample_subset_size
128 |
129 | # 3) add the residual
130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131 | return x_plus_residual.view_as(x)
132 |
133 |
134 | def get_branges_scales(x, sample_drop_ratio=0.0):
135 | b, n, d = x.shape
136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138 | residual_scale_factor = b / sample_subset_size
139 | return brange, residual_scale_factor
140 |
141 |
142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143 | if scaling_vector is None:
144 | x_flat = x.flatten(1)
145 | residual = residual.flatten(1)
146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147 | else:
148 | x_plus_residual = scaled_index_add(
149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150 | )
151 | return x_plus_residual
152 |
153 |
154 | attn_bias_cache: Dict[Tuple, Any] = {}
155 |
156 |
157 | def get_attn_bias_and_cat(x_list, branges=None):
158 | """
159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache
160 | """
161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163 | if all_shapes not in attn_bias_cache.keys():
164 | seqlens = []
165 | for b, x in zip(batch_sizes, x_list):
166 | for _ in range(b):
167 | seqlens.append(x.shape[1])
168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169 | attn_bias._batch_sizes = batch_sizes
170 | attn_bias_cache[all_shapes] = attn_bias
171 |
172 | if branges is not None:
173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174 | else:
175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176 | cat_tensors = torch.cat(tensors_bs1, dim=1)
177 |
178 | return attn_bias_cache[all_shapes], cat_tensors
179 |
180 |
181 | def drop_add_residual_stochastic_depth_list(
182 | x_list: List[Tensor],
183 | residual_func: Callable[[Tensor, Any], Tensor],
184 | sample_drop_ratio: float = 0.0,
185 | scaling_vector=None,
186 | ) -> Tensor:
187 | # 1) generate random set of indices for dropping samples in the batch
188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189 | branges = [s[0] for s in branges_scales]
190 | residual_scale_factors = [s[1] for s in branges_scales]
191 |
192 | # 2) get attention bias and index+concat the tensors
193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194 |
195 | # 3) apply residual_func to get residual, and split the result
196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197 |
198 | outputs = []
199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201 | return outputs
202 |
203 |
204 | class NestedTensorBlock(Block):
205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206 | """
207 | x_list contains a list of tensors to nest together and run
208 | """
209 | assert isinstance(self.attn, MemEffAttention)
210 |
211 | if self.training and self.sample_drop_ratio > 0.0:
212 |
213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214 | return self.attn(self.norm1(x), attn_bias=attn_bias)
215 |
216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217 | return self.mlp(self.norm2(x))
218 |
219 | x_list = drop_add_residual_stochastic_depth_list(
220 | x_list,
221 | residual_func=attn_residual_func,
222 | sample_drop_ratio=self.sample_drop_ratio,
223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224 | )
225 | x_list = drop_add_residual_stochastic_depth_list(
226 | x_list,
227 | residual_func=ffn_residual_func,
228 | sample_drop_ratio=self.sample_drop_ratio,
229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230 | )
231 | return x_list
232 | else:
233 |
234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236 |
237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238 | return self.ls2(self.mlp(self.norm2(x)))
239 |
240 | attn_bias, x = get_attn_bias_and_cat(x_list)
241 | x = x + attn_residual_func(x, attn_bias=attn_bias)
242 | x = x + ffn_residual_func(x)
243 | return attn_bias.split(x)
244 |
245 | def forward(self, x_or_x_list):
246 | if isinstance(x_or_x_list, Tensor):
247 | return super().forward(x_or_x_list)
248 | elif isinstance(x_or_x_list, list):
249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250 | return self.forward_nested(x_or_x_list)
251 | else:
252 | raise AssertionError
253 |
--------------------------------------------------------------------------------
/depth/any2/dinov2_layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10 |
11 |
12 | from torch import nn
13 |
14 |
15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16 | if drop_prob == 0.0 or not training:
17 | return x
18 | keep_prob = 1 - drop_prob
19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21 | if keep_prob > 0.0:
22 | random_tensor.div_(keep_prob)
23 | output = x * random_tensor
24 | return output
25 |
26 |
27 | class DropPath(nn.Module):
28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29 |
30 | def __init__(self, drop_prob=None):
31 | super(DropPath, self).__init__()
32 | self.drop_prob = drop_prob
33 |
34 | def forward(self, x):
35 | return drop_path(x, self.drop_prob, self.training)
36 |
--------------------------------------------------------------------------------
/depth/any2/dinov2_layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8 |
9 | from typing import Union
10 |
11 | import torch
12 | from torch import Tensor
13 | from torch import nn
14 |
15 |
16 | class LayerScale(nn.Module):
17 | def __init__(
18 | self,
19 | dim: int,
20 | init_values: Union[float, Tensor] = 1e-5,
21 | inplace: bool = False,
22 | ) -> None:
23 | super().__init__()
24 | self.inplace = inplace
25 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
29 |
--------------------------------------------------------------------------------
/depth/any2/dinov2_layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10 |
11 |
12 | from typing import Callable, Optional
13 |
14 | from torch import Tensor, nn
15 |
16 |
17 | class Mlp(nn.Module):
18 | def __init__(
19 | self,
20 | in_features: int,
21 | hidden_features: Optional[int] = None,
22 | out_features: Optional[int] = None,
23 | act_layer: Callable[..., nn.Module] = nn.GELU,
24 | drop: float = 0.0,
25 | bias: bool = True,
26 | ) -> None:
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
--------------------------------------------------------------------------------
/depth/any2/dinov2_layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | from typing import Callable, Optional, Tuple, Union
12 |
13 | from torch import Tensor
14 | import torch.nn as nn
15 |
16 |
17 | def make_2tuple(x):
18 | if isinstance(x, tuple):
19 | assert len(x) == 2
20 | return x
21 |
22 | assert isinstance(x, int)
23 | return (x, x)
24 |
25 |
26 | class PatchEmbed(nn.Module):
27 | """
28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29 |
30 | Args:
31 | img_size: Image size.
32 | patch_size: Patch token size.
33 | in_chans: Number of input image channels.
34 | embed_dim: Number of linear projection output channels.
35 | norm_layer: Normalization layer.
36 | """
37 |
38 | def __init__(
39 | self,
40 | img_size: Union[int, Tuple[int, int]] = 224,
41 | patch_size: Union[int, Tuple[int, int]] = 16,
42 | in_chans: int = 3,
43 | embed_dim: int = 768,
44 | norm_layer: Optional[Callable] = None,
45 | flatten_embedding: bool = True,
46 | ) -> None:
47 | super().__init__()
48 |
49 | image_HW = make_2tuple(img_size)
50 | patch_HW = make_2tuple(patch_size)
51 | patch_grid_size = (
52 | image_HW[0] // patch_HW[0],
53 | image_HW[1] // patch_HW[1],
54 | )
55 |
56 | self.img_size = image_HW
57 | self.patch_size = patch_HW
58 | self.patches_resolution = patch_grid_size
59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60 |
61 | self.in_chans = in_chans
62 | self.embed_dim = embed_dim
63 |
64 | self.flatten_embedding = flatten_embedding
65 |
66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | _, _, H, W = x.shape
71 | patch_H, patch_W = self.patch_size
72 |
73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75 |
76 | x = self.proj(x) # B C H W
77 | H, W = x.size(2), x.size(3)
78 | x = x.flatten(2).transpose(1, 2) # B HW C
79 | x = self.norm(x)
80 | if not self.flatten_embedding:
81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82 | return x
83 |
84 | def flops(self) -> float:
85 | Ho, Wo = self.patches_resolution
86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87 | if self.norm is not None:
88 | flops += Ho * Wo * self.embed_dim
89 | return flops
90 |
--------------------------------------------------------------------------------
/depth/any2/dinov2_layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Callable, Optional
8 |
9 | from torch import Tensor, nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class SwiGLUFFN(nn.Module):
14 | def __init__(
15 | self,
16 | in_features: int,
17 | hidden_features: Optional[int] = None,
18 | out_features: Optional[int] = None,
19 | act_layer: Callable[..., nn.Module] = None,
20 | drop: float = 0.0,
21 | bias: bool = True,
22 | ) -> None:
23 | super().__init__()
24 | out_features = out_features or in_features
25 | hidden_features = hidden_features or in_features
26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | x12 = self.w12(x)
31 | x1, x2 = x12.chunk(2, dim=-1)
32 | hidden = F.silu(x1) * x2
33 | return self.w3(hidden)
34 |
35 |
36 | try:
37 | from xformers.ops import SwiGLU
38 |
39 | XFORMERS_AVAILABLE = True
40 | except ImportError:
41 | SwiGLU = SwiGLUFFN
42 | XFORMERS_AVAILABLE = False
43 |
44 |
45 | class SwiGLUFFNFused(SwiGLU):
46 | def __init__(
47 | self,
48 | in_features: int,
49 | hidden_features: Optional[int] = None,
50 | out_features: Optional[int] = None,
51 | act_layer: Callable[..., nn.Module] = None,
52 | drop: float = 0.0,
53 | bias: bool = True,
54 | ) -> None:
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58 | super().__init__(
59 | in_features=in_features,
60 | hidden_features=hidden_features,
61 | out_features=out_features,
62 | bias=bias,
63 | )
64 |
--------------------------------------------------------------------------------
/depth/any2/dpt.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision.transforms import Compose
6 |
7 | from .dinov2 import DINOv2
8 | from .util.blocks import FeatureFusionBlock, _make_scratch
9 | from .util.transform import Resize, NormalizeImage, PrepareForNet
10 |
11 | def _make_fusion_block(features, use_bn, size=None):
12 | return FeatureFusionBlock(features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True, size=size)
13 |
14 | class ConvBlock(nn.Module):
15 | def __init__(self, in_feature, out_feature):
16 | super().__init__()
17 | self.conv_block = nn.Sequential(
18 | nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
19 | nn.BatchNorm2d(out_feature),
20 | nn.ReLU(True)
21 | )
22 | def forward(self, x):
23 | return self.conv_block(x)
24 |
25 | class DPTHead(nn.Module):
26 | def __init__(self, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
27 | super(DPTHead, self).__init__()
28 |
29 | self.use_clstoken = use_clstoken
30 |
31 | self.projects = nn.ModuleList([
32 | nn.Conv2d(in_channels=in_channels, out_channels=out_channel, kernel_size=1, stride=1, padding=0) for out_channel in out_channels
33 | ])
34 |
35 | self.resize_layers = nn.ModuleList([
36 | nn.ConvTranspose2d(in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0),
37 | nn.ConvTranspose2d(in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0),
38 | nn.Identity(),
39 | nn.Conv2d(in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1)
40 | ])
41 |
42 | if use_clstoken:
43 | self.readout_projects = nn.ModuleList()
44 | for _ in range(len(self.projects)):
45 | self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
46 |
47 | self.scratch = _make_scratch(out_channels, features, groups=1, expand=False)
48 |
49 | self.scratch.stem_transpose = None
50 |
51 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
52 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
53 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
54 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
55 |
56 | head_features_1 = features
57 | head_features_2 = 32
58 |
59 | self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
60 | self.scratch.output_conv2 = nn.Sequential(
61 | nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
62 | nn.ReLU(True),
63 | nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
64 | nn.ReLU(True),
65 | nn.Identity(),
66 | )
67 |
68 | def forward(self, out_features, patch_h, patch_w):
69 | out = []
70 | for i, x in enumerate(out_features):
71 | if self.use_clstoken:
72 | x, cls_token = x[0], x[1]
73 | readout = cls_token.unsqueeze(1).expand_as(x)
74 | x = self.readout_projects[i](torch.cat((x, readout), -1))
75 | else:
76 | x = x[0]
77 | x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
78 | x = self.projects[i](x)
79 | x = self.resize_layers[i](x)
80 | out.append(x)
81 |
82 | layer_1, layer_2, layer_3, layer_4 = out
83 |
84 | layer_1_rn = self.scratch.layer1_rn(layer_1)
85 | layer_2_rn = self.scratch.layer2_rn(layer_2)
86 | layer_3_rn = self.scratch.layer3_rn(layer_3)
87 | layer_4_rn = self.scratch.layer4_rn(layer_4)
88 |
89 | path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
90 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
91 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
92 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
93 |
94 | out = self.scratch.output_conv1(path_1)
95 | out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
96 | out = self.scratch.output_conv2(out)
97 | return out
98 |
99 | class DepthAnythingV2(nn.Module):
100 | def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False):
101 | super(DepthAnythingV2, self).__init__()
102 | self.intermediate_layer_idx = {
103 | 'vits': [2, 5, 8, 11],
104 | 'vitb': [2, 5, 8, 11],
105 | 'vitl': [4, 11, 17, 23],
106 | 'vitg': [9, 19, 29, 39]
107 | }
108 | self.encoder = encoder
109 | self.pretrained = DINOv2(model_name=encoder)
110 | self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
111 |
112 | def forward(self, x):
113 | patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
114 | features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
115 | depth = self.depth_head(features, patch_h, patch_w)
116 | depth = F.relu(depth)
117 | return depth # .squeeze(1)
118 |
119 | @torch.no_grad()
120 | def infer_image(self, image, input_size=518, bgr=True):
121 | image, (h, w) = self.image2tensor(image, input_size, bgr=bgr)
122 | depth = self.forward(image)
123 | depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
124 | return depth.cpu().numpy()
125 |
126 | def image2tensor(self, image, input_size=518, bgr=True):
127 | transform = Compose([
128 | Resize(width=input_size, height=input_size, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14,
129 | resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC),
130 | NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
131 | PrepareForNet(),
132 | ])
133 | h, w = image.shape[:2]
134 | if bgr: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
135 | image = transform({'image': image})['image']
136 | image = torch.from_numpy(image).unsqueeze(0)
137 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
138 | image = image.to(DEVICE)
139 | return image, (h, w)
140 |
141 |
--------------------------------------------------------------------------------
/depth/any2/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import cv2
4 | import numpy as np
5 | # import matplotlib
6 |
7 | import torch
8 |
9 | from deptha2.dpt import DepthAnythingV2
10 |
11 | from eps import img_list, img_read, basename, progbar
12 |
13 | parser = argparse.ArgumentParser(description='Depth Anything V2')
14 | parser.add_argument('-i', '--input', default='_in', help='Input image or folder')
15 | parser.add_argument('-o', '--out_dir', default='_out')
16 | parser.add_argument('-md','--maindir', default='./', help='Main directory')
17 | parser.add_argument('--encoder', default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg'])
18 | parser.add_argument('-sz', '--size', type=int, default=768) # 518
19 | parser.add_argument('--seed', default=None, type=int, help='Random seed')
20 | # parser.add_argument('--pre', action='store_true', help='display combined mix')
21 | parser.add_argument('-v', '--verbose', action='store_true')
22 | a = parser.parse_args()
23 |
24 | model_configs = {
25 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
26 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
27 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
28 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
29 | }
30 |
31 | def main():
32 | os.makedirs(a.out_dir, exist_ok=True)
33 | device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
34 |
35 | depth_anything = DepthAnythingV2(**model_configs[a.encoder])
36 | depth_anything.load_state_dict(torch.load(os.path.join(a.maindir, 'models', f'depth_anything_v2_{a.encoder}.pth'), map_location='cpu'))
37 | depth_anything = depth_anything.to(device).eval()
38 |
39 | # cmap = matplotlib.colormaps.get_cmap('Spectral_r')
40 |
41 | paths = [a.input] if os.path.isfile(a.input) else img_list(a.input)
42 | pbar = progbar(len(paths))
43 | for k, path in enumerate(paths):
44 | img_in = cv2.imread(path)
45 |
46 | depth = depth_anything.infer_image(img_in, a.size)
47 |
48 | depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
49 | depth = depth.astype(np.uint8)
50 | depth = np.repeat(depth[..., np.newaxis], 3, axis=-1)
51 | # depth = (cmap(depth)[:, :, :3] * 255)[:, :, ::-1].astype(np.uint8)
52 |
53 | # if a.pre:
54 | # split_region = np.ones((img_in.shape[0], 50, 3), dtype=np.uint8) * 255
55 | # depth = cv2.hconcat([img_in, split_region, depth])
56 |
57 | cv2.imwrite(os.path.join(a.out_dir, basename(path) + '.png'), depth)
58 | pbar.upd()
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/depth/any2/util/blocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5 | scratch = nn.Module()
6 |
7 | out_shape1 = out_shape
8 | out_shape2 = out_shape
9 | out_shape3 = out_shape
10 | if len(in_shape) >= 4:
11 | out_shape4 = out_shape
12 |
13 | if expand:
14 | out_shape1 = out_shape
15 | out_shape2 = out_shape * 2
16 | out_shape3 = out_shape * 4
17 | if len(in_shape) >= 4:
18 | out_shape4 = out_shape * 8
19 |
20 | scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
21 | scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
22 | scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
23 | if len(in_shape) >= 4:
24 | scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
25 |
26 | return scratch
27 |
28 |
29 | class ResidualConvUnit(nn.Module):
30 | """Residual convolution module.
31 | """
32 |
33 | def __init__(self, features, activation, bn):
34 | """Init.
35 |
36 | Args:
37 | features (int): number of features
38 | """
39 | super().__init__()
40 |
41 | self.bn = bn
42 |
43 | self.groups=1
44 |
45 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
46 |
47 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
48 |
49 | if self.bn == True:
50 | self.bn1 = nn.BatchNorm2d(features)
51 | self.bn2 = nn.BatchNorm2d(features)
52 |
53 | self.activation = activation
54 |
55 | self.skip_add = nn.quantized.FloatFunctional()
56 |
57 | def forward(self, x):
58 | """Forward pass.
59 |
60 | Args:
61 | x (tensor): input
62 |
63 | Returns:
64 | tensor: output
65 | """
66 |
67 | out = self.activation(x)
68 | out = self.conv1(out)
69 | if self.bn == True:
70 | out = self.bn1(out)
71 |
72 | out = self.activation(out)
73 | out = self.conv2(out)
74 | if self.bn == True:
75 | out = self.bn2(out)
76 |
77 | if self.groups > 1:
78 | out = self.conv_merge(out)
79 |
80 | return self.skip_add.add(out, x)
81 |
82 |
83 | class FeatureFusionBlock(nn.Module):
84 | """Feature fusion block.
85 | """
86 |
87 | def __init__(
88 | self,
89 | features,
90 | activation,
91 | deconv=False,
92 | bn=False,
93 | expand=False,
94 | align_corners=True,
95 | size=None
96 | ):
97 | """Init.
98 |
99 | Args:
100 | features (int): number of features
101 | """
102 | super(FeatureFusionBlock, self).__init__()
103 |
104 | self.deconv = deconv
105 | self.align_corners = align_corners
106 |
107 | self.groups=1
108 |
109 | self.expand = expand
110 | out_features = features
111 | if self.expand == True:
112 | out_features = features // 2
113 |
114 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
115 |
116 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
117 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
118 |
119 | self.skip_add = nn.quantized.FloatFunctional()
120 |
121 | self.size=size
122 |
123 | def forward(self, *xs, size=None):
124 | """Forward pass.
125 |
126 | Returns:
127 | tensor: output
128 | """
129 | output = xs[0]
130 |
131 | if len(xs) == 2:
132 | res = self.resConfUnit1(xs[1])
133 | output = self.skip_add.add(output, res)
134 |
135 | output = self.resConfUnit2(output)
136 |
137 | if (size is None) and (self.size is None):
138 | modifier = {"scale_factor": 2}
139 | elif size is None:
140 | modifier = {"size": self.size}
141 | else:
142 | modifier = {"size": size}
143 |
144 | output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
145 |
146 | output = self.out_conv(output)
147 |
148 | return output
149 |
--------------------------------------------------------------------------------
/depth/any2/util/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 |
5 | class Resize(object):
6 | """Resize sample to given size (width, height).
7 | """
8 |
9 | def __init__(
10 | self,
11 | width,
12 | height,
13 | resize_target=True,
14 | keep_aspect_ratio=False,
15 | ensure_multiple_of=1,
16 | resize_method="lower_bound",
17 | image_interpolation_method=cv2.INTER_AREA,
18 | ):
19 | """Init.
20 |
21 | Args:
22 | width (int): desired output width
23 | height (int): desired output height
24 | resize_target (bool, optional):
25 | True: Resize the full sample (image, mask, target).
26 | False: Resize image only.
27 | Defaults to True.
28 | keep_aspect_ratio (bool, optional):
29 | True: Keep the aspect ratio of the input sample.
30 | Output sample might not have the given width and height, and
31 | resize behaviour depends on the parameter 'resize_method'.
32 | Defaults to False.
33 | ensure_multiple_of (int, optional):
34 | Output width and height is constrained to be multiple of this parameter.
35 | Defaults to 1.
36 | resize_method (str, optional):
37 | "lower_bound": Output will be at least as large as the given size.
38 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
39 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
40 | Defaults to "lower_bound".
41 | """
42 | self.__width = width
43 | self.__height = height
44 |
45 | self.__resize_target = resize_target
46 | self.__keep_aspect_ratio = keep_aspect_ratio
47 | self.__multiple_of = ensure_multiple_of
48 | self.__resize_method = resize_method
49 | self.__image_interpolation_method = image_interpolation_method
50 |
51 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
52 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
53 |
54 | if max_val is not None and y > max_val:
55 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
56 |
57 | if y < min_val:
58 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
59 |
60 | return y
61 |
62 | def get_size(self, width, height):
63 | # determine new height and width
64 | scale_height = self.__height / height
65 | scale_width = self.__width / width
66 |
67 | if self.__keep_aspect_ratio:
68 | if self.__resize_method == "lower_bound":
69 | # scale such that output size is lower bound
70 | if scale_width > scale_height:
71 | # fit width
72 | scale_height = scale_width
73 | else:
74 | # fit height
75 | scale_width = scale_height
76 | elif self.__resize_method == "upper_bound":
77 | # scale such that output size is upper bound
78 | if scale_width < scale_height:
79 | # fit width
80 | scale_height = scale_width
81 | else:
82 | # fit height
83 | scale_width = scale_height
84 | elif self.__resize_method == "minimal":
85 | # scale as least as possbile
86 | if abs(1 - scale_width) < abs(1 - scale_height):
87 | # fit width
88 | scale_height = scale_width
89 | else:
90 | # fit height
91 | scale_width = scale_height
92 | else:
93 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
94 |
95 | if self.__resize_method == "lower_bound":
96 | new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
97 | new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
98 | elif self.__resize_method == "upper_bound":
99 | new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
100 | new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
101 | elif self.__resize_method == "minimal":
102 | new_height = self.constrain_to_multiple_of(scale_height * height)
103 | new_width = self.constrain_to_multiple_of(scale_width * width)
104 | else:
105 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
106 |
107 | return (new_width, new_height)
108 |
109 | def __call__(self, sample):
110 | width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
111 |
112 | # resize sample
113 | sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
114 |
115 | if self.__resize_target:
116 | if "depth" in sample:
117 | sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
118 |
119 | if "mask" in sample:
120 | sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
121 |
122 | return sample
123 |
124 |
125 | class NormalizeImage(object):
126 | """Normlize image by given mean and std.
127 | """
128 |
129 | def __init__(self, mean, std):
130 | self.__mean = mean
131 | self.__std = std
132 |
133 | def __call__(self, sample):
134 | sample["image"] = (sample["image"] - self.__mean) / self.__std
135 |
136 | return sample
137 |
138 |
139 | class PrepareForNet(object):
140 | """Prepare sample for usage as network input.
141 | """
142 |
143 | def __init__(self):
144 | pass
145 |
146 | def __call__(self, sample):
147 | image = np.transpose(sample["image"], (2, 0, 1))
148 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
149 |
150 | if "depth" in sample:
151 | depth = sample["depth"].astype(np.float32)
152 | sample["depth"] = np.ascontiguousarray(depth)
153 |
154 | if "mask" in sample:
155 | sample["mask"] = sample["mask"].astype(np.float32)
156 | sample["mask"] = np.ascontiguousarray(sample["mask"])
157 |
158 | return sample
--------------------------------------------------------------------------------
/depth/depth.py:
--------------------------------------------------------------------------------
1 | ### original method & code was by https://twitter.com/deKxi
2 |
3 | import logging
4 | logging.getLogger('xformers').setLevel(logging.ERROR) # shutup triton, before torch!
5 |
6 | import os
7 | import sys
8 | import cv2
9 | from imageio import imsave
10 | import numpy as np
11 | import PIL
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torchvision import transforms as T
16 |
17 | from aphantasia.utils import triangle_blur
18 | from .any2.dpt import DepthAnythingV2
19 |
20 | class InferDepthAny:
21 | def __init__(self, modtype='B', device=torch.device('cuda')):
22 | modtype = 'Large' if modtype[0].lower()=='l' else 'Small' if modtype[0].lower()=='s' else 'Base'
23 | from transformers import AutoModelForDepthEstimation
24 | model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-%s-hf" % modtype)
25 | self.model = model.cuda().eval()
26 |
27 | @torch.no_grad()
28 | def __call__(self, image):
29 | image = T.functional.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
30 | depth = self.model(pixel_values=image).predicted_depth.unsqueeze(0)
31 | return (depth - depth.min()) / (depth.max() - depth.min())
32 |
33 | def save_img(img, fname=None):
34 | if fname is not None:
35 | img = np.array(img)[:,:,:]
36 | img = np.transpose(img, (1,2,0))
37 | img = np.clip(img*255, 0, 255).astype(np.uint8)
38 | if img.shape[-1]==1: img = img[:,:,[0,0,0]]
39 | imsave(fname, np.array(img))
40 |
41 | def resize(img, size):
42 | return F.interpolate(img, size, mode='bicubic', align_corners=True).float().cuda()
43 |
44 | def grid_warp(img, dtensor, H, W, strength, centre, midpoint, dlens=0.05):
45 | # Building the coordinates
46 | xx = torch.linspace(-1, 1, W)
47 | yy = torch.linspace(-1, 1, H)
48 | gy, gx = torch.meshgrid(yy, xx)
49 |
50 | # Apply depth warp
51 | grid = torch.stack([gx, gy], dim=-1).cuda()
52 | d = centre - grid
53 | d_sum = dtensor[0]
54 | # Adjust midpoint / move direction
55 | d_sum = d_sum - torch.max(d_sum) * midpoint
56 | grid_warped = grid + d * d_sum.unsqueeze(-1) * strength
57 | img = F.grid_sample(img, grid_warped.unsqueeze(0).float(), mode='bilinear', align_corners=True, padding_mode='reflection')
58 |
59 | # Apply simple lens distortion to stretch periphery (instead of sphere wrap)
60 | lens_distortion = torch.sqrt((d**2).sum(axis=-1)).cuda()
61 | grid_warped = grid + d * lens_distortion.unsqueeze(-1) * strength * dlens
62 | img = F.grid_sample(img, grid_warped.unsqueeze(0).float(), mode='bilinear', align_corners=True, padding_mode='reflection')
63 |
64 | return img
65 |
66 | def depthwarp(img_t, img, infer_any, strength=0, centre=[0,0], midpoint=0.5, save_path=None, save_num=0, dlens=0.05):
67 | _, _, H, W = img.shape # [1,3,720,1280] [0..1]
68 |
69 | res = 518 # 518 on lower dimension for DepthAny
70 | dim = [res, int(res*W/H)] if H < W else [int(res*H/W), res]
71 | dim = [x - x % 14 for x in dim]
72 |
73 | image = resize(torch.lerp(img, triangle_blur(img, 5, 2), 0.5), dim) # [1,3,518,910] [0..1]
74 | depth = infer_any(image) # [1,1,h,w]
75 | depth = depth * torch.flip(infer_any(torch.flip(image, [-1])), [-1]) # enhance depth with mirrored estimation
76 | depth = resize(depth, (H,W)) # [1,1,H,W]
77 |
78 | if save_path is not None: # Save depth map out, currently its as its own image but it could just be added as an alpha channel to main image
79 | out_depth = depth.detach().clone().cpu().squeeze(0)
80 | save_img(out_depth, os.path.join(save_path, '%05d.jpg' % save_num))
81 |
82 | img = grid_warp(img_t, depth.squeeze(0), H, W, strength, torch.as_tensor(centre).cuda(), midpoint, dlens)
83 |
84 | return img
85 |
86 |
--------------------------------------------------------------------------------
/illustra.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import os
3 | import time
4 | import warnings
5 | warnings.filterwarnings("ignore")
6 | import argparse
7 | import numpy as np
8 | import random
9 | import shutil
10 |
11 | import torch
12 | import torchvision
13 | import torch.nn.functional as F
14 |
15 | import clip
16 | os.environ['KMP_DUPLICATE_LIB_OK']='True'
17 |
18 | from aphantasia.image import to_valid_rgb, fft_image
19 | from aphantasia.utils import slice_imgs, derivat, checkout, basename, file_list, img_list, img_read, txt_clean, old_torch, save_cfg, sim_func, aesthetic_model
20 | from aphantasia import transforms
21 | try: # progress bar for notebooks
22 | get_ipython().__class__.__name__
23 | from aphantasia.progress_bar import ProgressIPy as ProgressBar
24 | except: # normal console
25 | from aphantasia.progress_bar import ProgressBar
26 |
27 | clip_models = ['ViT-B/16', 'ViT-B/32', 'ViT-L/14', 'ViT-L/14@336px', 'RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101']
28 |
29 | def get_args():
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('-s', '--size', default='1280-720', help='Output resolution')
32 | parser.add_argument('-t', '--in_txt', default=None, help='input text or file - main topic')
33 | parser.add_argument('-t2', '--in_txt2', default=None, help='input text or file - style')
34 | parser.add_argument('-im', '--in_img', default=None, help='input image or directory with images')
35 | parser.add_argument('-r', '--resume', default=None, help='Resume from saved params')
36 | parser.add_argument( '--out_dir', default='_out/fft')
37 | parser.add_argument( '--save_step', default=1, type=int, help='Save every this step')
38 | parser.add_argument('-tr', '--translate', action='store_true', help='Translate with Google Translate')
39 | parser.add_argument('-v', '--verbose', dest='verbose', action='store_true')
40 | parser.add_argument('-nv', '--no-verbose', dest='verbose', action='store_false')
41 | parser.set_defaults(verbose=True)
42 | # training
43 | parser.add_argument('-m', '--model', default='ViT-B/32', choices=clip_models, help='Select CLIP model to use')
44 | parser.add_argument( '--steps', default=150, type=int, help='Iterations per input')
45 | parser.add_argument( '--samples', default=200, type=int, help='Samples to evaluate')
46 | parser.add_argument('-lr', '--lrate', default=0.05, type=float, help='Learning rate')
47 | parser.add_argument('-dm', '--dualmod', default=None, type=int, help='Every this step use another CLIP ViT model')
48 | # tweaks
49 | parser.add_argument('-opt', '--optimr', default='adam', choices=['adam', 'adamw'], help='Optimizer')
50 | parser.add_argument('-a', '--align', default='uniform', choices=['central', 'uniform', 'overscan', 'overmax'], help='Sampling distribution')
51 | parser.add_argument('-tf', '--transform', default='fast', choices=['none', 'custom', 'fast', 'elastic'], help='augmenting transforms')
52 | parser.add_argument( '--aest', default=1., type=float)
53 | parser.add_argument( '--contrast', default=1.1, type=float)
54 | parser.add_argument( '--colors', default=1.8, type=float)
55 | parser.add_argument('-d', '--decay', default=1.5, type=float)
56 | parser.add_argument('-sh', '--sharp', default=0, type=float)
57 | parser.add_argument('-mc', '--macro', default=0.4, type=float, help='Endorse macro forms 0..1 ')
58 | parser.add_argument('-e', '--enforce', default=0, type=float, help='Enhance consistency, boosts training')
59 | parser.add_argument('-n', '--noise', default=0, type=float, help='Add noise to decrease accumulation')
60 | parser.add_argument( '--sim', default='mix', help='Similarity function (dot/angular/spherical/mixed; None = cossim)')
61 | parser.add_argument( '--loop', action='store_true', help='Loop inputs [or keep the last one]')
62 | parser.add_argument( '--save_pt', action='store_true', help='save fft snapshots to pt file')
63 | # multi input
64 | parser.add_argument('-l', '--length', default=None, type=int, help='Override total length in sec')
65 | parser.add_argument( '--lsteps', default=25, type=int, help='Frames per step')
66 | parser.add_argument( '--fps', default=25, type=int)
67 | parser.add_argument( '--keep', default=1.5, type=float, help='Accumulate imagery: 0 random, 1+ ~prev')
68 | parser.add_argument( '--separate', action='store_true', help='process inputs separately')
69 | a = parser.parse_args()
70 |
71 | if a.size is not None: a.size = [int(s) for s in a.size.split('-')][::-1]
72 | if len(a.size)==1: a.size = a.size * 2
73 | if not a.separate: a.save_pt = True
74 | if a.dualmod is not None:
75 | a.model = 'ViT-B/32'
76 | a.sim = 'cossim'
77 |
78 | return a
79 |
80 | a = get_args()
81 |
82 | if a.translate is True:
83 | try:
84 | from googletrans import Translator
85 | except ImportError as e:
86 | print('\n Install googletrans module to enable translation!'); exit()
87 |
88 | def main():
89 | bx = 1.
90 |
91 | model_clip, _ = clip.load(a.model, jit=old_torch())
92 | try:
93 | a.modsize = model_clip.visual.input_resolution
94 | except:
95 | a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 448 if a.model == 'RN50x64' else 336 if '336' in a.model else 224
96 | model_clip = model_clip.eval().cuda()
97 | xmem = {'ViT-B/16':0.25, 'ViT-L/14':0.04, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN50x64':0.01, 'RN101':0.33}
98 | if a.model in xmem.keys():
99 | bx *= xmem[a.model]
100 |
101 | if a.dualmod is not None:
102 | model_clip2, _ = clip.load('ViT-B/16', jit=old_torch())
103 | bx *= 0.23 # second is vit-16
104 | dualmod_nums = list(range(a.steps))[a.dualmod::a.dualmod]
105 | print(' dual model every %d step' % a.dualmod)
106 |
107 | if a.aest != 0 and a.model in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14']:
108 | aest = aesthetic_model(a.model).cuda()
109 | if a.dualmod is not None:
110 | aest2 = aesthetic_model('ViT-B/16').cuda()
111 |
112 | if 'elastic' in a.transform:
113 | trform_f = transforms.transforms_elastic
114 | elif 'custom' in a.transform:
115 | trform_f = transforms.transforms_custom
116 | elif 'fast' in a.transform:
117 | trform_f = transforms.transforms_fast
118 | else:
119 | trform_f = transforms.normalize()
120 | bx *= 1.05
121 | bx *= 0.95
122 | if a.enforce != 0:
123 | bx *= 0.5
124 | a.samples = int(bx * a.samples)
125 |
126 | if a.translate:
127 | translator = Translator()
128 |
129 | def enc_text(txt, model_clip=model_clip):
130 | if txt is None or len(txt)==0: return None
131 | embs = []
132 | for subtxt in txt.split('|'):
133 | if ':' in subtxt:
134 | [subtxt, wt] = subtxt.split(':')
135 | wt = float(wt)
136 | else: wt = 1.
137 | emb = model_clip.encode_text(clip.tokenize(subtxt).cuda()[:77])
138 | # emb = emb / emb.norm(dim=-1, keepdim=True)
139 | embs.append([emb.detach().clone(), wt])
140 | return embs
141 |
142 | def enc_image(img, model_clip=model_clip):
143 | emb = model_clip.encode_image(img)
144 | # emb = emb / emb.norm(dim=-1, keepdim=True)
145 | return emb
146 |
147 | def proc_image(img_file, model_clip=model_clip):
148 | img_t = torch.from_numpy(img_read(img_file)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:]
149 | in_sliced = slice_imgs([img_t], a.samples, a.modsize, transforms.normalize(), a.align)[0]
150 | emb = enc_image(in_sliced, model_clip)
151 | return emb.detach().clone()
152 |
153 | def pick_(list_, num_):
154 | cnt = len(list_)
155 | if cnt == 0: return None
156 | num = num_ % cnt if a.loop is True else min(num_, cnt-1)
157 | return list_[num]
158 |
159 | def read_text(in_txt):
160 | if os.path.isfile(in_txt):
161 | with open(in_txt, 'r', encoding="utf-8") as f:
162 | lines = f.read().splitlines()
163 | texts = []
164 | for tt in lines:
165 | if len(tt.strip()) == 0: texts.append('')
166 | elif tt.strip()[0] != '#': texts.append(tt.strip())
167 | else:
168 | texts = [in_txt]
169 | return texts
170 |
171 | # Encode inputs
172 | count = 0
173 | texts = []
174 | styles = []
175 | img_paths = []
176 |
177 | if a.in_img is not None and os.path.exists(a.in_img):
178 | if a.verbose is True: print(' ref image:', basename(a.in_img))
179 | img_paths = img_list(a.in_img) if os.path.isdir(a.in_img) else [a.in_img]
180 | img_encs = [proc_image(image) for image in img_paths]
181 | if a.dualmod is not None:
182 | img_encs2 = [proc_image(image, model_clip2) for image in img_paths]
183 | count = max(count, len(img_encs))
184 |
185 | if a.in_txt is not None:
186 | if a.verbose is True: print(' topic:', a.in_txt)
187 | texts = read_text(a.in_txt)
188 | if a.translate:
189 | texts = [translator.translate(txt, dest='en').text for txt in texts]
190 | # if a.verbose is True: print(' translated to:', texts)
191 | txt_encs = [enc_text(txt) for txt in texts]
192 | if a.dualmod is not None:
193 | txt_encs2 = [enc_text(txt, model_clip2) for txt in texts]
194 | count = max(count, len(txt_encs))
195 |
196 | if a.in_txt2 is not None:
197 | if a.verbose is True: print(' style:', a.in_txt2)
198 | styles = read_text(a.in_txt2)
199 | if a.translate is True:
200 | styles = [tr.text for tr in translator.translate(styles)]
201 | # if a.verbose is True: print(' translated to:', styles)
202 | styl_encs = [enc_text(style) for style in styles]
203 | if a.dualmod is not None:
204 | styl_encs2 = [enc_text(style, model_clip2) for style in styles]
205 | count = max(count, len(styl_encs))
206 |
207 | assert count > 0, "No inputs found!"
208 |
209 | if a.verbose is True: print(' samples:', a.samples)
210 | sfx = ''
211 | if a.dualmod is None: sfx += '-%s' % a.model.replace('/','').replace('-','')
212 | if a.enforce != 0: sfx += '-e%.2g' % a.enforce
213 | # if a.noise > 0: sfx += '-n%.2g' % a.noise
214 | # if a.aest != 0: sfx += '-ae%.2g' % a.aest
215 |
216 | def train(num, i):
217 | loss = 0
218 | noise = a.noise * (torch.rand(1, 1, *params[0].shape[2:4], 1)-0.5).cuda() if a.noise > 0 else None
219 | img_out = image_f(noise)
220 | img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
221 |
222 | if a.in_txt is not None:
223 | txt_enc = pick_(txt_encs2, num) if a.dualmod is not None and i in dualmod_nums else pick_(txt_encs, num)
224 | if a.in_txt2 is not None:
225 | style_enc = pick_(styl_encs2, num) if a.dualmod is not None and i in dualmod_nums else pick_(styl_encs, num)
226 | if a.in_img is not None and os.path.isfile(a.in_img):
227 | img_enc = pick_(img_encs2, num) if a.dualmod is not None and i in dualmod_nums else pick_(img_encs, num)
228 | model_clip_ = model_clip2 if a.dualmod is not None and i in dualmod_nums else model_clip
229 | if a.aest != 0:
230 | aest_ = aest2 if a.dualmod is not None and i in dualmod_nums else aest
231 |
232 | out_enc = model_clip_.encode_image(img_sliced)
233 | if a.aest != 0 and aest_ is not None:
234 | loss -= 0.001 * a.aest * aest_(out_enc).mean()
235 | if a.in_txt is not None and txt_enc is not None: # input text - main topic
236 | for enc, wt in txt_enc:
237 | loss -= wt * sim_func(enc, out_enc, a.sim)
238 | if a.in_txt2 is not None and style_enc is not None: # input text - style
239 | for enc, wt in style_enc:
240 | loss -= wt * sim_func(enc, out_enc, a.sim)
241 | if a.in_img is not None and img_enc is not None: # input image
242 | loss -= sim_func(img_enc[:len(out_enc)], out_enc, a.sim)
243 | if a.sharp != 0: # scharr|sobel|naiv
244 | loss -= a.sharp * derivat(img_out, mode='naiv')
245 | if a.enforce != 0:
246 | img_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
247 | out_enc2 = model_clip_.encode_image(img_sliced)
248 | loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim)
249 | del out_enc2 # torch.cuda.empty_cache()
250 |
251 | del img_out, img_sliced, out_enc
252 | assert not isinstance(loss, int), ' Loss not defined, check inputs'
253 |
254 | optimizer.zero_grad()
255 | loss.backward()
256 | optimizer.step()
257 |
258 | if i % a.save_step == 0:
259 | with torch.no_grad():
260 | img = image_f(contrast=a.contrast).cpu().numpy()[0]
261 | checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.save_step)), verbose=a.verbose)
262 | pbar.upd()
263 | del img
264 |
265 |
266 | try:
267 | for num in range(count):
268 | shape = [1, 3, *a.size]
269 | global params
270 |
271 | if num == 0 or a.separate is True:
272 | resume_cur = a.resume
273 | else:
274 | opt_state = optimizer.state_dict()
275 | param_ = params[0].detach()
276 | resume_cur = [a.keep * param_ / (param_.max() - param_.min())]
277 |
278 | params, image_f, sz = fft_image(shape, 0.08, a.decay, resume_cur)
279 | if sz is not None: a.size = sz
280 | image_f = to_valid_rgb(image_f, colors = a.colors)
281 |
282 | if a.optimr.lower() == 'adamw':
283 | optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01, betas=(.0,.999), amsgrad=True)
284 | else:
285 | optimizer = torch.optim.Adam(params, a.lrate, betas=(.0, .999))
286 | if num > 0 and not a.separate: optimizer.load_state_dict(opt_state)
287 |
288 | out_names = []
289 | if a.resume is not None and num == 0: out_names += [basename(a.resume)[:12]]
290 | if a.in_txt is not None: out_names += [txt_clean(pick_(texts, num))[:32]]
291 | if a.in_txt2 is not None: out_names += [txt_clean(pick_(styles, num))[:32]]
292 | out_name = '-'.join(out_names) + sfx
293 | if count > 1: out_name = '%04d-' % (num+1) + out_name
294 | print(out_name)
295 | workdir = a.out_dir
296 | tempdir = os.path.join(workdir, out_name)
297 | os.makedirs(tempdir, exist_ok=True)
298 | if num == 0: save_cfg(a, workdir, out_name + '.txt')
299 |
300 | pbar = ProgressBar(a.steps // a.save_step)
301 | for i in range(a.steps):
302 | train(num, i)
303 |
304 | file_out = os.path.join(workdir, '%s-%d.jpg' % (out_name, a.steps))
305 | shutil.copy(img_list(tempdir)[-1], file_out)
306 | os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(workdir, out_name)))
307 | if a.save_pt is True:
308 | torch.save(params[0], '%s.pt' % os.path.join(workdir, out_name))
309 |
310 | except KeyboardInterrupt:
311 | exit()
312 |
313 | if not a.separate:
314 | vsteps = a.lsteps if a.length is None else int(a.length * a.fps / count)
315 | tempdir = os.path.join(workdir, '_final')
316 | os.makedirs(tempdir, exist_ok=True)
317 |
318 | def read_pt(file):
319 | return torch.load(file).cuda()
320 |
321 | if a.verbose is True: print(' rendering complete piece')
322 | ptfiles = file_list(workdir, 'pt')
323 | pbar = ProgressBar(vsteps * len(ptfiles))
324 | for px in range(len(ptfiles)):
325 | params1 = read_pt(ptfiles[px])
326 | params2 = read_pt(ptfiles[(px+1) % len(ptfiles)])
327 |
328 | params, image_f, sz_ = fft_image([1, 3, *a.size], resume=params1, sd=1., decay_power=a.decay)
329 | image_f = to_valid_rgb(image_f, colors = a.colors)
330 |
331 | for i in range(vsteps):
332 | with torch.no_grad():
333 | x = i/vsteps # math.sin(1.5708 * i/vsteps)
334 | img = image_f((params2 - params1) * x, contrast=a.contrast).cpu().numpy()[0]
335 | checkout(img, os.path.join(tempdir, '%05d.jpg' % (px * vsteps + i)), verbose=a.verbose)
336 | pbar.upd()
337 |
338 | os.system('ffmpeg -v warning -y -i %s/\%%05d.jpg "%s.mp4"' % (tempdir, os.path.join(a.out_dir, basename(a.in_txt))))
339 |
340 |
341 | if __name__ == '__main__':
342 | main()
343 |
--------------------------------------------------------------------------------
/illustrip.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import os
3 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
4 | import warnings
5 | warnings.filterwarnings("ignore")
6 | import argparse
7 | import numpy as np
8 | import shutil
9 | import PIL
10 | import time
11 | from imageio import imread, imsave
12 |
13 | try:
14 | from googletrans import Translator
15 | googletrans_ok = True
16 | except:
17 | googletrans_ok = False
18 |
19 | import torch
20 | import torchvision
21 | import torch.nn.functional as F
22 | from torchvision import transforms as T
23 |
24 | import clip
25 | os.environ['KMP_DUPLICATE_LIB_OK']='True'
26 |
27 | from aphantasia.image import to_valid_rgb, fft_image, resume_fft, pixel_image
28 | from aphantasia.utils import slice_imgs, derivat, sim_func, aesthetic_model, intrl, slerp, basename, file_list, img_list, img_read, pad_up_to, txt_clean, latent_anima, cvshow, checkout, save_cfg, old_torch
29 | from aphantasia import transforms
30 | from depth import depth
31 | try: # progress bar for notebooks
32 | get_ipython().__class__.__name__
33 | from aphantasia.progress_bar import ProgressIPy as ProgressBar
34 | except: # normal console
35 | from aphantasia.progress_bar import ProgressBar
36 |
37 | clip_models = ['ViT-B/16', 'ViT-B/32', 'RN50', 'RN50x4', 'RN50x16', 'RN101']
38 |
39 | def get_args():
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument('-s', '--size', default='1280-720', help='Output resolution')
42 | parser.add_argument('-t', '--in_txt', default=None, help='Text string or file to process (main topic)')
43 | parser.add_argument('-pre', '--in_txt_pre', default=None, help='Prefix for input text')
44 | parser.add_argument('-post', '--in_txt_post', default=None, help='Postfix for input text')
45 | parser.add_argument('-t2', '--in_txt2', default=None, help='Text string or file to process (style)')
46 | parser.add_argument('-t0', '--in_txt0', default=None, help='input text to subtract')
47 | parser.add_argument('-im', '--in_img', default=None, help='input image or directory with images')
48 | parser.add_argument('-wi', '--weight_img', default=0.5, type=float, help='weight for images')
49 | parser.add_argument('-r', '--resume', default=None, help='Resume from saved params or from an image')
50 | parser.add_argument( '--out_dir', default='_out')
51 | parser.add_argument('-tr', '--translate', action='store_true', help='Translate with Google Translate')
52 | parser.add_argument( '--invert', action='store_true', help='Invert criteria')
53 | parser.add_argument('-v', '--verbose', dest='verbose', action='store_true')
54 | parser.add_argument('-nv', '--no-verbose', dest='verbose', action='store_false')
55 | parser.set_defaults(verbose=True)
56 | # training
57 | parser.add_argument( '--gen', default='RGB', help='Generation (optimization) method: FFT or RGB')
58 | parser.add_argument('-m', '--model', default='ViT-B/32', choices=clip_models, help='Select CLIP model to use')
59 | parser.add_argument( '--steps', default=300, type=int, help='Iterations (frames) per scene (text line)')
60 | parser.add_argument( '--samples', default=100, type=int, help='Samples to evaluate per frame')
61 | parser.add_argument('-lr', '--lrate', default=0.1, type=float, help='Learning rate')
62 | parser.add_argument('-dm', '--dualmod', default=None, type=int, help='Every this step use another CLIP ViT model')
63 | # motion
64 | parser.add_argument('-ops', '--opt_step', default=1, type=int, help='How many optimizing steps per save/transform step')
65 | parser.add_argument('-sm', '--smooth', action='store_true', help='Smoothen interframe jittering for FFT method')
66 | parser.add_argument('-it', '--interpol', default=True, help='Interpolate topics? (or change by cut)')
67 | parser.add_argument( '--fstep', default=100, type=int, help='How many frames before changing motion')
68 | parser.add_argument( '--scale', default=0.012, type=float)
69 | parser.add_argument( '--shift', default=10., type=float, help='in pixels')
70 | parser.add_argument( '--angle', default=0.8, type=float, help='in degrees')
71 | parser.add_argument( '--shear', default=0.4, type=float)
72 | parser.add_argument( '--anima', default=True, help='Animate motion')
73 | # depth
74 | parser.add_argument('-d', '--depth', default=0, type=float, help='Add depth with such strength, if > 0')
75 | parser.add_argument( '--depth_model', default='b', help='Depth Anything model: large, base or small')
76 | parser.add_argument( '--depth_dir', default=None, help='Directory to save depth, if not None')
77 | # tweaks
78 | parser.add_argument('-a', '--align', default='overscan', choices=['central', 'uniform', 'overscan', 'overmax'], help='Sampling distribution')
79 | parser.add_argument('-tf', '--transform', default='fast', choices=['none', 'fast', 'custom', 'elastic'], help='augmenting transforms')
80 | parser.add_argument('-opt', '--optimizer', default='adam_custom', choices=['adam', 'adam_custom', 'adamw', 'adamw_custom'], help='Optimizer')
81 | parser.add_argument( '--fixcontrast', action='store_true', help='Required for proper resuming from image')
82 | parser.add_argument( '--contrast', default=1.2, type=float)
83 | parser.add_argument( '--colors', default=2.3, type=float)
84 | parser.add_argument('-sh', '--sharp', default=0, type=float)
85 | parser.add_argument('-mc', '--macro', default=0.3, type=float, help='Endorse macro forms 0..1 ')
86 | parser.add_argument( '--aest', default=0., type=float, help='Enhance aesthetics')
87 | parser.add_argument('-e', '--enforce', default=0, type=float, help='Enforce details (by boosting similarity between two parallel samples)')
88 | parser.add_argument('-x', '--expand', default=0, type=float, help='Boosts diversity (by enforcing difference between prev/next samples)')
89 | parser.add_argument('-n', '--noise', default=2., type=float, help='Add noise to make composition sparse (FFT only)') # 0.04
90 | parser.add_argument( '--sim', default='mix', help='Similarity function (angular/spherical/mixed; None = cossim)')
91 | parser.add_argument( '--rem', default=None, help='Dummy text to add to project name')
92 | a = parser.parse_args()
93 |
94 | if a.size is not None: a.size = [int(s) for s in a.size.split('-')][::-1]
95 | if len(a.size)==1: a.size = a.size * 2
96 | a.gen = a.gen.upper()
97 | a.invert = -1. if a.invert is True else 1.
98 |
99 | # Overriding some parameters, depending on other settings
100 | if a.gen == 'RGB':
101 | a.smooth = False
102 | a.align = 'overscan'
103 | if a.resume is not None: a.fixcontrast = True
104 | if a.model == 'ViT-B/16': a.sim = 'cossim'
105 |
106 | if a.translate is True and googletrans_ok is not True:
107 | print('\n Install googletrans module to enable translation!'); exit()
108 |
109 | if a.dualmod is not None:
110 | a.model = 'ViT-B/32'
111 | a.sim = 'cossim'
112 |
113 | return a
114 |
115 | def depth_transform(img_t, _deptha, depthX=0, scale=1., shift=[0,0], colors=1, depth_dir=None, save_num=0):
116 | if not isinstance(depthX, float): depthX = float(depthX)
117 | if not isinstance(scale, float): scale = float(scale[0])
118 | size = img_t.shape[-2:]
119 | # d X/Y define the origin point of the depth warp, effectively a "3D pan zoom", [-1..1]
120 | # plus = look ahead, minus = look aside
121 | dX = 100. * shift[0] / size[1]
122 | dY = 100. * shift[1] / size[0]
123 | # dZ = movement direction: 1 away (zoom out), 0 towards (zoom in), 0.5 stay
124 | dZ = 0.5 + 32. * (scale-1)
125 | def ttt(x): return x
126 | img = to_valid_rgb(ttt, colors = colors)(img_t)
127 | img = depth.depthwarp(img_t, img, _deptha, depthX, [dX,dY], dZ, save_path=depth_dir, save_num=save_num)
128 | return img
129 |
130 | def frame_transform(img, size, angle, shift, scale, shear):
131 | if old_torch(): # 1.7.1
132 | img = T.functional.affine(img, angle, tuple(shift), scale, shear, fillcolor=0, resample=PIL.Image.BILINEAR)
133 | img = T.functional.center_crop(img, size)
134 | img = pad_up_to(img, size)
135 | else: # 1.8+
136 | img = T.functional.affine(img, angle, tuple(shift), scale, shear, fill=0, interpolation=T.InterpolationMode.BILINEAR)
137 | img = T.functional.center_crop(img, size) # on 1.8+ also pads
138 | return img
139 |
140 | def main():
141 | a = get_args()
142 |
143 | # Load CLIP models
144 | model_clip, _ = clip.load(a.model, jit=old_torch())
145 | try:
146 | a.modsize = model_clip.visual.input_resolution
147 | except:
148 | a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224
149 | if a.verbose is True: print(' using model', a.model)
150 | xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33}
151 | if a.model in xmem.keys():
152 | a.samples = int(a.samples * xmem[a.model])
153 |
154 | if a.translate:
155 | translator = Translator()
156 |
157 | if a.dualmod is not None: # second is vit-16
158 | model_clip2, _ = clip.load('ViT-B/16', jit=old_torch())
159 | a.samples = int(a.samples * 0.23)
160 | dualmod_nums = list(range(a.steps))[a.dualmod::a.dualmod]
161 | print(' dual model every %d step' % a.dualmod)
162 |
163 | if a.aest != 0 and a.model in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14']:
164 | aest = aesthetic_model(a.model).cuda()
165 | if a.dualmod is not None:
166 | aest2 = aesthetic_model('ViT-B/16').cuda()
167 |
168 | if a.enforce != 0:
169 | a.samples = int(a.samples * 0.5)
170 |
171 | if 'elastic' in a.transform:
172 | trform_f = transforms.transforms_elastic
173 | a.samples = int(a.samples * 0.95)
174 | elif 'custom' in a.transform:
175 | trform_f = transforms.transforms_custom
176 | a.samples = int(a.samples * 0.95)
177 | elif 'fast' in a.transform:
178 | trform_f = transforms.transforms_fast
179 | a.samples = int(a.samples * 0.95)
180 | else:
181 | trform_f = transforms.normalize()
182 |
183 | def enc_text(txt, model_clip=model_clip):
184 | if txt is None or len(txt)==0: return None
185 | embs = []
186 | for subtxt in txt.split('|'):
187 | if ':' in subtxt:
188 | [subtxt, wt] = subtxt.split(':')
189 | wt = float(wt)
190 | else: wt = 1.
191 | emb = model_clip.encode_text(clip.tokenize(subtxt).cuda()[:77])
192 | embs.append([emb.detach().clone(), wt])
193 | return embs
194 |
195 | def enc_image(img_file, model_clip=model_clip):
196 | img_t = torch.from_numpy(img_read(img_file)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:]
197 | in_sliced = slice_imgs([img_t], a.samples, a.modsize, transforms.normalize(), a.align)[0]
198 | emb = model_clip.encode_image(in_sliced)
199 | return emb.detach().clone()
200 |
201 | def read_text(in_txt):
202 | if os.path.isfile(in_txt):
203 | with open(in_txt, 'r', encoding="utf-8") as f:
204 | lines = f.read().splitlines()
205 | texts = []
206 | for tt in lines:
207 | if len(tt.strip()) == 0: texts.append('')
208 | elif tt.strip()[0] != '#': texts.append(tt.strip())
209 | else:
210 | texts = [in_txt]
211 | return texts
212 |
213 | # Encode inputs
214 | count = 0
215 | texts = []
216 | styles = []
217 | notexts = []
218 | images = []
219 |
220 | if a.in_txt is not None:
221 | texts = read_text(a.in_txt)
222 | if a.in_txt_pre is not None:
223 | pretexts = read_text(a.in_txt_pre)
224 | texts = [' | '.join([pick_(pretexts, n), texts[n]]).strip() for n in range(len(texts))]
225 | if a.in_txt_post is not None:
226 | postexts = read_text(a.in_txt_post)
227 | texts = [' | '.join([texts[n], pick_(postexts, n)]).strip() for n in range(len(texts))]
228 | if a.translate is True:
229 | texts = [tr.text for tr in translator.translate(texts)]
230 | # print(' texts trans', texts)
231 | key_txt_encs = [enc_text(txt) for txt in texts]
232 | if a.dualmod is not None:
233 | key_txt_encs2 = [enc_text(txt, model_clip2) for txt in texts]
234 | count = max(count, len(key_txt_encs))
235 |
236 | if a.in_txt2 is not None:
237 | styles = read_text(a.in_txt2)
238 | if a.translate is True:
239 | styles = [tr.text for tr in translator.translate(styles)]
240 | # print(' styles trans', styles)
241 | key_styl_encs = [enc_text(style) for style in styles]
242 | if a.dualmod is not None:
243 | key_styl_encs2 = [enc_text(style, model_clip2) for style in styles]
244 | count = max(count, len(key_styl_encs))
245 |
246 | if a.in_txt0 is not None:
247 | notexts = read_text(a.in_txt0)
248 | if a.translate is True:
249 | notexts = [tr.text for tr in translator.translate(notexts)]
250 | # print(' notexts trans', notexts)
251 | key_not_encs = [enc_text(notext) for notext in notexts]
252 | if a.dualmod is not None:
253 | key_not_encs2 = [enc_text(notext, model_clip2) for notext in notexts]
254 | count = max(count, len(key_not_encs))
255 |
256 | if a.in_img is not None and os.path.exists(a.in_img):
257 | images = file_list(a.in_img) if os.path.isdir(a.in_img) else [a.in_img]
258 | key_img_encs = [enc_image(image) for image in images]
259 | if a.dualmod is not None:
260 | key_img_encs2 = [proc_image(image, model_clip2) for image in images]
261 | count = max(count, len(key_img_encs))
262 |
263 | assert count > 0, "No inputs found!"
264 |
265 | if a.verbose is True: print(' samples:', a.samples)
266 |
267 | global params_tmp
268 | shape = [1, 3, *a.size]
269 |
270 | if a.gen == 'RGB':
271 | params_tmp, _, sz = pixel_image(shape, a.resume)
272 | params_tmp = params_tmp[0].cuda().detach()
273 | else:
274 | params_tmp, sz = resume_fft(a.resume, shape, decay=1.5, sd=1)
275 | if sz is not None: a.size = sz
276 |
277 | if a.depth != 0:
278 | _deptha = depth.InferDepthAny(a.depth_model)
279 | if a.depth_dir is not None:
280 | os.makedirs(a.depth_dir, exist_ok=True)
281 | print(' depth dir:', a.depth_dir)
282 |
283 | steps = a.steps
284 | glob_steps = count * steps
285 | if glob_steps == a.fstep: a.fstep = glob_steps // 2 # otherwise no motion
286 |
287 | workname = basename(a.in_txt) if a.in_txt is not None else basename(a.in_img)
288 | workname = txt_clean(workname)
289 | workdir = os.path.join(a.out_dir, workname + '-%s' % a.gen.lower())
290 | if a.rem is not None: workdir += '-%s' % a.rem
291 | if a.dualmod is not None: workdir += '-dm%d' % a.dualmod
292 | if 'RN' in a.model.upper(): workdir += '-%s' % a.model
293 | tempdir = os.path.join(workdir, 'ttt')
294 | os.makedirs(tempdir, exist_ok=True)
295 | save_cfg(a, workdir)
296 | if a.in_txt is not None and os.path.isfile(a.in_txt):
297 | shutil.copy(a.in_txt, os.path.join(workdir, os.path.basename(a.in_txt)))
298 | if a.in_txt2 is not None and os.path.isfile(a.in_txt2):
299 | shutil.copy(a.in_txt2, os.path.join(workdir, os.path.basename(a.in_txt2)))
300 |
301 | midp = 0.5
302 | if a.anima:
303 | if a.gen == 'RGB': # zoom in
304 | m_scale = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[-0.3], verbose=False)
305 | m_scale = 1 + (m_scale + 0.3) * a.scale
306 | else:
307 | m_scale = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[0.6], verbose=False)
308 | m_scale = 1 - (m_scale-0.6) * a.scale
309 | m_shift = latent_anima([2], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp,midp], verbose=False)
310 | m_angle = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp], verbose=False)
311 | m_shear = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp], verbose=False)
312 | m_shift = (midp-m_shift) * a.shift * abs(m_scale-1) / a.scale
313 | m_angle = (midp-m_angle) * a.angle * abs(m_scale-1) / a.scale
314 | m_shear = (midp-m_shear) * a.shear * abs(m_scale-1) / a.scale
315 |
316 | def get_encs(encs, num):
317 | cnt = len(encs)
318 | if cnt == 0: return []
319 | enc_1 = encs[min(num, cnt-1)]
320 | enc_2 = encs[min(num+1, cnt-1)]
321 | if a.interpol is not True: return [enc_1] * steps
322 | enc_pairs = []
323 | for i in range(steps):
324 | enc1_step = []
325 | if enc_1 is not None:
326 | if isinstance(enc_1, list):
327 | for enc, wt in enc_1:
328 | enc1_step.append([enc, wt * (steps-i)/steps])
329 | else:
330 | enc1_step.append(enc_1 * (steps-i)/steps)
331 | enc2_step = []
332 | if enc_2 is not None:
333 | if isinstance(enc_2, list):
334 | for enc, wt in enc_2:
335 | enc2_step.append([enc, wt * i/steps])
336 | else:
337 | enc2_step.append(enc_2 * (steps-i)/steps)
338 | enc_pairs.append(enc1_step + enc2_step)
339 | return enc_pairs
340 |
341 | prev_enc = 0
342 | def process(num):
343 | global params_tmp, opt_state, params, image_f, optimizer
344 |
345 | txt_encs = get_encs(key_txt_encs, num)
346 | styl_encs = get_encs(key_styl_encs, num)
347 | not_encs = get_encs(key_not_encs, num)
348 | img_encs = get_encs(key_img_encs, num)
349 | if a.dualmod is not None:
350 | txt_encs2 = get_encs(key_txt_encs2, num)
351 | styl_encs2 = get_encs(key_styl_encs2, num)
352 | not_encs2 = get_encs(key_not_encs2, num)
353 | img_encs2 = get_encs(key_img_encs2, num)
354 | txt_encs = intrl(txt_encs, txt_encs2, a.dualmod)
355 | styl_encs = intrl(styl_encs, styl_encs2, a.dualmod)
356 | not_encs = intrl(not_encs, not_encs2, a.dualmod)
357 | img_encs = intrl(img_encs, img_encs2, a.dualmod)
358 | del txt_encs2, styl_encs2, not_encs2, img_encs2
359 |
360 | if a.verbose is True:
361 | if len(texts) > 0: print(' ref text: ', texts[min(num, len(texts)-1)][:80])
362 | if len(styles) > 0: print(' ref style: ', styles[min(num, len(styles)-1)][:80])
363 | if len(notexts) > 0: print(' ref avoid: ', notexts[min(num, len(notexts)-1)][:80])
364 | if len(images) > 0: print(' ref image: ', basename(images[min(num, len(images)-1)])[:80])
365 |
366 | pbar = ProgressBar(steps)
367 | for ii in range(steps):
368 | glob_step = num * steps + ii # save/transform
369 |
370 | txt_enc = txt_encs[ii % len(txt_encs)] if len(txt_encs) > 0 else None
371 | styl_enc = styl_encs[ii % len(styl_encs)] if len(styl_encs) > 0 else None
372 | not_enc = not_encs[ii % len(not_encs)] if len(not_encs) > 0 else None
373 | img_enc = img_encs[ii % len(img_encs)] if len(img_encs) > 0 else None
374 |
375 | model_clip_ = model_clip2 if a.dualmod is not None and ii in dualmod_nums else model_clip
376 | if a.aest != 0:
377 | aest_ = aest2 if a.dualmod is not None and ii in dualmod_nums else aest
378 |
379 | # MOTION: transform frame, reload params
380 |
381 | scale = m_scale[glob_step] if a.anima else 1 + a.scale
382 | shift = m_shift[glob_step] if a.anima else [0, a.shift]
383 | angle = m_angle[glob_step][0] if a.anima else a.angle
384 | shear = m_shear[glob_step][0] if a.anima else a.shear
385 |
386 | if a.gen == 'RGB':
387 | if a.depth > 0:
388 | params_tmp = depth_transform(params_tmp, _deptha, a.depth, scale, shift, a.colors, a.depth_dir, glob_step)
389 | params_tmp = frame_transform(params_tmp, a.size, angle, shift, scale, shear)
390 | params, image_f, _ = pixel_image([1, 3, *a.size], resume=params_tmp)
391 | img_tmp = None
392 |
393 | else: # FFT
394 | if old_torch(): # 1.7.1
395 | img_tmp = torch.irfft(params_tmp, 2, normalized=True, signal_sizes=a.size)
396 | if a.depth > 0:
397 | img_tmp = depth_transform(img_tmp, _deptha, a.depth, scale, shift, a.colors, a.depth_dir, glob_step)
398 | img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear)
399 | params_tmp = torch.rfft(img_tmp, 2, normalized=True)
400 | else: # 1.8+
401 | if type(params_tmp) is not torch.complex64:
402 | params_tmp = torch.view_as_complex(params_tmp)
403 | img_tmp = torch.fft.irfftn(params_tmp, s=a.size, norm='ortho')
404 | if a.depth > 0:
405 | img_tmp = depth_transform(img_tmp, _deptha, a.depth, scale, shift, a.colors, a.depth_dir, glob_step)
406 | img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear)
407 | params_tmp = torch.fft.rfftn(img_tmp, s=a.size, dim=[2,3], norm='ortho')
408 | params_tmp = torch.view_as_real(params_tmp)
409 | params, image_f, _ = fft_image([1, 3, *a.size], sd=1, resume=params_tmp)
410 |
411 | if a.optimizer.lower() == 'adamw':
412 | optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01)
413 | elif a.optimizer.lower() == 'adamw_custom':
414 | optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01, betas=(.0,.999), amsgrad=True)
415 | elif a.optimizer.lower() == 'adam':
416 | optimizer = torch.optim.Adam(params, a.lrate)
417 | else: # adam_custom
418 | optimizer = torch.optim.Adam(params, a.lrate, betas=(.0,.999))
419 | image_f = to_valid_rgb(image_f, colors = a.colors)
420 | del img_tmp
421 |
422 | if a.smooth is True and num + ii > 0:
423 | optimizer.load_state_dict(opt_state)
424 |
425 | ### optimization
426 | for ss in range(a.opt_step):
427 | loss = 0
428 |
429 | noise = a.noise * (torch.rand(1, 1, a.size[0], a.size[1]//2+1, 1)-0.5).cuda() if a.noise>0 else 0.
430 | img_out = image_f(noise, fixcontrast=a.fixcontrast)
431 |
432 | img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
433 | out_enc = model_clip_.encode_image(img_sliced)
434 |
435 | if a.aest != 0 and a.model in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14'] and aest_ is not None:
436 | loss -= 0.001 * a.aest * aest_(out_enc).mean()
437 |
438 | if a.gen == 'RGB': # empirical hack
439 | loss += abs(img_out.mean((2,3)) - 0.45).mean() # fix brightness
440 | loss += abs(img_out.std((2,3)) - 0.17).mean() # fix contrast
441 |
442 | if txt_enc is not None:
443 | for enc, wt in txt_enc:
444 | loss -= a.invert * wt * sim_func(enc, out_enc, a.sim)
445 | if styl_enc is not None:
446 | for enc, wt in styl_enc:
447 | loss -= wt * sim_func(enc, out_enc, a.sim)
448 | if not_enc is not None: # subtract text
449 | for enc, wt in not_enc:
450 | loss += wt * sim_func(enc, out_enc, a.sim)
451 | if img_enc is not None:
452 | for enc in img_enc:
453 | loss -= a.weight_img * sim_func(enc, out_enc, a.sim)
454 | if a.sharp != 0: # scharr|sobel|naive
455 | loss -= a.sharp * derivat(img_out, mode='naive')
456 | if a.enforce != 0:
457 | img_sliced = slice_imgs([image_f(noise, fixcontrast=a.fixcontrast)], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
458 | out_enc2 = model_clip_.encode_image(img_sliced)
459 | loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim)
460 | del out_enc2; torch.cuda.empty_cache()
461 | if a.expand > 0:
462 | global prev_enc
463 | if ii > 0:
464 | loss += a.expand * sim_func(prev_enc, out_enc, a.sim)
465 | prev_enc = out_enc.detach().clone()
466 | del img_out, img_sliced, out_enc; torch.cuda.empty_cache()
467 |
468 | optimizer.zero_grad()
469 | loss.backward()
470 | optimizer.step()
471 |
472 | ### save params & frame
473 |
474 | params_tmp = params[0].detach().clone()
475 | if a.smooth is True:
476 | opt_state = optimizer.state_dict()
477 |
478 | with torch.no_grad():
479 | img_t = image_f(contrast=a.contrast, fixcontrast=a.fixcontrast)[0].permute(1,2,0)
480 | img_np = torch.clip(img_t*255, 0, 255).cpu().numpy().astype(np.uint8)
481 | imsave(os.path.join(tempdir, '%06d.jpg' % glob_step), img_np, quality=95)
482 | if a.verbose is True: cvshow(img_np)
483 | del img_t, img_np
484 | pbar.upd()
485 |
486 | params_tmp = params[0].detach().clone()
487 |
488 | glob_start = time.time()
489 | try:
490 | for i in range(count):
491 | process(i)
492 | except KeyboardInterrupt:
493 | pass
494 |
495 | os.system('ffmpeg -v warning -y -i %s/\%%06d.jpg "%s.mp4"' % (tempdir, os.path.join(workdir, workname)))
496 |
497 |
498 | if __name__ == '__main__':
499 | main()
500 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy
2 | imageio
3 | ipywidgets
4 | regex
5 | tqdm
6 | # googletrans==3.1.0a0
7 | torch>=1.7.1
8 | torchvision>=0.8.2
9 | opencv-python
10 | # sentence_transformers
11 | transformers>=4.6.0
12 | kornia>=0.5.3
13 | lpips
14 | omegaconf>=2.0.0
15 | pytorch-lightning>=1.0.8
16 | einops
17 | PyWavelets>=1.1.1
18 | git+https://github.com/fbcotter/pytorch_wavelets
19 |
20 | matplotlib
21 | scipy
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='aphantasia',
5 | version='3.1.0',
6 | description='CLIP + FFT/DWT/RGB text-to-image tools',
7 | url='https://github.com/eps696/aphantasia',
8 | author='vadim epstein',
9 | packages=['aphantasia'],
10 | # packages=find_packages(),
11 | install_requires=[],
12 | classifiers=[],
13 | )
14 |
--------------------------------------------------------------------------------
/shader_expo.py:
--------------------------------------------------------------------------------
1 | # # CPPNs in GLSL
2 | # taken from https://github.com/wxs/cppn-to-glsl
3 | # Original code was for the NIPS Creativity Workshop submission 'Interactive CPPNs in GLSL'
4 | # modified Mordvintsev et al's CPPN notebook from https://github.com/tensorflow/lucid/blob/master/notebooks/differentiable-parameterizations/xy2rgb.ipynb
5 | # https://www.apache.org/licenses/LICENSE-2.0
6 |
7 | import numpy as np
8 |
9 | ### Code to convert to GLSL/HLSL
10 |
11 | def cppn_to_shader(layers, fn_name='cppn_fn', mode='shadertoy', verbose=False, fix_aspect=True, size=[1., 1.], precision=8):
12 | """
13 | Generate shader code out of the list of dicts defining trained CPPN layers
14 | mode='vvvv':
15 | Exports TextureFX shader file for vvvv
16 | mode='buffer':
17 | Exports txt file with values for dynamicbuffer input in TextureFX shader for vvvv (and optionally shader itself)
18 | mode='td':
19 | Exports code compatible with TouchDesigner: can be dropped into a 'GLSL TOP'
20 | (see https://docs.derivative.ca/GLSL_TOP). TouchDesigner can be found at http://derivative.ca
21 | mode='shadertoy':
22 | Exports code compatible with the ShaderToy editor at http://shadertoy.com
23 | mode='bookofshaders':
24 | Exports code compatible with the Book Of Shaders editor here http://editor.thebookofshaders.com/
25 | """
26 |
27 | # Set True to export TFX template for dynamic buffer mode (just once)
28 | export_tfx = False
29 |
30 | # the xy2rgb cppn's internal size is the output of its first layer (pre-activation)
31 | # so will just inspect that to figure it out
32 | n_hidden = layers[0]['weights'].shape[-1]
33 | if n_hidden % 4 != 0:
34 | raise ValueError('Currently only support multiples of 4 for hidden layer size')
35 | modes = {'vvvv', 'buffer', 'td', 'shadertoy', 'bookofshaders'}
36 | if mode not in modes:
37 | raise ValueError('Mode {} not one of the supported modes: {}'.format(mode, modes))
38 |
39 | if verbose and precision < 8: print(' .. precision', precision)
40 | fmt = '%' + '.%df' % precision
41 |
42 | global hlsl; hlsl = None
43 |
44 | if mode == 'buffer':
45 | global sbW; sbW = []
46 | buffer = True
47 | else: buffer = False
48 |
49 | if mode in ['vvvv', 'buffer']:
50 | hlsl = True
51 | snippet = """
52 | float2 R:TARGETSIZE;
53 | float4 """
54 | for i in range(2, len(layers)-2):
55 | snippet += "in%d_, " % i
56 | snippet = snippet[:-2] + ';'
57 | if mode == 'buffer':
58 | snippet += '\nStructuredBuffer sbW;'
59 | snippet += """
60 | #define mod(x,y) (x - y * floor(x/y))
61 | #define N_HIDDEN {}
62 | float4 {}(float2 uv) {{
63 | float4 bufA[N_HIDDEN/4];
64 | float4 bufB[N_HIDDEN/2];
65 | float4 tmp;
66 | bufB[0] = float4(uv.x, uv.y, 0., 0.);
67 | """.format(n_hidden, fn_name)
68 | elif mode == 'td':
69 | snippet = """
70 | uniform float uIn0;
71 | uniform float uIn1;
72 | uniform float uIn2;
73 | uniform float uIn3;
74 | out vec4 fragColor;
75 | """
76 | elif mode == 'shadertoy':
77 | snippet ="""
78 | #ifdef GL_ES
79 | precision lowp float;
80 | #endif
81 | """
82 | elif mode == 'bookofshaders':
83 | snippet ="""
84 | #ifdef GL_ES
85 | precision lowp float;
86 | #endif
87 | uniform vec2 u_resolution;
88 | uniform vec2 u_mouse;
89 | uniform float u_time;
90 | """
91 |
92 | if not mode in ['vvvv', 'buffer']:
93 | snippet += """
94 | #define N_HIDDEN {}
95 | vec4 bufA[N_HIDDEN/4];
96 | vec4 bufB[N_HIDDEN/2];
97 | vec4 {}(vec2 coordinate, float in0, float in1, float in2, float in3) {{
98 | vec4 tmp;
99 | bufB[0] = vec4(coordinate.x, coordinate.y, 0., 0.);
100 | """.format(n_hidden, fn_name)
101 |
102 | def vec(a):
103 | """Take a Python array of length 4 (or less) and output code for a GLSL vec4 or HLSL float4, possibly zero-padded at the end"""
104 | global hlsl, sbW
105 | if len(a) == 4:
106 | if hlsl is True:
107 | if 'sbW' in globals(): # check if sbW defined (working with structbuffer input instead of values)
108 | for i in range(4):
109 | sbW.append(a[i])
110 | return 'sbW[%d]' % (len(sbW)//4-1)
111 | # return 'float4({})'.format(', '.join(str(x) for x in a))
112 | return 'float4({})'.format(', '.join(fmt % x for x in a))
113 | else:
114 | # return 'vec4({})'.format(', '.join(str(x) for x in a))
115 | return 'vec4({})'.format(', '.join(fmt % x for x in a))
116 | else:
117 | assert len(a) < 4 , 'Length must less than 4'
118 | return vec(np.concatenate([a, [0.]*(4-len(a))]))
119 |
120 | def mat(a):
121 | # Take a numpy matrix of 4 rows and 4 or fewer columns, and output GLSL or HLSL code for a mat4,
122 | # possibly with zeros padded in the last columns
123 | if a.shape[0] < 4:
124 | m2 = np.vstack([a, [[0.,0.,0.,0.]] * (4 - a.shape[0])])
125 | return mat(m2)
126 | assert a.shape[0] == 4, 'Expected a of shape (4,n<=4). Got: {}.'.format(a.shape)
127 | global hlsl
128 | if hlsl is True:
129 | return 'float4x4({})'.format(', '.join(vec(row) for row in a))
130 | else:
131 | return 'mat4({})'.format(', '.join(vec(row) for row in a))
132 |
133 | for layer_i, layer_dict in enumerate(layers):
134 | weight = layer_dict['weights']
135 | bias = layer_dict['bias']
136 | activation = layer_dict['activation']
137 |
138 | _, _, from_size, to_size = weight.shape
139 | if verbose: print('Processing layer {}. from_size={}, to_size={} .. shape {}'.format(layer_i, from_size, to_size, weight.shape))
140 | snippet += '\n // layer {} \n'.format(layer_i)
141 |
142 | # First, compute the transformation from the last layer into bufA
143 | for to_index in range(max(1,to_size//4)):
144 | #Again, the max(1) is important here, because to_size is 3 for the last layer!
145 | if verbose: print(' generating output {} into bufA'.format(to_index))
146 | snippet += 'bufA[{}] = {}'.format(to_index, vec(bias[to_index*4:to_index*4+4]))
147 | if verbose: print('bufA[{}] = {} . . .'.format(to_index, vec(bias[to_index*4:to_index*4+4])))
148 | for from_index in range(max(1,from_size//4)):
149 | # the 'max' in the above loop gives us a special case for the first layer, where there are only two inputs.
150 | if mode in ['vvvv', 'buffer']:
151 | snippet += ' + mul(bufB[{}], {})'.format(from_index, mat(weight[0, 0, from_index*4:from_index*4+4, to_index*4:to_index*4+4]))
152 | # snippet += ' + mul({}, bufB[{}])'.format(mat(weight[0, 0, from_index*4:from_index*4+4, to_index*4:to_index*4+4]), from_index)
153 | else:
154 | snippet += ' + {} * bufB[{}]'.format(mat(weight[0, 0, from_index*4:from_index*4+4, to_index*4:to_index*4+4]), from_index)
155 | if mode in ['vvvv', 'buffer'] and layer_i > 1 and layer_i < len(layers)-2:
156 | suffix = ['x','y','z','w']
157 | snippet += ' + in{}_.{}'.format(layer_i, suffix[to_index%4])
158 | else:
159 | if layer_i == 3:
160 | snippet += ' + in{}'.format(to_index%4)
161 | snippet += ';\n'
162 |
163 | # print('export', layer_i, activation)
164 | if to_size != 3:
165 | if verbose: print(' Doing the activation into bufB')
166 | for to_index in range(to_size//4):
167 | if activation == 'comp':
168 | snippet += 'tmp = atan(bufA[{}]);\n'.format(to_index)
169 | snippet += 'bufB[{}] = tmp/0.67;\n'.format(to_index)
170 | snippet += 'bufB[{}] = (tmp*tmp) / 0.6;\n'.format(to_index + to_size//4)
171 | elif activation == 'unbias':
172 | snippet += 'tmp = atan(bufA[{}]);\n'.format(to_index)
173 | snippet += 'bufB[{}] = tmp/0.67;\n'.format(to_index)
174 | snippet += 'bufB[{}] = (tmp*tmp - 0.45) / 0.396;\n'.format(to_index + to_size//4)
175 | elif activation == 'relu':
176 | snippet += 'bufB[{}] = (max(bufA[{}], 0.) - 0.4) / 0.58;\n'.format(to_index, to_index)
177 | else:
178 | raise ValueError('Unknown activation: {}'.format(activation.__name__))
179 | else:
180 | if verbose: print(' Sigmoiding the last layer')
181 | # sigmoid at the last layer
182 | sigmoider = lambda s: '1. / (1. + exp(-{}))'.format(s)
183 | if mode in ['vvvv', 'buffer']:
184 | snippet += '\n return float4(({}).rgb, 1.0);\n'.format(sigmoider('bufA[0]'))
185 | # snippet += '\n return float4((1. / (1. + exp(-bufA[0]))).xyz, 1.0);\n}'
186 | else:
187 | snippet += '\n return vec4(({}).xyz, 1.0);\n'.format(sigmoider('bufA[0]'))
188 | # snippet += '\n return vec4((1. / (1. + exp(-bufA[0]))).xyz, 1.0);\n}'
189 | snippet += '}\n'
190 |
191 | if mode in ['vvvv', 'buffer']:
192 | snippet += """
193 | float4 PS(float4 p:SV_Position, float2 uv:TEXCOORD0): SV_Target {
194 | uv = 2 * (uv - 0.5);
195 | """
196 | if fix_aspect:
197 | snippet += """
198 | uv *= R/R.y;
199 | """
200 | snippet += """
201 | return {}(2*uv);
202 | }}
203 | technique10 Process
204 | {{ pass P0
205 | {{ SetPixelShader(CompileShader(ps_4_0,PS())); }}
206 | }}
207 | """.format(fn_name)
208 | elif mode == 'td':
209 | snippet += """
210 | void main() {
211 | // Normalized pixel coordinates (from 0 to 1)
212 | vec2 uv = vUV.xy;
213 | """
214 | if fix_aspect:
215 | snippet += """
216 | // TODO: don't know how to find the resolution of the GLSL Top output to fix aspect...
217 | """
218 | snippet += """
219 | // Shifted to the form expected by the CPPN
220 | uv.xy = vec2(1., -1.) * 2. * (uv.xy - vec2(0.5, 0.5));
221 | uv.y /= {} / {};
222 | // Output to screen
223 | fragColor = TDOutputSwizzle({}(uv.xy, uIn0, uIn1, uIn2, uIn3));
224 | }}
225 | """.format(float(size[0]), float(size[1]), fn_name)
226 | elif mode == 'shadertoy':
227 | snippet += """
228 | void mainImage( out vec4 fragColor, in vec2 fragCoord ) {
229 | // Normalized pixel coordinates (from 0 to 1)
230 | vec2 uv = fragCoord/iResolution.xy;
231 | vec2 mouseNorm = (iMouse.xy / iResolution.xy) - vec2(0.5, 0.5);
232 | """
233 | if fix_aspect:
234 | snippet += """
235 | uv.x *= iResolution.x / iResolution.y;
236 | uv.x -= ((iResolution.x / iResolution.y) - 1.) /2.;
237 | """
238 | snippet += """
239 | // Shifted to the form expected by the CPPN
240 | uv = vec2(1., -1.) * 1.5 * (uv - vec2(0.5, 0.5));
241 | uv.y /= {} / {};
242 | // Output to screen
243 | fragColor = {}(uv, 0.23*sin(iTime), 0.32*sin(0.69*iTime), 0.32*sin(0.44*iTime), 0.23*sin(1.23*iTime));
244 | }}
245 | """.format(float(size[0]), float(size[1]), fn_name)
246 | elif mode=='bookofshaders':
247 | snippet += """
248 | void main() {
249 | vec2 st = gl_FragCoord.xy/u_resolution.xy;
250 | """
251 | if fix_aspect:
252 | snippet += """
253 | st.x *= u_resolution.x/u_resolution.y;
254 | st.x -= ((u_resolution.x / u_resolution.y) - 1.) /2.;
255 | """
256 | snippet += """
257 | st = vec2(1., -1.) * 1.5 * (st - vec2(0.5, 0.5));
258 | st.y /= {} / {};
259 | gl_FragColor = {}(st, 0.23*sin(u_time), 0.32*sin(0.69*u_time), 0.32*sin(0.44*u_time), 0.23*sin(1.23*u_time));
260 | }}
261 | """.format(float(size[0]), float(size[1]), fn_name)
262 |
263 | if buffer is True:
264 | # buffer = ','.join('%.8f'%x for x in sbW)
265 | buffer = ','.join(fmt % x for x in sbW)
266 | if export_tfx == True:
267 | with open('CPPN-%d-%d.tfx' % (len(layers)-1, n_hidden), 'w') as f:
268 | f.write(snippet)
269 | # print(' total values', len(sbW))
270 | return buffer
271 | else:
272 | return snippet
273 |
274 |
--------------------------------------------------------------------------------