├── .gitignore
├── CITATION.cff
├── README.md
├── colabs
├── Structured_Dreaming_Styledreams.ipynb
└── Structured_Dreaming_Styledreams_faces.ipynb
├── deprecated
├── __init__.py
├── autograd.py
├── generate.py
├── model.py
├── training.py
└── utils.py
├── requirements.txt
├── res
├── styledream_face_thumb.jpeg
└── styledream_thumb.jpeg
├── setup.py
└── structure
├── __init__.py
├── clip.py
├── data
└── bpe_simple_vocab_16e6.txt
├── ops.py
├── optim.py
├── sample.py
├── stylegan_utils
├── __init__.py
└── ops.py
├── transform.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: Ekgren
5 | given-names: Ariel
6 | google-scholar-id: fhs8fggAAAAJ
7 | title: "Structured Dreaming"
8 | version: 0.0.1
9 | date-released: 2021-09-01
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Structured Dreaming
2 | ###### Note that this is a work in progress and will change or it might freeze in time in an unfinished state.
3 |
4 | ## Colabs
5 | 
6 | [Styledreams](https://colab.research.google.com/github/ekgren/StructuredDreaming/blob/main/colabs/Structured_Dreaming_Styledreams.ipynb)
7 | -- CLIP x StyleGAN2 notebook for fine-tuning StyleGAN2 from CLIP.
8 |
9 | 
10 | [Styledreams Faces](https://colab.research.google.com/github/ekgren/StructuredDreaming/blob/main/colabs/Structured_Dreaming_Styledreams_faces.ipynb)
11 | -- CLIP x StyleGAN2 notebook for finding a face from a photo in Stylegan latent space and then fine-tuning the model from CLIP to change the face.
12 |
13 | ## Optimizer
14 | In this repo The ClampSGD optimizer is used it also has it's own repository here: [https://github.com/ekgren/open-optimizers](https://github.com/ekgren/open-optimizers).
15 |
16 | ## Introduction
17 | By now it is well known that neural networks trained to classify images also have the capacity to generate
18 | images [[1]](#1). There are a lot of variations on
19 | this theme with entire artistic movements based on DeepDream [[4]](#4), libraries such as Lucid [[6]](#6)
20 | and advanced feature visualization tools such as OpenAI microscope [[7]](#7). With the release of
21 | CLIP [[5]](#5) and open research on twitter [[8]](#8) generative exploration of image networks has gained a lot of popularity.
22 |
23 | As described in Differentiable image parameterizations [[1]](#1) all these
24 | generative techniques work in the same way. Given a network used for image related tasks such
25 | as representational learning or classification we can backpropagate from a desired representation
26 | and optimize the input image towards a high activation image.
27 |
28 | The simplest parametrization of the input image is in the form of RGB values for each pixel.
29 | But naively backpropagating to the image will not work as described in the chapter
30 | [Enemy of feature visualization](https://distill.pub/2017/feature-visualization/#enemy-of-feature-vis) of Feature Visualization [[2]](#2).
31 | The network ends up “cheating” and you will end up with an image full of noise and
32 | nonsensical high-frequency patterns that the network responds strongly to.
33 |
34 | In this work we will continue to explore different techniques to avoid the "cheating" and create both informative and
35 | or visually interesting images.
36 |
37 | ## References
38 | [1]
39 | Mordvintsev, A., Pezzotti, N., Schubert, L., & Olah, C. (2018).
40 | Differentiable image parameterizations. Distill, 3(7), e12.
41 | https://distill.pub/2018/differentiable-parameterizations/
42 |
43 | [2]
44 | Olah, C., Mordvintsev, A., & Schubert, L. (2017).
45 | Feature visualization. Distill, 2(11), e7.
46 | https://distill.pub/2017/feature-visualization/
47 |
48 | [3]
49 | Goh, G., Cammarata, N., Voss, C., Carter, S., Petrov, M., Schubert, L., ... & Olah, C. (2021).
50 | Multimodal neurons in artificial neural networks. Distill, 6(3), e30.
51 |
52 | [4]
53 | https://ai.googleblog.com/2015/06/inceptionism-going-deeper-into-neural.html
54 |
55 | [5]
56 | https://github.com/openai/CLIP
57 |
58 | [6]
59 | https://github.com/tensorflow/lucid
60 |
61 | [7]
62 | https://microscope.openai.com/
63 |
64 | [8]
65 | https://twitter.com/advadnoun/status/1348375026697834496
66 |
--------------------------------------------------------------------------------
/deprecated/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ekgren/StructuredDreaming/d9040bda41fdc634cd4e653207b9938fa5c3f947/deprecated/__init__.py
--------------------------------------------------------------------------------
/deprecated/autograd.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BoostGradFunc(torch.autograd.Function):
5 | @staticmethod
6 | def forward(ctx, tensor, boost_val):
7 | ctx.set_materialize_grads(False)
8 | ctx.boost_val = boost_val
9 | return tensor
10 |
11 | @staticmethod
12 | def backward(ctx, grad_output):
13 | if grad_output is None:
14 | return None, None
15 | return grad_output * ctx.boost_val, None
16 |
17 |
18 | class ClampGradFunc(torch.autograd.Function):
19 | @staticmethod
20 | def forward(ctx, tensor, clamp_val):
21 | ctx.set_materialize_grads(False)
22 | ctx.clamp_val = clamp_val
23 | return tensor
24 |
25 | @staticmethod
26 | def backward(ctx, grad_output):
27 | if grad_output is None:
28 | return None, None
29 | return grad_output.clamp(-ctx.clamp_val, ctx.clamp_val), None
30 |
31 |
32 | class BoostGrad(torch.nn.Module):
33 | """
34 | This class is a PyTorch module that takes a tensor as input and boosts its gradient.
35 |
36 | Parameters
37 | ----------
38 | boost_val : float
39 | The value by which the gradient of the input tensor is boosted.
40 | """
41 |
42 | def __init__(self, boost_val: float = 1e1):
43 | super().__init__()
44 | self.boost_grad = BoostGradFunc.apply
45 | self.boost_val = boost_val
46 |
47 | def forward(self, input: torch.Tensor) -> torch.Tensor:
48 | return self.boost_grad(input, self.boost_val)
49 |
50 |
51 | class ClampGrad(torch.nn.Module):
52 | """
53 | This class is a PyTorch module that takes a tensor as input and clamps its gradient.
54 |
55 | Parameters
56 | ----------
57 | clamp_val : float
58 | The value to clamp the gradients to.
59 |
60 | Methods
61 | -------
62 | forward(input)
63 | Clamps the gradients of the input tensor to be between -clamp_val and clamp_val.
64 | """
65 |
66 | def __init__(self, clamp_val: float = 1e-6):
67 | super().__init__()
68 | self.clamp_grad = ClampGradFunc.apply
69 | self.clamp_val = clamp_val
70 |
71 | def forward(self, input: torch.Tensor) -> torch.Tensor:
72 | return self.clamp_grad(input, self.clamp_val)
73 |
--------------------------------------------------------------------------------
/deprecated/generate.py:
--------------------------------------------------------------------------------
1 | """ Generate images with CLIP """
2 |
3 | import json
4 | import os
5 |
6 | import click
7 | import numpy as np
8 | import PIL.Image
9 | import torch
10 | import torchvision
11 | from tqdm import tqdm
12 |
13 | from structure.clip import load, tokenize, convert_weights
14 | from deprecated.model import ImgBaseOld
15 | from structure.utils import (
16 | Pipeline,
17 | Upscale,
18 | Pixelate,
19 | Dropper,
20 | Prod,
21 | SamplePatch,
22 | grad_sign,
23 | model_to_fp32,
24 | ArgDict,
25 | )
26 |
27 |
28 | @click.command()
29 | @click.pass_context
30 |
31 | # General options.
32 | @click.option("--text", help="Text prompt", required=True)
33 | @click.option("--seed", type=int, help="Random seed", default=0)
34 | @click.option(
35 | "--outdir",
36 | help="Where to save the output images",
37 | type=str,
38 | required=True,
39 | metavar="DIR",
40 | )
41 |
42 | # Training.
43 | @click.option(
44 | "--iterations",
45 | help="Number of iterations of generating image [default: 1000]",
46 | type=int,
47 | default=1000,
48 | )
49 | @click.option(
50 | "--grad_acc_steps",
51 | help="Gradient accumulation steps [default: 1]",
52 | type=int,
53 | default=1,
54 | )
55 | @click.option("--lr", help="Learning rate [default: 0.01]", type=float, default=0.01)
56 | # TODO: add betas = (0.99, 0.999)
57 |
58 | # Model.
59 | @click.option("--image_size", help="Image size [default: 512]", type=int, default=512)
60 | @click.option(
61 | "--weight_init", help="Image weight init [default: 0.05]", type=float, default=0.05
62 | )
63 | @click.option(
64 | "--decolorize", help="Image weight init [default: 0.001]", type=float, default=0.001
65 | )
66 | @click.option(
67 | "--darken", help="Image weight init [default: 0.005]", type=float, default=0.005
68 | )
69 |
70 | # General img options.
71 | # TODO: add mode = 'area'
72 |
73 | # Pixelate pipeline.
74 | @click.option("--px_no", help="Number of patches [default: 32]", type=int, default=32)
75 | @click.option(
76 | "--px_patch_size_min",
77 | help="Pixelate patch min size [default: 256]",
78 | type=int,
79 | default=256,
80 | )
81 | @click.option(
82 | "--px_patch_size_max",
83 | help="Pixelate patch max size [default: 512]",
84 | type=int,
85 | default=512,
86 | )
87 | @click.option(
88 | "--px_size_min", help="Pixelation min size [default: 32]", type=int, default=32
89 | )
90 | @click.option(
91 | "--px_size_max", help="Pixelation max size [default: 224]", type=int, default=224
92 | )
93 | @click.option(
94 | "--px_drop", help="Pixelation dropout [default: 0.3]", type=float, default=0.3
95 | )
96 |
97 | # Upscale pipeline.
98 | @click.option("--up_no", help="Number of patches [default: 32]", type=int, default=32)
99 | @click.option(
100 | "--up_patch_size_min",
101 | help="Upscale patch min size [default: 64]",
102 | type=int,
103 | default=64,
104 | )
105 | @click.option(
106 | "--up_patch_size_max",
107 | help="Upscale patch max size [default: 512]",
108 | type=int,
109 | default=512,
110 | )
111 | @click.option(
112 | "--up_drop", help="Pixelation dropout [default: 0.3]", type=float, default=0.3
113 | )
114 |
115 | # Grow parameters.
116 | @click.option(
117 | "--grow_init_res", help="Number of patches [default: 32]", type=int, default=32
118 | )
119 | @click.option(
120 | "--grow_step_size", help="Number of patches [default: 64]", type=int, default=64
121 | )
122 | @click.option(
123 | "--grow_step", help="Number of patches [default: 20]", type=int, default=20
124 | )
125 | def main(ctx, **config_kwargs):
126 | """Generate images.
127 | Examples:
128 | """
129 | args = ArgDict(**config_kwargs)
130 | # Print options.
131 | print()
132 | print("Generation options:")
133 | print(json.dumps(args, indent=2))
134 | print()
135 | print(f"Output directory: {args.outdir}")
136 | print(f"Prompt: {args.text}")
137 | print()
138 |
139 | # Initialize
140 | device = torch.device("cuda")
141 | np.random.seed(args["seed"])
142 | torch.manual_seed(args["seed"])
143 | os.makedirs(args["outdir"], exist_ok=True)
144 |
145 | #################
146 | # PARAMETERS
147 | #################
148 |
149 | # Training
150 | betas = (0.99, 0.999)
151 |
152 | # General img options
153 | res_out = 224
154 | mode = "area"
155 | #################
156 |
157 | print("Loading clip.")
158 | perceptor, normalize_image = load("ViT-B/32", jit=False)
159 | txt_tok = tokenize(args.text)
160 | text_latent = perceptor.encode_text(txt_tok.cuda()).detach()
161 |
162 | # Setting up image generation
163 | model = ImgBaseOld(
164 | size=args.image_size,
165 | weight_init=args.weight_init,
166 | decolorize=args.decolorize,
167 | darken=args.darken,
168 | ).cuda()
169 | normalize = torchvision.transforms.Normalize(
170 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
171 | )
172 | optimizer = torch.optim.Adam(model.parameters(), args.lr, betas=betas)
173 |
174 | grow_res = args.grow_init_res
175 | grow_pipeline = Pipeline(
176 | Pixelate(scale_size_min=grow_res, scale_size_max=grow_res),
177 | Upscale(res_out=args.image_size, mode=mode),
178 | )
179 |
180 | px_pipeline = Pipeline(
181 | SamplePatch(size_min=args.px_patch_size_min, size_max=args.px_patch_size_max),
182 | Dropper(drop=args.px_drop, drop2d=True),
183 | Pixelate(scale_size_min=args.px_size_min, scale_size_max=args.px_size_max),
184 | Upscale(res_out=res_out, mode=mode),
185 | Dropper(drop=args.px_drop, drop2d=True),
186 | )
187 |
188 | up_pipeline = Pipeline(
189 | SamplePatch(args.up_patch_size_min, args.up_patch_size_max),
190 | Dropper(drop=args.up_drop, drop2d=True),
191 | Upscale(res_out=res_out, mode=mode),
192 | Dropper(drop=args.up_drop, drop2d=True),
193 | )
194 |
195 | patches = [px_pipeline] * args.px_no + [up_pipeline] * args.up_no
196 | patches = Prod(*patches)
197 |
198 | # Image generation
199 | print("Generating image.")
200 | for i in tqdm(range(args.iterations)):
201 | optimizer.zero_grad()
202 | img = normalize_image(model())
203 | img = grow_pipeline(img)
204 |
205 | img_processed = torch.cat(patches(img), 0)
206 | img_latents = perceptor.encode_image(img_processed)
207 |
208 | loss = (
209 | 10 * torch.cosine_similarity(text_latent, img_latents, dim=-1).mean().neg()
210 | )
211 | loss.backward()
212 |
213 | if (i + 1) % args.grad_acc_steps == 0:
214 | model_to_fp32(perceptor.visual)
215 | model_to_fp32(model)
216 | grad_sign(model.w.grad)
217 | optimizer.step()
218 | convert_weights(perceptor.visual)
219 | convert_weights(model)
220 |
221 | # Post processing
222 | model.post_process()
223 |
224 | # Update grow resolution
225 | if (i + 1) % args.grow_step == 0:
226 | grow_res = min(args.image_size, grow_res + args.grow_step_size)
227 | grow_pipeline = Pipeline(
228 | Pixelate(scale_size_min=grow_res, scale_size_max=grow_res),
229 | Upscale(res_out=args.image_size, mode=mode),
230 | )
231 |
232 | # DEBUG
233 | if (i + 1) % 20 == 0:
234 | with torch.no_grad():
235 | img = model()
236 | _img = (img.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8)
237 | f_name = f'{args.outdir}/{args.text.replace(" ", "_")}__seed{args.seed:04d}.png'
238 | PIL.Image.fromarray(_img[0].cpu().numpy(), "RGB").save(f_name)
239 |
240 | with torch.no_grad():
241 | img = model()
242 | _img = (img.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8)
243 | f_name = f'{args.outdir}/{args.text.replace(" ", "_")}__seed{args.seed:04d}.png'
244 | PIL.Image.fromarray(_img[0].cpu().numpy(), "RGB").save(f_name)
245 |
246 |
247 | if __name__ == "__main__":
248 | main()
249 |
--------------------------------------------------------------------------------
/deprecated/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from deprecated.autograd import BoostGrad
4 |
5 |
6 | class ImgBase(torch.nn.Module):
7 | def __init__(self, size: int = 224, k: float = 5.0, weight_init: float = 0.05):
8 | super().__init__()
9 | self.size = size
10 | self.k = k
11 |
12 | self.color = torch.nn.Parameter(
13 | torch.tensor(
14 | [
15 | [-0.1409, 0.0855, -0.7620],
16 | [0.2596, -0.5239, 0.0996],
17 | [0.1653, -0.0719, 0.0889],
18 | ]
19 | )
20 | )
21 | self.w = torch.nn.Parameter(
22 | torch.randn(1, 3, size, size, requires_grad=True) * weight_init
23 | )
24 |
25 | def forward(self) -> torch.Tensor:
26 | img = self.w
27 | color = self.color / self.color.norm(p=2)
28 | img = torch.nn.functional.linear(img.permute(0, 2, 3, 1), color).permute(
29 | 0, 3, 1, 2
30 | )
31 | img = self.to_rgb(img, self.k)
32 | return img
33 |
34 | def to_rgb(self, input: torch.Tensor, k: float) -> torch.Tensor:
35 | return (input.clamp(-k, k) + k) / (2 * k)
36 |
37 |
38 | class ImgBaseOld(torch.nn.Module):
39 | """X"""
40 |
41 | def __init__(self, size=224, weight_init=0.05, decolorize=0.0, darken=0.0):
42 | super().__init__()
43 | self.decolorize = decolorize
44 | self.darken = darken
45 | self.w = torch.ones(1, 3, size, size, requires_grad=True) * weight_init
46 | self.w = torch.nn.Parameter(self.w.half())
47 |
48 | def forward(self):
49 | return self.w
50 |
51 | def post_process(self):
52 | with torch.no_grad():
53 | self.w.clamp_(0.0, 1.0)
54 | if self.decolorize > 0.0:
55 | self.w += self.decolorize * (
56 | -self.w + self.w.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
57 | )
58 | if self.darken > 0.0:
59 | self.w *= 1.0 - self.darken
60 |
61 |
62 | class ImgBaseFFT(torch.nn.Module):
63 | """X"""
64 |
65 | def __init__(self, size=224, k=15.0, weight_init=0.05):
66 | super().__init__()
67 | self.size = size
68 | self.k = k
69 | self.color = torch.nn.Linear(3, 3, bias=False)
70 | w = torch.fft.rfft2(
71 | torch.randn(1, 3, size, size, requires_grad=True) * weight_init
72 | )
73 | self.w = torch.nn.Parameter(w)
74 | self.act = torch.sin
75 | self.bg = BoostGrad()
76 | self.norm = ChanNorm(dim=3)
77 |
78 | def forward(self):
79 | img = torch.fft.irfft2(self.w)
80 | img = self.bg.apply(img)
81 | img = self.norm(img)
82 | img = self.color(img.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
83 | return img
84 |
85 | def get_img(self, size=None):
86 | size = size if size is not None else self.size
87 | img = torch.fft.irfft2(self.w)
88 | if size != self.size:
89 | img = torch.nn.functional.interpolate(img, (size, size), mode="area")
90 | img = self.norm(img)
91 | img = self.color(img.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
92 | return (img.clamp(-self.k, self.k) + self.k) / (2 * self.k)
93 |
94 |
95 | # From: https://github.com/lucidrains/stylegan2-pytorch
96 | class ChanNorm(torch.nn.Module):
97 | def __init__(self, dim, eps=1e-5):
98 | super().__init__()
99 | self.eps = eps
100 | self.g = torch.nn.Parameter(torch.ones(1, dim, 1, 1))
101 | self.b = torch.nn.Parameter(torch.zeros(1, dim, 1, 1))
102 |
103 | def forward(self, x):
104 | std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()
105 | mean = torch.mean(x, dim=1, keepdim=True)
106 | return (x - mean) / (std + self.eps) * self.g + self.b
107 |
--------------------------------------------------------------------------------
/deprecated/training.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def training_loop(
6 | random_seed=0, # Global random seed.
7 | ):
8 | # Initialize.
9 | device = torch.device("cuda")
10 | np.random.seed(random_seed)
11 | torch.manual_seed(random_seed)
12 |
--------------------------------------------------------------------------------
/deprecated/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from torch.nn.functional import dropout, dropout2d, interpolate, avg_pool2d
4 | import torch
5 |
6 |
7 | class SamplePatch(object):
8 | def __init__(self, size_min, size_max):
9 | assert (
10 | size_min <= size_max
11 | ), "size_min should be equal or smaller than size_max."
12 | assert size_min >= 2, "size_min can't be smaller than 2."
13 |
14 | self.size_min = size_min
15 | self.size_max = size_max
16 |
17 | def __call__(self, img):
18 | """Get uniformly sampled patch from image."""
19 | bs, ch, h, w = img.shape
20 | assert h == w, "Assumes equal width and height of image."
21 | assert (
22 | self.size_max <= h
23 | ), "size_max has to be smaller than or equal to image height."
24 | if self.size_min == self.size_max:
25 | patch_size = self.size_max
26 | else:
27 | patch_size = random.randint(self.size_min, self.size_max)
28 | h_offset = random.randint(0, h - patch_size)
29 | w_offset = random.randint(0, w - patch_size)
30 | patch_h_s = h_offset
31 | patch_h_e = h_offset + patch_size
32 | patch_w_s = w_offset
33 | patch_w_e = w_offset + patch_size
34 | img_out = img[:, :, patch_h_s:patch_h_e, patch_w_s:patch_w_e]
35 | return img_out
36 |
37 |
38 | class Dropper(object):
39 | def __init__(self, drop=0.0, drop2d=False):
40 | self.drop = drop
41 | self.drop2d = drop2d
42 |
43 | def __call__(self, img):
44 | if not self.drop2d:
45 | return dropout(img, p=self.drop)
46 | else:
47 | return dropout2d(img, p=self.drop)
48 |
49 |
50 | class Upscale(object):
51 | def __init__(self, res_out=224, mode="area"):
52 | self.res_out = res_out
53 | self.mode = mode
54 |
55 | def __call__(self, img):
56 | bs, ch, h, w = img.shape
57 | if h != self.res_out:
58 | return interpolate(img, (self.res_out, self.res_out), mode=self.mode)
59 | else:
60 | return img
61 |
62 |
63 | class Pixelate(object):
64 | def __init__(self, scale_size_min, scale_size_max):
65 | self.scale_size_max = scale_size_max
66 | self.scale_size_min = scale_size_min
67 |
68 | def __call__(self, img):
69 | bs, ch, h, w = img.shape
70 | downscale = random.randint(self.scale_size_min, self.scale_size_max)
71 | kernel_size = int(max(1, h / downscale))
72 | return avg_pool2d(img, kernel_size)
73 |
74 |
75 | class RandomPool(object):
76 | """
77 | Randomly pads the input image with a value between -1 and 1.
78 |
79 | Args:
80 | img (torch.Tensor): Input image.
81 | pad_val (float, optional): The value to pad with. Defaults to random.random() * 2. - 1.
82 |
83 | Returns:
84 | torch.Tensor: The padded image.
85 | """
86 |
87 | def __init__(self, kernel_min=1, kernel_max=8):
88 | self.kernel_min = kernel_min
89 | self.kernel_max = kernel_max
90 |
91 | def __call__(self, img, pad_val=None):
92 | pad_val = pad_val if pad_val is not None else random.random() * 2.0 - 1.0
93 | # Double uniform sample different distribution
94 | kernel_size = random.randint(
95 | self.kernel_min, random.randint(self.kernel_min, self.kernel_max)
96 | )
97 | # kernel_size = random.randint(self.kernel_min, self.kernel_max)
98 | img = avg_pool2d(img, kernel_size=kernel_size, stride=None, padding=0)
99 | return img
100 |
101 |
102 | class RandomMirror(object):
103 | def __init__(self, blend=False):
104 | # Blend not properly implemented
105 | self.modes = 2 if blend is False else 4
106 |
107 | def __call__(self, img):
108 | bs, ch, h, w = img.shape
109 | augmentation = random.randint(0, self.modes)
110 | if augmentation == 0:
111 | pass
112 | elif augmentation == 1:
113 | img_l = img[:, :, :, : w // 2]
114 | img_r = torch.flip(img[:, :, :, : w // 2], [3])
115 | img = torch.cat([img_l, img_r], dim=3)
116 | elif augmentation == 2:
117 | img_l = torch.flip(img[:, :, :, w // 2 :], [3])
118 | img_r = img[:, :, :, w // 2 :]
119 | img = torch.cat([img_l, img_r], dim=3)
120 | elif augmentation == 3:
121 | img_l = img[:, :, :, : w // 2]
122 | img_r = (
123 | img[:, :, :, w // 2 :] + torch.flip(img[:, :, :, : w // 2], [3])
124 | ) / 2
125 | img = torch.cat([img_l, img_r], dim=3)
126 | elif augmentation == 4:
127 | img_l = (
128 | img[:, :, :, : w // 2] + torch.flip(img[:, :, :, w // 2 :], [3])
129 | ) / 2
130 | img_r = img[:, :, :, w // 2 :]
131 | img = torch.cat([img_l, img_r], dim=3)
132 | return img
133 |
134 |
135 | class RandomPad(object):
136 | def __init__(self, pad_min=0, pad_max=224, step=1):
137 | self.pad_min = pad_min
138 | self.pad_max = pad_max
139 | self.step = step
140 |
141 | def __call__(self, img):
142 | pad_modes = ["circular", "reflect", "replicate"]
143 | pad = (
144 | random.randrange(self.pad_min, self.pad_max, self.step),
145 | random.randrange(self.pad_min, self.pad_max, self.step),
146 | random.randrange(self.pad_min, self.pad_max, self.step),
147 | random.randrange(self.pad_min, self.pad_max, self.step),
148 | )
149 | return torch.nn.functional.pad(img, pad, mode=random.choice(pad_modes))
150 |
151 |
152 | class Flip(object):
153 | """Horizontal random flip p=0.5"""
154 |
155 | def __init__(self):
156 | pass
157 |
158 | def __call__(self, img):
159 | if random.randint(0, 1) == 0:
160 | return img
161 | return torch.flip(img, [3])
162 |
163 |
164 | class Pipeline(object):
165 | def __init__(self, *args):
166 | self.functions = args
167 |
168 | def __call__(self, img):
169 | tmp = img
170 | for f in self.functions:
171 | tmp = f(tmp)
172 | return tmp
173 |
174 |
175 | class Prod(object):
176 | def __init__(self, *args):
177 | self.functions = args
178 |
179 | def __call__(self, x):
180 | return [f(x) for f in self.functions]
181 |
182 |
183 | def get_sub(img, steps=2):
184 | """Get grid of image.
185 | Only support certain image sizes."""
186 | bs, ch, h, w = img.shape
187 | outs = []
188 | for h_step in range(steps):
189 | for w_step in range(steps):
190 | h_ixs = torch.arange(h_step, h, steps)
191 | w_ixs = torch.arange(w_step, h, steps)
192 | _w = w_ixs.repeat(int(h / steps))
193 | _h = h_ixs.repeat_interleave(int(h / steps))
194 | outs.append(img[:, :, _h, _w].reshape(1, 3, int(h / steps), int(h / steps)))
195 | return outs
196 |
197 |
198 | def grad_drop(grad, drop=0.0, drop2d=True):
199 | """Dropout gradient."""
200 | with torch.no_grad():
201 | if not drop2d:
202 | grad += -grad + dropout(grad, drop)
203 | else:
204 | grad += -grad + dropout2d(grad, drop)
205 |
206 |
207 | def grad_sign(grad):
208 | """Convert gradient to signed gradient."""
209 | with torch.no_grad():
210 | grad += -grad + grad.sign()
211 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy
2 | regex
3 | tqdm
4 | click
5 | torch~=1.7.1
6 | torchvision~=0.8.2
--------------------------------------------------------------------------------
/res/styledream_face_thumb.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ekgren/StructuredDreaming/d9040bda41fdc634cd4e653207b9938fa5c3f947/res/styledream_face_thumb.jpeg
--------------------------------------------------------------------------------
/res/styledream_thumb.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ekgren/StructuredDreaming/d9040bda41fdc634cd4e653207b9938fa5c3f947/res/styledream_thumb.jpeg
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup, find_packages
4 |
5 | setup(name='structureddreaming',
6 | version='0.0.1',
7 | description='Structured Dreaming',
8 | author='Ariel Ekgren',
9 | author_email='',
10 | url='https://github.com/ekgren/StructuredDreaming',
11 | install_requires=[],
12 | packages=find_packages(),
13 | entry_points={}
14 | )
15 |
--------------------------------------------------------------------------------
/structure/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ekgren/StructuredDreaming/d9040bda41fdc634cd4e653207b9938fa5c3f947/structure/__init__.py
--------------------------------------------------------------------------------
/structure/clip.py:
--------------------------------------------------------------------------------
1 | # Code from https://github.com/openai/CLIP/
2 |
3 | from collections import OrderedDict
4 | from typing import Tuple, Union
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 | from pathlib import Path
10 |
11 | import hashlib
12 | import os
13 | import urllib
14 | import warnings
15 | from typing import Union, List
16 |
17 | import torch
18 | from PIL import Image
19 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
20 | from tqdm import tqdm
21 |
22 | _MODELS = {
23 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
24 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
25 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
26 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
27 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
28 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
29 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
30 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
31 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
32 | }
33 |
34 |
35 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
36 | os.makedirs(root, exist_ok=True)
37 | filename = os.path.basename(url)
38 |
39 | expected_sha256 = url.split("/")[-2]
40 | download_target = os.path.join(root, filename)
41 |
42 | if os.path.exists(download_target) and not os.path.isfile(download_target):
43 | raise RuntimeError(f"{download_target} exists and is not a regular file")
44 |
45 | if os.path.isfile(download_target):
46 | if (
47 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
48 | == expected_sha256
49 | ):
50 | return download_target
51 | else:
52 | warnings.warn(
53 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
54 | )
55 |
56 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
57 | with tqdm(
58 | total=int(source.info().get("Content-Length")),
59 | ncols=80,
60 | unit="iB",
61 | unit_scale=True,
62 | ) as loop:
63 | while True:
64 | buffer = source.read(8192)
65 | if not buffer:
66 | break
67 |
68 | output.write(buffer)
69 | loop.update(len(buffer))
70 |
71 | if (
72 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
73 | != expected_sha256
74 | ):
75 | raise RuntimeError(
76 | f"Model has been downloaded but the SHA256 checksum does not not match"
77 | )
78 |
79 | return download_target
80 |
81 |
82 | def _transform():
83 | return Compose(
84 | [
85 | Normalize(
86 | (0.48145466, 0.4578275, 0.40821073),
87 | (0.26862954, 0.26130258, 0.27577711),
88 | ),
89 | ]
90 | )
91 |
92 |
93 | def available_models() -> List[str]:
94 | """Returns the names of available CLIP models"""
95 | return list(_MODELS.keys())
96 |
97 |
98 | def load(
99 | name: str,
100 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
101 | jit=True,
102 | ):
103 | """Load a CLIP model
104 | Parameters
105 | ----------
106 | name : str
107 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
108 | device : Union[str, torch.device]
109 | The device to put the loaded model
110 | jit : bool
111 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
112 | Returns
113 | -------
114 | model : torch.nn.Module
115 | The CLIP model
116 | preprocess : Callable[[PIL.Image], torch.Tensor]
117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118 | """
119 | if name in _MODELS:
120 | model_path = _download(_MODELS[name])
121 | elif os.path.isfile(name):
122 | model_path = name
123 | else:
124 | raise RuntimeError(
125 | f"Model {name} not found; available models = {available_models()}"
126 | )
127 |
128 | try:
129 | # loading JIT archive
130 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
131 | state_dict = None
132 | except RuntimeError:
133 | # loading saved state dict
134 | if jit:
135 | warnings.warn(
136 | f"File {model_path} is not a JIT archive. Loading as a state dict instead"
137 | )
138 | jit = False
139 | state_dict = torch.load(model_path, map_location="cpu")
140 |
141 | if not jit:
142 | model = build_model(state_dict or model.state_dict()).to(device)
143 | if str(device) == "cpu":
144 | model.float()
145 | return model, _transform()
146 |
147 | # patch the device names
148 | device_holder = torch.jit.trace(
149 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
150 | )
151 | device_node = [
152 | n
153 | for n in device_holder.graph.findAllNodes("prim::Constant")
154 | if "Device" in repr(n)
155 | ][-1]
156 |
157 | def patch_device(module):
158 | graphs = [module.graph] if hasattr(module, "graph") else []
159 | if hasattr(module, "forward1"):
160 | graphs.append(module.forward1.graph)
161 |
162 | for graph in graphs:
163 | for node in graph.findAllNodes("prim::Constant"):
164 | if "value" in node.attributeNames() and str(node["value"]).startswith(
165 | "cuda"
166 | ):
167 | node.copyAttributes(device_node)
168 |
169 | model.apply(patch_device)
170 | patch_device(model.encode_image)
171 | patch_device(model.encode_text)
172 |
173 | # patch dtype to float32 on CPU
174 | if str(device) == "cpu":
175 | float_holder = torch.jit.trace(
176 | lambda: torch.ones([]).float(), example_inputs=[]
177 | )
178 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
179 | float_node = float_input.node()
180 |
181 | def patch_float(module):
182 | graphs = [module.graph] if hasattr(module, "graph") else []
183 | if hasattr(module, "forward1"):
184 | graphs.append(module.forward1.graph)
185 |
186 | for graph in graphs:
187 | for node in graph.findAllNodes("aten::to"):
188 | inputs = list(node.inputs())
189 | for i in [
190 | 1,
191 | 2,
192 | ]: # dtype can be the second or third argument to aten::to()
193 | if inputs[i].node()["value"] == 5:
194 | inputs[i].node().copyAttributes(float_node)
195 |
196 | model.apply(patch_float)
197 | patch_float(model.encode_image)
198 | patch_float(model.encode_text)
199 |
200 | model.float()
201 |
202 | return model, _transform()
203 |
204 |
205 | def tokenize(
206 | texts: Union[str, List[str]], context_length: int = 77
207 | ) -> torch.LongTensor:
208 | """
209 | Returns the tokenized representation of given input string(s)
210 | Parameters
211 | ----------
212 | texts : Union[str, List[str]]
213 | An input string or a list of input strings to tokenize
214 | context_length : int
215 | The context length to use; all CLIP models use 77 as the context length
216 | Returns
217 | -------
218 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
219 | """
220 | if isinstance(texts, str):
221 | texts = [texts]
222 |
223 | sot_token = _tokenizer.encoder["<|startoftext|>"]
224 | eot_token = _tokenizer.encoder["<|endoftext|>"]
225 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
227 |
228 | for i, tokens in enumerate(all_tokens):
229 | if len(tokens) > context_length:
230 | raise RuntimeError(
231 | f"Input {texts[i]} is too long for context length {context_length}"
232 | )
233 | result[i, : len(tokens)] = torch.tensor(tokens)
234 |
235 | return result
236 |
237 |
238 | class Bottleneck(nn.Module):
239 | expansion = 4
240 |
241 | def __init__(self, inplanes, planes, stride=1):
242 | super().__init__()
243 |
244 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
245 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
246 | self.bn1 = nn.BatchNorm2d(planes)
247 |
248 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
249 | self.bn2 = nn.BatchNorm2d(planes)
250 |
251 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
252 |
253 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
254 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
255 |
256 | self.relu = nn.ReLU(inplace=True)
257 | self.downsample = None
258 | self.stride = stride
259 |
260 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
261 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
262 | self.downsample = nn.Sequential(
263 | OrderedDict(
264 | [
265 | ("-1", nn.AvgPool2d(stride)),
266 | (
267 | "0",
268 | nn.Conv2d(
269 | inplanes,
270 | planes * self.expansion,
271 | 1,
272 | stride=1,
273 | bias=False,
274 | ),
275 | ),
276 | ("1", nn.BatchNorm2d(planes * self.expansion)),
277 | ]
278 | )
279 | )
280 |
281 | def forward(self, x: torch.Tensor):
282 | identity = x
283 |
284 | out = self.relu(self.bn1(self.conv1(x)))
285 | out = self.relu(self.bn2(self.conv2(out)))
286 | out = self.avgpool(out)
287 | out = self.bn3(self.conv3(out))
288 |
289 | if self.downsample is not None:
290 | identity = self.downsample(x)
291 |
292 | out += identity
293 | out = self.relu(out)
294 | return out
295 |
296 |
297 | class AttentionPool2d(nn.Module):
298 | def __init__(
299 | self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
300 | ):
301 | super().__init__()
302 | self.positional_embedding = nn.Parameter(
303 | torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5
304 | )
305 | self.k_proj = nn.Linear(embed_dim, embed_dim)
306 | self.q_proj = nn.Linear(embed_dim, embed_dim)
307 | self.v_proj = nn.Linear(embed_dim, embed_dim)
308 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
309 | self.num_heads = num_heads
310 |
311 | def forward(self, x):
312 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
313 | 2, 0, 1
314 | ) # NCHW -> (HW)NC
315 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
316 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
317 | x, _ = F.multi_head_attention_forward(
318 | query=x,
319 | key=x,
320 | value=x,
321 | embed_dim_to_check=x.shape[-1],
322 | num_heads=self.num_heads,
323 | q_proj_weight=self.q_proj.weight,
324 | k_proj_weight=self.k_proj.weight,
325 | v_proj_weight=self.v_proj.weight,
326 | in_proj_weight=None,
327 | in_proj_bias=torch.cat(
328 | [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
329 | ),
330 | bias_k=None,
331 | bias_v=None,
332 | add_zero_attn=False,
333 | dropout_p=0,
334 | out_proj_weight=self.c_proj.weight,
335 | out_proj_bias=self.c_proj.bias,
336 | use_separate_proj_weight=True,
337 | training=self.training,
338 | need_weights=False,
339 | )
340 |
341 | return x[0]
342 |
343 |
344 | class ModifiedResNet(nn.Module):
345 | """
346 | A ResNet class that is similar to torchvision's but contains the following changes:
347 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
348 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
349 | - The final pooling layer is a QKV attention instead of an average pool
350 | """
351 |
352 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
353 | super().__init__()
354 | self.output_dim = output_dim
355 | self.input_resolution = input_resolution
356 |
357 | # the 3-layer stem
358 | self.conv1 = nn.Conv2d(
359 | 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
360 | )
361 | self.bn1 = nn.BatchNorm2d(width // 2)
362 | self.conv2 = nn.Conv2d(
363 | width // 2, width // 2, kernel_size=3, padding=1, bias=False
364 | )
365 | self.bn2 = nn.BatchNorm2d(width // 2)
366 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
367 | self.bn3 = nn.BatchNorm2d(width)
368 | self.avgpool = nn.AvgPool2d(2)
369 | self.relu = nn.ReLU(inplace=True)
370 |
371 | # residual layers
372 | self._inplanes = width # this is a *mutable* variable used during construction
373 | self.layer1 = self._make_layer(width, layers[0])
374 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
375 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
376 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
377 |
378 | embed_dim = width * 32 # the ResNet feature dimension
379 | self.attnpool = AttentionPool2d(
380 | input_resolution // 32, embed_dim, heads, output_dim
381 | )
382 |
383 | def _make_layer(self, planes, blocks, stride=1):
384 | layers = [Bottleneck(self._inplanes, planes, stride)]
385 |
386 | self._inplanes = planes * Bottleneck.expansion
387 | for _ in range(1, blocks):
388 | layers.append(Bottleneck(self._inplanes, planes))
389 |
390 | return nn.Sequential(*layers)
391 |
392 | def forward(self, x):
393 | def stem(x):
394 | for conv, bn in [
395 | (self.conv1, self.bn1),
396 | (self.conv2, self.bn2),
397 | (self.conv3, self.bn3),
398 | ]:
399 | x = self.relu(bn(conv(x)))
400 | x = self.avgpool(x)
401 | return x
402 |
403 | x = x.type(self.conv1.weight.dtype)
404 | x = stem(x)
405 | x = self.layer1(x)
406 | x = self.layer2(x)
407 | x = self.layer3(x)
408 | x = self.layer4(x)
409 | x = self.attnpool(x)
410 |
411 | return x
412 |
413 |
414 | class LayerNorm(nn.LayerNorm):
415 | """Subclass torch's LayerNorm to handle fp16."""
416 |
417 | def forward(self, x: torch.Tensor):
418 | orig_type = x.dtype
419 | ret = super().forward(x.type(torch.float32))
420 | return ret.type(orig_type)
421 |
422 |
423 | class QuickGELU(nn.Module):
424 | def forward(self, x: torch.Tensor):
425 | return x * torch.sigmoid(1.702 * x)
426 |
427 |
428 | class ResidualAttentionBlock(nn.Module):
429 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
430 | super().__init__()
431 |
432 | self.attn = nn.MultiheadAttention(d_model, n_head)
433 | self.ln_1 = LayerNorm(d_model)
434 | self.mlp = nn.Sequential(
435 | OrderedDict(
436 | [
437 | ("c_fc", nn.Linear(d_model, d_model * 4)),
438 | ("gelu", QuickGELU()),
439 | ("c_proj", nn.Linear(d_model * 4, d_model)),
440 | ]
441 | )
442 | )
443 | self.ln_2 = LayerNorm(d_model)
444 | self.attn_mask = attn_mask
445 |
446 | def attention(self, x: torch.Tensor):
447 | self.attn_mask = (
448 | self.attn_mask.to(dtype=x.dtype, device=x.device)
449 | if self.attn_mask is not None
450 | else None
451 | )
452 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
453 |
454 | def forward(self, x: torch.Tensor):
455 | x = x + self.attention(self.ln_1(x))
456 | x = x + self.mlp(self.ln_2(x))
457 | return x
458 |
459 |
460 | class Transformer(nn.Module):
461 | def __init__(
462 | self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None
463 | ):
464 | super().__init__()
465 | self.width = width
466 | self.layers = layers
467 | self.resblocks = nn.Sequential(
468 | *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
469 | )
470 |
471 | def forward(self, x: torch.Tensor):
472 | return self.resblocks(x)
473 |
474 |
475 | class VisualTransformer(nn.Module):
476 | def __init__(
477 | self,
478 | input_resolution: int,
479 | patch_size: int,
480 | width: int,
481 | layers: int,
482 | heads: int,
483 | output_dim: int,
484 | ):
485 | super().__init__()
486 | self.input_resolution = input_resolution
487 | self.output_dim = output_dim
488 | self.conv1 = nn.Conv2d(
489 | in_channels=3,
490 | out_channels=width,
491 | kernel_size=patch_size,
492 | stride=patch_size,
493 | bias=False,
494 | )
495 |
496 | scale = width ** -0.5
497 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
498 | self.positional_embedding = nn.Parameter(
499 | scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
500 | )
501 | self.ln_pre = LayerNorm(width)
502 |
503 | self.transformer = Transformer(width, layers, heads)
504 |
505 | self.ln_post = LayerNorm(width)
506 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
507 |
508 | def forward(self, x: torch.Tensor):
509 | x = self.conv1(x) # shape = [*, width, grid, grid]
510 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
511 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
512 | x = torch.cat(
513 | [
514 | self.class_embedding.to(x.dtype)
515 | + torch.zeros(
516 | x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
517 | ),
518 | x,
519 | ],
520 | dim=1,
521 | ) # shape = [*, grid ** 2 + 1, width]
522 | x = x + self.positional_embedding.to(x.dtype)
523 | x = self.ln_pre(x)
524 |
525 | x = x.permute(1, 0, 2) # NLD -> LND
526 | x = self.transformer(x)
527 | x = x.permute(1, 0, 2) # LND -> NLD
528 |
529 | x = self.ln_post(x[:, 0, :])
530 |
531 | if self.proj is not None:
532 | x = x @ self.proj
533 |
534 | return x
535 |
536 |
537 | class CLIP(nn.Module):
538 | def __init__(
539 | self,
540 | embed_dim: int,
541 | # vision
542 | image_resolution: int,
543 | vision_layers: Union[Tuple[int, int, int, int], int],
544 | vision_width: int,
545 | vision_patch_size: int,
546 | # text
547 | context_length: int,
548 | vocab_size: int,
549 | transformer_width: int,
550 | transformer_heads: int,
551 | transformer_layers: int,
552 | ):
553 | super().__init__()
554 |
555 | self.context_length = context_length
556 |
557 | if isinstance(vision_layers, (tuple, list)):
558 | vision_heads = vision_width * 32 // 64
559 | self.visual = ModifiedResNet(
560 | layers=vision_layers,
561 | output_dim=embed_dim,
562 | heads=vision_heads,
563 | input_resolution=image_resolution,
564 | width=vision_width,
565 | )
566 | else:
567 | vision_heads = vision_width // 64
568 | self.visual = VisualTransformer(
569 | input_resolution=image_resolution,
570 | patch_size=vision_patch_size,
571 | width=vision_width,
572 | layers=vision_layers,
573 | heads=vision_heads,
574 | output_dim=embed_dim,
575 | )
576 |
577 | self.transformer = Transformer(
578 | width=transformer_width,
579 | layers=transformer_layers,
580 | heads=transformer_heads,
581 | attn_mask=self.build_attention_mask(),
582 | )
583 |
584 | self.vocab_size = vocab_size
585 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
586 | self.positional_embedding = nn.Parameter(
587 | torch.empty(self.context_length, transformer_width)
588 | )
589 | self.ln_final = LayerNorm(transformer_width)
590 |
591 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
592 | self.logit_scale = nn.Parameter(torch.ones([]))
593 |
594 | self.initialize_parameters()
595 |
596 | def initialize_parameters(self):
597 | nn.init.normal_(self.token_embedding.weight, std=0.02)
598 | nn.init.normal_(self.positional_embedding, std=0.01)
599 |
600 | if isinstance(self.visual, ModifiedResNet):
601 | if self.visual.attnpool is not None:
602 | std = self.visual.attnpool.c_proj.in_features ** -0.5
603 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
604 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
605 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
606 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
607 |
608 | for resnet_block in [
609 | self.visual.layer1,
610 | self.visual.layer2,
611 | self.visual.layer3,
612 | self.visual.layer4,
613 | ]:
614 | for name, param in resnet_block.named_parameters():
615 | if name.endswith("bn3.weight"):
616 | nn.init.zeros_(param)
617 |
618 | proj_std = (self.transformer.width ** -0.5) * (
619 | (2 * self.transformer.layers) ** -0.5
620 | )
621 | attn_std = self.transformer.width ** -0.5
622 | fc_std = (2 * self.transformer.width) ** -0.5
623 | for block in self.transformer.resblocks:
624 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
625 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
626 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
627 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
628 |
629 | if self.text_projection is not None:
630 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
631 |
632 | def build_attention_mask(self):
633 | # lazily create causal attention mask, with full attention between the vision tokens
634 | # pytorch uses additive attention mask; fill with -inf
635 | mask = torch.empty(self.context_length, self.context_length)
636 | mask.fill_(float("-inf"))
637 | mask.triu_(1) # zero out the lower diagonal
638 | return mask
639 |
640 | @property
641 | def dtype(self):
642 | return self.visual.conv1.weight.dtype
643 |
644 | def encode_image(self, image):
645 | return self.visual(image.type(self.dtype))
646 |
647 | def encode_text(self, text):
648 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
649 |
650 | x = x + self.positional_embedding.type(self.dtype)
651 | x = x.permute(1, 0, 2) # NLD -> LND
652 | x = self.transformer(x)
653 | x = x.permute(1, 0, 2) # LND -> NLD
654 | x = self.ln_final(x).type(self.dtype)
655 |
656 | # x.shape = [batch_size, n_ctx, transformer.width]
657 | # take features from the eot embedding (eot_token is the highest number in each sequence)
658 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
659 |
660 | return x
661 |
662 | def forward(self, image, text):
663 | image_features = self.encode_image(image)
664 | text_features = self.encode_text(text)
665 |
666 | # normalized features
667 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
668 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
669 |
670 | # cosine similarity as logits
671 | logit_scale = self.logit_scale.exp()
672 | logits_per_image = logit_scale * image_features @ text_features.t()
673 | logits_per_text = logit_scale * text_features @ image_features.t()
674 |
675 | # shape = [global_batch_size, global_batch_size]
676 | return logits_per_image, logits_per_text
677 |
678 |
679 | def convert_weights(model: nn.Module):
680 | """Convert applicable model parameters to fp16"""
681 |
682 | def _convert_weights_to_fp16(l):
683 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
684 | l.weight.data = l.weight.data.half()
685 | if l.bias is not None:
686 | l.bias.data = l.bias.data.half()
687 |
688 | if isinstance(l, nn.MultiheadAttention):
689 | for attr in [
690 | *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
691 | "in_proj_bias",
692 | "bias_k",
693 | "bias_v",
694 | ]:
695 | tensor = getattr(l, attr)
696 | if tensor is not None:
697 | tensor.data = tensor.data.half()
698 |
699 | for name in ["text_projection", "proj"]:
700 | if hasattr(l, name):
701 | attr = getattr(l, name)
702 | if attr is not None:
703 | attr.data = attr.data.half()
704 |
705 | model.apply(_convert_weights_to_fp16)
706 |
707 |
708 | def build_model(state_dict: dict):
709 | vit = "visual.proj" in state_dict
710 |
711 | if vit:
712 | vision_width = state_dict["visual.conv1.weight"].shape[0]
713 | vision_layers = len(
714 | [
715 | k
716 | for k in state_dict.keys()
717 | if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
718 | ]
719 | )
720 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
721 | grid_size = round(
722 | (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
723 | )
724 | image_resolution = vision_patch_size * grid_size
725 | else:
726 | counts: list = [
727 | len(
728 | set(
729 | k.split(".")[2]
730 | for k in state_dict
731 | if k.startswith(f"visual.layer{b}")
732 | )
733 | )
734 | for b in [1, 2, 3, 4]
735 | ]
736 | vision_layers = tuple(counts)
737 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
738 | output_width = round(
739 | (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
740 | )
741 | vision_patch_size = None
742 | assert (
743 | output_width ** 2 + 1
744 | == state_dict["visual.attnpool.positional_embedding"].shape[0]
745 | )
746 | image_resolution = output_width * 32
747 |
748 | embed_dim = state_dict["text_projection"].shape[1]
749 | context_length = state_dict["positional_embedding"].shape[0]
750 | vocab_size = state_dict["token_embedding.weight"].shape[0]
751 | transformer_width = state_dict["ln_final.weight"].shape[0]
752 | transformer_heads = transformer_width // 64
753 | transformer_layers = len(
754 | set(
755 | k.split(".")[2]
756 | for k in state_dict
757 | if k.startswith(f"transformer.resblocks")
758 | )
759 | )
760 |
761 | model = CLIP(
762 | embed_dim,
763 | image_resolution,
764 | vision_layers,
765 | vision_width,
766 | vision_patch_size,
767 | context_length,
768 | vocab_size,
769 | transformer_width,
770 | transformer_heads,
771 | transformer_layers,
772 | )
773 |
774 | for key in ["input_resolution", "context_length", "vocab_size"]:
775 | if key in state_dict:
776 | del state_dict[key]
777 |
778 | convert_weights(model)
779 | model.load_state_dict(state_dict)
780 | return model.eval()
781 |
782 |
783 | import gzip
784 | import html
785 | import os
786 | from functools import lru_cache
787 |
788 | import ftfy
789 | import regex as re
790 |
791 |
792 | @lru_cache()
793 | def default_bpe():
794 | return os.path.join(
795 | os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt"
796 | )
797 |
798 |
799 | @lru_cache()
800 | def bytes_to_unicode():
801 | """
802 | Returns list of utf-8 byte and a corresponding list of unicode strings.
803 | The reversible bpe codes work on unicode strings.
804 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
805 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
806 | This is a signficant percentage of your normal, say, 32K bpe vocab.
807 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
808 | And avoids mapping to whitespace/control characters the bpe code barfs on.
809 | """
810 | bs = (
811 | list(range(ord("!"), ord("~") + 1))
812 | + list(range(ord("¡"), ord("¬") + 1))
813 | + list(range(ord("®"), ord("ÿ") + 1))
814 | )
815 | cs = bs[:]
816 | n = 0
817 | for b in range(2 ** 8):
818 | if b not in bs:
819 | bs.append(b)
820 | cs.append(2 ** 8 + n)
821 | n += 1
822 | cs = [chr(n) for n in cs]
823 | return dict(zip(bs, cs))
824 |
825 |
826 | def get_pairs(word):
827 | """Return set of symbol pairs in a word.
828 | Word is represented as tuple of symbols (symbols being variable-length strings).
829 | """
830 | pairs = set()
831 | prev_char = word[0]
832 | for char in word[1:]:
833 | pairs.add((prev_char, char))
834 | prev_char = char
835 | return pairs
836 |
837 |
838 | def basic_clean(text):
839 | text = ftfy.fix_text(text)
840 | text = html.unescape(html.unescape(text))
841 | return text.strip()
842 |
843 |
844 | def whitespace_clean(text):
845 | text = re.sub(r"\s+", " ", text)
846 | text = text.strip()
847 | return text
848 |
849 |
850 | class SimpleTokenizer(object):
851 | def __init__(self, bpe_path: str = default_bpe()):
852 | self.byte_encoder = bytes_to_unicode()
853 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
854 | merges = Path(bpe_path).read_text(encoding="utf8").split("\n")
855 | merges = merges[1 : 49152 - 256 - 2 + 1]
856 | merges = [tuple(merge.split()) for merge in merges]
857 | vocab = list(bytes_to_unicode().values())
858 | vocab = vocab + [v + "" for v in vocab]
859 | for merge in merges:
860 | vocab.append("".join(merge))
861 | vocab.extend(["<|startoftext|>", "<|endoftext|>"])
862 | self.encoder = dict(zip(vocab, range(len(vocab))))
863 | self.decoder = {v: k for k, v in self.encoder.items()}
864 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
865 | self.cache = {
866 | "<|startoftext|>": "<|startoftext|>",
867 | "<|endoftext|>": "<|endoftext|>",
868 | }
869 | self.pat = re.compile(
870 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
871 | re.IGNORECASE,
872 | )
873 |
874 | def bpe(self, token):
875 | if token in self.cache:
876 | return self.cache[token]
877 | word = tuple(token[:-1]) + (token[-1] + "",)
878 | pairs = get_pairs(word)
879 |
880 | if not pairs:
881 | return token + ""
882 |
883 | while True:
884 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
885 | if bigram not in self.bpe_ranks:
886 | break
887 | first, second = bigram
888 | new_word = []
889 | i = 0
890 | while i < len(word):
891 | try:
892 | j = word.index(first, i)
893 | new_word.extend(word[i:j])
894 | i = j
895 | except:
896 | new_word.extend(word[i:])
897 | break
898 |
899 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
900 | new_word.append(first + second)
901 | i += 2
902 | else:
903 | new_word.append(word[i])
904 | i += 1
905 | new_word = tuple(new_word)
906 | word = new_word
907 | if len(word) == 1:
908 | break
909 | else:
910 | pairs = get_pairs(word)
911 | word = " ".join(word)
912 | self.cache[token] = word
913 | return word
914 |
915 | def encode(self, text):
916 | bpe_tokens = []
917 | text = whitespace_clean(basic_clean(text)).lower()
918 | for token in re.findall(self.pat, text):
919 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
920 | bpe_tokens.extend(
921 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
922 | )
923 | return bpe_tokens
924 |
925 | def decode(self, tokens):
926 | text = "".join([self.decoder[token] for token in tokens])
927 | text = (
928 | bytearray([self.byte_decoder[c] for c in text])
929 | .decode("utf-8", errors="replace")
930 | .replace("", " ")
931 | )
932 | return text
933 | import gzip
934 |
935 |
936 | _tokenizer = SimpleTokenizer()
937 |
--------------------------------------------------------------------------------
/structure/ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # From the paper "CLOOB: Modern Hopfield Networks with InfoLOOB Outperform CLIP"
4 | # Arxiv link: https://arxiv.org/abs/2110.11316
5 | # Code: https://github.com/ml-jku/cloob/
6 | def infoLOOB_loss(x, y, i, inv_tau):
7 | tau = 1 / inv_tau
8 | k = x @ y.T / tau
9 | positives = -torch.mean(torch.sum(k * i, dim=1))
10 |
11 | # For logsumexp the zero entries must be equal to a very large negative number
12 | large_neg = -1000.0
13 | arg_lse = k * torch.logical_not(i) + i * large_neg
14 | negatives = torch.mean(torch.logsumexp(arg_lse, dim=1))
15 |
16 | return tau * (positives + negatives)
--------------------------------------------------------------------------------
/structure/optim.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import Tensor
4 | from torch.optim import Optimizer
5 | from typing import List, Optional
6 |
7 |
8 | class SimpleSGD(Optimizer):
9 | def __init__(self, params, lr=0.1):
10 | if not 0.0 <= lr:
11 | raise ValueError("Invalid learning rate: {}".format(lr))
12 | defaults = dict(lr=lr)
13 | super().__init__(params, defaults)
14 |
15 | def step(self, closure=None):
16 | loss = None
17 | if closure is not None:
18 | loss = closure()
19 | for group in self.param_groups:
20 | for p in group["params"]:
21 | if p.grad is None:
22 | continue
23 | lr = group["lr"]
24 | grad = p.grad.data
25 | p.data.add_(grad, alpha=-lr)
26 | return loss
27 |
28 |
29 | class ClampSGD(Optimizer):
30 | r"""Clamp SGD
31 | Args:
32 | params (iterable): iterable of parameters to optimize or dicts defining
33 | parameter groups
34 | lr (float, optional): learning rate (default: 1e-1)
35 | clamp (float, optional): clamping of gradient (default: 1e-30)
36 | drop (float, optional): dropout of gradient (default: 0)
37 | """
38 |
39 | def __init__(self, params, lr=1e-1, clamp=1e-30, drop=0.0):
40 | if not 0.0 <= lr:
41 | raise ValueError("Invalid learning rate: {}".format(lr))
42 | if not 0.0 <= drop < 1.0:
43 | raise ValueError("Invalid dropout value: {}".format(drop))
44 | defaults = dict(lr=lr, clamp=clamp, drop=drop)
45 | super().__init__(params, defaults)
46 |
47 | @torch.no_grad()
48 | def step(self, closure=None):
49 | loss = None
50 | if closure is not None:
51 | loss = closure()
52 | for group in self.param_groups:
53 | for p in group["params"]:
54 | if p.grad is None:
55 | continue
56 | lr = group["lr"]
57 | clamp = group["clamp"]
58 | drop = group["drop"]
59 | grad = torch.nn.functional.dropout(
60 | p.grad.data.clamp_(-clamp, clamp).div_(clamp), p=drop
61 | )
62 | p.data.add_(grad, alpha=-lr)
63 | return loss
64 |
65 |
66 | class SignSGD(Optimizer):
67 | r"""Sign SGD
68 | Args:
69 | params (iterable): iterable of parameters to optimize or dicts defining
70 | parameter groups
71 | lr (float, optional): learning rate (default: 1e-1)
72 | drop (float, optional): dropout of gradient (default: 0)
73 | """
74 |
75 | def __init__(self, params, lr=1e-1, drop=0.0):
76 | if not 0.0 <= lr:
77 | raise ValueError("Invalid learning rate: {}".format(lr))
78 | if not 0.0 <= drop < 1.0:
79 | raise ValueError("Invalid dropout value: {}".format(drop))
80 | defaults = dict(lr=lr, drop=drop)
81 | super().__init__(params, defaults)
82 |
83 | @torch.no_grad()
84 | def step(self, closure=None):
85 | loss = None
86 | if closure is not None:
87 | loss = closure()
88 | for group in self.param_groups:
89 | for p in group["params"]:
90 | if p.grad is None:
91 | continue
92 | lr = group["lr"]
93 | clamp = group["clamp"]
94 | drop = group["drop"]
95 | grad = torch.nn.functional.dropout(
96 | p.grad.data.sign_(), p=drop
97 | )
98 | p.data.add_(grad, alpha=-lr)
99 | return loss
100 |
101 |
102 | class AdamW(Optimizer):
103 | r"""Implements AdamW algorithm.
104 |
105 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
106 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
107 |
108 | Args:
109 | params (iterable): iterable of parameters to optimize or dicts defining
110 | parameter groups
111 | lr (float, optional): learning rate (default: 1e-3)
112 | betas (Tuple[float, float], optional): coefficients used for computing
113 | running averages of gradient and its square (default: (0.9, 0.999))
114 | eps (float, optional): term added to the denominator to improve
115 | numerical stability (default: 1e-8)
116 | weight_decay (float, optional): weight decay coefficient (default: 1e-2)
117 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
118 | algorithm from the paper `On the Convergence of Adam and Beyond`_
119 | (default: False)
120 |
121 | .. _Adam\: A Method for Stochastic Optimization:
122 | https://arxiv.org/abs/1412.6980
123 | .. _Decoupled Weight Decay Regularization:
124 | https://arxiv.org/abs/1711.05101
125 | .. _On the Convergence of Adam and Beyond:
126 | https://openreview.net/forum?id=ryQu7f-RZ
127 | """
128 |
129 | def __init__(
130 | self,
131 | params,
132 | lr=1e-3,
133 | betas=(0.9, 0.999),
134 | eps=1e-8,
135 | weight_decay=1e-2,
136 | amsgrad=False,
137 | clamp=1e-30,
138 | drop=0.0,
139 | ):
140 | if not 0.0 <= lr:
141 | raise ValueError("Invalid learning rate: {}".format(lr))
142 | if not 0.0 <= eps:
143 | raise ValueError("Invalid epsilon value: {}".format(eps))
144 | if not 0.0 <= betas[0] < 1.0:
145 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
146 | if not 0.0 <= betas[1] < 1.0:
147 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
148 | if not 0.0 <= weight_decay:
149 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
150 | if not 0.0 <= drop < 1.0:
151 | raise ValueError("Invalid dropout value: {}".format(drop))
152 | defaults = dict(
153 | lr=lr,
154 | betas=betas,
155 | eps=eps,
156 | weight_decay=weight_decay,
157 | amsgrad=amsgrad,
158 | clamp=clamp,
159 | drop=drop,
160 | )
161 | super(AdamW, self).__init__(params, defaults)
162 |
163 | def __setstate__(self, state):
164 | super(AdamW, self).__setstate__(state)
165 | for group in self.param_groups:
166 | group.setdefault("amsgrad", False)
167 |
168 | @torch.no_grad()
169 | def step(self, closure=None):
170 | """Performs a single optimization step.
171 |
172 | Args:
173 | closure (callable, optional): A closure that reevaluates the model
174 | and returns the loss.
175 | """
176 | loss = None
177 | if closure is not None:
178 | with torch.enable_grad():
179 | loss = closure()
180 |
181 | for group in self.param_groups:
182 | params_with_grad = []
183 | grads = []
184 | exp_avgs = []
185 | exp_avg_sqs = []
186 | state_sums = []
187 | max_exp_avg_sqs = []
188 | state_steps = []
189 | amsgrad = group["amsgrad"]
190 | beta1, beta2 = group["betas"]
191 | clamp = group["clamp"]
192 | drop = group["drop"]
193 |
194 | for p in group["params"]:
195 | if p.grad is None:
196 | continue
197 | params_with_grad.append(p)
198 | if p.grad.is_sparse:
199 | raise RuntimeError("AdamW does not support sparse gradients")
200 | grads.append(p.grad)
201 |
202 | state = self.state[p]
203 |
204 | # State initialization
205 | if len(state) == 0:
206 | state["step"] = 0
207 | # Exponential moving average of gradient values
208 | state["exp_avg"] = torch.zeros_like(
209 | p, memory_format=torch.preserve_format
210 | )
211 | # Exponential moving average of squared gradient values
212 | state["exp_avg_sq"] = torch.zeros_like(
213 | p, memory_format=torch.preserve_format
214 | )
215 | if amsgrad:
216 | # Maintains max of all exp. moving avg. of sq. grad. values
217 | state["max_exp_avg_sq"] = torch.zeros_like(
218 | p, memory_format=torch.preserve_format
219 | )
220 |
221 | exp_avgs.append(state["exp_avg"])
222 | exp_avg_sqs.append(state["exp_avg_sq"])
223 |
224 | if amsgrad:
225 | max_exp_avg_sqs.append(state["max_exp_avg_sq"])
226 |
227 | # update the steps for each param group update
228 | state["step"] += 1
229 | # record the step after step update
230 | state_steps.append(state["step"])
231 |
232 | adamw(
233 | params_with_grad,
234 | grads,
235 | exp_avgs,
236 | exp_avg_sqs,
237 | max_exp_avg_sqs,
238 | state_steps,
239 | amsgrad=amsgrad,
240 | beta1=beta1,
241 | beta2=beta2,
242 | lr=group["lr"],
243 | weight_decay=group["weight_decay"],
244 | eps=group["eps"],
245 | clamp=clamp,
246 | drop=drop,
247 | )
248 |
249 | return loss
250 |
251 |
252 | def adamw(
253 | params: List[Tensor],
254 | grads: List[Tensor],
255 | exp_avgs: List[Tensor],
256 | exp_avg_sqs: List[Tensor],
257 | max_exp_avg_sqs: List[Tensor],
258 | state_steps: List[int],
259 | *,
260 | amsgrad: bool,
261 | beta1: float,
262 | beta2: float,
263 | lr: float,
264 | weight_decay: float,
265 | eps: float,
266 | clamp: float,
267 | drop: float
268 | ):
269 | r"""Functional API that performs AdamW algorithm computation.
270 | See :class:`~torch.optim.AdamW` for details.
271 | """
272 | for i, param in enumerate(params):
273 | grad = torch.nn.functional.dropout(
274 | grads[i].clamp_(-clamp, clamp).div_(clamp), p=drop
275 | )
276 | exp_avg = exp_avgs[i]
277 | exp_avg_sq = exp_avg_sqs[i]
278 | step = state_steps[i]
279 |
280 | # Perform stepweight decay
281 | param.mul_(1 - lr * weight_decay)
282 |
283 | bias_correction1 = 1 - beta1 ** step
284 | bias_correction2 = 1 - beta2 ** step
285 |
286 | # Decay the first and second moment running average coefficient
287 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
288 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
289 | if amsgrad:
290 | # Maintains the maximum of all 2nd moment running avg. till now
291 | torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
292 | # Use the max. for normalizing running avg. of gradient
293 | denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
294 | else:
295 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
296 |
297 | step_size = lr / bias_correction1
298 |
299 | param.addcdiv_(exp_avg, denom, value=-step_size)
300 |
--------------------------------------------------------------------------------
/structure/sample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 | from . import transform
5 |
6 |
7 | @torch.jit.script
8 | def random_generate_grid(
9 | l: int = 2, mode: str = "all", device: str = "cuda"
10 | ) -> torch.Tensor:
11 | """
12 | Generates a grid of coordinates in the range [-1, 1] to be used with pytorch
13 | grid_sample.
14 |
15 | Parameters
16 | ----------
17 | l : int
18 | The number of coordinates along each dimension.
19 | mode : str
20 | 'all', 'even' or 'repeat'. If 'all', generates a grid of size l*l.
21 | If 'even', generates a grid of size l*l.
22 | If 'repeat', generates a grid of size 1*l.
23 | device : str
24 | 'cuda' or 'cpu'.
25 |
26 | Returns
27 | -------
28 | xy : torch.Tensor
29 | A tensor of shape (l*l, 2) if mode == 'all', or (l*l, 2) if mode == 'repeat'.
30 | """
31 | if mode != "all" and mode != "even" and mode != "repeat":
32 | raise ValueError(
33 | "random_generate_grid(): expected mode to be "
34 | "'all' or 'repeat', but got: '{}'".format(mode)
35 | )
36 |
37 | if mode == "all":
38 | x, _ = torch.rand(l, l, device=device).sort()
39 | y, _ = torch.rand(l, l, device=device).sort()
40 | x = (x - x.min(dim=1, keepdim=True).values) / (
41 | x.max(dim=1, keepdim=True).values - x.min(dim=1, keepdim=True).values
42 | )
43 | y = (y - y.min(dim=1, keepdim=True).values) / (
44 | y.max(dim=1, keepdim=True).values - y.min(dim=1, keepdim=True).values
45 | )
46 | x = x.view(-1)
47 | y = y.t().reshape(-1)
48 | xy = torch.stack([x, y], dim=1).view(1, l, l, 2) * 2.0 - 1.0
49 | elif mode == "even":
50 | coords = torch.arange(l, device=device)
51 | coords = (coords.float() / (l - 1.0) - 0.5) * 2.0
52 | offset = torch.rand(1, device=device) * 0.6
53 | x_offset = offset * (1.0 + 0.05 * (torch.rand(1, device=device) * 2.0 - 1.0))
54 | y_offset = offset * (1.0 + 0.05 * (torch.rand(1, device=device) * 2.0 - 1.0))
55 | x = (
56 | coords * (1.0 - x_offset)
57 | + (torch.rand(1, device=device) * 2.0 - 1.0) * x_offset
58 | )
59 | y = (
60 | coords * (1.0 - y_offset)
61 | + (torch.rand(1, device=device) * 2.0 - 1.0) * y_offset
62 | )
63 | grid_y, grid_x = torch.meshgrid(y, x)
64 | xy = torch.stack([grid_x, grid_y], dim=2).view(1, l, l, 2)
65 | else: # mode == 'repeat'
66 | x, _ = torch.rand(l, device=device).sort()
67 | y, _ = torch.rand(l, device=device).sort()
68 | x = (x - x.min()) / (x.max() - x.min())
69 | y = (y - y.min()) / (y.max() - y.min())
70 | grid_y, grid_x = torch.meshgrid(y, x)
71 | xy = torch.stack([grid_x, grid_y], dim=2).view(1, l, l, 2) * 2.0 - 1.0
72 | return xy
73 |
74 |
75 | class ImgSampleBase(torch.nn.Module):
76 | def __init__(
77 | self,
78 | kernel_min: int = 1,
79 | kernel_max: int = 8,
80 | grid_size_min: int = 224,
81 | grid_size_max: int = 448,
82 | noise: float = 1.0,
83 | noise_std: float = 0.3,
84 | cutout: float = 1.0,
85 | cutout_size: float = 0.25,
86 | distortion_scale: float = 0.5,
87 | perspective: float = 1.0,
88 | downsamples: int = 1,
89 | ):
90 | super().__init__()
91 | self.kernel_min = kernel_min
92 | self.kernel_max = kernel_max
93 | self.grid_size_min = grid_size_min
94 | self.grid_size_max = grid_size_max
95 | self.noise = noise
96 | self.noise_std = noise_std
97 | self.cutout = cutout
98 | self.cutout_size = cutout_size
99 | self.downsamples = downsamples
100 | self.perspective_transformer = torchvision.transforms.RandomPerspective(
101 | distortion_scale=distortion_scale,
102 | p=perspective,
103 | interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
104 | )
105 |
106 | def forward(
107 | self, input: torch.Tensor, size: int = 224, bs: int = 1
108 | ) -> torch.Tensor:
109 | imgs = []
110 |
111 | for _ in range(bs):
112 | # Generate grid for sampling
113 | grid_mode_idx = int(torch.randint(1, 3, ()).item())
114 | grid_mode = ("all", "even", "repeat")
115 | grid_size = int(
116 | self.grid_size_min
117 | + torch.rand(()).item() * (self.grid_size_max - self.grid_size_min)
118 | )
119 | grid = random_generate_grid(grid_size, mode=grid_mode[grid_mode_idx])
120 |
121 | # Sample original input
122 | img = torch.nn.functional.grid_sample(
123 | input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
124 | )
125 | img = torch.nn.functional.interpolate(img, (size, size), mode="area")
126 | imgs.append(img)
127 | imgs.append(self.perspective_transformer(img))
128 | img = torch.flip(img, [3])
129 | imgs.append(img)
130 | imgs.append(self.perspective_transformer(img))
131 |
132 | for i in range(self.downsamples):
133 | # Transform, downsize and sample original input
134 | img = transform.noise(input, noise=self.noise, noise_std=self.noise_std)
135 | img = transform.cutout(
136 | img, cutout=self.cutout, cutout_size=self.cutout_size
137 | )
138 | # Draw kernel size from uniform x uniform distribution
139 | kernel_size = int(
140 | torch.randint(
141 | self.kernel_min,
142 | int(torch.randint(self.kernel_min + 1, self.kernel_max, ())),
143 | (),
144 | ).item()
145 | )
146 | img = torch.nn.functional.avg_pool2d(
147 | img, kernel_size=kernel_size, stride=kernel_size, padding=0
148 | )
149 | img = torch.nn.functional.grid_sample(
150 | img, grid, mode="bilinear", padding_mode="zeros", align_corners=False
151 | )
152 | img = torch.nn.functional.interpolate(img, (size, size), mode="area")
153 | imgs.append(img)
154 | imgs.append(self.perspective_transformer(img))
155 | img = torch.flip(img, [3])
156 | imgs.append(img)
157 | imgs.append(self.perspective_transformer(img))
158 |
159 | return torch.cat(imgs, dim=0)
160 |
161 |
162 | class ImgSampleStylegan(torch.nn.Module):
163 | def __init__(
164 | self,
165 | kernel_min: int = 1,
166 | kernel_max: int = 8,
167 | grid_size_min: int = 224,
168 | grid_size_max: int = 448,
169 | noise: float = 1.0,
170 | noise_std: float = 0.3,
171 | cutout: float = 1.0,
172 | cutout_size: float = 0.25,
173 | ):
174 | super().__init__()
175 | self.kernel_min = kernel_min
176 | self.kernel_max = kernel_max
177 | self.grid_size_min = grid_size_min
178 | self.grid_size_max = grid_size_max
179 | self.noise = noise
180 | self.noise_std = noise_std
181 | self.cutout = cutout
182 | self.cutout_size = cutout_size
183 |
184 | def forward(
185 | self, input: torch.Tensor, size: int = 224, bs: int = 1
186 | ) -> torch.Tensor:
187 | imgs = []
188 |
189 | for _ in range(bs):
190 | # Draw kernel size from uniform x uniform distribution
191 | kernel_size = int(
192 | torch.randint(
193 | self.kernel_min,
194 | int(torch.randint(self.kernel_min + 1, self.kernel_max, ())),
195 | (),
196 | ).item()
197 | )
198 |
199 | # Generate grid for sampling
200 | grid_mode_idx = int(torch.randint(1, 3, ()).item())
201 | grid_mode = ("all", "even", "repeat")
202 | grid_size = int(
203 | self.grid_size_min
204 | + torch.rand(()).item() * (self.grid_size_max - self.grid_size_min)
205 | )
206 | grid = random_generate_grid(grid_size, mode=grid_mode[grid_mode_idx])
207 |
208 | # Sample original input
209 | img = torch.nn.functional.grid_sample(
210 | input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
211 | )
212 | img = torch.nn.functional.interpolate(img, (size, size), mode="area")
213 | imgs.append(img)
214 |
215 | # Transform, downsize and sample original input
216 | img = transform.noise(input, noise=self.noise, noise_std=self.noise_std)
217 | img = transform.cutout(
218 | img, cutout=self.cutout, cutout_size=self.cutout_size
219 | )
220 | img = torch.nn.functional.avg_pool2d(
221 | img, kernel_size=kernel_size, stride=kernel_size, padding=0
222 | )
223 | img = torch.nn.functional.grid_sample(
224 | img, grid, mode="bilinear", padding_mode="zeros", align_corners=False
225 | )
226 | img = torch.nn.functional.interpolate(img, (size, size), mode="area")
227 | imgs.append(img)
228 |
229 | return torch.cat(imgs, dim=0)
230 |
--------------------------------------------------------------------------------
/structure/stylegan_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ekgren/StructuredDreaming/d9040bda41fdc634cd4e653207b9938fa5c3f947/structure/stylegan_utils/__init__.py
--------------------------------------------------------------------------------
/structure/stylegan_utils/ops.py:
--------------------------------------------------------------------------------
1 | """
2 | Code slightly modified from https://github.com/NVlabs/ffhq-dataset
3 | Line: https://github.com/NVlabs/ffhq-dataset/blob/193deb624543e54d93b54a296fcfc9a68a3ed8ce/download_ffhq.py#L259
4 | """
5 | import PIL
6 | import scipy
7 | import numpy as np
8 |
9 |
10 | def align_image(img,
11 | lm,
12 | output_size=1024,
13 | transform_size=4096,
14 | enable_padding=False,
15 | rotate_level=True,
16 | random_shift=0.0,
17 | retry_crops=False):
18 | """
19 | Code slightly modified from https://github.com/NVlabs/ffhq-dataset
20 | Line: https://github.com/NVlabs/ffhq-dataset/blob/193deb624543e54d93b54a296fcfc9a68a3ed8ce/download_ffhq.py#L259
21 |
22 | Expects an image and a landmarks vector as input.
23 | Return aligned image.
24 | """
25 |
26 | lm = np.array(lm)
27 | lm_chin = lm[0 : 17] # left-right
28 | lm_eyebrow_left = lm[17 : 22] # left-right
29 | lm_eyebrow_right = lm[22 : 27] # left-right
30 | lm_nose = lm[27 : 31] # top-down
31 | lm_nostrils = lm[31 : 36] # top-down
32 | lm_eye_left = lm[36 : 42] # left-clockwise
33 | lm_eye_right = lm[42 : 48] # left-clockwise
34 | lm_mouth_outer = lm[48 : 60] # left-clockwise
35 | lm_mouth_inner = lm[60 : 68] # left-clockwise
36 |
37 | # Calculate auxiliary vectors.
38 | eye_left = np.mean(lm_eye_left, axis=0)
39 | eye_right = np.mean(lm_eye_right, axis=0)
40 | eye_avg = (eye_left + eye_right) * 0.5
41 | eye_to_eye = eye_right - eye_left
42 | mouth_left = lm_mouth_outer[0]
43 | mouth_right = lm_mouth_outer[6]
44 | mouth_avg = (mouth_left + mouth_right) * 0.5
45 | eye_to_mouth = mouth_avg - eye_avg
46 |
47 | # Choose oriented crop rectangle.
48 | if rotate_level:
49 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
50 | x /= np.hypot(*x)
51 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
52 | y = np.flipud(x) * [-1, 1]
53 | c0 = eye_avg + eye_to_mouth * 0.1
54 | else:
55 | x = np.array([1, 0], dtype=np.float64)
56 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
57 | y = np.flipud(x) * [-1, 1]
58 | c0 = eye_avg + eye_to_mouth * 0.1
59 |
60 | quad = np.stack([c0 - x - y, c0 - x + y, c0 + x + y, c0 + x - y])
61 | qsize = np.hypot(*x) * 2
62 |
63 | # Keep drawing new random crop offsets until we find one that is contained in the image
64 | # and does not require padding
65 | if random_shift != 0:
66 | for _ in range(1000):
67 | # Offset the crop rectange center by a random shift proportional to image dimension
68 | # and the requested standard deviation
69 | c = (c0 + np.hypot(*x)*2 * random_shift * np.random.normal(0, 1, c0.shape))
70 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
71 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
72 | if not retry_crops or not (crop[0] < 0 or crop[1] < 0 or crop[2] >= img.width or crop[3] >= img.height):
73 | # We're happy with this crop (either it fits within the image, or retries are disabled)
74 | break
75 | else:
76 | # rejected N times, give up and move to next image
77 | # (does not happen in practice with the FFHQ data)
78 | print('rejected image')
79 | #return
80 |
81 | # Shrink.
82 | shrink = int(np.floor(qsize / output_size * 0.5))
83 | if shrink > 1:
84 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
85 | img = img.resize(rsize, PIL.Image.ANTIALIAS)
86 | quad /= shrink
87 | qsize /= shrink
88 |
89 | # Crop.
90 | border = max(int(np.rint(qsize * 0.1)), 3)
91 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
92 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
93 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
94 | img = img.crop(crop)
95 | quad -= crop[0:2]
96 |
97 | # Pad.
98 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
99 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
100 | if enable_padding and max(pad) > border - 4:
101 | pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
102 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
103 | h, w, _ = img.shape
104 | y, x, _ = np.ogrid[:h, :w, :1]
105 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
106 | blur = qsize * 0.02
107 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
108 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
109 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
110 | quad += pad[:2]
111 |
112 | # Transform.
113 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
114 | if output_size < transform_size:
115 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
116 |
117 | return img
--------------------------------------------------------------------------------
/structure/transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | @torch.jit.script
5 | def noise(
6 | input: torch.Tensor,
7 | noise: float = 0.0,
8 | noise_std: float = 0.1,
9 | p: float = 1.0,
10 | device: str = "cuda",
11 | ) -> torch.Tensor:
12 | """Apply additive RGB noise with probability (noise * p)."""
13 | if noise > 0:
14 | batch_size, num_channels, height, width = input.shape
15 | sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * noise_std
16 | sigma = torch.where(
17 | torch.rand([batch_size, 1, 1, 1], device=device) < noise * p,
18 | sigma,
19 | torch.zeros_like(sigma),
20 | )
21 | input = (
22 | input
23 | + torch.randn([batch_size, num_channels, height, width], device=device)
24 | * sigma
25 | )
26 | return input
27 |
28 |
29 | @torch.jit.script
30 | def cutout(
31 | input: torch.Tensor,
32 | cutout: float = 0.0,
33 | cutout_size: float = 0.1,
34 | p: float = 1.0,
35 | device: str = "cuda",
36 | ) -> torch.Tensor:
37 | """Apply cutout with probability (cutout * p)."""
38 | if cutout > 0:
39 | batch_size, num_channels, height, width = input.shape
40 | size = torch.full([batch_size, 2, 1, 1, 1], cutout_size, device=device)
41 | size = torch.where(
42 | torch.rand([batch_size, 1, 1, 1, 1], device=device) < cutout * p,
43 | size,
44 | torch.zeros_like(size),
45 | )
46 | center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
47 | coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
48 | coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1])
49 | mask_x = ((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2
50 | mask_y = ((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2
51 | mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
52 | input = input * mask
53 | return input
54 |
55 |
56 | @torch.jit.script
57 | def color(input: torch.Tensor,
58 | c: torch.Tensor,
59 | c_bias: torch.Tensor,
60 | c_scale: torch.Tensor) -> torch.Tensor:
61 | c = torch.sin(c)
62 | c = c / (c.norm() + 1e-8)
63 | c_bias = torch.sin(c_bias)
64 | c_bias = c_bias / (c_bias.norm() + 1e-8)
65 | input = input.permute(0, 2, 3, 1)
66 | input = torch.nn.functional.linear(input, c, bias=c_bias)
67 | input = input.permute(0, 3, 1, 2) * c_scale.abs()
68 | input = (torch.sin(input) + 1.)/2.
69 | return input
70 |
--------------------------------------------------------------------------------
/structure/utils.py:
--------------------------------------------------------------------------------
1 | import io
2 | import requests
3 |
4 | import numpy as np
5 |
6 |
7 | def model_to_fp32(model):
8 | for p in model.parameters():
9 | p.data = p.data.float()
10 | p.grad.data = p.grad.data.float()
11 |
12 |
13 | def img_pil_to_opencv(input):
14 | """ Converts PIL image to numpy array with open cv formatting. """
15 | open_cv_image = np.array(input)
16 | open_cv_image = open_cv_image[:, :, ::-1].copy()
17 | return open_cv_image
18 |
19 |
20 | def get_img(url):
21 | """ Minimal function to download image from url. """
22 | response = requests.get(url)
23 | return io.BytesIO(response.content)
24 |
25 |
26 | class ArgDict(dict):
27 | def __setattr__(self, name, value):
28 | self[name] = value
29 |
30 | def __getattr__(self, name):
31 | try:
32 | return self[name]
33 | except KeyError:
34 | raise AttributeError(name)
35 |
36 | def __delattr__(self, name):
37 | del self[name]
38 |
--------------------------------------------------------------------------------