├── .gitignore ├── CITATION.cff ├── README.md ├── colabs ├── Structured_Dreaming_Styledreams.ipynb └── Structured_Dreaming_Styledreams_faces.ipynb ├── deprecated ├── __init__.py ├── autograd.py ├── generate.py ├── model.py ├── training.py └── utils.py ├── requirements.txt ├── res ├── styledream_face_thumb.jpeg └── styledream_thumb.jpeg ├── setup.py └── structure ├── __init__.py ├── clip.py ├── data └── bpe_simple_vocab_16e6.txt ├── ops.py ├── optim.py ├── sample.py ├── stylegan_utils ├── __init__.py └── ops.py ├── transform.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Ekgren 5 | given-names: Ariel 6 | google-scholar-id: fhs8fggAAAAJ 7 | title: "Structured Dreaming" 8 | version: 0.0.1 9 | date-released: 2021-09-01 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Structured Dreaming 2 | ###### Note that this is a work in progress and will change or it might freeze in time in an unfinished state. 3 | 4 | ## Colabs 5 | ![Styledream_thumb](res/styledream_thumb.jpeg) 6 | [Styledreams](https://colab.research.google.com/github/ekgren/StructuredDreaming/blob/main/colabs/Structured_Dreaming_Styledreams.ipynb) 7 | -- CLIP x StyleGAN2 notebook for fine-tuning StyleGAN2 from CLIP. 8 | 9 | ![Styledream_thumb](res/styledream_face_thumb.jpeg) 10 | [Styledreams Faces](https://colab.research.google.com/github/ekgren/StructuredDreaming/blob/main/colabs/Structured_Dreaming_Styledreams_faces.ipynb) 11 | -- CLIP x StyleGAN2 notebook for finding a face from a photo in Stylegan latent space and then fine-tuning the model from CLIP to change the face. 12 | 13 | ## Optimizer 14 | In this repo The ClampSGD optimizer is used it also has it's own repository here: [https://github.com/ekgren/open-optimizers](https://github.com/ekgren/open-optimizers). 15 | 16 | ## Introduction 17 | By now it is well known that neural networks trained to classify images also have the capacity to generate 18 | images [[1]](#1). There are a lot of variations on 19 | this theme with entire artistic movements based on DeepDream [[4]](#4), libraries such as Lucid [[6]](#6) 20 | and advanced feature visualization tools such as OpenAI microscope [[7]](#7). With the release of 21 | CLIP [[5]](#5) and open research on twitter [[8]](#8) generative exploration of image networks has gained a lot of popularity. 22 | 23 | As described in Differentiable image parameterizations [[1]](#1) all these 24 | generative techniques work in the same way. Given a network used for image related tasks such 25 | as representational learning or classification we can backpropagate from a desired representation 26 | and optimize the input image towards a high activation image. 27 | 28 | The simplest parametrization of the input image is in the form of RGB values for each pixel. 29 | But naively backpropagating to the image will not work as described in the chapter 30 | [Enemy of feature visualization](https://distill.pub/2017/feature-visualization/#enemy-of-feature-vis) of Feature Visualization [[2]](#2). 31 | The network ends up “cheating” and you will end up with an image full of noise and 32 | nonsensical high-frequency patterns that the network responds strongly to. 33 | 34 | In this work we will continue to explore different techniques to avoid the "cheating" and create both informative and 35 | or visually interesting images. 36 | 37 | ## References 38 | [1] 39 | Mordvintsev, A., Pezzotti, N., Schubert, L., & Olah, C. (2018). 40 | Differentiable image parameterizations. Distill, 3(7), e12. 41 | https://distill.pub/2018/differentiable-parameterizations/ 42 | 43 | [2] 44 | Olah, C., Mordvintsev, A., & Schubert, L. (2017). 45 | Feature visualization. Distill, 2(11), e7. 46 | https://distill.pub/2017/feature-visualization/ 47 | 48 | [3] 49 | Goh, G., Cammarata, N., Voss, C., Carter, S., Petrov, M., Schubert, L., ... & Olah, C. (2021). 50 | Multimodal neurons in artificial neural networks. Distill, 6(3), e30. 51 | 52 | [4] 53 | https://ai.googleblog.com/2015/06/inceptionism-going-deeper-into-neural.html 54 | 55 | [5] 56 | https://github.com/openai/CLIP 57 | 58 | [6] 59 | https://github.com/tensorflow/lucid 60 | 61 | [7] 62 | https://microscope.openai.com/ 63 | 64 | [8] 65 | https://twitter.com/advadnoun/status/1348375026697834496 66 | -------------------------------------------------------------------------------- /deprecated/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ekgren/StructuredDreaming/d9040bda41fdc634cd4e653207b9938fa5c3f947/deprecated/__init__.py -------------------------------------------------------------------------------- /deprecated/autograd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BoostGradFunc(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, tensor, boost_val): 7 | ctx.set_materialize_grads(False) 8 | ctx.boost_val = boost_val 9 | return tensor 10 | 11 | @staticmethod 12 | def backward(ctx, grad_output): 13 | if grad_output is None: 14 | return None, None 15 | return grad_output * ctx.boost_val, None 16 | 17 | 18 | class ClampGradFunc(torch.autograd.Function): 19 | @staticmethod 20 | def forward(ctx, tensor, clamp_val): 21 | ctx.set_materialize_grads(False) 22 | ctx.clamp_val = clamp_val 23 | return tensor 24 | 25 | @staticmethod 26 | def backward(ctx, grad_output): 27 | if grad_output is None: 28 | return None, None 29 | return grad_output.clamp(-ctx.clamp_val, ctx.clamp_val), None 30 | 31 | 32 | class BoostGrad(torch.nn.Module): 33 | """ 34 | This class is a PyTorch module that takes a tensor as input and boosts its gradient. 35 | 36 | Parameters 37 | ---------- 38 | boost_val : float 39 | The value by which the gradient of the input tensor is boosted. 40 | """ 41 | 42 | def __init__(self, boost_val: float = 1e1): 43 | super().__init__() 44 | self.boost_grad = BoostGradFunc.apply 45 | self.boost_val = boost_val 46 | 47 | def forward(self, input: torch.Tensor) -> torch.Tensor: 48 | return self.boost_grad(input, self.boost_val) 49 | 50 | 51 | class ClampGrad(torch.nn.Module): 52 | """ 53 | This class is a PyTorch module that takes a tensor as input and clamps its gradient. 54 | 55 | Parameters 56 | ---------- 57 | clamp_val : float 58 | The value to clamp the gradients to. 59 | 60 | Methods 61 | ------- 62 | forward(input) 63 | Clamps the gradients of the input tensor to be between -clamp_val and clamp_val. 64 | """ 65 | 66 | def __init__(self, clamp_val: float = 1e-6): 67 | super().__init__() 68 | self.clamp_grad = ClampGradFunc.apply 69 | self.clamp_val = clamp_val 70 | 71 | def forward(self, input: torch.Tensor) -> torch.Tensor: 72 | return self.clamp_grad(input, self.clamp_val) 73 | -------------------------------------------------------------------------------- /deprecated/generate.py: -------------------------------------------------------------------------------- 1 | """ Generate images with CLIP """ 2 | 3 | import json 4 | import os 5 | 6 | import click 7 | import numpy as np 8 | import PIL.Image 9 | import torch 10 | import torchvision 11 | from tqdm import tqdm 12 | 13 | from structure.clip import load, tokenize, convert_weights 14 | from deprecated.model import ImgBaseOld 15 | from structure.utils import ( 16 | Pipeline, 17 | Upscale, 18 | Pixelate, 19 | Dropper, 20 | Prod, 21 | SamplePatch, 22 | grad_sign, 23 | model_to_fp32, 24 | ArgDict, 25 | ) 26 | 27 | 28 | @click.command() 29 | @click.pass_context 30 | 31 | # General options. 32 | @click.option("--text", help="Text prompt", required=True) 33 | @click.option("--seed", type=int, help="Random seed", default=0) 34 | @click.option( 35 | "--outdir", 36 | help="Where to save the output images", 37 | type=str, 38 | required=True, 39 | metavar="DIR", 40 | ) 41 | 42 | # Training. 43 | @click.option( 44 | "--iterations", 45 | help="Number of iterations of generating image [default: 1000]", 46 | type=int, 47 | default=1000, 48 | ) 49 | @click.option( 50 | "--grad_acc_steps", 51 | help="Gradient accumulation steps [default: 1]", 52 | type=int, 53 | default=1, 54 | ) 55 | @click.option("--lr", help="Learning rate [default: 0.01]", type=float, default=0.01) 56 | # TODO: add betas = (0.99, 0.999) 57 | 58 | # Model. 59 | @click.option("--image_size", help="Image size [default: 512]", type=int, default=512) 60 | @click.option( 61 | "--weight_init", help="Image weight init [default: 0.05]", type=float, default=0.05 62 | ) 63 | @click.option( 64 | "--decolorize", help="Image weight init [default: 0.001]", type=float, default=0.001 65 | ) 66 | @click.option( 67 | "--darken", help="Image weight init [default: 0.005]", type=float, default=0.005 68 | ) 69 | 70 | # General img options. 71 | # TODO: add mode = 'area' 72 | 73 | # Pixelate pipeline. 74 | @click.option("--px_no", help="Number of patches [default: 32]", type=int, default=32) 75 | @click.option( 76 | "--px_patch_size_min", 77 | help="Pixelate patch min size [default: 256]", 78 | type=int, 79 | default=256, 80 | ) 81 | @click.option( 82 | "--px_patch_size_max", 83 | help="Pixelate patch max size [default: 512]", 84 | type=int, 85 | default=512, 86 | ) 87 | @click.option( 88 | "--px_size_min", help="Pixelation min size [default: 32]", type=int, default=32 89 | ) 90 | @click.option( 91 | "--px_size_max", help="Pixelation max size [default: 224]", type=int, default=224 92 | ) 93 | @click.option( 94 | "--px_drop", help="Pixelation dropout [default: 0.3]", type=float, default=0.3 95 | ) 96 | 97 | # Upscale pipeline. 98 | @click.option("--up_no", help="Number of patches [default: 32]", type=int, default=32) 99 | @click.option( 100 | "--up_patch_size_min", 101 | help="Upscale patch min size [default: 64]", 102 | type=int, 103 | default=64, 104 | ) 105 | @click.option( 106 | "--up_patch_size_max", 107 | help="Upscale patch max size [default: 512]", 108 | type=int, 109 | default=512, 110 | ) 111 | @click.option( 112 | "--up_drop", help="Pixelation dropout [default: 0.3]", type=float, default=0.3 113 | ) 114 | 115 | # Grow parameters. 116 | @click.option( 117 | "--grow_init_res", help="Number of patches [default: 32]", type=int, default=32 118 | ) 119 | @click.option( 120 | "--grow_step_size", help="Number of patches [default: 64]", type=int, default=64 121 | ) 122 | @click.option( 123 | "--grow_step", help="Number of patches [default: 20]", type=int, default=20 124 | ) 125 | def main(ctx, **config_kwargs): 126 | """Generate images. 127 | Examples: 128 | """ 129 | args = ArgDict(**config_kwargs) 130 | # Print options. 131 | print() 132 | print("Generation options:") 133 | print(json.dumps(args, indent=2)) 134 | print() 135 | print(f"Output directory: {args.outdir}") 136 | print(f"Prompt: {args.text}") 137 | print() 138 | 139 | # Initialize 140 | device = torch.device("cuda") 141 | np.random.seed(args["seed"]) 142 | torch.manual_seed(args["seed"]) 143 | os.makedirs(args["outdir"], exist_ok=True) 144 | 145 | ################# 146 | # PARAMETERS 147 | ################# 148 | 149 | # Training 150 | betas = (0.99, 0.999) 151 | 152 | # General img options 153 | res_out = 224 154 | mode = "area" 155 | ################# 156 | 157 | print("Loading clip.") 158 | perceptor, normalize_image = load("ViT-B/32", jit=False) 159 | txt_tok = tokenize(args.text) 160 | text_latent = perceptor.encode_text(txt_tok.cuda()).detach() 161 | 162 | # Setting up image generation 163 | model = ImgBaseOld( 164 | size=args.image_size, 165 | weight_init=args.weight_init, 166 | decolorize=args.decolorize, 167 | darken=args.darken, 168 | ).cuda() 169 | normalize = torchvision.transforms.Normalize( 170 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) 171 | ) 172 | optimizer = torch.optim.Adam(model.parameters(), args.lr, betas=betas) 173 | 174 | grow_res = args.grow_init_res 175 | grow_pipeline = Pipeline( 176 | Pixelate(scale_size_min=grow_res, scale_size_max=grow_res), 177 | Upscale(res_out=args.image_size, mode=mode), 178 | ) 179 | 180 | px_pipeline = Pipeline( 181 | SamplePatch(size_min=args.px_patch_size_min, size_max=args.px_patch_size_max), 182 | Dropper(drop=args.px_drop, drop2d=True), 183 | Pixelate(scale_size_min=args.px_size_min, scale_size_max=args.px_size_max), 184 | Upscale(res_out=res_out, mode=mode), 185 | Dropper(drop=args.px_drop, drop2d=True), 186 | ) 187 | 188 | up_pipeline = Pipeline( 189 | SamplePatch(args.up_patch_size_min, args.up_patch_size_max), 190 | Dropper(drop=args.up_drop, drop2d=True), 191 | Upscale(res_out=res_out, mode=mode), 192 | Dropper(drop=args.up_drop, drop2d=True), 193 | ) 194 | 195 | patches = [px_pipeline] * args.px_no + [up_pipeline] * args.up_no 196 | patches = Prod(*patches) 197 | 198 | # Image generation 199 | print("Generating image.") 200 | for i in tqdm(range(args.iterations)): 201 | optimizer.zero_grad() 202 | img = normalize_image(model()) 203 | img = grow_pipeline(img) 204 | 205 | img_processed = torch.cat(patches(img), 0) 206 | img_latents = perceptor.encode_image(img_processed) 207 | 208 | loss = ( 209 | 10 * torch.cosine_similarity(text_latent, img_latents, dim=-1).mean().neg() 210 | ) 211 | loss.backward() 212 | 213 | if (i + 1) % args.grad_acc_steps == 0: 214 | model_to_fp32(perceptor.visual) 215 | model_to_fp32(model) 216 | grad_sign(model.w.grad) 217 | optimizer.step() 218 | convert_weights(perceptor.visual) 219 | convert_weights(model) 220 | 221 | # Post processing 222 | model.post_process() 223 | 224 | # Update grow resolution 225 | if (i + 1) % args.grow_step == 0: 226 | grow_res = min(args.image_size, grow_res + args.grow_step_size) 227 | grow_pipeline = Pipeline( 228 | Pixelate(scale_size_min=grow_res, scale_size_max=grow_res), 229 | Upscale(res_out=args.image_size, mode=mode), 230 | ) 231 | 232 | # DEBUG 233 | if (i + 1) % 20 == 0: 234 | with torch.no_grad(): 235 | img = model() 236 | _img = (img.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8) 237 | f_name = f'{args.outdir}/{args.text.replace(" ", "_")}__seed{args.seed:04d}.png' 238 | PIL.Image.fromarray(_img[0].cpu().numpy(), "RGB").save(f_name) 239 | 240 | with torch.no_grad(): 241 | img = model() 242 | _img = (img.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8) 243 | f_name = f'{args.outdir}/{args.text.replace(" ", "_")}__seed{args.seed:04d}.png' 244 | PIL.Image.fromarray(_img[0].cpu().numpy(), "RGB").save(f_name) 245 | 246 | 247 | if __name__ == "__main__": 248 | main() 249 | -------------------------------------------------------------------------------- /deprecated/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from deprecated.autograd import BoostGrad 4 | 5 | 6 | class ImgBase(torch.nn.Module): 7 | def __init__(self, size: int = 224, k: float = 5.0, weight_init: float = 0.05): 8 | super().__init__() 9 | self.size = size 10 | self.k = k 11 | 12 | self.color = torch.nn.Parameter( 13 | torch.tensor( 14 | [ 15 | [-0.1409, 0.0855, -0.7620], 16 | [0.2596, -0.5239, 0.0996], 17 | [0.1653, -0.0719, 0.0889], 18 | ] 19 | ) 20 | ) 21 | self.w = torch.nn.Parameter( 22 | torch.randn(1, 3, size, size, requires_grad=True) * weight_init 23 | ) 24 | 25 | def forward(self) -> torch.Tensor: 26 | img = self.w 27 | color = self.color / self.color.norm(p=2) 28 | img = torch.nn.functional.linear(img.permute(0, 2, 3, 1), color).permute( 29 | 0, 3, 1, 2 30 | ) 31 | img = self.to_rgb(img, self.k) 32 | return img 33 | 34 | def to_rgb(self, input: torch.Tensor, k: float) -> torch.Tensor: 35 | return (input.clamp(-k, k) + k) / (2 * k) 36 | 37 | 38 | class ImgBaseOld(torch.nn.Module): 39 | """X""" 40 | 41 | def __init__(self, size=224, weight_init=0.05, decolorize=0.0, darken=0.0): 42 | super().__init__() 43 | self.decolorize = decolorize 44 | self.darken = darken 45 | self.w = torch.ones(1, 3, size, size, requires_grad=True) * weight_init 46 | self.w = torch.nn.Parameter(self.w.half()) 47 | 48 | def forward(self): 49 | return self.w 50 | 51 | def post_process(self): 52 | with torch.no_grad(): 53 | self.w.clamp_(0.0, 1.0) 54 | if self.decolorize > 0.0: 55 | self.w += self.decolorize * ( 56 | -self.w + self.w.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1) 57 | ) 58 | if self.darken > 0.0: 59 | self.w *= 1.0 - self.darken 60 | 61 | 62 | class ImgBaseFFT(torch.nn.Module): 63 | """X""" 64 | 65 | def __init__(self, size=224, k=15.0, weight_init=0.05): 66 | super().__init__() 67 | self.size = size 68 | self.k = k 69 | self.color = torch.nn.Linear(3, 3, bias=False) 70 | w = torch.fft.rfft2( 71 | torch.randn(1, 3, size, size, requires_grad=True) * weight_init 72 | ) 73 | self.w = torch.nn.Parameter(w) 74 | self.act = torch.sin 75 | self.bg = BoostGrad() 76 | self.norm = ChanNorm(dim=3) 77 | 78 | def forward(self): 79 | img = torch.fft.irfft2(self.w) 80 | img = self.bg.apply(img) 81 | img = self.norm(img) 82 | img = self.color(img.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 83 | return img 84 | 85 | def get_img(self, size=None): 86 | size = size if size is not None else self.size 87 | img = torch.fft.irfft2(self.w) 88 | if size != self.size: 89 | img = torch.nn.functional.interpolate(img, (size, size), mode="area") 90 | img = self.norm(img) 91 | img = self.color(img.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 92 | return (img.clamp(-self.k, self.k) + self.k) / (2 * self.k) 93 | 94 | 95 | # From: https://github.com/lucidrains/stylegan2-pytorch 96 | class ChanNorm(torch.nn.Module): 97 | def __init__(self, dim, eps=1e-5): 98 | super().__init__() 99 | self.eps = eps 100 | self.g = torch.nn.Parameter(torch.ones(1, dim, 1, 1)) 101 | self.b = torch.nn.Parameter(torch.zeros(1, dim, 1, 1)) 102 | 103 | def forward(self, x): 104 | std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt() 105 | mean = torch.mean(x, dim=1, keepdim=True) 106 | return (x - mean) / (std + self.eps) * self.g + self.b 107 | -------------------------------------------------------------------------------- /deprecated/training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def training_loop( 6 | random_seed=0, # Global random seed. 7 | ): 8 | # Initialize. 9 | device = torch.device("cuda") 10 | np.random.seed(random_seed) 11 | torch.manual_seed(random_seed) 12 | -------------------------------------------------------------------------------- /deprecated/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from torch.nn.functional import dropout, dropout2d, interpolate, avg_pool2d 4 | import torch 5 | 6 | 7 | class SamplePatch(object): 8 | def __init__(self, size_min, size_max): 9 | assert ( 10 | size_min <= size_max 11 | ), "size_min should be equal or smaller than size_max." 12 | assert size_min >= 2, "size_min can't be smaller than 2." 13 | 14 | self.size_min = size_min 15 | self.size_max = size_max 16 | 17 | def __call__(self, img): 18 | """Get uniformly sampled patch from image.""" 19 | bs, ch, h, w = img.shape 20 | assert h == w, "Assumes equal width and height of image." 21 | assert ( 22 | self.size_max <= h 23 | ), "size_max has to be smaller than or equal to image height." 24 | if self.size_min == self.size_max: 25 | patch_size = self.size_max 26 | else: 27 | patch_size = random.randint(self.size_min, self.size_max) 28 | h_offset = random.randint(0, h - patch_size) 29 | w_offset = random.randint(0, w - patch_size) 30 | patch_h_s = h_offset 31 | patch_h_e = h_offset + patch_size 32 | patch_w_s = w_offset 33 | patch_w_e = w_offset + patch_size 34 | img_out = img[:, :, patch_h_s:patch_h_e, patch_w_s:patch_w_e] 35 | return img_out 36 | 37 | 38 | class Dropper(object): 39 | def __init__(self, drop=0.0, drop2d=False): 40 | self.drop = drop 41 | self.drop2d = drop2d 42 | 43 | def __call__(self, img): 44 | if not self.drop2d: 45 | return dropout(img, p=self.drop) 46 | else: 47 | return dropout2d(img, p=self.drop) 48 | 49 | 50 | class Upscale(object): 51 | def __init__(self, res_out=224, mode="area"): 52 | self.res_out = res_out 53 | self.mode = mode 54 | 55 | def __call__(self, img): 56 | bs, ch, h, w = img.shape 57 | if h != self.res_out: 58 | return interpolate(img, (self.res_out, self.res_out), mode=self.mode) 59 | else: 60 | return img 61 | 62 | 63 | class Pixelate(object): 64 | def __init__(self, scale_size_min, scale_size_max): 65 | self.scale_size_max = scale_size_max 66 | self.scale_size_min = scale_size_min 67 | 68 | def __call__(self, img): 69 | bs, ch, h, w = img.shape 70 | downscale = random.randint(self.scale_size_min, self.scale_size_max) 71 | kernel_size = int(max(1, h / downscale)) 72 | return avg_pool2d(img, kernel_size) 73 | 74 | 75 | class RandomPool(object): 76 | """ 77 | Randomly pads the input image with a value between -1 and 1. 78 | 79 | Args: 80 | img (torch.Tensor): Input image. 81 | pad_val (float, optional): The value to pad with. Defaults to random.random() * 2. - 1. 82 | 83 | Returns: 84 | torch.Tensor: The padded image. 85 | """ 86 | 87 | def __init__(self, kernel_min=1, kernel_max=8): 88 | self.kernel_min = kernel_min 89 | self.kernel_max = kernel_max 90 | 91 | def __call__(self, img, pad_val=None): 92 | pad_val = pad_val if pad_val is not None else random.random() * 2.0 - 1.0 93 | # Double uniform sample different distribution 94 | kernel_size = random.randint( 95 | self.kernel_min, random.randint(self.kernel_min, self.kernel_max) 96 | ) 97 | # kernel_size = random.randint(self.kernel_min, self.kernel_max) 98 | img = avg_pool2d(img, kernel_size=kernel_size, stride=None, padding=0) 99 | return img 100 | 101 | 102 | class RandomMirror(object): 103 | def __init__(self, blend=False): 104 | # Blend not properly implemented 105 | self.modes = 2 if blend is False else 4 106 | 107 | def __call__(self, img): 108 | bs, ch, h, w = img.shape 109 | augmentation = random.randint(0, self.modes) 110 | if augmentation == 0: 111 | pass 112 | elif augmentation == 1: 113 | img_l = img[:, :, :, : w // 2] 114 | img_r = torch.flip(img[:, :, :, : w // 2], [3]) 115 | img = torch.cat([img_l, img_r], dim=3) 116 | elif augmentation == 2: 117 | img_l = torch.flip(img[:, :, :, w // 2 :], [3]) 118 | img_r = img[:, :, :, w // 2 :] 119 | img = torch.cat([img_l, img_r], dim=3) 120 | elif augmentation == 3: 121 | img_l = img[:, :, :, : w // 2] 122 | img_r = ( 123 | img[:, :, :, w // 2 :] + torch.flip(img[:, :, :, : w // 2], [3]) 124 | ) / 2 125 | img = torch.cat([img_l, img_r], dim=3) 126 | elif augmentation == 4: 127 | img_l = ( 128 | img[:, :, :, : w // 2] + torch.flip(img[:, :, :, w // 2 :], [3]) 129 | ) / 2 130 | img_r = img[:, :, :, w // 2 :] 131 | img = torch.cat([img_l, img_r], dim=3) 132 | return img 133 | 134 | 135 | class RandomPad(object): 136 | def __init__(self, pad_min=0, pad_max=224, step=1): 137 | self.pad_min = pad_min 138 | self.pad_max = pad_max 139 | self.step = step 140 | 141 | def __call__(self, img): 142 | pad_modes = ["circular", "reflect", "replicate"] 143 | pad = ( 144 | random.randrange(self.pad_min, self.pad_max, self.step), 145 | random.randrange(self.pad_min, self.pad_max, self.step), 146 | random.randrange(self.pad_min, self.pad_max, self.step), 147 | random.randrange(self.pad_min, self.pad_max, self.step), 148 | ) 149 | return torch.nn.functional.pad(img, pad, mode=random.choice(pad_modes)) 150 | 151 | 152 | class Flip(object): 153 | """Horizontal random flip p=0.5""" 154 | 155 | def __init__(self): 156 | pass 157 | 158 | def __call__(self, img): 159 | if random.randint(0, 1) == 0: 160 | return img 161 | return torch.flip(img, [3]) 162 | 163 | 164 | class Pipeline(object): 165 | def __init__(self, *args): 166 | self.functions = args 167 | 168 | def __call__(self, img): 169 | tmp = img 170 | for f in self.functions: 171 | tmp = f(tmp) 172 | return tmp 173 | 174 | 175 | class Prod(object): 176 | def __init__(self, *args): 177 | self.functions = args 178 | 179 | def __call__(self, x): 180 | return [f(x) for f in self.functions] 181 | 182 | 183 | def get_sub(img, steps=2): 184 | """Get grid of image. 185 | Only support certain image sizes.""" 186 | bs, ch, h, w = img.shape 187 | outs = [] 188 | for h_step in range(steps): 189 | for w_step in range(steps): 190 | h_ixs = torch.arange(h_step, h, steps) 191 | w_ixs = torch.arange(w_step, h, steps) 192 | _w = w_ixs.repeat(int(h / steps)) 193 | _h = h_ixs.repeat_interleave(int(h / steps)) 194 | outs.append(img[:, :, _h, _w].reshape(1, 3, int(h / steps), int(h / steps))) 195 | return outs 196 | 197 | 198 | def grad_drop(grad, drop=0.0, drop2d=True): 199 | """Dropout gradient.""" 200 | with torch.no_grad(): 201 | if not drop2d: 202 | grad += -grad + dropout(grad, drop) 203 | else: 204 | grad += -grad + dropout2d(grad, drop) 205 | 206 | 207 | def grad_sign(grad): 208 | """Convert gradient to signed gradient.""" 209 | with torch.no_grad(): 210 | grad += -grad + grad.sign() 211 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | click 5 | torch~=1.7.1 6 | torchvision~=0.8.2 -------------------------------------------------------------------------------- /res/styledream_face_thumb.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ekgren/StructuredDreaming/d9040bda41fdc634cd4e653207b9938fa5c3f947/res/styledream_face_thumb.jpeg -------------------------------------------------------------------------------- /res/styledream_thumb.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ekgren/StructuredDreaming/d9040bda41fdc634cd4e653207b9938fa5c3f947/res/styledream_thumb.jpeg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup(name='structureddreaming', 6 | version='0.0.1', 7 | description='Structured Dreaming', 8 | author='Ariel Ekgren', 9 | author_email='', 10 | url='https://github.com/ekgren/StructuredDreaming', 11 | install_requires=[], 12 | packages=find_packages(), 13 | entry_points={} 14 | ) 15 | -------------------------------------------------------------------------------- /structure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ekgren/StructuredDreaming/d9040bda41fdc634cd4e653207b9938fa5c3f947/structure/__init__.py -------------------------------------------------------------------------------- /structure/clip.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/openai/CLIP/ 2 | 3 | from collections import OrderedDict 4 | from typing import Tuple, Union 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from pathlib import Path 10 | 11 | import hashlib 12 | import os 13 | import urllib 14 | import warnings 15 | from typing import Union, List 16 | 17 | import torch 18 | from PIL import Image 19 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 20 | from tqdm import tqdm 21 | 22 | _MODELS = { 23 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 24 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 25 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 26 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 27 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 28 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 29 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 30 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 31 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 32 | } 33 | 34 | 35 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 36 | os.makedirs(root, exist_ok=True) 37 | filename = os.path.basename(url) 38 | 39 | expected_sha256 = url.split("/")[-2] 40 | download_target = os.path.join(root, filename) 41 | 42 | if os.path.exists(download_target) and not os.path.isfile(download_target): 43 | raise RuntimeError(f"{download_target} exists and is not a regular file") 44 | 45 | if os.path.isfile(download_target): 46 | if ( 47 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 48 | == expected_sha256 49 | ): 50 | return download_target 51 | else: 52 | warnings.warn( 53 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 54 | ) 55 | 56 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 57 | with tqdm( 58 | total=int(source.info().get("Content-Length")), 59 | ncols=80, 60 | unit="iB", 61 | unit_scale=True, 62 | ) as loop: 63 | while True: 64 | buffer = source.read(8192) 65 | if not buffer: 66 | break 67 | 68 | output.write(buffer) 69 | loop.update(len(buffer)) 70 | 71 | if ( 72 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 73 | != expected_sha256 74 | ): 75 | raise RuntimeError( 76 | f"Model has been downloaded but the SHA256 checksum does not not match" 77 | ) 78 | 79 | return download_target 80 | 81 | 82 | def _transform(): 83 | return Compose( 84 | [ 85 | Normalize( 86 | (0.48145466, 0.4578275, 0.40821073), 87 | (0.26862954, 0.26130258, 0.27577711), 88 | ), 89 | ] 90 | ) 91 | 92 | 93 | def available_models() -> List[str]: 94 | """Returns the names of available CLIP models""" 95 | return list(_MODELS.keys()) 96 | 97 | 98 | def load( 99 | name: str, 100 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 101 | jit=True, 102 | ): 103 | """Load a CLIP model 104 | Parameters 105 | ---------- 106 | name : str 107 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 108 | device : Union[str, torch.device] 109 | The device to put the loaded model 110 | jit : bool 111 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 112 | Returns 113 | ------- 114 | model : torch.nn.Module 115 | The CLIP model 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name]) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError( 125 | f"Model {name} not found; available models = {available_models()}" 126 | ) 127 | 128 | try: 129 | # loading JIT archive 130 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 131 | state_dict = None 132 | except RuntimeError: 133 | # loading saved state dict 134 | if jit: 135 | warnings.warn( 136 | f"File {model_path} is not a JIT archive. Loading as a state dict instead" 137 | ) 138 | jit = False 139 | state_dict = torch.load(model_path, map_location="cpu") 140 | 141 | if not jit: 142 | model = build_model(state_dict or model.state_dict()).to(device) 143 | if str(device) == "cpu": 144 | model.float() 145 | return model, _transform() 146 | 147 | # patch the device names 148 | device_holder = torch.jit.trace( 149 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] 150 | ) 151 | device_node = [ 152 | n 153 | for n in device_holder.graph.findAllNodes("prim::Constant") 154 | if "Device" in repr(n) 155 | ][-1] 156 | 157 | def patch_device(module): 158 | graphs = [module.graph] if hasattr(module, "graph") else [] 159 | if hasattr(module, "forward1"): 160 | graphs.append(module.forward1.graph) 161 | 162 | for graph in graphs: 163 | for node in graph.findAllNodes("prim::Constant"): 164 | if "value" in node.attributeNames() and str(node["value"]).startswith( 165 | "cuda" 166 | ): 167 | node.copyAttributes(device_node) 168 | 169 | model.apply(patch_device) 170 | patch_device(model.encode_image) 171 | patch_device(model.encode_text) 172 | 173 | # patch dtype to float32 on CPU 174 | if str(device) == "cpu": 175 | float_holder = torch.jit.trace( 176 | lambda: torch.ones([]).float(), example_inputs=[] 177 | ) 178 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 179 | float_node = float_input.node() 180 | 181 | def patch_float(module): 182 | graphs = [module.graph] if hasattr(module, "graph") else [] 183 | if hasattr(module, "forward1"): 184 | graphs.append(module.forward1.graph) 185 | 186 | for graph in graphs: 187 | for node in graph.findAllNodes("aten::to"): 188 | inputs = list(node.inputs()) 189 | for i in [ 190 | 1, 191 | 2, 192 | ]: # dtype can be the second or third argument to aten::to() 193 | if inputs[i].node()["value"] == 5: 194 | inputs[i].node().copyAttributes(float_node) 195 | 196 | model.apply(patch_float) 197 | patch_float(model.encode_image) 198 | patch_float(model.encode_text) 199 | 200 | model.float() 201 | 202 | return model, _transform() 203 | 204 | 205 | def tokenize( 206 | texts: Union[str, List[str]], context_length: int = 77 207 | ) -> torch.LongTensor: 208 | """ 209 | Returns the tokenized representation of given input string(s) 210 | Parameters 211 | ---------- 212 | texts : Union[str, List[str]] 213 | An input string or a list of input strings to tokenize 214 | context_length : int 215 | The context length to use; all CLIP models use 77 as the context length 216 | Returns 217 | ------- 218 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 219 | """ 220 | if isinstance(texts, str): 221 | texts = [texts] 222 | 223 | sot_token = _tokenizer.encoder["<|startoftext|>"] 224 | eot_token = _tokenizer.encoder["<|endoftext|>"] 225 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 227 | 228 | for i, tokens in enumerate(all_tokens): 229 | if len(tokens) > context_length: 230 | raise RuntimeError( 231 | f"Input {texts[i]} is too long for context length {context_length}" 232 | ) 233 | result[i, : len(tokens)] = torch.tensor(tokens) 234 | 235 | return result 236 | 237 | 238 | class Bottleneck(nn.Module): 239 | expansion = 4 240 | 241 | def __init__(self, inplanes, planes, stride=1): 242 | super().__init__() 243 | 244 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 245 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 246 | self.bn1 = nn.BatchNorm2d(planes) 247 | 248 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 249 | self.bn2 = nn.BatchNorm2d(planes) 250 | 251 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 252 | 253 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 254 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 255 | 256 | self.relu = nn.ReLU(inplace=True) 257 | self.downsample = None 258 | self.stride = stride 259 | 260 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 261 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 262 | self.downsample = nn.Sequential( 263 | OrderedDict( 264 | [ 265 | ("-1", nn.AvgPool2d(stride)), 266 | ( 267 | "0", 268 | nn.Conv2d( 269 | inplanes, 270 | planes * self.expansion, 271 | 1, 272 | stride=1, 273 | bias=False, 274 | ), 275 | ), 276 | ("1", nn.BatchNorm2d(planes * self.expansion)), 277 | ] 278 | ) 279 | ) 280 | 281 | def forward(self, x: torch.Tensor): 282 | identity = x 283 | 284 | out = self.relu(self.bn1(self.conv1(x))) 285 | out = self.relu(self.bn2(self.conv2(out))) 286 | out = self.avgpool(out) 287 | out = self.bn3(self.conv3(out)) 288 | 289 | if self.downsample is not None: 290 | identity = self.downsample(x) 291 | 292 | out += identity 293 | out = self.relu(out) 294 | return out 295 | 296 | 297 | class AttentionPool2d(nn.Module): 298 | def __init__( 299 | self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None 300 | ): 301 | super().__init__() 302 | self.positional_embedding = nn.Parameter( 303 | torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5 304 | ) 305 | self.k_proj = nn.Linear(embed_dim, embed_dim) 306 | self.q_proj = nn.Linear(embed_dim, embed_dim) 307 | self.v_proj = nn.Linear(embed_dim, embed_dim) 308 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 309 | self.num_heads = num_heads 310 | 311 | def forward(self, x): 312 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( 313 | 2, 0, 1 314 | ) # NCHW -> (HW)NC 315 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 316 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 317 | x, _ = F.multi_head_attention_forward( 318 | query=x, 319 | key=x, 320 | value=x, 321 | embed_dim_to_check=x.shape[-1], 322 | num_heads=self.num_heads, 323 | q_proj_weight=self.q_proj.weight, 324 | k_proj_weight=self.k_proj.weight, 325 | v_proj_weight=self.v_proj.weight, 326 | in_proj_weight=None, 327 | in_proj_bias=torch.cat( 328 | [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] 329 | ), 330 | bias_k=None, 331 | bias_v=None, 332 | add_zero_attn=False, 333 | dropout_p=0, 334 | out_proj_weight=self.c_proj.weight, 335 | out_proj_bias=self.c_proj.bias, 336 | use_separate_proj_weight=True, 337 | training=self.training, 338 | need_weights=False, 339 | ) 340 | 341 | return x[0] 342 | 343 | 344 | class ModifiedResNet(nn.Module): 345 | """ 346 | A ResNet class that is similar to torchvision's but contains the following changes: 347 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 348 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 349 | - The final pooling layer is a QKV attention instead of an average pool 350 | """ 351 | 352 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 353 | super().__init__() 354 | self.output_dim = output_dim 355 | self.input_resolution = input_resolution 356 | 357 | # the 3-layer stem 358 | self.conv1 = nn.Conv2d( 359 | 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False 360 | ) 361 | self.bn1 = nn.BatchNorm2d(width // 2) 362 | self.conv2 = nn.Conv2d( 363 | width // 2, width // 2, kernel_size=3, padding=1, bias=False 364 | ) 365 | self.bn2 = nn.BatchNorm2d(width // 2) 366 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 367 | self.bn3 = nn.BatchNorm2d(width) 368 | self.avgpool = nn.AvgPool2d(2) 369 | self.relu = nn.ReLU(inplace=True) 370 | 371 | # residual layers 372 | self._inplanes = width # this is a *mutable* variable used during construction 373 | self.layer1 = self._make_layer(width, layers[0]) 374 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 375 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 376 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 377 | 378 | embed_dim = width * 32 # the ResNet feature dimension 379 | self.attnpool = AttentionPool2d( 380 | input_resolution // 32, embed_dim, heads, output_dim 381 | ) 382 | 383 | def _make_layer(self, planes, blocks, stride=1): 384 | layers = [Bottleneck(self._inplanes, planes, stride)] 385 | 386 | self._inplanes = planes * Bottleneck.expansion 387 | for _ in range(1, blocks): 388 | layers.append(Bottleneck(self._inplanes, planes)) 389 | 390 | return nn.Sequential(*layers) 391 | 392 | def forward(self, x): 393 | def stem(x): 394 | for conv, bn in [ 395 | (self.conv1, self.bn1), 396 | (self.conv2, self.bn2), 397 | (self.conv3, self.bn3), 398 | ]: 399 | x = self.relu(bn(conv(x))) 400 | x = self.avgpool(x) 401 | return x 402 | 403 | x = x.type(self.conv1.weight.dtype) 404 | x = stem(x) 405 | x = self.layer1(x) 406 | x = self.layer2(x) 407 | x = self.layer3(x) 408 | x = self.layer4(x) 409 | x = self.attnpool(x) 410 | 411 | return x 412 | 413 | 414 | class LayerNorm(nn.LayerNorm): 415 | """Subclass torch's LayerNorm to handle fp16.""" 416 | 417 | def forward(self, x: torch.Tensor): 418 | orig_type = x.dtype 419 | ret = super().forward(x.type(torch.float32)) 420 | return ret.type(orig_type) 421 | 422 | 423 | class QuickGELU(nn.Module): 424 | def forward(self, x: torch.Tensor): 425 | return x * torch.sigmoid(1.702 * x) 426 | 427 | 428 | class ResidualAttentionBlock(nn.Module): 429 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 430 | super().__init__() 431 | 432 | self.attn = nn.MultiheadAttention(d_model, n_head) 433 | self.ln_1 = LayerNorm(d_model) 434 | self.mlp = nn.Sequential( 435 | OrderedDict( 436 | [ 437 | ("c_fc", nn.Linear(d_model, d_model * 4)), 438 | ("gelu", QuickGELU()), 439 | ("c_proj", nn.Linear(d_model * 4, d_model)), 440 | ] 441 | ) 442 | ) 443 | self.ln_2 = LayerNorm(d_model) 444 | self.attn_mask = attn_mask 445 | 446 | def attention(self, x: torch.Tensor): 447 | self.attn_mask = ( 448 | self.attn_mask.to(dtype=x.dtype, device=x.device) 449 | if self.attn_mask is not None 450 | else None 451 | ) 452 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 453 | 454 | def forward(self, x: torch.Tensor): 455 | x = x + self.attention(self.ln_1(x)) 456 | x = x + self.mlp(self.ln_2(x)) 457 | return x 458 | 459 | 460 | class Transformer(nn.Module): 461 | def __init__( 462 | self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None 463 | ): 464 | super().__init__() 465 | self.width = width 466 | self.layers = layers 467 | self.resblocks = nn.Sequential( 468 | *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] 469 | ) 470 | 471 | def forward(self, x: torch.Tensor): 472 | return self.resblocks(x) 473 | 474 | 475 | class VisualTransformer(nn.Module): 476 | def __init__( 477 | self, 478 | input_resolution: int, 479 | patch_size: int, 480 | width: int, 481 | layers: int, 482 | heads: int, 483 | output_dim: int, 484 | ): 485 | super().__init__() 486 | self.input_resolution = input_resolution 487 | self.output_dim = output_dim 488 | self.conv1 = nn.Conv2d( 489 | in_channels=3, 490 | out_channels=width, 491 | kernel_size=patch_size, 492 | stride=patch_size, 493 | bias=False, 494 | ) 495 | 496 | scale = width ** -0.5 497 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 498 | self.positional_embedding = nn.Parameter( 499 | scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) 500 | ) 501 | self.ln_pre = LayerNorm(width) 502 | 503 | self.transformer = Transformer(width, layers, heads) 504 | 505 | self.ln_post = LayerNorm(width) 506 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 507 | 508 | def forward(self, x: torch.Tensor): 509 | x = self.conv1(x) # shape = [*, width, grid, grid] 510 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 511 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 512 | x = torch.cat( 513 | [ 514 | self.class_embedding.to(x.dtype) 515 | + torch.zeros( 516 | x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device 517 | ), 518 | x, 519 | ], 520 | dim=1, 521 | ) # shape = [*, grid ** 2 + 1, width] 522 | x = x + self.positional_embedding.to(x.dtype) 523 | x = self.ln_pre(x) 524 | 525 | x = x.permute(1, 0, 2) # NLD -> LND 526 | x = self.transformer(x) 527 | x = x.permute(1, 0, 2) # LND -> NLD 528 | 529 | x = self.ln_post(x[:, 0, :]) 530 | 531 | if self.proj is not None: 532 | x = x @ self.proj 533 | 534 | return x 535 | 536 | 537 | class CLIP(nn.Module): 538 | def __init__( 539 | self, 540 | embed_dim: int, 541 | # vision 542 | image_resolution: int, 543 | vision_layers: Union[Tuple[int, int, int, int], int], 544 | vision_width: int, 545 | vision_patch_size: int, 546 | # text 547 | context_length: int, 548 | vocab_size: int, 549 | transformer_width: int, 550 | transformer_heads: int, 551 | transformer_layers: int, 552 | ): 553 | super().__init__() 554 | 555 | self.context_length = context_length 556 | 557 | if isinstance(vision_layers, (tuple, list)): 558 | vision_heads = vision_width * 32 // 64 559 | self.visual = ModifiedResNet( 560 | layers=vision_layers, 561 | output_dim=embed_dim, 562 | heads=vision_heads, 563 | input_resolution=image_resolution, 564 | width=vision_width, 565 | ) 566 | else: 567 | vision_heads = vision_width // 64 568 | self.visual = VisualTransformer( 569 | input_resolution=image_resolution, 570 | patch_size=vision_patch_size, 571 | width=vision_width, 572 | layers=vision_layers, 573 | heads=vision_heads, 574 | output_dim=embed_dim, 575 | ) 576 | 577 | self.transformer = Transformer( 578 | width=transformer_width, 579 | layers=transformer_layers, 580 | heads=transformer_heads, 581 | attn_mask=self.build_attention_mask(), 582 | ) 583 | 584 | self.vocab_size = vocab_size 585 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 586 | self.positional_embedding = nn.Parameter( 587 | torch.empty(self.context_length, transformer_width) 588 | ) 589 | self.ln_final = LayerNorm(transformer_width) 590 | 591 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 592 | self.logit_scale = nn.Parameter(torch.ones([])) 593 | 594 | self.initialize_parameters() 595 | 596 | def initialize_parameters(self): 597 | nn.init.normal_(self.token_embedding.weight, std=0.02) 598 | nn.init.normal_(self.positional_embedding, std=0.01) 599 | 600 | if isinstance(self.visual, ModifiedResNet): 601 | if self.visual.attnpool is not None: 602 | std = self.visual.attnpool.c_proj.in_features ** -0.5 603 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 604 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 605 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 606 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 607 | 608 | for resnet_block in [ 609 | self.visual.layer1, 610 | self.visual.layer2, 611 | self.visual.layer3, 612 | self.visual.layer4, 613 | ]: 614 | for name, param in resnet_block.named_parameters(): 615 | if name.endswith("bn3.weight"): 616 | nn.init.zeros_(param) 617 | 618 | proj_std = (self.transformer.width ** -0.5) * ( 619 | (2 * self.transformer.layers) ** -0.5 620 | ) 621 | attn_std = self.transformer.width ** -0.5 622 | fc_std = (2 * self.transformer.width) ** -0.5 623 | for block in self.transformer.resblocks: 624 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 625 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 626 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 627 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 628 | 629 | if self.text_projection is not None: 630 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 631 | 632 | def build_attention_mask(self): 633 | # lazily create causal attention mask, with full attention between the vision tokens 634 | # pytorch uses additive attention mask; fill with -inf 635 | mask = torch.empty(self.context_length, self.context_length) 636 | mask.fill_(float("-inf")) 637 | mask.triu_(1) # zero out the lower diagonal 638 | return mask 639 | 640 | @property 641 | def dtype(self): 642 | return self.visual.conv1.weight.dtype 643 | 644 | def encode_image(self, image): 645 | return self.visual(image.type(self.dtype)) 646 | 647 | def encode_text(self, text): 648 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 649 | 650 | x = x + self.positional_embedding.type(self.dtype) 651 | x = x.permute(1, 0, 2) # NLD -> LND 652 | x = self.transformer(x) 653 | x = x.permute(1, 0, 2) # LND -> NLD 654 | x = self.ln_final(x).type(self.dtype) 655 | 656 | # x.shape = [batch_size, n_ctx, transformer.width] 657 | # take features from the eot embedding (eot_token is the highest number in each sequence) 658 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 659 | 660 | return x 661 | 662 | def forward(self, image, text): 663 | image_features = self.encode_image(image) 664 | text_features = self.encode_text(text) 665 | 666 | # normalized features 667 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 668 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 669 | 670 | # cosine similarity as logits 671 | logit_scale = self.logit_scale.exp() 672 | logits_per_image = logit_scale * image_features @ text_features.t() 673 | logits_per_text = logit_scale * text_features @ image_features.t() 674 | 675 | # shape = [global_batch_size, global_batch_size] 676 | return logits_per_image, logits_per_text 677 | 678 | 679 | def convert_weights(model: nn.Module): 680 | """Convert applicable model parameters to fp16""" 681 | 682 | def _convert_weights_to_fp16(l): 683 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 684 | l.weight.data = l.weight.data.half() 685 | if l.bias is not None: 686 | l.bias.data = l.bias.data.half() 687 | 688 | if isinstance(l, nn.MultiheadAttention): 689 | for attr in [ 690 | *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], 691 | "in_proj_bias", 692 | "bias_k", 693 | "bias_v", 694 | ]: 695 | tensor = getattr(l, attr) 696 | if tensor is not None: 697 | tensor.data = tensor.data.half() 698 | 699 | for name in ["text_projection", "proj"]: 700 | if hasattr(l, name): 701 | attr = getattr(l, name) 702 | if attr is not None: 703 | attr.data = attr.data.half() 704 | 705 | model.apply(_convert_weights_to_fp16) 706 | 707 | 708 | def build_model(state_dict: dict): 709 | vit = "visual.proj" in state_dict 710 | 711 | if vit: 712 | vision_width = state_dict["visual.conv1.weight"].shape[0] 713 | vision_layers = len( 714 | [ 715 | k 716 | for k in state_dict.keys() 717 | if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") 718 | ] 719 | ) 720 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 721 | grid_size = round( 722 | (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 723 | ) 724 | image_resolution = vision_patch_size * grid_size 725 | else: 726 | counts: list = [ 727 | len( 728 | set( 729 | k.split(".")[2] 730 | for k in state_dict 731 | if k.startswith(f"visual.layer{b}") 732 | ) 733 | ) 734 | for b in [1, 2, 3, 4] 735 | ] 736 | vision_layers = tuple(counts) 737 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 738 | output_width = round( 739 | (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 740 | ) 741 | vision_patch_size = None 742 | assert ( 743 | output_width ** 2 + 1 744 | == state_dict["visual.attnpool.positional_embedding"].shape[0] 745 | ) 746 | image_resolution = output_width * 32 747 | 748 | embed_dim = state_dict["text_projection"].shape[1] 749 | context_length = state_dict["positional_embedding"].shape[0] 750 | vocab_size = state_dict["token_embedding.weight"].shape[0] 751 | transformer_width = state_dict["ln_final.weight"].shape[0] 752 | transformer_heads = transformer_width // 64 753 | transformer_layers = len( 754 | set( 755 | k.split(".")[2] 756 | for k in state_dict 757 | if k.startswith(f"transformer.resblocks") 758 | ) 759 | ) 760 | 761 | model = CLIP( 762 | embed_dim, 763 | image_resolution, 764 | vision_layers, 765 | vision_width, 766 | vision_patch_size, 767 | context_length, 768 | vocab_size, 769 | transformer_width, 770 | transformer_heads, 771 | transformer_layers, 772 | ) 773 | 774 | for key in ["input_resolution", "context_length", "vocab_size"]: 775 | if key in state_dict: 776 | del state_dict[key] 777 | 778 | convert_weights(model) 779 | model.load_state_dict(state_dict) 780 | return model.eval() 781 | 782 | 783 | import gzip 784 | import html 785 | import os 786 | from functools import lru_cache 787 | 788 | import ftfy 789 | import regex as re 790 | 791 | 792 | @lru_cache() 793 | def default_bpe(): 794 | return os.path.join( 795 | os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt" 796 | ) 797 | 798 | 799 | @lru_cache() 800 | def bytes_to_unicode(): 801 | """ 802 | Returns list of utf-8 byte and a corresponding list of unicode strings. 803 | The reversible bpe codes work on unicode strings. 804 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 805 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 806 | This is a signficant percentage of your normal, say, 32K bpe vocab. 807 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 808 | And avoids mapping to whitespace/control characters the bpe code barfs on. 809 | """ 810 | bs = ( 811 | list(range(ord("!"), ord("~") + 1)) 812 | + list(range(ord("¡"), ord("¬") + 1)) 813 | + list(range(ord("®"), ord("ÿ") + 1)) 814 | ) 815 | cs = bs[:] 816 | n = 0 817 | for b in range(2 ** 8): 818 | if b not in bs: 819 | bs.append(b) 820 | cs.append(2 ** 8 + n) 821 | n += 1 822 | cs = [chr(n) for n in cs] 823 | return dict(zip(bs, cs)) 824 | 825 | 826 | def get_pairs(word): 827 | """Return set of symbol pairs in a word. 828 | Word is represented as tuple of symbols (symbols being variable-length strings). 829 | """ 830 | pairs = set() 831 | prev_char = word[0] 832 | for char in word[1:]: 833 | pairs.add((prev_char, char)) 834 | prev_char = char 835 | return pairs 836 | 837 | 838 | def basic_clean(text): 839 | text = ftfy.fix_text(text) 840 | text = html.unescape(html.unescape(text)) 841 | return text.strip() 842 | 843 | 844 | def whitespace_clean(text): 845 | text = re.sub(r"\s+", " ", text) 846 | text = text.strip() 847 | return text 848 | 849 | 850 | class SimpleTokenizer(object): 851 | def __init__(self, bpe_path: str = default_bpe()): 852 | self.byte_encoder = bytes_to_unicode() 853 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 854 | merges = Path(bpe_path).read_text(encoding="utf8").split("\n") 855 | merges = merges[1 : 49152 - 256 - 2 + 1] 856 | merges = [tuple(merge.split()) for merge in merges] 857 | vocab = list(bytes_to_unicode().values()) 858 | vocab = vocab + [v + "" for v in vocab] 859 | for merge in merges: 860 | vocab.append("".join(merge)) 861 | vocab.extend(["<|startoftext|>", "<|endoftext|>"]) 862 | self.encoder = dict(zip(vocab, range(len(vocab)))) 863 | self.decoder = {v: k for k, v in self.encoder.items()} 864 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 865 | self.cache = { 866 | "<|startoftext|>": "<|startoftext|>", 867 | "<|endoftext|>": "<|endoftext|>", 868 | } 869 | self.pat = re.compile( 870 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 871 | re.IGNORECASE, 872 | ) 873 | 874 | def bpe(self, token): 875 | if token in self.cache: 876 | return self.cache[token] 877 | word = tuple(token[:-1]) + (token[-1] + "",) 878 | pairs = get_pairs(word) 879 | 880 | if not pairs: 881 | return token + "" 882 | 883 | while True: 884 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 885 | if bigram not in self.bpe_ranks: 886 | break 887 | first, second = bigram 888 | new_word = [] 889 | i = 0 890 | while i < len(word): 891 | try: 892 | j = word.index(first, i) 893 | new_word.extend(word[i:j]) 894 | i = j 895 | except: 896 | new_word.extend(word[i:]) 897 | break 898 | 899 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 900 | new_word.append(first + second) 901 | i += 2 902 | else: 903 | new_word.append(word[i]) 904 | i += 1 905 | new_word = tuple(new_word) 906 | word = new_word 907 | if len(word) == 1: 908 | break 909 | else: 910 | pairs = get_pairs(word) 911 | word = " ".join(word) 912 | self.cache[token] = word 913 | return word 914 | 915 | def encode(self, text): 916 | bpe_tokens = [] 917 | text = whitespace_clean(basic_clean(text)).lower() 918 | for token in re.findall(self.pat, text): 919 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 920 | bpe_tokens.extend( 921 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") 922 | ) 923 | return bpe_tokens 924 | 925 | def decode(self, tokens): 926 | text = "".join([self.decoder[token] for token in tokens]) 927 | text = ( 928 | bytearray([self.byte_decoder[c] for c in text]) 929 | .decode("utf-8", errors="replace") 930 | .replace("", " ") 931 | ) 932 | return text 933 | import gzip 934 | 935 | 936 | _tokenizer = SimpleTokenizer() 937 | -------------------------------------------------------------------------------- /structure/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # From the paper "CLOOB: Modern Hopfield Networks with InfoLOOB Outperform CLIP" 4 | # Arxiv link: https://arxiv.org/abs/2110.11316 5 | # Code: https://github.com/ml-jku/cloob/ 6 | def infoLOOB_loss(x, y, i, inv_tau): 7 | tau = 1 / inv_tau 8 | k = x @ y.T / tau 9 | positives = -torch.mean(torch.sum(k * i, dim=1)) 10 | 11 | # For logsumexp the zero entries must be equal to a very large negative number 12 | large_neg = -1000.0 13 | arg_lse = k * torch.logical_not(i) + i * large_neg 14 | negatives = torch.mean(torch.logsumexp(arg_lse, dim=1)) 15 | 16 | return tau * (positives + negatives) -------------------------------------------------------------------------------- /structure/optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | from torch.optim import Optimizer 5 | from typing import List, Optional 6 | 7 | 8 | class SimpleSGD(Optimizer): 9 | def __init__(self, params, lr=0.1): 10 | if not 0.0 <= lr: 11 | raise ValueError("Invalid learning rate: {}".format(lr)) 12 | defaults = dict(lr=lr) 13 | super().__init__(params, defaults) 14 | 15 | def step(self, closure=None): 16 | loss = None 17 | if closure is not None: 18 | loss = closure() 19 | for group in self.param_groups: 20 | for p in group["params"]: 21 | if p.grad is None: 22 | continue 23 | lr = group["lr"] 24 | grad = p.grad.data 25 | p.data.add_(grad, alpha=-lr) 26 | return loss 27 | 28 | 29 | class ClampSGD(Optimizer): 30 | r"""Clamp SGD 31 | Args: 32 | params (iterable): iterable of parameters to optimize or dicts defining 33 | parameter groups 34 | lr (float, optional): learning rate (default: 1e-1) 35 | clamp (float, optional): clamping of gradient (default: 1e-30) 36 | drop (float, optional): dropout of gradient (default: 0) 37 | """ 38 | 39 | def __init__(self, params, lr=1e-1, clamp=1e-30, drop=0.0): 40 | if not 0.0 <= lr: 41 | raise ValueError("Invalid learning rate: {}".format(lr)) 42 | if not 0.0 <= drop < 1.0: 43 | raise ValueError("Invalid dropout value: {}".format(drop)) 44 | defaults = dict(lr=lr, clamp=clamp, drop=drop) 45 | super().__init__(params, defaults) 46 | 47 | @torch.no_grad() 48 | def step(self, closure=None): 49 | loss = None 50 | if closure is not None: 51 | loss = closure() 52 | for group in self.param_groups: 53 | for p in group["params"]: 54 | if p.grad is None: 55 | continue 56 | lr = group["lr"] 57 | clamp = group["clamp"] 58 | drop = group["drop"] 59 | grad = torch.nn.functional.dropout( 60 | p.grad.data.clamp_(-clamp, clamp).div_(clamp), p=drop 61 | ) 62 | p.data.add_(grad, alpha=-lr) 63 | return loss 64 | 65 | 66 | class SignSGD(Optimizer): 67 | r"""Sign SGD 68 | Args: 69 | params (iterable): iterable of parameters to optimize or dicts defining 70 | parameter groups 71 | lr (float, optional): learning rate (default: 1e-1) 72 | drop (float, optional): dropout of gradient (default: 0) 73 | """ 74 | 75 | def __init__(self, params, lr=1e-1, drop=0.0): 76 | if not 0.0 <= lr: 77 | raise ValueError("Invalid learning rate: {}".format(lr)) 78 | if not 0.0 <= drop < 1.0: 79 | raise ValueError("Invalid dropout value: {}".format(drop)) 80 | defaults = dict(lr=lr, drop=drop) 81 | super().__init__(params, defaults) 82 | 83 | @torch.no_grad() 84 | def step(self, closure=None): 85 | loss = None 86 | if closure is not None: 87 | loss = closure() 88 | for group in self.param_groups: 89 | for p in group["params"]: 90 | if p.grad is None: 91 | continue 92 | lr = group["lr"] 93 | clamp = group["clamp"] 94 | drop = group["drop"] 95 | grad = torch.nn.functional.dropout( 96 | p.grad.data.sign_(), p=drop 97 | ) 98 | p.data.add_(grad, alpha=-lr) 99 | return loss 100 | 101 | 102 | class AdamW(Optimizer): 103 | r"""Implements AdamW algorithm. 104 | 105 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 106 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 107 | 108 | Args: 109 | params (iterable): iterable of parameters to optimize or dicts defining 110 | parameter groups 111 | lr (float, optional): learning rate (default: 1e-3) 112 | betas (Tuple[float, float], optional): coefficients used for computing 113 | running averages of gradient and its square (default: (0.9, 0.999)) 114 | eps (float, optional): term added to the denominator to improve 115 | numerical stability (default: 1e-8) 116 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 117 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 118 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 119 | (default: False) 120 | 121 | .. _Adam\: A Method for Stochastic Optimization: 122 | https://arxiv.org/abs/1412.6980 123 | .. _Decoupled Weight Decay Regularization: 124 | https://arxiv.org/abs/1711.05101 125 | .. _On the Convergence of Adam and Beyond: 126 | https://openreview.net/forum?id=ryQu7f-RZ 127 | """ 128 | 129 | def __init__( 130 | self, 131 | params, 132 | lr=1e-3, 133 | betas=(0.9, 0.999), 134 | eps=1e-8, 135 | weight_decay=1e-2, 136 | amsgrad=False, 137 | clamp=1e-30, 138 | drop=0.0, 139 | ): 140 | if not 0.0 <= lr: 141 | raise ValueError("Invalid learning rate: {}".format(lr)) 142 | if not 0.0 <= eps: 143 | raise ValueError("Invalid epsilon value: {}".format(eps)) 144 | if not 0.0 <= betas[0] < 1.0: 145 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 146 | if not 0.0 <= betas[1] < 1.0: 147 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 148 | if not 0.0 <= weight_decay: 149 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 150 | if not 0.0 <= drop < 1.0: 151 | raise ValueError("Invalid dropout value: {}".format(drop)) 152 | defaults = dict( 153 | lr=lr, 154 | betas=betas, 155 | eps=eps, 156 | weight_decay=weight_decay, 157 | amsgrad=amsgrad, 158 | clamp=clamp, 159 | drop=drop, 160 | ) 161 | super(AdamW, self).__init__(params, defaults) 162 | 163 | def __setstate__(self, state): 164 | super(AdamW, self).__setstate__(state) 165 | for group in self.param_groups: 166 | group.setdefault("amsgrad", False) 167 | 168 | @torch.no_grad() 169 | def step(self, closure=None): 170 | """Performs a single optimization step. 171 | 172 | Args: 173 | closure (callable, optional): A closure that reevaluates the model 174 | and returns the loss. 175 | """ 176 | loss = None 177 | if closure is not None: 178 | with torch.enable_grad(): 179 | loss = closure() 180 | 181 | for group in self.param_groups: 182 | params_with_grad = [] 183 | grads = [] 184 | exp_avgs = [] 185 | exp_avg_sqs = [] 186 | state_sums = [] 187 | max_exp_avg_sqs = [] 188 | state_steps = [] 189 | amsgrad = group["amsgrad"] 190 | beta1, beta2 = group["betas"] 191 | clamp = group["clamp"] 192 | drop = group["drop"] 193 | 194 | for p in group["params"]: 195 | if p.grad is None: 196 | continue 197 | params_with_grad.append(p) 198 | if p.grad.is_sparse: 199 | raise RuntimeError("AdamW does not support sparse gradients") 200 | grads.append(p.grad) 201 | 202 | state = self.state[p] 203 | 204 | # State initialization 205 | if len(state) == 0: 206 | state["step"] = 0 207 | # Exponential moving average of gradient values 208 | state["exp_avg"] = torch.zeros_like( 209 | p, memory_format=torch.preserve_format 210 | ) 211 | # Exponential moving average of squared gradient values 212 | state["exp_avg_sq"] = torch.zeros_like( 213 | p, memory_format=torch.preserve_format 214 | ) 215 | if amsgrad: 216 | # Maintains max of all exp. moving avg. of sq. grad. values 217 | state["max_exp_avg_sq"] = torch.zeros_like( 218 | p, memory_format=torch.preserve_format 219 | ) 220 | 221 | exp_avgs.append(state["exp_avg"]) 222 | exp_avg_sqs.append(state["exp_avg_sq"]) 223 | 224 | if amsgrad: 225 | max_exp_avg_sqs.append(state["max_exp_avg_sq"]) 226 | 227 | # update the steps for each param group update 228 | state["step"] += 1 229 | # record the step after step update 230 | state_steps.append(state["step"]) 231 | 232 | adamw( 233 | params_with_grad, 234 | grads, 235 | exp_avgs, 236 | exp_avg_sqs, 237 | max_exp_avg_sqs, 238 | state_steps, 239 | amsgrad=amsgrad, 240 | beta1=beta1, 241 | beta2=beta2, 242 | lr=group["lr"], 243 | weight_decay=group["weight_decay"], 244 | eps=group["eps"], 245 | clamp=clamp, 246 | drop=drop, 247 | ) 248 | 249 | return loss 250 | 251 | 252 | def adamw( 253 | params: List[Tensor], 254 | grads: List[Tensor], 255 | exp_avgs: List[Tensor], 256 | exp_avg_sqs: List[Tensor], 257 | max_exp_avg_sqs: List[Tensor], 258 | state_steps: List[int], 259 | *, 260 | amsgrad: bool, 261 | beta1: float, 262 | beta2: float, 263 | lr: float, 264 | weight_decay: float, 265 | eps: float, 266 | clamp: float, 267 | drop: float 268 | ): 269 | r"""Functional API that performs AdamW algorithm computation. 270 | See :class:`~torch.optim.AdamW` for details. 271 | """ 272 | for i, param in enumerate(params): 273 | grad = torch.nn.functional.dropout( 274 | grads[i].clamp_(-clamp, clamp).div_(clamp), p=drop 275 | ) 276 | exp_avg = exp_avgs[i] 277 | exp_avg_sq = exp_avg_sqs[i] 278 | step = state_steps[i] 279 | 280 | # Perform stepweight decay 281 | param.mul_(1 - lr * weight_decay) 282 | 283 | bias_correction1 = 1 - beta1 ** step 284 | bias_correction2 = 1 - beta2 ** step 285 | 286 | # Decay the first and second moment running average coefficient 287 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 288 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 289 | if amsgrad: 290 | # Maintains the maximum of all 2nd moment running avg. till now 291 | torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 292 | # Use the max. for normalizing running avg. of gradient 293 | denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) 294 | else: 295 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) 296 | 297 | step_size = lr / bias_correction1 298 | 299 | param.addcdiv_(exp_avg, denom, value=-step_size) 300 | -------------------------------------------------------------------------------- /structure/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | from . import transform 5 | 6 | 7 | @torch.jit.script 8 | def random_generate_grid( 9 | l: int = 2, mode: str = "all", device: str = "cuda" 10 | ) -> torch.Tensor: 11 | """ 12 | Generates a grid of coordinates in the range [-1, 1] to be used with pytorch 13 | grid_sample. 14 | 15 | Parameters 16 | ---------- 17 | l : int 18 | The number of coordinates along each dimension. 19 | mode : str 20 | 'all', 'even' or 'repeat'. If 'all', generates a grid of size l*l. 21 | If 'even', generates a grid of size l*l. 22 | If 'repeat', generates a grid of size 1*l. 23 | device : str 24 | 'cuda' or 'cpu'. 25 | 26 | Returns 27 | ------- 28 | xy : torch.Tensor 29 | A tensor of shape (l*l, 2) if mode == 'all', or (l*l, 2) if mode == 'repeat'. 30 | """ 31 | if mode != "all" and mode != "even" and mode != "repeat": 32 | raise ValueError( 33 | "random_generate_grid(): expected mode to be " 34 | "'all' or 'repeat', but got: '{}'".format(mode) 35 | ) 36 | 37 | if mode == "all": 38 | x, _ = torch.rand(l, l, device=device).sort() 39 | y, _ = torch.rand(l, l, device=device).sort() 40 | x = (x - x.min(dim=1, keepdim=True).values) / ( 41 | x.max(dim=1, keepdim=True).values - x.min(dim=1, keepdim=True).values 42 | ) 43 | y = (y - y.min(dim=1, keepdim=True).values) / ( 44 | y.max(dim=1, keepdim=True).values - y.min(dim=1, keepdim=True).values 45 | ) 46 | x = x.view(-1) 47 | y = y.t().reshape(-1) 48 | xy = torch.stack([x, y], dim=1).view(1, l, l, 2) * 2.0 - 1.0 49 | elif mode == "even": 50 | coords = torch.arange(l, device=device) 51 | coords = (coords.float() / (l - 1.0) - 0.5) * 2.0 52 | offset = torch.rand(1, device=device) * 0.6 53 | x_offset = offset * (1.0 + 0.05 * (torch.rand(1, device=device) * 2.0 - 1.0)) 54 | y_offset = offset * (1.0 + 0.05 * (torch.rand(1, device=device) * 2.0 - 1.0)) 55 | x = ( 56 | coords * (1.0 - x_offset) 57 | + (torch.rand(1, device=device) * 2.0 - 1.0) * x_offset 58 | ) 59 | y = ( 60 | coords * (1.0 - y_offset) 61 | + (torch.rand(1, device=device) * 2.0 - 1.0) * y_offset 62 | ) 63 | grid_y, grid_x = torch.meshgrid(y, x) 64 | xy = torch.stack([grid_x, grid_y], dim=2).view(1, l, l, 2) 65 | else: # mode == 'repeat' 66 | x, _ = torch.rand(l, device=device).sort() 67 | y, _ = torch.rand(l, device=device).sort() 68 | x = (x - x.min()) / (x.max() - x.min()) 69 | y = (y - y.min()) / (y.max() - y.min()) 70 | grid_y, grid_x = torch.meshgrid(y, x) 71 | xy = torch.stack([grid_x, grid_y], dim=2).view(1, l, l, 2) * 2.0 - 1.0 72 | return xy 73 | 74 | 75 | class ImgSampleBase(torch.nn.Module): 76 | def __init__( 77 | self, 78 | kernel_min: int = 1, 79 | kernel_max: int = 8, 80 | grid_size_min: int = 224, 81 | grid_size_max: int = 448, 82 | noise: float = 1.0, 83 | noise_std: float = 0.3, 84 | cutout: float = 1.0, 85 | cutout_size: float = 0.25, 86 | distortion_scale: float = 0.5, 87 | perspective: float = 1.0, 88 | downsamples: int = 1, 89 | ): 90 | super().__init__() 91 | self.kernel_min = kernel_min 92 | self.kernel_max = kernel_max 93 | self.grid_size_min = grid_size_min 94 | self.grid_size_max = grid_size_max 95 | self.noise = noise 96 | self.noise_std = noise_std 97 | self.cutout = cutout 98 | self.cutout_size = cutout_size 99 | self.downsamples = downsamples 100 | self.perspective_transformer = torchvision.transforms.RandomPerspective( 101 | distortion_scale=distortion_scale, 102 | p=perspective, 103 | interpolation=torchvision.transforms.InterpolationMode.BILINEAR, 104 | ) 105 | 106 | def forward( 107 | self, input: torch.Tensor, size: int = 224, bs: int = 1 108 | ) -> torch.Tensor: 109 | imgs = [] 110 | 111 | for _ in range(bs): 112 | # Generate grid for sampling 113 | grid_mode_idx = int(torch.randint(1, 3, ()).item()) 114 | grid_mode = ("all", "even", "repeat") 115 | grid_size = int( 116 | self.grid_size_min 117 | + torch.rand(()).item() * (self.grid_size_max - self.grid_size_min) 118 | ) 119 | grid = random_generate_grid(grid_size, mode=grid_mode[grid_mode_idx]) 120 | 121 | # Sample original input 122 | img = torch.nn.functional.grid_sample( 123 | input, grid, mode="bilinear", padding_mode="zeros", align_corners=False 124 | ) 125 | img = torch.nn.functional.interpolate(img, (size, size), mode="area") 126 | imgs.append(img) 127 | imgs.append(self.perspective_transformer(img)) 128 | img = torch.flip(img, [3]) 129 | imgs.append(img) 130 | imgs.append(self.perspective_transformer(img)) 131 | 132 | for i in range(self.downsamples): 133 | # Transform, downsize and sample original input 134 | img = transform.noise(input, noise=self.noise, noise_std=self.noise_std) 135 | img = transform.cutout( 136 | img, cutout=self.cutout, cutout_size=self.cutout_size 137 | ) 138 | # Draw kernel size from uniform x uniform distribution 139 | kernel_size = int( 140 | torch.randint( 141 | self.kernel_min, 142 | int(torch.randint(self.kernel_min + 1, self.kernel_max, ())), 143 | (), 144 | ).item() 145 | ) 146 | img = torch.nn.functional.avg_pool2d( 147 | img, kernel_size=kernel_size, stride=kernel_size, padding=0 148 | ) 149 | img = torch.nn.functional.grid_sample( 150 | img, grid, mode="bilinear", padding_mode="zeros", align_corners=False 151 | ) 152 | img = torch.nn.functional.interpolate(img, (size, size), mode="area") 153 | imgs.append(img) 154 | imgs.append(self.perspective_transformer(img)) 155 | img = torch.flip(img, [3]) 156 | imgs.append(img) 157 | imgs.append(self.perspective_transformer(img)) 158 | 159 | return torch.cat(imgs, dim=0) 160 | 161 | 162 | class ImgSampleStylegan(torch.nn.Module): 163 | def __init__( 164 | self, 165 | kernel_min: int = 1, 166 | kernel_max: int = 8, 167 | grid_size_min: int = 224, 168 | grid_size_max: int = 448, 169 | noise: float = 1.0, 170 | noise_std: float = 0.3, 171 | cutout: float = 1.0, 172 | cutout_size: float = 0.25, 173 | ): 174 | super().__init__() 175 | self.kernel_min = kernel_min 176 | self.kernel_max = kernel_max 177 | self.grid_size_min = grid_size_min 178 | self.grid_size_max = grid_size_max 179 | self.noise = noise 180 | self.noise_std = noise_std 181 | self.cutout = cutout 182 | self.cutout_size = cutout_size 183 | 184 | def forward( 185 | self, input: torch.Tensor, size: int = 224, bs: int = 1 186 | ) -> torch.Tensor: 187 | imgs = [] 188 | 189 | for _ in range(bs): 190 | # Draw kernel size from uniform x uniform distribution 191 | kernel_size = int( 192 | torch.randint( 193 | self.kernel_min, 194 | int(torch.randint(self.kernel_min + 1, self.kernel_max, ())), 195 | (), 196 | ).item() 197 | ) 198 | 199 | # Generate grid for sampling 200 | grid_mode_idx = int(torch.randint(1, 3, ()).item()) 201 | grid_mode = ("all", "even", "repeat") 202 | grid_size = int( 203 | self.grid_size_min 204 | + torch.rand(()).item() * (self.grid_size_max - self.grid_size_min) 205 | ) 206 | grid = random_generate_grid(grid_size, mode=grid_mode[grid_mode_idx]) 207 | 208 | # Sample original input 209 | img = torch.nn.functional.grid_sample( 210 | input, grid, mode="bilinear", padding_mode="zeros", align_corners=False 211 | ) 212 | img = torch.nn.functional.interpolate(img, (size, size), mode="area") 213 | imgs.append(img) 214 | 215 | # Transform, downsize and sample original input 216 | img = transform.noise(input, noise=self.noise, noise_std=self.noise_std) 217 | img = transform.cutout( 218 | img, cutout=self.cutout, cutout_size=self.cutout_size 219 | ) 220 | img = torch.nn.functional.avg_pool2d( 221 | img, kernel_size=kernel_size, stride=kernel_size, padding=0 222 | ) 223 | img = torch.nn.functional.grid_sample( 224 | img, grid, mode="bilinear", padding_mode="zeros", align_corners=False 225 | ) 226 | img = torch.nn.functional.interpolate(img, (size, size), mode="area") 227 | imgs.append(img) 228 | 229 | return torch.cat(imgs, dim=0) 230 | -------------------------------------------------------------------------------- /structure/stylegan_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ekgren/StructuredDreaming/d9040bda41fdc634cd4e653207b9938fa5c3f947/structure/stylegan_utils/__init__.py -------------------------------------------------------------------------------- /structure/stylegan_utils/ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code slightly modified from https://github.com/NVlabs/ffhq-dataset 3 | Line: https://github.com/NVlabs/ffhq-dataset/blob/193deb624543e54d93b54a296fcfc9a68a3ed8ce/download_ffhq.py#L259 4 | """ 5 | import PIL 6 | import scipy 7 | import numpy as np 8 | 9 | 10 | def align_image(img, 11 | lm, 12 | output_size=1024, 13 | transform_size=4096, 14 | enable_padding=False, 15 | rotate_level=True, 16 | random_shift=0.0, 17 | retry_crops=False): 18 | """ 19 | Code slightly modified from https://github.com/NVlabs/ffhq-dataset 20 | Line: https://github.com/NVlabs/ffhq-dataset/blob/193deb624543e54d93b54a296fcfc9a68a3ed8ce/download_ffhq.py#L259 21 | 22 | Expects an image and a landmarks vector as input. 23 | Return aligned image. 24 | """ 25 | 26 | lm = np.array(lm) 27 | lm_chin = lm[0 : 17] # left-right 28 | lm_eyebrow_left = lm[17 : 22] # left-right 29 | lm_eyebrow_right = lm[22 : 27] # left-right 30 | lm_nose = lm[27 : 31] # top-down 31 | lm_nostrils = lm[31 : 36] # top-down 32 | lm_eye_left = lm[36 : 42] # left-clockwise 33 | lm_eye_right = lm[42 : 48] # left-clockwise 34 | lm_mouth_outer = lm[48 : 60] # left-clockwise 35 | lm_mouth_inner = lm[60 : 68] # left-clockwise 36 | 37 | # Calculate auxiliary vectors. 38 | eye_left = np.mean(lm_eye_left, axis=0) 39 | eye_right = np.mean(lm_eye_right, axis=0) 40 | eye_avg = (eye_left + eye_right) * 0.5 41 | eye_to_eye = eye_right - eye_left 42 | mouth_left = lm_mouth_outer[0] 43 | mouth_right = lm_mouth_outer[6] 44 | mouth_avg = (mouth_left + mouth_right) * 0.5 45 | eye_to_mouth = mouth_avg - eye_avg 46 | 47 | # Choose oriented crop rectangle. 48 | if rotate_level: 49 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 50 | x /= np.hypot(*x) 51 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 52 | y = np.flipud(x) * [-1, 1] 53 | c0 = eye_avg + eye_to_mouth * 0.1 54 | else: 55 | x = np.array([1, 0], dtype=np.float64) 56 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 57 | y = np.flipud(x) * [-1, 1] 58 | c0 = eye_avg + eye_to_mouth * 0.1 59 | 60 | quad = np.stack([c0 - x - y, c0 - x + y, c0 + x + y, c0 + x - y]) 61 | qsize = np.hypot(*x) * 2 62 | 63 | # Keep drawing new random crop offsets until we find one that is contained in the image 64 | # and does not require padding 65 | if random_shift != 0: 66 | for _ in range(1000): 67 | # Offset the crop rectange center by a random shift proportional to image dimension 68 | # and the requested standard deviation 69 | c = (c0 + np.hypot(*x)*2 * random_shift * np.random.normal(0, 1, c0.shape)) 70 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 71 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 72 | if not retry_crops or not (crop[0] < 0 or crop[1] < 0 or crop[2] >= img.width or crop[3] >= img.height): 73 | # We're happy with this crop (either it fits within the image, or retries are disabled) 74 | break 75 | else: 76 | # rejected N times, give up and move to next image 77 | # (does not happen in practice with the FFHQ data) 78 | print('rejected image') 79 | #return 80 | 81 | # Shrink. 82 | shrink = int(np.floor(qsize / output_size * 0.5)) 83 | if shrink > 1: 84 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 85 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 86 | quad /= shrink 87 | qsize /= shrink 88 | 89 | # Crop. 90 | border = max(int(np.rint(qsize * 0.1)), 3) 91 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 92 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) 93 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 94 | img = img.crop(crop) 95 | quad -= crop[0:2] 96 | 97 | # Pad. 98 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 99 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) 100 | if enable_padding and max(pad) > border - 4: 101 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 102 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 103 | h, w, _ = img.shape 104 | y, x, _ = np.ogrid[:h, :w, :1] 105 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3])) 106 | blur = qsize * 0.02 107 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 108 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) 109 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 110 | quad += pad[:2] 111 | 112 | # Transform. 113 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 114 | if output_size < transform_size: 115 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 116 | 117 | return img -------------------------------------------------------------------------------- /structure/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.jit.script 5 | def noise( 6 | input: torch.Tensor, 7 | noise: float = 0.0, 8 | noise_std: float = 0.1, 9 | p: float = 1.0, 10 | device: str = "cuda", 11 | ) -> torch.Tensor: 12 | """Apply additive RGB noise with probability (noise * p).""" 13 | if noise > 0: 14 | batch_size, num_channels, height, width = input.shape 15 | sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * noise_std 16 | sigma = torch.where( 17 | torch.rand([batch_size, 1, 1, 1], device=device) < noise * p, 18 | sigma, 19 | torch.zeros_like(sigma), 20 | ) 21 | input = ( 22 | input 23 | + torch.randn([batch_size, num_channels, height, width], device=device) 24 | * sigma 25 | ) 26 | return input 27 | 28 | 29 | @torch.jit.script 30 | def cutout( 31 | input: torch.Tensor, 32 | cutout: float = 0.0, 33 | cutout_size: float = 0.1, 34 | p: float = 1.0, 35 | device: str = "cuda", 36 | ) -> torch.Tensor: 37 | """Apply cutout with probability (cutout * p).""" 38 | if cutout > 0: 39 | batch_size, num_channels, height, width = input.shape 40 | size = torch.full([batch_size, 2, 1, 1, 1], cutout_size, device=device) 41 | size = torch.where( 42 | torch.rand([batch_size, 1, 1, 1, 1], device=device) < cutout * p, 43 | size, 44 | torch.zeros_like(size), 45 | ) 46 | center = torch.rand([batch_size, 2, 1, 1, 1], device=device) 47 | coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1]) 48 | coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1]) 49 | mask_x = ((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2 50 | mask_y = ((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2 51 | mask = torch.logical_or(mask_x, mask_y).to(torch.float32) 52 | input = input * mask 53 | return input 54 | 55 | 56 | @torch.jit.script 57 | def color(input: torch.Tensor, 58 | c: torch.Tensor, 59 | c_bias: torch.Tensor, 60 | c_scale: torch.Tensor) -> torch.Tensor: 61 | c = torch.sin(c) 62 | c = c / (c.norm() + 1e-8) 63 | c_bias = torch.sin(c_bias) 64 | c_bias = c_bias / (c_bias.norm() + 1e-8) 65 | input = input.permute(0, 2, 3, 1) 66 | input = torch.nn.functional.linear(input, c, bias=c_bias) 67 | input = input.permute(0, 3, 1, 2) * c_scale.abs() 68 | input = (torch.sin(input) + 1.)/2. 69 | return input 70 | -------------------------------------------------------------------------------- /structure/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import requests 3 | 4 | import numpy as np 5 | 6 | 7 | def model_to_fp32(model): 8 | for p in model.parameters(): 9 | p.data = p.data.float() 10 | p.grad.data = p.grad.data.float() 11 | 12 | 13 | def img_pil_to_opencv(input): 14 | """ Converts PIL image to numpy array with open cv formatting. """ 15 | open_cv_image = np.array(input) 16 | open_cv_image = open_cv_image[:, :, ::-1].copy() 17 | return open_cv_image 18 | 19 | 20 | def get_img(url): 21 | """ Minimal function to download image from url. """ 22 | response = requests.get(url) 23 | return io.BytesIO(response.content) 24 | 25 | 26 | class ArgDict(dict): 27 | def __setattr__(self, name, value): 28 | self[name] = value 29 | 30 | def __getattr__(self, name): 31 | try: 32 | return self[name] 33 | except KeyError: 34 | raise AttributeError(name) 35 | 36 | def __delattr__(self, name): 37 | del self[name] 38 | --------------------------------------------------------------------------------