├── CLIP_ ├── .DS_Store ├── LICENSE ├── MANIFEST.in ├── README.md ├── astronaut.png ├── clip │ ├── .DS_Store │ ├── __init__.py │ ├── auxilary.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py ├── example.py ├── model-card.md ├── notebooks │ ├── Interacting_with_CLIP.ipynb │ └── Prompt_Engineering_for_ImageNet.ipynb ├── requirements.txt ├── setup.py └── tests │ └── test_consistency.py ├── LICENSE ├── README.md ├── basic_diffvg.py ├── custom_parser.py ├── datasets ├── datasets_metadata │ ├── new_emojis_list │ └── old_emojis_list ├── from-free-svg.zip └── nft-apes.zip ├── docker ├── Dockerfile └── README.md ├── figures ├── 005-1.png └── 005.pdf ├── geometric_loss.py ├── histogram_loss.py ├── models ├── .DS_Store ├── decomp.py ├── edge.py ├── histogram.py ├── loss.py ├── painter_params.py ├── pyramid.py └── structure.py ├── reduce_and_optimize.py ├── reduce_or_add_and_optimize.py ├── target_images └── 083.png ├── test └── config_init.npy └── utils.py /CLIP_/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajevnisek/optimize-and-reduce/f621b8eac4830d213d889f9c78d1bca1e47b0e7c/CLIP_/.DS_Store -------------------------------------------------------------------------------- /CLIP_/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /CLIP_/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include clip/bpe_simple_vocab_16e6.txt.gz 2 | -------------------------------------------------------------------------------- /CLIP_/README.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | 3 | [[Blog]](https://openai.com/blog/clip/) [[Paper]](https://arxiv.org/abs/2103.00020) [[Model Card]](model-card.md) [[Colab]](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb) 4 | 5 | CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision. 6 | 7 | 8 | 9 | ## Approach 10 | 11 | ![CLIP](CLIP.png) 12 | 13 | 14 | 15 | ## Usage 16 | 17 | First, [install PyTorch 1.7.1](https://pytorch.org/get-started/locally/) and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick: 18 | 19 | ```bash 20 | $ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0 21 | $ pip install ftfy regex tqdm 22 | $ pip install git+https://github.com/openai/CLIP.git 23 | ``` 24 | 25 | Replace `cudatoolkit=11.0` above with the appropriate CUDA version on your machine or `cpuonly` when installing on a machine without a GPU. 26 | 27 | ```python 28 | import torch 29 | import clip 30 | from PIL import Image 31 | 32 | device = "cuda" if torch.cuda.is_available() else "cpu" 33 | model, preprocess = clip.load("ViT-B/32", device=device) 34 | 35 | image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device) 36 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) 37 | 38 | with torch.no_grad(): 39 | image_features = model.encode_image(image) 40 | text_features = model.encode_text(text) 41 | 42 | logits_per_image, logits_per_text = model(image, text) 43 | probs = logits_per_image.softmax(dim=-1).cpu().numpy() 44 | 45 | print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]] 46 | ``` 47 | 48 | 49 | ## API 50 | 51 | The CLIP module `clip` provides the following methods: 52 | 53 | #### `clip.available_models()` 54 | 55 | Returns the names of the available CLIP models. 56 | 57 | #### `clip.load(name, device=..., jit=True)` 58 | 59 | Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint. 60 | 61 | The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded. 62 | 63 | #### `clip.tokenize(text: Union[str, List[str]], context_length=77)` 64 | 65 | Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model 66 | 67 | --- 68 | 69 | The model returned by `clip.load()` supports the following methods: 70 | 71 | #### `model.encode_image(image: Tensor)` 72 | 73 | Given a batch of images, returns the image features encoded by the vision portion of the CLIP model. 74 | 75 | #### `model.encode_text(text: Tensor)` 76 | 77 | Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model. 78 | 79 | #### `model(image: Tensor, text: Tensor)` 80 | 81 | Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100. 82 | 83 | 84 | 85 | ## More Examples 86 | 87 | ### Zero-Shot Prediction 88 | 89 | The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset. 90 | 91 | ```python 92 | import os 93 | import clip 94 | import torch 95 | from torchvision.datasets import CIFAR100 96 | 97 | # Load the model 98 | device = "cuda" if torch.cuda.is_available() else "cpu" 99 | model, preprocess = clip.load('ViT-B/32', device) 100 | 101 | # Download the dataset 102 | cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) 103 | 104 | # Prepare the inputs 105 | image, class_id = cifar100[3637] 106 | image_input = preprocess(image).unsqueeze(0).to(device) 107 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) 108 | 109 | # Calculate features 110 | with torch.no_grad(): 111 | image_features = model.encode_image(image_input) 112 | text_features = model.encode_text(text_inputs) 113 | 114 | # Pick the top 5 most similar labels for the image 115 | image_features /= image_features.norm(dim=-1, keepdim=True) 116 | text_features /= text_features.norm(dim=-1, keepdim=True) 117 | similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) 118 | values, indices = similarity[0].topk(5) 119 | 120 | # Print the result 121 | print("\nTop predictions:\n") 122 | for value, index in zip(values, indices): 123 | print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%") 124 | ``` 125 | 126 | The output will look like the following (the exact numbers may be slightly different depending on the compute device): 127 | 128 | ``` 129 | Top predictions: 130 | 131 | snake: 65.31% 132 | turtle: 12.29% 133 | sweet_pepper: 3.83% 134 | lizard: 1.88% 135 | crocodile: 1.75% 136 | ``` 137 | 138 | Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs. 139 | 140 | 141 | ### Linear-probe evaluation 142 | 143 | The example below uses [scikit-learn](https://scikit-learn.org/) to perform logistic regression on image features. 144 | 145 | ```python 146 | import os 147 | import clip 148 | import torch 149 | 150 | import numpy as np 151 | from sklearn.linear_model import LogisticRegression 152 | from torch.utils.data import DataLoader 153 | from torchvision.datasets import CIFAR100 154 | from tqdm import tqdm 155 | 156 | # Load the model 157 | device = "cuda" if torch.cuda.is_available() else "cpu" 158 | model, preprocess = clip.load('ViT-B/32', device) 159 | 160 | # Load the dataset 161 | root = os.path.expanduser("~/.cache") 162 | train = CIFAR100(root, download=True, train=True, transform=preprocess) 163 | test = CIFAR100(root, download=True, train=False, transform=preprocess) 164 | 165 | 166 | def get_features(dataset): 167 | all_features = [] 168 | all_labels = [] 169 | 170 | with torch.no_grad(): 171 | for images, labels in tqdm(DataLoader(dataset, batch_size=100)): 172 | features = model.encode_image(images.to(device)) 173 | 174 | all_features.append(features) 175 | all_labels.append(labels) 176 | 177 | return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy() 178 | 179 | # Calculate the image features 180 | train_features, train_labels = get_features(train) 181 | test_features, test_labels = get_features(test) 182 | 183 | # Perform logistic regression 184 | classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) 185 | classifier.fit(train_features, train_labels) 186 | 187 | # Evaluate using the logistic regression classifier 188 | predictions = classifier.predict(test_features) 189 | accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100. 190 | print(f"Accuracy = {accuracy:.3f}") 191 | ``` 192 | 193 | Note that the `C` value should be determined via a hyperparameter sweep using a validation split. 194 | -------------------------------------------------------------------------------- /CLIP_/astronaut.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajevnisek/optimize-and-reduce/f621b8eac4830d213d889f9c78d1bca1e47b0e7c/CLIP_/astronaut.png -------------------------------------------------------------------------------- /CLIP_/clip/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajevnisek/optimize-and-reduce/f621b8eac4830d213d889f9c78d1bca1e47b0e7c/CLIP_/clip/.DS_Store -------------------------------------------------------------------------------- /CLIP_/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /CLIP_/clip/auxilary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | from typing import Tuple, Optional 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.nn.init import xavier_uniform_ 8 | from torch.nn.init import constant_ 9 | from torch.nn.init import xavier_normal_ 10 | from torch.nn.parameter import Parameter 11 | from torch.nn import functional as F 12 | 13 | # We define this function as _pad because it takes an argument 14 | # named pad, which clobbers the recursive reference to the pad 15 | # function needed for __torch_function__ support 16 | pad = F.pad 17 | 18 | 19 | # This class exists solely for Transformer; it has an annotation stating 20 | # that bias is never None, which appeases TorchScript 21 | class _LinearWithBias(torch.nn.Linear): 22 | bias: Tensor 23 | 24 | def __init__(self, in_features: int, out_features: int) -> None: 25 | super().__init__(in_features, out_features, bias=True) 26 | 27 | 28 | def multi_head_attention_forward(query: Tensor, 29 | key: Tensor, 30 | value: Tensor, 31 | embed_dim_to_check: int, 32 | num_heads: int, 33 | in_proj_weight: Tensor, 34 | in_proj_bias: Tensor, 35 | bias_k: Optional[Tensor], 36 | bias_v: Optional[Tensor], 37 | add_zero_attn: bool, 38 | dropout_p: float, 39 | out_proj_weight: Tensor, 40 | out_proj_bias: Tensor, 41 | training: bool = True, 42 | key_padding_mask: Optional[Tensor] = None, 43 | need_weights: bool = True, 44 | attn_mask: Optional[Tensor] = None, 45 | use_separate_proj_weight: bool = False, 46 | q_proj_weight: Optional[Tensor] = None, 47 | k_proj_weight: Optional[Tensor] = None, 48 | v_proj_weight: Optional[Tensor] = None, 49 | static_k: Optional[Tensor] = None, 50 | static_v: Optional[Tensor] = None, 51 | attention_probs_forward_hook=None, 52 | attention_probs_backwards_hook=None, 53 | ) -> Tuple[Tensor, Optional[Tensor]]: 54 | if not torch.jit.is_scripting(): 55 | tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, 56 | out_proj_weight, out_proj_bias) 57 | if any([type(t) is not Tensor for t in tens_ops]) and F.has_torch_function(tens_ops): 58 | return F.handle_torch_function( 59 | multi_head_attention_forward, tens_ops, query, key, value, 60 | embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, 61 | bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, 62 | out_proj_bias, training=training, key_padding_mask=key_padding_mask, 63 | need_weights=need_weights, attn_mask=attn_mask, 64 | use_separate_proj_weight=use_separate_proj_weight, 65 | q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, 66 | v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) 67 | tgt_len, bsz, embed_dim = query.size() 68 | assert embed_dim == embed_dim_to_check 69 | # allow MHA to have different sizes for the feature dimension 70 | assert key.size(0) == value.size(0) and key.size(1) == value.size(1) 71 | 72 | head_dim = embed_dim // num_heads 73 | assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" 74 | scaling = float(head_dim) ** -0.5 75 | 76 | if not use_separate_proj_weight: 77 | if torch.equal(query, key) and torch.equal(key, value): 78 | # self-attention 79 | q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) 80 | 81 | elif torch.equal(key, value): 82 | # encoder-decoder attention 83 | # This is inline in_proj function with in_proj_weight and in_proj_bias 84 | _b = in_proj_bias 85 | _start = 0 86 | _end = embed_dim 87 | _w = in_proj_weight[_start:_end, :] 88 | if _b is not None: 89 | _b = _b[_start:_end] 90 | q = F.linear(query, _w, _b) 91 | 92 | if key is None: 93 | assert value is None 94 | k = None 95 | v = None 96 | else: 97 | 98 | # This is inline in_proj function with in_proj_weight and in_proj_bias 99 | _b = in_proj_bias 100 | _start = embed_dim 101 | _end = None 102 | _w = in_proj_weight[_start:, :] 103 | if _b is not None: 104 | _b = _b[_start:] 105 | k, v = F.linear(key, _w, _b).chunk(2, dim=-1) 106 | 107 | else: 108 | # This is inline in_proj function with in_proj_weight and in_proj_bias 109 | _b = in_proj_bias 110 | _start = 0 111 | _end = embed_dim 112 | _w = in_proj_weight[_start:_end, :] 113 | if _b is not None: 114 | _b = _b[_start:_end] 115 | q = F.linear(query, _w, _b) 116 | 117 | # This is inline in_proj function with in_proj_weight and in_proj_bias 118 | _b = in_proj_bias 119 | _start = embed_dim 120 | _end = embed_dim * 2 121 | _w = in_proj_weight[_start:_end, :] 122 | if _b is not None: 123 | _b = _b[_start:_end] 124 | k = F.linear(key, _w, _b) 125 | 126 | # This is inline in_proj function with in_proj_weight and in_proj_bias 127 | _b = in_proj_bias 128 | _start = embed_dim * 2 129 | _end = None 130 | _w = in_proj_weight[_start:, :] 131 | if _b is not None: 132 | _b = _b[_start:] 133 | v = F.linear(value, _w, _b) 134 | else: 135 | q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) 136 | len1, len2 = q_proj_weight_non_opt.size() 137 | assert len1 == embed_dim and len2 == query.size(-1) 138 | 139 | k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) 140 | len1, len2 = k_proj_weight_non_opt.size() 141 | assert len1 == embed_dim and len2 == key.size(-1) 142 | 143 | v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) 144 | len1, len2 = v_proj_weight_non_opt.size() 145 | assert len1 == embed_dim and len2 == value.size(-1) 146 | 147 | if in_proj_bias is not None: 148 | q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) 149 | k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) 150 | v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) 151 | else: 152 | q = F.linear(query, q_proj_weight_non_opt, in_proj_bias) 153 | k = F.linear(key, k_proj_weight_non_opt, in_proj_bias) 154 | v = F.linear(value, v_proj_weight_non_opt, in_proj_bias) 155 | q = q * scaling 156 | 157 | if attn_mask is not None: 158 | assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ 159 | attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ 160 | 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) 161 | if attn_mask.dtype == torch.uint8: 162 | warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 163 | attn_mask = attn_mask.to(torch.bool) 164 | 165 | if attn_mask.dim() == 2: 166 | attn_mask = attn_mask.unsqueeze(0) 167 | if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: 168 | raise RuntimeError('The size of the 2D attn_mask is not correct.') 169 | elif attn_mask.dim() == 3: 170 | if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: 171 | raise RuntimeError('The size of the 3D attn_mask is not correct.') 172 | else: 173 | raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) 174 | # attn_mask's dim is 3 now. 175 | 176 | # convert ByteTensor key_padding_mask to bool 177 | if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: 178 | warnings.warn( 179 | "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 180 | key_padding_mask = key_padding_mask.to(torch.bool) 181 | 182 | if bias_k is not None and bias_v is not None: 183 | if static_k is None and static_v is None: 184 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 185 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 186 | if attn_mask is not None: 187 | attn_mask = pad(attn_mask, (0, 1)) 188 | if key_padding_mask is not None: 189 | key_padding_mask = pad(key_padding_mask, (0, 1)) 190 | else: 191 | assert static_k is None, "bias cannot be added to static key." 192 | assert static_v is None, "bias cannot be added to static value." 193 | else: 194 | assert bias_k is None 195 | assert bias_v is None 196 | 197 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 198 | if k is not None: 199 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 200 | if v is not None: 201 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 202 | 203 | if static_k is not None: 204 | assert static_k.size(0) == bsz * num_heads 205 | assert static_k.size(2) == head_dim 206 | k = static_k 207 | 208 | if static_v is not None: 209 | assert static_v.size(0) == bsz * num_heads 210 | assert static_v.size(2) == head_dim 211 | v = static_v 212 | 213 | src_len = k.size(1) 214 | 215 | if key_padding_mask is not None: 216 | assert key_padding_mask.size(0) == bsz 217 | assert key_padding_mask.size(1) == src_len 218 | 219 | if add_zero_attn: 220 | src_len += 1 221 | k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) 222 | v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) 223 | if attn_mask is not None: 224 | attn_mask = pad(attn_mask, (0, 1)) 225 | if key_padding_mask is not None: 226 | key_padding_mask = pad(key_padding_mask, (0, 1)) 227 | 228 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 229 | assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] 230 | 231 | if attn_mask is not None: 232 | if attn_mask.dtype == torch.bool: 233 | attn_output_weights.masked_fill_(attn_mask, float('-inf')) 234 | else: 235 | attn_output_weights += attn_mask 236 | 237 | if key_padding_mask is not None: 238 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 239 | attn_output_weights = attn_output_weights.masked_fill( 240 | key_padding_mask.unsqueeze(1).unsqueeze(2), 241 | float('-inf'), 242 | ) 243 | attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) 244 | 245 | attn_output_weights = F.softmax( 246 | attn_output_weights, dim=-1) 247 | attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training) 248 | 249 | # use hooks for the attention weights if necessary 250 | if attention_probs_forward_hook is not None and attention_probs_backwards_hook is not None: 251 | attention_probs_forward_hook(attn_output_weights) 252 | attn_output_weights.register_hook(attention_probs_backwards_hook) 253 | 254 | attn_output = torch.bmm(attn_output_weights, v) 255 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] 256 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 257 | attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) 258 | 259 | if need_weights: 260 | # average attention weights over heads 261 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 262 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 263 | else: 264 | return attn_output, None 265 | 266 | 267 | class MultiheadAttention(torch.nn.Module): 268 | r"""Allows the model to jointly attend to information 269 | from different representation subspaces. 270 | See reference: Attention Is All You Need 271 | 272 | .. math:: 273 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 274 | \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) 275 | 276 | Args: 277 | embed_dim: total dimension of the model. 278 | num_heads: parallel attention heads. 279 | dropout: a Dropout layer on attn_output_weights. Default: 0.0. 280 | bias: add bias as module parameter. Default: True. 281 | add_bias_kv: add bias to the key and value sequences at dim=0. 282 | add_zero_attn: add a new batch of zeros to the key and 283 | value sequences at dim=1. 284 | kdim: total number of features in key. Default: None. 285 | vdim: total number of features in value. Default: None. 286 | 287 | Note: if kdim and vdim are None, they will be set to embed_dim such that 288 | query, key, and value have the same number of features. 289 | 290 | Examples:: 291 | 292 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 293 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 294 | """ 295 | bias_k: Optional[torch.Tensor] 296 | bias_v: Optional[torch.Tensor] 297 | 298 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, 299 | vdim=None): 300 | super(MultiheadAttention, self).__init__() 301 | self.embed_dim = embed_dim 302 | self.kdim = kdim if kdim is not None else embed_dim 303 | self.vdim = vdim if vdim is not None else embed_dim 304 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 305 | 306 | self.num_heads = num_heads 307 | self.dropout = dropout 308 | self.head_dim = embed_dim // num_heads 309 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 310 | 311 | if self._qkv_same_embed_dim is False: 312 | self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) 313 | self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) 314 | self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) 315 | self.register_parameter('in_proj_weight', None) 316 | else: 317 | self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) 318 | self.register_parameter('q_proj_weight', None) 319 | self.register_parameter('k_proj_weight', None) 320 | self.register_parameter('v_proj_weight', None) 321 | 322 | if bias: 323 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) 324 | else: 325 | self.register_parameter('in_proj_bias', None) 326 | self.out_proj = _LinearWithBias(embed_dim, embed_dim) 327 | 328 | if add_bias_kv: 329 | self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) 330 | self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) 331 | else: 332 | self.bias_k = self.bias_v = None 333 | 334 | self.add_zero_attn = add_zero_attn 335 | 336 | self._reset_parameters() 337 | 338 | def _reset_parameters(self): 339 | if self._qkv_same_embed_dim: 340 | xavier_uniform_(self.in_proj_weight) 341 | else: 342 | xavier_uniform_(self.q_proj_weight) 343 | xavier_uniform_(self.k_proj_weight) 344 | xavier_uniform_(self.v_proj_weight) 345 | 346 | if self.in_proj_bias is not None: 347 | constant_(self.in_proj_bias, 0.) 348 | constant_(self.out_proj.bias, 0.) 349 | if self.bias_k is not None: 350 | xavier_normal_(self.bias_k) 351 | if self.bias_v is not None: 352 | xavier_normal_(self.bias_v) 353 | 354 | def __setstate__(self, state): 355 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 356 | if '_qkv_same_embed_dim' not in state: 357 | state['_qkv_same_embed_dim'] = True 358 | 359 | super(MultiheadAttention, self).__setstate__(state) 360 | 361 | def forward(self, query, key, value, key_padding_mask=None, 362 | need_weights=True, attn_mask=None, attention_probs_forward_hook=None, 363 | attention_probs_backwards_hook=None): 364 | r""" 365 | Args: 366 | query, key, value: map a query and a set of key-value pairs to an output. 367 | See "Attention Is All You Need" for more details. 368 | key_padding_mask: if provided, specified padding elements in the key will 369 | be ignored by the attention. When given a binary mask and a value is True, 370 | the corresponding value on the attention layer will be ignored. When given 371 | a byte mask and a value is non-zero, the corresponding value on the attention 372 | layer will be ignored 373 | need_weights: output attn_output_weights. 374 | attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all 375 | the batches while a 3D mask allows to specify a different mask for the entries of each batch. 376 | 377 | Shape: 378 | - Inputs: 379 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 380 | the embedding dimension. 381 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 382 | the embedding dimension. 383 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 384 | the embedding dimension. 385 | - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. 386 | If a ByteTensor is provided, the non-zero positions will be ignored while the position 387 | with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the 388 | value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. 389 | - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 390 | 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, 391 | S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked 392 | positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend 393 | while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` 394 | is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor 395 | is provided, it will be added to the attention weight. 396 | 397 | - Outputs: 398 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 399 | E is the embedding dimension. 400 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 401 | L is the target sequence length, S is the source sequence length. 402 | """ 403 | if not self._qkv_same_embed_dim: 404 | return multi_head_attention_forward( 405 | query, key, value, self.embed_dim, self.num_heads, 406 | self.in_proj_weight, self.in_proj_bias, 407 | self.bias_k, self.bias_v, self.add_zero_attn, 408 | self.dropout, self.out_proj.weight, self.out_proj.bias, 409 | training=self.training, 410 | key_padding_mask=key_padding_mask, need_weights=need_weights, 411 | attn_mask=attn_mask, use_separate_proj_weight=True, 412 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 413 | v_proj_weight=self.v_proj_weight, 414 | attention_probs_forward_hook=attention_probs_forward_hook, 415 | attention_probs_backwards_hook=attention_probs_backwards_hook) 416 | else: 417 | return multi_head_attention_forward( 418 | query, key, value, self.embed_dim, self.num_heads, 419 | self.in_proj_weight, self.in_proj_bias, 420 | self.bias_k, self.bias_v, self.add_zero_attn, 421 | self.dropout, self.out_proj.weight, self.out_proj.bias, 422 | training=self.training, 423 | key_padding_mask=key_padding_mask, need_weights=need_weights, 424 | attn_mask=attn_mask, 425 | attention_probs_forward_hook=attention_probs_forward_hook, 426 | attention_probs_backwards_hook=attention_probs_backwards_hook) 427 | -------------------------------------------------------------------------------- /CLIP_/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajevnisek/optimize-and-reduce/f621b8eac4830d213d889f9c78d1bca1e47b0e7c/CLIP_/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /CLIP_/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | __all__ = ["available_models", "load", "tokenize"] 16 | _tokenizer = _Tokenizer() 17 | 18 | _MODELS = { 19 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 20 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 21 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 22 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 23 | } 24 | 25 | 26 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 27 | os.makedirs(root, exist_ok=True) 28 | filename = os.path.basename(url) 29 | 30 | expected_sha256 = url.split("/")[-2] 31 | download_target = os.path.join(root, filename) 32 | 33 | if os.path.exists(download_target) and not os.path.isfile(download_target): 34 | raise RuntimeError(f"{download_target} exists and is not a regular file") 35 | 36 | if os.path.isfile(download_target): 37 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 38 | return download_target 39 | else: 40 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 41 | 42 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 43 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 44 | while True: 45 | buffer = source.read(8192) 46 | if not buffer: 47 | break 48 | 49 | output.write(buffer) 50 | loop.update(len(buffer)) 51 | 52 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 53 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 54 | 55 | return download_target 56 | 57 | 58 | def _transform(n_px): 59 | return Compose([ 60 | Resize(n_px, interpolation=Image.BICUBIC), 61 | CenterCrop(n_px), 62 | lambda image: image.convert("RGB"), 63 | ToTensor(), 64 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 65 | ]) 66 | 67 | 68 | def available_models() -> List[str]: 69 | """Returns the names of available CLIP models""" 70 | return list(_MODELS.keys()) 71 | 72 | 73 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): 74 | """Load a CLIP model 75 | 76 | Parameters 77 | ---------- 78 | name : str 79 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 80 | 81 | device : Union[str, torch.device] 82 | The device to put the loaded model 83 | 84 | jit : bool 85 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 86 | 87 | Returns 88 | ------- 89 | model : torch.nn.Module 90 | The CLIP model 91 | 92 | preprocess : Callable[[PIL.Image], torch.Tensor] 93 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 94 | """ 95 | if name in _MODELS: 96 | model_path = _download(_MODELS[name]) 97 | elif os.path.isfile(name): 98 | model_path = name 99 | else: 100 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 101 | 102 | try: 103 | # loading JIT archive 104 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 105 | state_dict = None 106 | except RuntimeError: 107 | # loading saved state dict 108 | if jit: 109 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 110 | jit = False 111 | state_dict = torch.load(model_path, map_location="cpu") 112 | 113 | if not jit: 114 | model = build_model(state_dict or model.state_dict()).to(device) 115 | if str(device) == "cpu": 116 | model.float() 117 | return model, _transform(model.visual.input_resolution) 118 | 119 | # patch the device names 120 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 121 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 122 | 123 | def patch_device(module): 124 | graphs = [module.graph] if hasattr(module, "graph") else [] 125 | if hasattr(module, "forward1"): 126 | graphs.append(module.forward1.graph) 127 | 128 | for graph in graphs: 129 | for node in graph.findAllNodes("prim::Constant"): 130 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 131 | node.copyAttributes(device_node) 132 | 133 | model.apply(patch_device) 134 | patch_device(model.encode_image) 135 | patch_device(model.encode_text) 136 | 137 | # patch dtype to float32 on CPU 138 | if str(device) == "cpu": 139 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 140 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 141 | float_node = float_input.node() 142 | 143 | def patch_float(module): 144 | graphs = [module.graph] if hasattr(module, "graph") else [] 145 | if hasattr(module, "forward1"): 146 | graphs.append(module.forward1.graph) 147 | 148 | for graph in graphs: 149 | for node in graph.findAllNodes("aten::to"): 150 | inputs = list(node.inputs()) 151 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 152 | if inputs[i].node()["value"] == 5: 153 | inputs[i].node().copyAttributes(float_node) 154 | 155 | model.apply(patch_float) 156 | patch_float(model.encode_image) 157 | patch_float(model.encode_text) 158 | 159 | model.float() 160 | 161 | return model, _transform(model.input_resolution.item()) 162 | 163 | 164 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 165 | """ 166 | Returns the tokenized representation of given input string(s) 167 | 168 | Parameters 169 | ---------- 170 | texts : Union[str, List[str]] 171 | An input string or a list of input strings to tokenize 172 | 173 | context_length : int 174 | The context length to use; all CLIP models use 77 as the context length 175 | 176 | Returns 177 | ------- 178 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 179 | """ 180 | if isinstance(texts, str): 181 | texts = [texts] 182 | 183 | sot_token = _tokenizer.encoder["<|startoftext|>"] 184 | eot_token = _tokenizer.encoder["<|endoftext|>"] 185 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 186 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 187 | 188 | for i, tokens in enumerate(all_tokens): 189 | if len(tokens) > context_length: 190 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 191 | result[i, :len(tokens)] = torch.tensor(tokens) 192 | 193 | return result 194 | -------------------------------------------------------------------------------- /CLIP_/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from .auxilary import * 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | self.attn_probs = None 182 | self.attn_grad = None 183 | 184 | def set_attn_probs(self, attn_probs): 185 | self.attn_probs = attn_probs 186 | 187 | def set_attn_grad(self, attn_grad): 188 | self.attn_grad = attn_grad 189 | 190 | def attention(self, x: torch.Tensor): 191 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 192 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, attention_probs_forward_hook=self.set_attn_probs, 193 | attention_probs_backwards_hook=self.set_attn_grad)[0] 194 | 195 | def forward(self, x: torch.Tensor): 196 | x = x + self.attention(self.ln_1(x)) 197 | x = x + self.mlp(self.ln_2(x)) 198 | return x 199 | 200 | 201 | class Transformer(nn.Module): 202 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 203 | super().__init__() 204 | self.width = width 205 | self.layers = layers 206 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 207 | 208 | def forward(self, x: torch.Tensor): 209 | return self.resblocks(x) 210 | 211 | 212 | class VisualTransformer(nn.Module): 213 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 214 | super().__init__() 215 | self.input_resolution = input_resolution 216 | self.output_dim = output_dim 217 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 218 | 219 | scale = width ** -0.5 220 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 221 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 222 | self.ln_pre = LayerNorm(width) 223 | 224 | self.transformer = Transformer(width, layers, heads) 225 | 226 | self.ln_post = LayerNorm(width) 227 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 228 | 229 | def forward(self, x: torch.Tensor): 230 | x = self.conv1(x) # shape = [*, width, grid, grid] 231 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 232 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 233 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 234 | x = x + self.positional_embedding.to(x.dtype) 235 | x = self.ln_pre(x) 236 | 237 | x = x.permute(1, 0, 2) # NLD -> LND 238 | x = self.transformer(x) 239 | x = x.permute(1, 0, 2) # LND -> NLD 240 | 241 | x = self.ln_post(x[:, 0, :]) 242 | 243 | if self.proj is not None: 244 | x = x @ self.proj 245 | 246 | return x 247 | 248 | 249 | class CLIP(nn.Module): 250 | def __init__(self, 251 | embed_dim: int, 252 | # vision 253 | image_resolution: int, 254 | vision_layers: Union[Tuple[int, int, int, int], int], 255 | vision_width: int, 256 | vision_patch_size: int, 257 | # text 258 | context_length: int, 259 | vocab_size: int, 260 | transformer_width: int, 261 | transformer_heads: int, 262 | transformer_layers: int 263 | ): 264 | super().__init__() 265 | 266 | self.context_length = context_length 267 | 268 | if isinstance(vision_layers, (tuple, list)): 269 | vision_heads = vision_width * 32 // 64 270 | self.visual = ModifiedResNet( 271 | layers=vision_layers, 272 | output_dim=embed_dim, 273 | heads=vision_heads, 274 | input_resolution=image_resolution, 275 | width=vision_width 276 | ) 277 | else: 278 | vision_heads = vision_width // 64 279 | self.visual = VisualTransformer( 280 | input_resolution=image_resolution, 281 | patch_size=vision_patch_size, 282 | width=vision_width, 283 | layers=vision_layers, 284 | heads=vision_heads, 285 | output_dim=embed_dim 286 | ) 287 | 288 | self.transformer = Transformer( 289 | width=transformer_width, 290 | layers=transformer_layers, 291 | heads=transformer_heads, 292 | attn_mask=self.build_attention_mask() 293 | ) 294 | 295 | self.vocab_size = vocab_size 296 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 297 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 298 | self.ln_final = LayerNorm(transformer_width) 299 | 300 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 301 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 302 | 303 | self.initialize_parameters() 304 | 305 | def initialize_parameters(self): 306 | nn.init.normal_(self.token_embedding.weight, std=0.02) 307 | nn.init.normal_(self.positional_embedding, std=0.01) 308 | 309 | if isinstance(self.visual, ModifiedResNet): 310 | if self.visual.attnpool is not None: 311 | std = self.visual.attnpool.c_proj.in_features ** -0.5 312 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 313 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 314 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 315 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 316 | 317 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 318 | for name, param in resnet_block.named_parameters(): 319 | if name.endswith("bn3.weight"): 320 | nn.init.zeros_(param) 321 | 322 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 323 | attn_std = self.transformer.width ** -0.5 324 | fc_std = (2 * self.transformer.width) ** -0.5 325 | for block in self.transformer.resblocks: 326 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 327 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 328 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 329 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 330 | 331 | if self.text_projection is not None: 332 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 333 | 334 | def build_attention_mask(self): 335 | # lazily create causal attention mask, with full attention between the vision tokens 336 | # pytorch uses additive attention mask; fill with -inf 337 | mask = torch.empty(self.context_length, self.context_length) 338 | mask.fill_(float("-inf")) 339 | mask.triu_(1) # zero out the lower diagonal 340 | return mask 341 | 342 | @property 343 | def dtype(self): 344 | return self.visual.conv1.weight.dtype 345 | 346 | def encode_image(self, image): 347 | return self.visual(image.type(self.dtype)) 348 | 349 | def encode_text(self, text): 350 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 351 | 352 | x = x + self.positional_embedding.type(self.dtype) 353 | x = x.permute(1, 0, 2) # NLD -> LND 354 | x = self.transformer(x) 355 | x = x.permute(1, 0, 2) # LND -> NLD 356 | x = self.ln_final(x).type(self.dtype) 357 | 358 | # x.shape = [batch_size, n_ctx, transformer.width] 359 | # take features from the eot embedding (eot_token is the highest number in each sequence) 360 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 361 | 362 | return x 363 | 364 | def forward(self, image, text): 365 | image_features = self.encode_image(image) 366 | text_features = self.encode_text(text) 367 | 368 | # normalized features 369 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 370 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 371 | 372 | # cosine similarity as logits 373 | logit_scale = self.logit_scale.exp() 374 | logits_per_image = logit_scale * image_features @ text_features.t() 375 | logits_per_text = logit_scale * text_features @ image_features.t() 376 | 377 | # shape = [global_batch_size, global_batch_size] 378 | return logits_per_image, logits_per_text 379 | 380 | 381 | def convert_weights(model: nn.Module): 382 | """Convert applicable model parameters to fp16""" 383 | 384 | def _convert_weights_to_fp16(l): 385 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 386 | l.weight.data = l.weight.data.half() 387 | if l.bias is not None: 388 | l.bias.data = l.bias.data.half() 389 | 390 | if isinstance(l, MultiheadAttention): 391 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 392 | tensor = getattr(l, attr) 393 | if tensor is not None: 394 | tensor.data = tensor.data.half() 395 | 396 | for name in ["text_projection", "proj"]: 397 | if hasattr(l, name): 398 | attr = getattr(l, name) 399 | if attr is not None: 400 | attr.data = attr.data.half() 401 | 402 | model.apply(_convert_weights_to_fp16) 403 | 404 | 405 | def build_model(state_dict: dict): 406 | vit = "visual.proj" in state_dict 407 | 408 | if vit: 409 | vision_width = state_dict["visual.conv1.weight"].shape[0] 410 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 411 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 412 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 413 | image_resolution = vision_patch_size * grid_size 414 | else: 415 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 416 | vision_layers = tuple(counts) 417 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 418 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 419 | vision_patch_size = None 420 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 421 | image_resolution = output_width * 32 422 | 423 | embed_dim = state_dict["text_projection"].shape[1] 424 | context_length = state_dict["positional_embedding"].shape[0] 425 | vocab_size = state_dict["token_embedding.weight"].shape[0] 426 | transformer_width = state_dict["ln_final.weight"].shape[0] 427 | transformer_heads = transformer_width // 64 428 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 429 | 430 | model = CLIP( 431 | embed_dim, 432 | image_resolution, vision_layers, vision_width, vision_patch_size, 433 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 434 | ) 435 | 436 | for key in ["input_resolution", "context_length", "vocab_size"]: 437 | if key in state_dict: 438 | del state_dict[key] 439 | 440 | convert_weights(model) 441 | model.load_state_dict(state_dict) 442 | return model.eval() 443 | -------------------------------------------------------------------------------- /CLIP_/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /CLIP_/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | from PIL import Image 4 | import numpy as np 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | 8 | def interpret(image, text, model, device, index=None): 9 | logits_per_image, logits_per_text = model(image, text) 10 | probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy() 11 | if index is None: 12 | index = np.argmax(logits_per_image.cpu().data.numpy(), axis=-1) 13 | one_hot = np.zeros((1, logits_per_image.size()[-1]), dtype=np.float32) 14 | one_hot[0, index] = 1 15 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 16 | one_hot = torch.sum(one_hot.cuda() * logits_per_image) 17 | model.zero_grad() 18 | one_hot.backward(retain_graph=True) 19 | 20 | image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values()) 21 | num_tokens = image_attn_blocks[0].attn_probs.shape[-1] 22 | R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device) 23 | for blk in image_attn_blocks: 24 | grad = blk.attn_grad 25 | cam = blk.attn_probs 26 | cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1]) 27 | grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1]) 28 | cam = grad * cam 29 | cam = cam.clamp(min=0).mean(dim=0) 30 | R += torch.matmul(cam, R) 31 | R[0, 0] = 0 32 | image_relevance = R[0, 1:] 33 | 34 | # create heatmap from mask on image 35 | def show_cam_on_image(img, mask): 36 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 37 | heatmap = np.float32(heatmap) / 255 38 | cam = heatmap + np.float32(img) 39 | cam = cam / np.max(cam) 40 | return cam 41 | 42 | image_relevance = image_relevance.reshape(1, 1, 7, 7) 43 | image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear') 44 | image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy() 45 | image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min()) 46 | image = image[0].permute(1, 2, 0).data.cpu().numpy() 47 | image = (image - image.min()) / (image.max() - image.min()) 48 | vis = show_cam_on_image(image, image_relevance) 49 | vis = np.uint8(255 * vis) 50 | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) 51 | 52 | plt.imshow(vis) 53 | plt.show() 54 | 55 | print("Label probs:", probs) 56 | 57 | def main(): 58 | device = "cuda" if torch.cuda.is_available() else "cpu" 59 | model, preprocess = clip.load("ViT-B/32", device=device, jit=False) 60 | 61 | image = preprocess(Image.open("catdog.png")).unsqueeze(0).to(device) 62 | text = clip.tokenize(["a dog", "a cat"]).to(device) 63 | interpret(model=model, image=image, text=text, device=device, index=0) 64 | interpret(model=model, image=image, text=text, device=device, index=1) 65 | 66 | image = preprocess(Image.open("el1.png")).unsqueeze(0).to(device) 67 | text = clip.tokenize(["an elephant", "a zebra"]).to(device) 68 | interpret(model=model, image=image, text=text, device=device, index=0) 69 | interpret(model=model, image=image, text=text, device=device, index=1) 70 | 71 | image = preprocess(Image.open("el2.png")).unsqueeze(0).to(device) 72 | text = clip.tokenize(["an elephant", "a zebra"]).to(device) 73 | interpret(model=model, image=image, text=text, device=device, index=0) 74 | interpret(model=model, image=image, text=text, device=device, index=1) 75 | 76 | image = preprocess(Image.open("el3.png")).unsqueeze(0).to(device) 77 | text = clip.tokenize(["an elephant", "a zebra"]).to(device) 78 | interpret(model=model, image=image, text=text, device=device, index=0) 79 | interpret(model=model, image=image, text=text, device=device, index=1) 80 | 81 | image = preprocess(Image.open("el4.png")).unsqueeze(0).to(device) 82 | text = clip.tokenize(["an elephant", "a zebra"]).to(device) 83 | interpret(model=model, image=image, text=text, device=device, index=0) 84 | interpret(model=model, image=image, text=text, device=device, index=1) 85 | 86 | image = preprocess(Image.open("dogbird.png")).unsqueeze(0).to(device) 87 | text = clip.tokenize(["a basset hound", "a parrot"]).to(device) 88 | interpret(model=model, image=image, text=text, device=device, index=0) 89 | interpret(model=model, image=image, text=text, device=device, index=1) 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | 95 | -------------------------------------------------------------------------------- /CLIP_/model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: CLIP 2 | 3 | Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we’re providing some accompanying information about the multimodal model. 4 | 5 | ## Model Details 6 | 7 | The CLIP model was developed by researchers at OpenAI to learn about what contributes to robustness in computer vision tasks. The model was also developed to test the ability of models to generalize to arbitrary image classification tasks in a zero-shot manner. It was not developed for general model deployment - to deploy models like CLIP, researchers will first need to carefully study their capabilities in relation to the specific context they’re being deployed within. 8 | 9 | ### Model Date 10 | 11 | January 2021 12 | 13 | ### Model Type 14 | 15 | The base model uses a ResNet50 with several modifications as an image encoder and uses a masked self-attention Transformer as a text encoder. These encoders are trained to maximize the similarity of (image, text) pairs via a contrastive loss. There is also a variant of the model where the ResNet image encoder is replaced with a Vision Transformer. 16 | 17 | ### Model Version 18 | 19 | Initially, we’ve released one CLIP model based on the Vision Transformer architecture equivalent to ViT-B/32, along with the RN50 model, using the architecture equivalent to ResNet-50. 20 | 21 | As part of the staged release process, we have also released the RN101 model, as well as RN50x4, a RN50 scaled up 4x according to the [EfficientNet](https://arxiv.org/abs/1905.11946) scaling rule. 22 | 23 | Please see the paper linked below for further details about their specification. 24 | 25 | ### Documents 26 | 27 | - [Blog Post](https://openai.com/blog/clip/) 28 | - [CLIP Paper](https://arxiv.org/abs/2103.00020) 29 | 30 | 31 | 32 | ## Model Use 33 | 34 | ### Intended Use 35 | 36 | The model is intended as a research output for research communities. We hope that this model will enable researchers to better understand and explore zero-shot, arbitrary image classification. We also hope it can be used for interdisciplinary studies of the potential impact of such models - the CLIP paper includes a discussion of potential downstream impacts to provide an example for this sort of analysis. 37 | 38 | #### Primary intended uses 39 | 40 | The primary intended users of these models are AI researchers. 41 | 42 | We primarily imagine the model will be used by researchers to better understand robustness, generalization, and other capabilities, biases, and constraints of computer vision models. 43 | 44 | ### Out-of-Scope Use Cases 45 | 46 | **Any** deployed use case of the model - whether commercial or not - is currently out of scope. Non-deployed use cases such as image search in a constrained environment, are also not recommended unless there is thorough in-domain testing of the model with a specific, fixed class taxonomy. This is because our safety assessment demonstrated a high need for task specific testing especially given the variability of CLIP’s performance with different class taxonomies. This makes untested and unconstrained deployment of the model in any use case currently potentially harmful. 47 | 48 | Certain use cases which would fall under the domain of surveillance and facial recognition are always out-of-scope regardless of performance of the model. This is because the use of artificial intelligence for tasks such as these can be premature currently given the lack of testing norms and checks to ensure its fair use. 49 | 50 | Since the model has not been purposefully trained in or evaluated on any languages other than English, its use should be limited to English language use cases. 51 | 52 | 53 | 54 | ## Data 55 | 56 | The model was trained on publicly available image-caption data. This was done through a combination of crawling a handful of websites and using commonly-used pre-existing image datasets such as [YFCC100M](http://projects.dfki.uni-kl.de/yfcc100m/). A large portion of the data comes from our crawling of the internet. This means that the data is more representative of people and societies most connected to the internet which tend to skew towards more developed nations, and younger, male users. 57 | 58 | ### Data Mission Statement 59 | 60 | Our goal with building this dataset was to test out robustness and generalizability in computer vision tasks. As a result, the focus was on gathering large quantities of data from different publicly-available internet data sources. The data was gathered in a mostly non-interventionist manner. However, we only crawled websites that had policies against excessively violent and adult images and allowed us to filter out such content. We do not intend for this dataset to be used as the basis for any commercial or deployed model and will not be releasing the dataset. 61 | 62 | 63 | 64 | ## Performance and Limitations 65 | 66 | ### Performance 67 | 68 | We have evaluated the performance of CLIP on a wide range of benchmarks across a variety of computer vision datasets such as OCR to texture recognition to fine-grained classification. The paper describes model performance on the following datasets: 69 | 70 | - Food101 71 | - CIFAR10 72 | - CIFAR100 73 | - Birdsnap 74 | - SUN397 75 | - Stanford Cars 76 | - FGVC Aircraft 77 | - VOC2007 78 | - DTD 79 | - Oxford-IIIT Pet dataset 80 | - Caltech101 81 | - Flowers102 82 | - MNIST 83 | - SVHN 84 | - IIIT5K 85 | - Hateful Memes 86 | - SST-2 87 | - UCF101 88 | - Kinetics700 89 | - Country211 90 | - CLEVR Counting 91 | - KITTI Distance 92 | - STL-10 93 | - RareAct 94 | - Flickr30 95 | - MSCOCO 96 | - ImageNet 97 | - ImageNet-A 98 | - ImageNet-R 99 | - ImageNet Sketch 100 | - ObjectNet (ImageNet Overlap) 101 | - Youtube-BB 102 | - ImageNet-Vid 103 | 104 | ## Limitations 105 | 106 | CLIP and our analysis of it have a number of limitations. CLIP currently struggles with respect to certain tasks such as fine grained classification and counting objects. CLIP also poses issues with regards to fairness and bias which we discuss in the paper and briefly in the next section. Additionally, our approach to testing CLIP also has an important limitation- in many cases we have used linear probes to evaluate the performance of CLIP and there is evidence suggesting that linear probes can underestimate model performance. 107 | 108 | ### Bias and Fairness 109 | 110 | We find that the performance of CLIP - and the specific biases it exhibits - can depend significantly on class design and the choices one makes for categories to include and exclude. We tested the risk of certain kinds of denigration with CLIP by classifying images of people from [Fairface](https://arxiv.org/abs/1908.04913) into crime-related and non-human animal categories. We found significant disparities with respect to race and gender. Additionally, we found that these disparities could shift based on how the classes were constructed. (Details captured in the Broader Impacts Section in the paper). 111 | 112 | We also tested the performance of CLIP on gender, race and age classification using the Fairface dataset (We default to using race categories as they are constructed in the Fairface dataset.) in order to assess quality of performance across different demographics. We found accuracy >96% across all races for gender classification with ‘Middle Eastern’ having the highest accuracy (98.4%) and ‘White’ having the lowest (96.5%). Additionally, CLIP averaged ~93% for racial classification and ~63% for age classification. Our use of evaluations to test for gender, race and age classification as well as denigration harms is simply to evaluate performance of the model across people and surface potential risks and not to demonstrate an endorsement/enthusiasm for such tasks. 113 | 114 | 115 | 116 | ## Feedback 117 | 118 | ### Where to send questions or comments about the model 119 | 120 | Please use [this Google Form](https://forms.gle/Uv7afRH5dvY34ZEs9) 121 | -------------------------------------------------------------------------------- /CLIP_/requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | torch~=1.7.1 5 | torchvision~=0.8.2 6 | -------------------------------------------------------------------------------- /CLIP_/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="clip", 8 | py_modules=["clip"], 9 | version="1.0", 10 | description="", 11 | author="OpenAI", 12 | packages=find_packages(exclude=["tests*"]), 13 | install_requires=[ 14 | str(r) 15 | for r in pkg_resources.parse_requirements( 16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 17 | ) 18 | ], 19 | include_package_data=True, 20 | extras_require={'dev': ['pytest']}, 21 | ) 22 | -------------------------------------------------------------------------------- /CLIP_/tests/test_consistency.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from PIL import Image 5 | 6 | import clip 7 | 8 | 9 | @pytest.mark.parametrize('model_name', clip.available_models()) 10 | def test_consistency(model_name): 11 | device = "cpu" 12 | jit_model, transform = clip.load(model_name, device=device) 13 | py_model, _ = clip.load(model_name, device=device, jit=False) 14 | 15 | image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device) 16 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) 17 | 18 | with torch.no_grad(): 19 | logits_per_image, _ = jit_model(image, text) 20 | jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 21 | 22 | logits_per_image, _ = py_model(image, text) 23 | py_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 24 | 25 | assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 PDaddi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimize and Reduce: A Top-Down Approach for Image Vectorization 2 | ![](https://img.shields.io/badge/version-1.0.0-blue) 3 | [![Pytorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?e&logo=PyTorch&logoColor=white)](https://pytorch.org/) 4 | [![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/zjukg/DUET/blob/main/licence) 5 | [![AAAI](https://img.shields.io/badge/AAAI-2024-%23f1592a?labelColor=%23003973&color=%23be1c1a)](https://aaai.org/Conferences/AAAI-24/) 6 | 7 | 8 | ## 🔔 News 9 | - **`2023-12`** Our paper: Optimize and Reduce: A Top-Down Approach for Image Vectorization was accepted by **`AAAI 2024`** 10 | 11 | 12 | We propose Optimize & Reduce (O&R), a top-down approach to vectorization that is both fast and domain-agnostic. O&R aims to attain a *compact* representation of input images by iteratively optimizing Bézier curve parameters and significantly reducing the number of shapes, using a devised importance measure. 13 | 14 | ![title](figures/005-1.png) 15 | 16 | By [Or Hirschorn*](https://scholar.google.co.il/citations?user=GgFuT_QAAAAJ&hl=iw&oi=ao), [Amir Jevnisek*](https://scholar.google.com/citations?user=czm6bkUAAAAJ&hl=en&oi=ao), and [Shai Avidan](https://scholar.google.co.il/citations?hl=iw&user=hpItE1QAAAAJ) 17 | 18 | Where * denotes equal contribution. 19 | 20 | ## 📕 Setup 21 | ```shell 22 | cd docker 23 | docker build -t optimize_and_reduce_aaai . 24 | cd .. 25 | docker run -v $(pwd):/home/code -it optimize_and_reduce_aaai /bin/bash 26 | ``` 27 | 28 | 29 | ## 🚀 Run 30 | 1) Running O&R: 31 | 32 | ```shell 33 | python reduce_or_add_and_optimize.py --target target_images/083.png \ 34 | --scheduler 256 128 64 --num_iter 100 100 100 \ 35 | --recons_loss_type l1_and_clip_mix --l1_and_clip_alpha 0.95 \ 36 | --geometric_loss_type geometric --ranking_loss_type mse \ 37 | --canvas_width 256 --canvas_height 256 --advanced_logging 38 | ``` 39 | 2) Running the baseline DiffVG: 40 | ```shell 41 | python basic_diffvg.py --target target_images/083.png \ 42 | --num_paths 64 --num_epochs 1 --num_iter 400 \ 43 | --recons_loss_type l1 --geometric_loss_type none \ 44 | --canvas_width 256 --canvas_height 256 --scheduler 400 \ 45 | --init_type random 46 | ``` 47 | 48 | ## 📚 Dataset Download 49 | 1. [Old Emojis](https://github.com/googlefonts/noto-emoji/releases/tag/v2015-09-29-license-apache), take the images from [this](datasets/datasets_metadata/old_emojis_list) list 50 | 2. [New Emojis](https://github.com/googlefonts/noto-emoji/tree/main/png), take the images from [this](datasets/datasets_metadata/new_emojis_list) list 51 | 3. [Free-SVG](datasets/from-free-svg.zip) 52 | 4. [NFT-Apes](datasets/nft-apes.zip) 53 | 5. [Midjourney Images](https://drive.google.com/file/d/1U9Vjz5ULUzFE-ythlupVodYIp0E9AERC/view?usp=drive_link) 54 | 55 | 56 | 57 | ## 🌈 Cite: 58 | Please consider citing this paper if you found the ```code``` or ```data``` useful. 59 | 60 | ```bigquery 61 | @inproceedings{DBLP:conf/aaai/OptimizeReduce, 62 | author = {Or Hirchorn and 63 | Amir Jevnisek and 64 | Shai Avidan}, 65 | title = {Optimize and Reduce: A Top-Down Approach for Image Vectorization}, 66 | booktitle = {{AAAI}}, 67 | publisher = {{AAAI} Press}, 68 | year = {2024} 69 | } -------------------------------------------------------------------------------- /custom_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | 4 | SUPPORTED_RECONSTRUCTION_LOSSES = ['mse', 'l1', 'pyramid-mse', 'pyramid-l1', 5 | 'clip', 'l1_and_clip_mix'] 6 | SUPPORTED_RANKING_LOSSES = SUPPORTED_RECONSTRUCTION_LOSSES + ['histogram'] 7 | SUPPORTED_GEOMETRIC_LOSSES = ['none', 'geometric', 8 | # 'xing' 9 | ] 10 | SUPPORTED_INIT_TYPES = ['random', 'custom'] 11 | SUPPORTED_INIT_SHAPES_TYPES = ['circle', 'random'] 12 | 13 | 14 | def parse_arguments(): 15 | parser = argparse.ArgumentParser("Basic DiffVG runner.") 16 | parser.add_argument("--target", help="target PNG image path", 17 | default=osp.join('target_images', 'tiger.png')) 18 | parser.add_argument("--input_svg", help="initial svg figure", 19 | default=None) 20 | parser.add_argument("--results_dir", help="root results directory", 21 | default='results') 22 | parser.add_argument("--num_paths", type=int, default=512, 23 | help='number of bezier curves in final image') 24 | parser.add_argument("--num_iter", nargs='*', action='store', 25 | default=[100, 100, 100, 100, 100, 500], 26 | help='number optimization iterations in every epoch') 27 | parser.add_argument("--num_epochs", type=int, default=1, 28 | help='number of epochs, such that totally we have: ' 29 | 'num_epochs * num_iter optimization steps') 30 | parser.add_argument("--recons_loss_type", type=str, default='l1', 31 | choices=SUPPORTED_RECONSTRUCTION_LOSSES) 32 | parser.add_argument("--geometric_loss_type", type=str, default='geometric', 33 | choices=SUPPORTED_GEOMETRIC_LOSSES) 34 | parser.add_argument("--geometric_loss_lamda_geometric_punish", type=float, default=10.0, 35 | help="The punishment for non-convex behaving shapes.") 36 | parser.add_argument("--lambda_geometric", type=float, default=0.01, help="Lambda geometric loss") 37 | parser.add_argument("--sample_beta", type=float, default=5e-3, help="Lambda geometric loss") 38 | parser.add_argument("--l1_and_clip_alpha", type=float, default=1.0, 39 | help="The extent to which we take L1 in the convex " 40 | "sum of L1&Clip losses.\n" 41 | "1.0 = only L1, 0.0 = only Clip") 42 | parser.add_argument("--clip_config_file", type=str, 43 | default='test/config_init.npy', 44 | help="The extent to which we take L1 in the convex " 45 | "sum of L1&Clip losses.\n" 46 | "1.0 = only L1, 0.0 = only Clip") 47 | parser.add_argument('--scheduler', nargs='*', action='store', 48 | default=[256, 128, 64, 32, 16, 8], 49 | help='num of shapes schedule in descending order.') 50 | parser.add_argument("--ranking_loss_type", type=str, default='l1', 51 | choices=SUPPORTED_RANKING_LOSSES) 52 | parser.add_argument("--ranking_l1_and_clip_alpha", type=float, default=1.0, 53 | help="The extent to which we take L1 in the convex " 54 | "sum of L1&Clip losses for RANKING.\n" 55 | "1.0 = only L1, 0.0 = only Clip") 56 | parser.add_argument("--ranking_clip_config_file", type=str, 57 | default='test/config_init.npy', 58 | help="clip loss configuration file for RANKING") 59 | parser.add_argument('--canvas_width', type=int, default=256) 60 | parser.add_argument('--canvas_height', type=int, default=256) 61 | parser.add_argument('--advanced_logging', action='store_true', 62 | help='generate advanced logging') 63 | parser.add_argument('--sample_importance', action='store_true', 64 | help='Sample shapes during reduce phase') 65 | parser.add_argument('--experiment_name', type=str, 66 | default='', 67 | help='specify experiment name, default is:' 68 | '{image_name}_{num_shapes}_{rec_loss}_{geom_loss}') 69 | parser.add_argument('--text_prompt', type=str, default=None, help='Clip text loss') 70 | parser.add_argument('--init_type', type=str, default='custom', 71 | choices=SUPPORTED_INIT_TYPES, 72 | help='Shapes initialization method') 73 | parser.add_argument('--init_shape', type=str, default='circle', 74 | choices=SUPPORTED_INIT_SHAPES_TYPES, 75 | help='Shapes initialization shape') 76 | parser.add_argument('--early_stopping', action='store_true', 77 | help='if set, we use early stopping.') 78 | arguments = parser.parse_args() 79 | 80 | assert len(arguments.scheduler) == len(arguments.num_iter) 81 | 82 | return arguments 83 | -------------------------------------------------------------------------------- /datasets/datasets_metadata/new_emojis_list: -------------------------------------------------------------------------------- 1 | 32/emoji_u1f609 2 | 512/emoji_u1f60c 3 | 32/emoji_u1f637 4 | 72/emoji_u1f619 5 | 512/emoji_u1f620 6 | 512/emoji_u1f630 7 | 128/emoji_u1f620 8 | 128/emoji_u1f913 9 | 32/emoji_u1f62f 10 | 72/emoji_u1f62b 11 | 32/emoji_u1f629 12 | 32/emoji_u1f627 13 | 32/emoji_u1f62a 14 | 32/emoji_u1f603 15 | 72/emoji_u1f603 16 | 72/emoji_u1f62a 17 | 72/emoji_u1f612 18 | 512/emoji_u1f605 19 | 32/emoji_u1f605 20 | 72/emoji_u1f970 21 | 72/emoji_u1f623 22 | 512/emoji_u1f61f 23 | 512/emoji_u1f606 24 | 128/emoji_u1f62e 25 | 128/emoji_u1f627 26 | 128/emoji_u1f912 27 | 32/emoji_u1f635 28 | 32/emoji_u1f643 29 | 512/emoji_u1f631 30 | 72/emoji_u1f632 31 | 128/emoji_u1f61c 32 | 128/emoji_u1f612 33 | 128/emoji_u1f60a 34 | 512/emoji_u1f912 35 | 72/emoji_u1f610 36 | 128/emoji_u1f613 37 | 32/emoji_u1f60f 38 | 32/emoji_u1f61f 39 | 32/emoji_u1f917 40 | 72/emoji_u1f625 41 | 72/emoji_u1f616 42 | 72/emoji_u1f975 43 | 128/emoji_u1f925 44 | 128/emoji_u1f917 45 | 512/emoji_u1f611 46 | 32/emoji_u263a 47 | 128/emoji_u1f632 48 | 512/emoji_u1f641 49 | 128/emoji_u1f62b 50 | 128/emoji_u1f644 51 | 512/emoji_u1f61d 52 | 512/emoji_u1f642 53 | 512/emoji_u1f602 54 | 32/emoji_u1f621 55 | 512/emoji_u1f615 56 | 128/emoji_u1f630 57 | 128/emoji_u1f607 58 | 32/emoji_u1f620 59 | 32/emoji_u1f607 60 | 32/emoji_u1f613 61 | 72/emoji_u1f60b 62 | 128/emoji_u1f636_200d_1f32b 63 | 32/emoji_u1f611 64 | 128/emoji_u1f61f 65 | 32/emoji_u1f911 66 | 32/emoji_u1f618 67 | 32/emoji_u1f62e 68 | 32/emoji_u1f600 69 | 128/emoji_u1f606 70 | 32/emoji_u1f622 71 | 128/emoji_u1f634 72 | 72/emoji_u1f62f 73 | 72/emoji_u1f617 74 | 512/emoji_u1fae2 75 | 128/emoji_u1f61a 76 | 72/emoji_u1f62c 77 | 72/emoji_u1f609 78 | 72/emoji_u1f605 79 | 72/emoji_u1f971 80 | 512/emoji_u1f637 81 | 512/emoji_u1f913 82 | 128/emoji_u1f619 83 | 32/emoji_u1f633 84 | 32/emoji_u1fae8 85 | 512/emoji_u1f914 86 | 128/emoji_u1f636 87 | 72/emoji_u1f60f 88 | 72/emoji_u1f600 89 | 72/emoji_u1f635_200d_1f4ab 90 | 32/emoji_u1f913 91 | 72/emoji_u1f642 92 | 512/emoji_u1f973 93 | 128/emoji_u1f605 94 | 72/emoji_u1f613 95 | 512/emoji_u1f970 96 | 128/emoji_u1f629 97 | 72/emoji_u1f620 98 | 72/emoji_u1f611 99 | 512/emoji_u1f627 100 | 32/emoji_u1f636_200d_1f32b 101 | 32/emoji_u1f631 102 | 32/emoji_u1fae0 103 | 512/emoji_u1f618 104 | 512/emoji_u1f62d 105 | 512/emoji_u1f634 106 | 512/emoji_u1f636 107 | 512/emoji_u1f62f 108 | 32/emoji_u1f610 109 | 72/emoji_u1fae0 110 | 512/emoji_u1f62b 111 | 512/emoji_u1f910 112 | 72/emoji_u1f60c 113 | 512/emoji_u1f600 114 | 32/emoji_u1f62c 115 | 512/emoji_u1f60d 116 | 512/emoji_u1f628 117 | 32/emoji_u1f614 118 | 128/emoji_u1f910 119 | 32/emoji_u1f628 120 | 32/emoji_u1f616 121 | 32/emoji_u1f60b 122 | 128/emoji_u1f626 123 | 72/emoji_u1f915 124 | 512/emoji_u1f915 125 | 512/emoji_u1f925 126 | 128/emoji_u1f642 127 | 72/emoji_u1f634 128 | 72/emoji_u1f644 129 | 72/emoji_u1f636 130 | 128/emoji_u1f974 131 | 72/emoji_u1f622 132 | 128/emoji_u1f621 133 | 128/emoji_u1f975 134 | 32/emoji_u1f972 135 | 72/emoji_u1f637 136 | 32/emoji_u1f61a 137 | 512/emoji_u1fae0 138 | 72/emoji_u1f629 139 | 32/emoji_u1f974 140 | 512/emoji_u1f632 141 | 512/emoji_u1f644 142 | 128/emoji_u1f631 143 | 512/emoji_u1f61c 144 | 512/emoji_u1f60e 145 | 512/emoji_u1f60f 146 | 72/emoji_u1f633 147 | 512/emoji_u1f923 148 | 128/emoji_u1f633 149 | 128/emoji_u1f617 150 | 72/emoji_u1f604 151 | -------------------------------------------------------------------------------- /datasets/datasets_metadata/old_emojis_list: -------------------------------------------------------------------------------- 1 | emoji_u1f6af 2 | emoji_u1f3a7 3 | emoji_u1f5ff 4 | emoji_u1f631 5 | emoji_u1f6a4 6 | emoji_u1f6ae 7 | emoji_u1f627 8 | emoji_u1f6a2 9 | emoji_u1f478 10 | emoji_u1f474 11 | emoji_u1f618 12 | emoji_u1f637 13 | emoji_u1f6b0 14 | emoji_u1f1ec_1f1e7 15 | emoji_u1f30e 16 | emoji_u1f4fc 17 | emoji_u1f622 18 | emoji_u1f31e 19 | emoji_u1f3a8 20 | emoji_u1f6a8 21 | emoji_u1f602 22 | emoji_u1f1ea_1f1f8 23 | emoji_u1f473 24 | emoji_u1f612 25 | emoji_u1f6a3 26 | emoji_u1f605 27 | emoji_u1f6a0 28 | emoji_u1f6b2 29 | emoji_u1f4fb 30 | emoji_u1f6a6 31 | emoji_u1f639 32 | emoji_u1f6b1 33 | emoji_u1f635 34 | emoji_u1f472 35 | emoji_u1f3a4 36 | emoji_u1f640 37 | emoji_u1f33c 38 | emoji_u1f603 39 | emoji_u1f479 40 | emoji_u1f611 41 | emoji_u1f33a 42 | emoji_u1f6ab 43 | emoji_u1f31b 44 | emoji_u1f620 45 | emoji_u1f604 46 | emoji_u1f1ee_1f1f9 47 | emoji_u1f609 48 | emoji_u1f31f 49 | emoji_u1f636 50 | emoji_u1f632 51 | emoji_u1f614 52 | emoji_u1f621 53 | emoji_u1f5fe 54 | emoji_u1f638 55 | emoji_u1f476 56 | emoji_u1f467 57 | emoji_u1f30f 58 | emoji_u1f628 59 | emoji_u1f33d 60 | emoji_u1f31a 61 | emoji_u1f634 62 | emoji_u1f6aa 63 | emoji_u1f633 64 | emoji_u1f648 65 | emoji_u1f3a1 66 | emoji_u1f6a9 67 | emoji_u1f33e 68 | emoji_u1f624 69 | emoji_u1f645 70 | emoji_u1f626 71 | emoji_u1f475 72 | emoji_u1f3a3 73 | emoji_u1f470 74 | emoji_u1f623 75 | emoji_u1f6ac 76 | emoji_u1f30c 77 | emoji_u1f601 78 | emoji_u1f468 79 | emoji_u1f617 80 | emoji_u1f3a5 81 | emoji_u1f1ef_1f1f5 82 | emoji_u1f33f 83 | emoji_u1f469 84 | emoji_u1f608 85 | emoji_u1f471 86 | emoji_u1f616 87 | emoji_u1f649 88 | emoji_u1f3a2 89 | emoji_u1f5fd 90 | emoji_u1f629 91 | emoji_u1f646 92 | emoji_u1f477 93 | emoji_u1f6a7 94 | emoji_u1f1eb_1f1f7 95 | emoji_u1f466 96 | emoji_u1f31d 97 | emoji_u1f34b 98 | emoji_u1f30d 99 | emoji_u1f610 100 | emoji_u1f34a 101 | emoji_u1f31c 102 | emoji_u1f5fb 103 | emoji_u1f6a1 104 | emoji_u1f606 105 | emoji_u1f6ad 106 | emoji_u1f647 107 | emoji_u1f30b 108 | emoji_u1f5fc 109 | emoji_u1f6b3 110 | emoji_u1f625 111 | emoji_u1f615 112 | emoji_u1f600 113 | emoji_u1f6a5 114 | emoji_u1f3a0 115 | emoji_u1f33b 116 | emoji_u1f619 117 | emoji_u1f613 118 | emoji_u1f3a6 119 | emoji_u1f630 120 | emoji_u1f607 -------------------------------------------------------------------------------- /datasets/from-free-svg.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajevnisek/optimize-and-reduce/f621b8eac4830d213d889f9c78d1bca1e47b0e7c/datasets/from-free-svg.zip -------------------------------------------------------------------------------- /datasets/nft-apes.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajevnisek/optimize-and-reduce/f621b8eac4830d213d889f9c78d1bca1e47b0e7c/datasets/nft-apes.zip -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM yaelvinker/clipasso_docker 2 | RUN echo 'alias ll="ls -l"' >> ~/.bashrc 3 | RUN mkdir /home/code 4 | WORKDIR /home/code 5 | RUN pip install cairosvg 6 | RUN apt-get update -y 7 | RUN apt-get install python3-cffi python3-brotli libpango-1.0-0 libharfbuzz0b libpangoft2-1.0-0 libgtk-3-dev gcc -y 8 | RUN pip3 install -U scikit-learn scipy matplotlib 9 | RUN pip install ipdb 10 | RUN pip install webp 11 | RUN pip install kornia==0.5.0 12 | RUN pip install opencv-python==4.5.4.60 # avoid connected components segfault -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # How to run it? 2 | ```code 3 | docker build -t optimize_and_reduce_aaai . 4 | cd .. 5 | docker run -v $(pwd):/home/code -it optimize_and_reduce_aaai /bin/bash 6 | ``` 7 | 8 | -------------------------------------------------------------------------------- /figures/005-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajevnisek/optimize-and-reduce/f621b8eac4830d213d889f9c78d1bca1e47b0e7c/figures/005-1.png -------------------------------------------------------------------------------- /figures/005.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajevnisek/optimize-and-reduce/f621b8eac4830d213d889f9c78d1bca1e47b0e7c/figures/005.pdf -------------------------------------------------------------------------------- /geometric_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import pydiffvg 5 | 6 | 7 | class GeometryLoss: 8 | def __init__(self, device='cpu', pathObj=None, xyalign=False, parallel=False, smooth_node=False, 9 | lamda_geometric_punish=10): 10 | self.orientation = 0.01 11 | self.device = device 12 | if pathObj is not None: 13 | self.pathObj = pathObj 14 | self.pathId = pathObj.id 15 | self.get_segments(pathObj) 16 | if xyalign: 17 | self.make_hor_ver_constraints(pathObj) 18 | if parallel: 19 | self.make_parallel_constraints(pathObj) 20 | if smooth_node: 21 | self.make_smoothness_constraints(pathObj) 22 | self.xyalign = xyalign 23 | self.parallel = parallel 24 | self.smooth_node = smooth_node 25 | 26 | self.lamda_geometric_punish = lamda_geometric_punish 27 | self.smooth_nodes = [] 28 | 29 | def make_smoothness_constraints(self, pathObj): 30 | 31 | for idx, node in enumerate(self.iterate_nodes()): 32 | sm, t0, t1 = self.node_smoothness(node, pathObj) 33 | self.smooth_nodes.append((node, ((t0.norm() / self.segment_approx_length(node[0], pathObj)).item(), 34 | (t1.norm() / self.segment_approx_length(node[1], pathObj)).item()))) 35 | # if abs(sm) < 1e-2: 36 | # self.smooth_nodes.append((node, ((t0.norm() / self.segment_approx_length(node[0], pathObj)).item(), 37 | # (t1.norm() / self.segment_approx_length(node[1], pathObj)).item()))) 38 | # # print("Node {} is smooth (smoothness {})".format(idx,sm)) 39 | # else: 40 | # pass 41 | 42 | def node_smoothness(self, node, pathObj): 43 | t0 = self.tangent_out(node[0], pathObj) 44 | t1 = self.tangent_in(node[1], pathObj) 45 | t1rot = torch.stack((-t1[1], t1[0])) 46 | smoothness = t0.dot(t1rot) / (t0.norm() * t1.norm()) 47 | 48 | return smoothness, t0, t1 49 | 50 | def segment_approx_length(self, segment, pathObj): 51 | if segment[0] == 0: 52 | # line 53 | idxs = self.segList[segment[0]][segment[1]] 54 | # should have a pair of indices now 55 | length = (pathObj.points[idxs[1], :] - pathObj.points[idxs[0], :]).norm() 56 | return length 57 | elif segment[0] == 1: 58 | # quadric 59 | idxs = self.segList[segment[0]][segment[1]] 60 | # should have a pair of indices now 61 | length = (pathObj.points[idxs[1], :] - pathObj.points[idxs[0], :]).norm() + ( 62 | pathObj.points[idxs[2], :] - pathObj.points[idxs[1], :]).norm() 63 | return length 64 | elif segment[0] == 2: 65 | # cubic 66 | idxs = self.segList[segment[0]][segment[1]] 67 | # should have a pair of indices now 68 | length = (pathObj.points[idxs[1], :] - pathObj.points[idxs[0], :]).norm() + ( 69 | pathObj.points[idxs[2], :] - pathObj.points[idxs[1], :]).norm() + ( 70 | pathObj.points[idxs[3], :] - pathObj.points[idxs[2], :]).norm() 71 | return length 72 | 73 | def tangent_in(self, segment, pathObj): 74 | if segment[0] == 0: 75 | # line 76 | idxs = self.segList[segment[0]][segment[1]] 77 | # should have a pair of indices now 78 | tangent = (pathObj.points[idxs[1], :] - pathObj.points[idxs[0], :]) / 2 79 | return tangent 80 | elif segment[0] == 1: 81 | # quadric 82 | idxs = self.segList[segment[0]][segment[1]] 83 | # should have a pair of indices now 84 | tangent = (pathObj.points[idxs[1], :] - pathObj.points[idxs[0], :]) 85 | return tangent 86 | elif segment[0] == 2: 87 | # cubic 88 | idxs = self.segList[segment[0]][segment[1]] 89 | # should have a pair of indices now 90 | tangent = (pathObj.points[idxs[1], :] - pathObj.points[idxs[0], :]) 91 | return tangent 92 | 93 | assert (False) 94 | 95 | def tangent_out(self, segment, pathObj): 96 | if segment[0] == 0: 97 | # line 98 | idxs = self.segList[segment[0]][segment[1]] 99 | # should have a pair of indices now 100 | tangent = (pathObj.points[idxs[0], :] - pathObj.points[idxs[1], :]) / 2 101 | return tangent 102 | elif segment[0] == 1: 103 | # quadric 104 | idxs = self.segList[segment[0]][segment[1]] 105 | # should have a pair of indices now 106 | tangent = (pathObj.points[idxs[1], :] - pathObj.points[idxs[2], :]) 107 | return tangent 108 | elif segment[0] == 2: 109 | # cubic 110 | idxs = self.segList[segment[0]][segment[1]] 111 | # should have a pair of indices now 112 | tangent = (pathObj.points[idxs[2], :] - pathObj.points[idxs[3], :]) 113 | return tangent 114 | 115 | assert False 116 | 117 | def get_segments(self, pathObj): 118 | self.segments = [] 119 | self.lines = [] 120 | self.quadrics = [] 121 | self.cubics = [] 122 | self.segList = (self.lines, self.quadrics, self.cubics) 123 | idx = 0 124 | total_points = pathObj.points.shape[0] 125 | for ncp in pathObj.num_control_points.numpy(): 126 | if ncp == 0: 127 | self.segments.append((0, len(self.lines))) 128 | self.lines.append((idx, (idx + 1) % total_points)) 129 | idx += 1 130 | elif ncp == 1: 131 | self.segments.append((1, len(self.quadrics))) 132 | self.quadrics.append((idx, (idx + 1), (idx + 2) % total_points)) 133 | idx += ncp + 1 134 | elif ncp == 2: 135 | self.segments.append((2, len(self.cubics))) 136 | self.cubics.append((idx, (idx + 1), (idx + 2), (idx + 3) % total_points)) 137 | idx += ncp + 1 138 | 139 | def iterate_nodes(self): 140 | for prev, next in zip([self.segments[-1]] + self.segments[:-1], self.segments): 141 | yield (prev, next) 142 | 143 | def make_hor_ver_constraints(self, pathObj): 144 | self.horizontals = [] 145 | self.verticals = [] 146 | for idx, line in enumerate(self.lines): 147 | startPt = pathObj.points[line[0], :] 148 | endPt = pathObj.points[line[1], :] 149 | 150 | dif = endPt - startPt 151 | 152 | if abs(dif[0]) < 1e-6: 153 | # is horizontal 154 | self.horizontals.append(idx) 155 | 156 | if abs(dif[1]) < 1e-6: 157 | # is vertical 158 | self.verticals.append(idx) 159 | 160 | def make_parallel_constraints(self, pathObj): 161 | slopes = [] 162 | for lidx, line in enumerate(self.lines): 163 | startPt = pathObj.points[line[0], :] 164 | endPt = pathObj.points[line[1], :] 165 | 166 | dif = endPt - startPt 167 | 168 | slope = math.atan2(dif[1], dif[0]) 169 | if slope < 0: 170 | slope += math.pi 171 | 172 | minidx = -1 173 | for idx, s in enumerate(slopes): 174 | if abs(s[0] - slope) < 1e-3: 175 | minidx = idx 176 | break 177 | 178 | if minidx >= 0: 179 | slopes[minidx][1].append(lidx) 180 | else: 181 | slopes.append((slope, [lidx])) 182 | 183 | self.parallel_groups = [sgroup[1] for sgroup in slopes if len(sgroup[1]) > 1 and ( 184 | not self.xyalign or (sgroup[0] > 1e-3 and abs(sgroup[0] - (math.pi / 2)) > 1e-3))] 185 | 186 | def make_line_diff(self, pathObj, lidx): 187 | line = self.lines[lidx] 188 | startPt = pathObj.points[line[0], :] 189 | endPt = pathObj.points[line[1], :] 190 | 191 | dif = endPt - startPt 192 | return dif 193 | 194 | def calc_hor_ver_loss(self, loss, pathObj): 195 | for lidx in self.horizontals: 196 | dif = self.make_line_diff(pathObj, lidx) 197 | loss += dif[0].pow(2) 198 | 199 | for lidx in self.verticals: 200 | dif = self.make_line_diff(pathObj, lidx) 201 | loss += dif[1].pow(2) 202 | return loss 203 | 204 | def calc_parallel_loss(self, loss, pathObj): 205 | for group in self.parallel_groups: 206 | diffs = [self.make_line_diff(pathObj, lidx) for lidx in group] 207 | difmat = torch.stack(diffs, 1) 208 | lengths = difmat.pow(2).sum(dim=0).sqrt() 209 | difmat = difmat / lengths 210 | difmat = torch.cat((difmat, torch.zeros(1, difmat.shape[1]))) 211 | rotmat = difmat[:, list(range(1, difmat.shape[1])) + [0]] 212 | cross = difmat.cross(rotmat) 213 | ploss = cross.pow(2).sum() * lengths.sum() * 10 214 | loss += ploss 215 | return loss 216 | 217 | def calc_smoothness_loss(self, loss, pathObj): 218 | for node, tlengths in self.smooth_nodes: 219 | sl, t0, t1 = self.node_smoothness(node, pathObj) 220 | # add smoothness loss 221 | loss += sl.pow(2) * t0.norm().sqrt() * t1.norm().sqrt() 222 | tl = ((t0.norm() / self.segment_approx_length(node[0], pathObj)) - tlengths[0]).pow(2) + ( 223 | (t1.norm() / self.segment_approx_length(node[1], pathObj)) - tlengths[1]).pow(2) 224 | loss += tl * 10 225 | return loss 226 | 227 | def compute(self, pathObjs): 228 | loss = torch.tensor(0., device=self.device) 229 | # For Straight Lines 230 | if self.xyalign: 231 | for pathObj in pathObjs: 232 | loss += self.calc_hor_ver_loss(loss, pathObj) 233 | if self.parallel: 234 | for pathObj in pathObjs: 235 | loss += self.calc_parallel_loss(loss, pathObj) 236 | 237 | # Smoothness 238 | curves = torch.stack([pathObj.points.to(self.device).view(-1, 2) for pathObj in pathObjs]) 239 | total_loss = self.control_geometric_loss(curves) 240 | loss += total_loss.mean() 241 | 242 | return loss 243 | 244 | def control_geometric_loss(self, curves, temperature=10): # x[npoints,2] 245 | # segments - Shapes x Segments x 4(Quadratic Bezier) x 2 246 | segments = curves.unfold(1, 4, 3).permute(0, 1, 3, 2) 247 | A = segments[..., 0, :] 248 | B = segments[..., 1, :] 249 | C = segments[..., 2, :] 250 | D = segments[..., 3, :] 251 | # Whether AB intersects CD 252 | intersect, orient = doIntersect(A, B, C, D) 253 | AB = (A - B) / (torch.norm((A - B), dim=-1)[..., None]) 254 | BC = (B - C) / (torch.norm((B - C), dim=-1)[..., None]) 255 | CD = (C - D) / (torch.norm((C - D), dim=-1)[..., None]) 256 | dot_product = dot(AB, CD) 257 | cross_product = cross(AB, CD) 258 | signed_angle = torch.atan2(cross_product, dot_product) 259 | 260 | # If intersect - Small angles are preferred, lamda self loop to prevent this intersection 261 | # If not - Large angles are preferred 262 | intersect_loss = intersect * (self.lamda_geometric_punish-1 * torch.cos(signed_angle)) 263 | # + (1 - intersect) * -1 * torch.abs(torch.cos(dot_product)) 264 | # Require ABC and BCD to have the same orientation - 265 | orient_loss = self.orientation * orient * self.lamda_geometric_punish 266 | angle_loss = torch.relu(-1*dot(AB, BC)) + torch.relu(-1*dot(BC, CD)) 267 | # orient_loss = (1 - orient) * self.lamda_geometric_punish + orient * (torch.relu(cross_product)) 268 | 269 | total_loss = intersect_loss + orient_loss + angle_loss 270 | 271 | return total_loss 272 | 273 | def create_shape(self, pathObj, alpha=0.05): 274 | path = pathObj 275 | path.is_closed = False 276 | path.stroke_width = torch.tensor([1]) 277 | shapes = [pathObj] 278 | groups = [pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]), 279 | fill_color=torch.tensor([0, 0, 0, 0]), 280 | stroke_color=torch.tensor([0, 0, 0, alpha]))] 281 | shapes.append(pydiffvg.Path(num_control_points=torch.tensor([0]), 282 | points=torch.stack([path.points[0], path.points[-1]]), 283 | is_closed=False, stroke_width=torch.tensor([1]))) 284 | groups.append(pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]), 285 | fill_color=torch.tensor([0, 0, 0, 0.1]), 286 | stroke_color=torch.tensor([0, 0, 0, alpha]))) 287 | scene_args = pydiffvg.RenderFunction.serialize_scene(256, 256, shapes, groups) 288 | render = pydiffvg.RenderFunction.apply 289 | img = render(256, 256, 2, 2, 0, None, *scene_args) 290 | alpha_img = torch.clamp(img[..., -1] - alpha, 0) / alpha 291 | return alpha_img.sum() 292 | 293 | def create_ploygons(self, segments, alpha=0.05): 294 | shapes = [] 295 | groups = [] 296 | for s in range(segments.shape[0]): 297 | pts = segments[s] 298 | inner_sorting = torch.argsort(pts[:, 1]) 299 | pts_inner_sorted = pts[inner_sorting] 300 | outer_sorting = torch.argsort(pts_inner_sorted[:, 0], stable=True) 301 | sorted_pts = pts_inner_sorted[outer_sorting] 302 | shapes.append(pydiffvg.Polygon( 303 | torch.stack([sorted_pts[0], sorted_pts[1], sorted_pts[3], sorted_pts[2]]), 304 | is_closed=True)) 305 | groups.append(pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]), 306 | fill_color=torch.tensor([0, 0, 0, alpha]))) 307 | shapes.append(pydiffvg.Path(num_control_points=torch.tensor([0]), 308 | points=torch.stack([segments[0][0], segments[-1][-1]]), 309 | is_closed=False, stroke_width=torch.tensor([1]))) 310 | groups.append(pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(shapes) - 1]), 311 | fill_color=torch.tensor([0, 0, 0, 0.1]), 312 | stroke_color=torch.tensor([0, 0, 0, alpha]))) 313 | scene_args = pydiffvg.RenderFunction.serialize_scene(256, 256, shapes, groups) 314 | render = pydiffvg.RenderFunction.apply 315 | img = render(256, 256, 2, 2, 0, None, *scene_args) 316 | alpha_img = torch.clamp(img[..., -1] - alpha, 0) / alpha 317 | return alpha_img.sum() 318 | 319 | def __call__(self, shapes): 320 | return self.compute(shapes) 321 | 322 | 323 | def orientation(p, q, r): 324 | # to find the orientation of an ordered triplet (p,q,r) 325 | # function returns the following values: 326 | # 0 : Collinear points 327 | # 1 : Clockwise points 328 | # 2 : Counterclockwise 329 | # See https://www.geeksforgeeks.org/orientation-3-ordered-points/amp/ for details of below formula. 330 | val = (q[..., 1] - p[..., 1]) * (r[..., 0] - q[..., 0]) - \ 331 | (q[..., 0] - p[..., 0]) * (r[..., 1] - q[..., 1]) 332 | return torch.tanh(val) 333 | 334 | 335 | def xor_gate(a, b): 336 | return a + b - (2 * a * b) 337 | 338 | 339 | def or_gate(a, b): 340 | return a + b - (a * b) 341 | 342 | 343 | def and_gate(a, b): 344 | return a * b 345 | 346 | 347 | # The main function that returns true if 348 | # the line segment 'p1,q1' and 'p2,q2' intersect. 349 | def doIntersect(p1, q1, p2, q2): 350 | # Find the 4 orientations required for the general and special cases 351 | # 0 : Clockwise points 352 | # 1 : Counterclockwise 353 | o1 = (1 + orientation(p1, q1, p2)) * 0.5 354 | o2 = (1 + orientation(p1, q1, q2)) * 0.5 355 | o3 = (1 + orientation(p2, q2, p1)) * 0.5 356 | o4 = (1 + orientation(p2, q2, q1)) * 0.5 357 | o5 = (1 + orientation(q1, p2, q2)) * 0.5 358 | return ( 359 | and_gate(xor_gate(o1, o2), xor_gate(o3, o4)).squeeze(-1), 360 | and_gate(o1, o5).squeeze(-1) 361 | ) 362 | 363 | 364 | def dot(v1, v2): 365 | return v1[..., 0] * v2[..., 0] + v1[..., 1] * v2[..., 1] 366 | 367 | 368 | def cross(v1, v2): 369 | return v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0] 370 | 371 | if __name__ == "__main__": 372 | loss = GeometryLoss() 373 | _, _, shapes, _ = pydiffvg.svg_to_scene("results_geometric/geo_0.1/eye.svg") 374 | loss1 = loss.compute(shapes) 375 | # _, _, shapes, _ = pydiffvg.svg_to_scene("test_xing/xing2.svg") 376 | # loss2 = loss.compute(shapes) 377 | _, _, shapes, _ = pydiffvg.svg_to_scene("results_geometric/geo_0.1/good_eye.svg") 378 | loss3 = loss.compute(shapes) 379 | # _, _, shapes, _ = pydiffvg.svg_to_scene("test_xing/no_xing2.svg") 380 | # loss4 = loss.compute(shapes) 381 | print(loss1, "loss2", loss3, "loss4") -------------------------------------------------------------------------------- /histogram_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import skimage 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from enum import Enum 7 | from collections import Counter 8 | 9 | from PIL import Image 10 | from scipy.ndimage import median_filter 11 | from sklearn.metrics import pairwise_distances_argmin 12 | from sklearn.cluster import DBSCAN, SpectralClustering, KMeans 13 | 14 | 15 | def get_num_clusters(clustering_obj): 16 | return len(np.unique(clustering_obj.labels_)) 17 | 18 | 19 | def count_clusterless_pixels(labels_img): 20 | return (labels_img == -1).sum() 21 | 22 | 23 | def bhattacharyya_distance(p,q): 24 | bc = np.sum(np.sqrt(p * q)) 25 | return -np.log(bc) 26 | 27 | 28 | class ClusteringMethod(Enum): 29 | DBSCAN = 1 30 | KMEANS = 2 31 | SPECTRAL_CLUSTERING = 3 32 | 33 | 34 | class NonDifferentiableHistogramLoss: 35 | def __init__(self, base_image_tensor: torch.Tensor, 36 | method: ClusteringMethod = ClusteringMethod.DBSCAN): 37 | self.base_image_tensor = base_image_tensor 38 | self.base_labels_image, self.base_hist, self.base_cluster_centers = \ 39 | self.base_torch_image_to_clusters(base_image_tensor, method) 40 | 41 | 42 | @staticmethod 43 | def base_torch_image_to_clusters(normalized_tensor_image: torch.tensor, 44 | method: ClusteringMethod = 45 | ClusteringMethod.DBSCAN): 46 | resized_tensor = torch.nn.functional.interpolate( 47 | normalized_tensor_image.permute(2, 0, 1).unsqueeze(0), 48 | (100, 100)).squeeze(0).permute(1, 2, 0) 49 | numpy_image = resized_tensor.cpu().numpy() * 255.0 50 | # get sizes: 51 | (h, w) = numpy_image.shape[:-1] 52 | # prepare data: 53 | X = numpy_image[..., :3].reshape(h * w, -1) 54 | # cluster 55 | clustering_method = { 56 | ClusteringMethod.DBSCAN: DBSCAN(eps=5, min_samples=20, ), 57 | ClusteringMethod.KMEANS: KMeans(n_clusters=20), 58 | ClusteringMethod.SPECTRAL_CLUSTERING: SpectralClustering( 59 | n_clusters=20), 60 | } 61 | clustering = clustering_method[method].fit(X) 62 | # extract cluster centers: 63 | labels = clustering.labels_ 64 | labels_image = labels.reshape(h, w) 65 | sorted_labels = np.unique(labels_image) 66 | sorted_labels.sort() 67 | 68 | cluster_centers = [] 69 | for label in sorted_labels: 70 | indices = (clustering.labels_ == label) 71 | cluster_centers.append( 72 | np.array([ 73 | [X[indices, 0].mean(), X[indices, 1].mean(), 74 | X[indices, 2].mean()] 75 | ]) 76 | ) 77 | 78 | if -1 in sorted_labels: 79 | cluster_centers = cluster_centers[1:] + [cluster_centers[0]] 80 | 81 | cluster_centers = np.concatenate(cluster_centers, axis=0) 82 | 83 | # handle unclustered labels 84 | if -1 in sorted_labels: 85 | kernel_size = 3 86 | 87 | while count_clusterless_pixels(labels_image) > 0: 88 | labels_filtered = median_filter(labels_image, size=kernel_size) 89 | new_labels_image = labels_image.copy() 90 | new_labels_image[labels_image == -1] = labels_filtered[ 91 | labels_image == -1] 92 | if count_clusterless_pixels(new_labels_image) >= \ 93 | count_clusterless_pixels(labels_image): 94 | assert False, "Does not converge" 95 | labels_image = new_labels_image 96 | kernel_size += 2 97 | 98 | # create histogram: 99 | hist = Counter(labels_image.reshape(h * w)) 100 | if -1 in sorted_labels: 101 | cluster_centers_to_return = cluster_centers[:-1] / 255.0 102 | else: 103 | cluster_centers_to_return = cluster_centers / 255.0 104 | return labels_image, hist, cluster_centers_to_return 105 | 106 | @staticmethod 107 | def show_histogram_of_cluster_centers(hist, cluster_centers, path=''): 108 | plt.clf() 109 | plt.bar(hist.keys(), hist.values(), color=cluster_centers) 110 | x_legend = ['(' + ', '.join([f'{int(x * 255)}' for x in list(cc)]) + ')' 111 | for cc in cluster_centers] 112 | plt.xticks(np.arange(len(hist.keys())), x_legend, rotation=300) 113 | plt.title( 114 | 'how many pixels per cluster center?\n count vs cluster center \n' 115 | f'total clusters: {len(hist.keys())}') 116 | plt.ylabel('count') 117 | plt.xlabel('cluster center') 118 | plt.yscale('log') 119 | plt.tight_layout() 120 | if path == '': 121 | plt.show() 122 | else: 123 | plt.savefig(path) 124 | 125 | def new_image_histogram(self, new_image: torch.Tensor): 126 | """ 127 | 128 | """ 129 | h, w = new_image.shape[0], new_image.shape[1] 130 | X = new_image.reshape(h * w, -1).cpu() 131 | 132 | L = pairwise_distances_argmin(X[..., :3], self.base_cluster_centers) 133 | # L = L.reshape(h, w) 134 | # colored_labels_new_image = cluster_centers[L] 135 | # plt.imshow(colored_labels_new_image);plt.show() 136 | # L = L.reshape(h * w) 137 | return Counter(L) 138 | 139 | def align_histogram_of_non_base(self, new_image_histogram: torch.Tensor): 140 | # put zeros in for cluster centers which do not exist in the new 141 | # image histogram: 142 | for cluster_center_index in self.base_hist: 143 | if cluster_center_index not in new_image_histogram: 144 | new_image_histogram[cluster_center_index] = 0 145 | return new_image_histogram 146 | 147 | def rank(self, new_image: torch.Tensor): 148 | new_image_histogram = self.new_image_histogram(new_image) 149 | aligned_new_image_histogram = self.align_histogram_of_non_base( 150 | new_image_histogram) 151 | 152 | # convert histogram to normalized histograms 153 | q = np.array([self.base_hist[k] for k in sorted(self.base_hist.keys())]) 154 | q = q / sum(q) 155 | p = np.array([aligned_new_image_histogram[k] 156 | for k in sorted(aligned_new_image_histogram.keys())]) 157 | p = p / sum(p) 158 | return bhattacharyya_distance(p, q) 159 | 160 | def __call__(self, base_image, other_image): 161 | return self.rank(other_image) 162 | 163 | 164 | def read_image_from_path_to_normalized_tensor(path): 165 | numpy_image = skimage.io.imread(path) 166 | normalized_tensor_image = torch.from_numpy(numpy_image).to( 167 | torch.float32) / 255.0 168 | return normalized_tensor_image 169 | 170 | 171 | def main(): 172 | path = 'target_images/apes/23126082.png' 173 | # path = 'target_images/tiger.png' 174 | normalized_tensor_image = read_image_from_path_to_normalized_tensor(path) 175 | hist_loss = NonDifferentiableHistogramLoss(normalized_tensor_image, 176 | method=ClusteringMethod.DBSCAN) 177 | 178 | colored_labels = hist_loss.base_cluster_centers[hist_loss.base_labels_image 179 | ] * 255.0 180 | plt.subplot(1, 2, 1) 181 | plt.imshow(normalized_tensor_image) 182 | plt.title('image') 183 | plt.subplot(1, 2, 2) 184 | plt.imshow(colored_labels.astype(np.uint8)) 185 | plt.title('labels') 186 | plt.show() 187 | hist_loss.show_histogram_of_cluster_centers(hist_loss.base_hist, 188 | hist_loss.base_cluster_centers) 189 | images_to_path = { 190 | 'some noisy tiger': 'results/diffvg_tiger_l1_and_clip_mix_alpha_0/' 191 | 'result.png', 192 | 'an ape': 'target_images/apes/23126082.png'} 193 | import ipdb; ipdb.set_trace() 194 | for description in images_to_path: 195 | new_image_path = images_to_path[description] 196 | new_normalized_tensor_image = read_image_from_path_to_normalized_tensor( 197 | new_image_path) 198 | new_histogram = hist_loss.new_image_histogram( 199 | new_normalized_tensor_image) 200 | aligned_new_histogram = hist_loss.align_histogram_of_non_base( 201 | new_histogram) 202 | hist_loss.show_histogram_of_cluster_centers( 203 | aligned_new_histogram, hist_loss.base_cluster_centers) 204 | distance = hist_loss.rank(new_normalized_tensor_image) 205 | print(f"Bhattacharyya distance to {description}: " 206 | f" {distance:.2f}") 207 | 208 | 209 | if __name__ == "__main__": 210 | main() 211 | -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajevnisek/optimize-and-reduce/f621b8eac4830d213d889f9c78d1bca1e47b0e7c/models/.DS_Store -------------------------------------------------------------------------------- /models/decomp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from PIL import Image 6 | from collections import Counter 7 | from sklearn.cluster import KMeans 8 | import torch 9 | import random as rng 10 | import numpy.random as npr 11 | import copy 12 | from collections import OrderedDict 13 | from scipy.ndimage.filters import gaussian_filter 14 | import networkx as nx 15 | from tqdm import tqdm 16 | 17 | 18 | def maskLoss(recon, trgt, masks): 19 | loss = 0 20 | for i, mask in enumerate(masks): 21 | h, w = mask.shape 22 | masked_recon = recon[i][3] * recon[i][:3] + (1 - recon[i][3]) * torch.ones_like(recon[i][:3]) 23 | masked_trgt = mask * trgt + (1 - mask) * torch.ones_like(trgt) 24 | loss += ((masked_recon - masked_trgt) ** 2).mean() / mask.sum() * h * w 25 | return loss 26 | 27 | 28 | def create_masks(img, fg_mask, num_shapes=4, num_cluster_centers=20, ignore_clusters_smaller_than=0.01, sigma_const=0.001, device='cpu', output_dir=''): 29 | img = np.transpose(img.squeeze().cpu().numpy(), (1, 2, 0)) 30 | H, W, C = img.shape 31 | fg_mask = cv2.resize(fg_mask, dsize=(H, W), interpolation=cv2.INTER_CUBIC) 32 | X = img.reshape(H * W, C) 33 | cluster = KMeans(num_cluster_centers).fit(X) 34 | color_clusters = cluster.labels_.reshape(H, W) 35 | cluster_indices = np.unique(color_clusters) 36 | relevant_cluster_indices = [ 37 | c for c in cluster_indices 38 | if (color_clusters == c).sum() > H * W * ignore_clusters_smaller_than] 39 | if len(relevant_cluster_indices) != num_cluster_centers: 40 | print(f'Narrowed down cluster number from: {num_cluster_centers} to:' 41 | f'{len(relevant_cluster_indices)} clusters.') 42 | new_K = len(relevant_cluster_indices) 43 | cluster = KMeans(new_K).fit(X) 44 | map = cluster.labels_.reshape(H, W) 45 | idcnt = {} 46 | for idi in range(new_K): 47 | idcnt[idi] = (((map == idi) * fg_mask).sum(), cluster.cluster_centers_[idi]) 48 | counter = 0 49 | shapes = [] 50 | for i, (_, color) in tqdm(sorted(idcnt.items(), key=lambda item: item[1][0], reverse=True)): 51 | sigmas = sigma_const * idcnt[i][0] 52 | mask = (map == i).astype(np.float32) 53 | downsampled_mask = cv2.resize(mask, dsize=(128, 128), interpolation=cv2.INTER_CUBIC) * \ 54 | cv2.resize(fg_mask, dsize=(128, 128), interpolation=cv2.INTER_CUBIC) 55 | blurred_mask = gaussian_filter(downsampled_mask, sigma=sigmas) 56 | # blurred_mask = np.pad(blurred_mask, [(pad_width, pad_width), (pad_width, pad_width)], mode='constant') 57 | mrf_mask = run_mrf_on_graph(blurred_mask) 58 | _, component, cstats, ccenter = cv2.connectedComponentsWithStats(mrf_mask.astype(np.uint8), connectivity=4) 59 | for j in np.unique(component): 60 | if j != 0: 61 | comp_mask = cv2.resize((component == j).astype(np.uint8), dsize=(H, W), interpolation=cv2.INTER_CUBIC) 62 | mean_row, mean_col = [int(x.mean()) for x in np.where(comp_mask > 0.5)] 63 | shapes.append((comp_mask, (component == j).sum(), counter, color, (mean_row, mean_col))) 64 | counter += 1 65 | 66 | shapes_by_size = sorted(shapes, key=lambda shape: shape[1], reverse=True) 67 | if len(shapes) > num_shapes: 68 | shapes_by_size = shapes_by_size[:num_shapes] 69 | # shapes_by_layer = sorted(shapes_by_size, key=lambda shape: shape[2]) 70 | 71 | result = np.ones((H, W, 3)) 72 | for mask in shapes_by_size: 73 | result[mask[0] > 0.5, ...] = mask[3] 74 | 75 | cv2.imwrite("{}/{}".format(output_dir, "masks.png"), (result*255).astype(np.uint8)[..., ::-1]) 76 | 77 | masks = [shape[0] for shape in shapes_by_size] 78 | coords = [shape[-1] for shape in shapes_by_size] 79 | 80 | trimmed_object_masks_tensor = [torch.tensor(x).to(device) for x in masks] 81 | coords_tensor = [torch.cat([torch.tensor(x).to(device)[None], torch.tensor(y).to(device)[None]]) for (x, y) in coords] 82 | return trimmed_object_masks_tensor, coords_tensor 83 | 84 | 85 | def run_mrf_on_graph(blurred_mask, binary_term_weight=0.5): 86 | H, W = blurred_mask.shape 87 | unary_term_weights_connected_to_source = blurred_mask 88 | unary_term_weights_connected_to_sink = 1 - blurred_mask 89 | G = nx.DiGraph() 90 | for h in range(H): 91 | for w in range(W): 92 | for i in [1]: 93 | for j in [1]: 94 | if H > h + i > 0 and W > w + j > 0: 95 | G.add_edge(str((h, w)), str((h + i, w)), 96 | capacity=np.absolute( 97 | blurred_mask[h + i, w] - blurred_mask[h, w]) * binary_term_weight) 98 | G.add_edge(str((h, w)), str((h, w + j)), 99 | capacity=np.absolute( 100 | blurred_mask[h, w + j] - blurred_mask[h, w]) * binary_term_weight) 101 | G.add_edge("SOURCE", str((h, w)), 102 | capacity=unary_term_weights_connected_to_source[h, w]) 103 | 104 | G.add_edge(str((h, w)), "SINK", 105 | capacity=unary_term_weights_connected_to_sink[h, w]) 106 | cut_value, partition = nx.minimum_cut(G, "SOURCE", "SINK") 107 | mask = np.zeros((H, W)) 108 | for item in partition[0]: 109 | if item != "SOURCE": 110 | i, j = item[1:-1].split(",") 111 | mask[int(i), int(j)] = 1 112 | return mask 113 | 114 | 115 | class Sparse_coord_init: 116 | def __init__(self, pred, gt, format='[2D x c]', num_cluster_centers=10, ignore_clusters_smaller_than=0.5 / 100., 117 | quantile_interval=20, nodiff_thres=0.1): 118 | if isinstance(pred, torch.Tensor): 119 | pred = pred.detach().cpu().numpy() 120 | if isinstance(gt, torch.Tensor): 121 | gt = gt.detach().cpu().numpy() 122 | if format == '[bs x c x 2D]': 123 | self.map = ((pred[0] - gt[0]) ** 2).sum(0) 124 | self.reference_gt = copy.deepcopy( 125 | np.transpose(gt[0], (1, 2, 0))) 126 | elif format == '[2D x c]': 127 | self.map = (np.abs(pred - gt)).sum(-1) 128 | self.reference_gt = copy.deepcopy(gt[0]) 129 | else: 130 | raise ValueError 131 | self.num_cluster_centers = num_cluster_centers 132 | # OptionA: Zero too small errors to avoid the error too small deadloop 133 | H, W, C = gt.shape 134 | X = gt.reshape(H * W, C) 135 | cluster = KMeans(self.num_cluster_centers).fit(X) 136 | color_clusters = cluster.labels_.reshape(H, W) 137 | cluster_indices = np.unique(color_clusters) 138 | relevant_cluster_indices = [ 139 | c for c in cluster_indices 140 | if (color_clusters == c).sum() > H * W * ignore_clusters_smaller_than] 141 | if len(relevant_cluster_indices) != self.num_cluster_centers: 142 | print(f'Narrowed down cluster number from: {self.num_cluster_centers} to:' 143 | f'{len(relevant_cluster_indices)} clusters.') 144 | cluster = KMeans(len(relevant_cluster_indices)).fit(X) 145 | self.map = cluster.labels_.reshape(H, W) 146 | self.idcnt = {} 147 | for idi in sorted(np.unique(self.map)): 148 | self.idcnt[idi] = (self.map == idi).sum() 149 | self.idcnt.pop(min(self.idcnt.keys())) 150 | # remove smallest one to remove the correct region 151 | 152 | def __call__(self): 153 | if len(self.idcnt) == 0: 154 | h, w = self.map.shape 155 | return [npr.uniform(0, 1) * w, npr.uniform(0, 1) * h] 156 | target_id = max(self.idcnt, key=self.idcnt.get) 157 | _, component, cstats, ccenter = cv2.connectedComponentsWithStats( 158 | (self.map == target_id).astype(np.uint8), connectivity=4) 159 | 160 | # remove cid = 0, it is the invalid area 161 | csize = [ci[-1] for ci in cstats[1:]] 162 | target_cid = csize.index(max(csize)) + 1 163 | center = ccenter[target_cid][::-1] 164 | coord = np.stack(np.where(component == target_cid)).T 165 | dist = np.linalg.norm(coord - center, axis=1) 166 | target_coord_id = np.argmin(dist) 167 | coord_h, coord_w = coord[target_coord_id] 168 | # replace_sampling 169 | self.idcnt[target_id] -= max(csize) 170 | if self.idcnt[target_id] == 0: 171 | self.idcnt.pop(target_id) 172 | self.map[component == target_cid] = 0 173 | return (component == target_cid).astype(np.uint8), coord_w, coord_h 174 | 175 | 176 | class Decomp: 177 | def __init__(self, num_cluster_centers=8, 178 | ignore_clusters_smaller_than=0.1 / 100.0, 179 | ignore_shapes_smaller_than=0.1 / 100., 180 | add_positional_encoding=False, 181 | device='cpu'): 182 | super(Decomp, self).__init__() 183 | self.add_positional_encoding = add_positional_encoding 184 | self.num_cluster_centers = num_cluster_centers 185 | self.ignore_clusters_smaller_than = ignore_clusters_smaller_than 186 | self.ignore_shapes_smaller_than = ignore_shapes_smaller_than 187 | self.device = device 188 | 189 | def decomp(self, img): 190 | img = np.transpose(img.squeeze().cpu().numpy(), (1, 2, 0)) 191 | H, W, C = img.shape 192 | masks, coords = [], [] 193 | sparse_coord_init = Sparse_coord_init(np.ones_like(img), img) 194 | for shape in range(self.num_cluster_centers): 195 | mask, coord_w, coord_h = sparse_coord_init() 196 | masks.append(mask) 197 | coords.append((coord_h, coord_w)) 198 | 199 | trimmed_object_masks_tensor = [torch.tensor(x).to(self.device) for x in masks] 200 | coords_tensor = [torch.cat([torch.tensor(x).to(self.device)[None], torch.tensor(y).to(self.device)[None]]) 201 | for (x, y) in coords] 202 | return trimmed_object_masks_tensor, coords_tensor 203 | 204 | if self.add_positional_encoding: 205 | cols, rows = np.meshgrid(np.arange(W), np.arange(H)) 206 | img = np.concatenate([img, cols[..., np.newaxis], rows[..., np.newaxis]], 207 | axis=2) 208 | C += 2 209 | 210 | X = img.reshape(H * W, C) 211 | cluster = KMeans(self.num_cluster_centers).fit(X) 212 | color_clusters = cluster.labels_.reshape(H, W) 213 | 214 | # plt.subplot(1, 2, 1) 215 | # plt.title('input image') 216 | # plt.imshow(Image.open(img_path)) 217 | # plt.subplot(1, 2, 2) 218 | # plt.title('color clusters') 219 | # plt.imshow(color_clusters, cmap=plt.colormaps.get('Pastel1')) 220 | # plt.colorbar() 221 | # fig = plt.gcf() 222 | # fig.set_size_inches((8, 8)) 223 | # plt.savefig(os.path.join(RESULTS_DIR, f'color_clusters_for_{image_name}.png')) 224 | 225 | cluster_indices = np.unique(color_clusters) 226 | relevant_cluster_indices = [ 227 | c for c in cluster_indices 228 | if (color_clusters == c).sum() > H * W * self.ignore_clusters_smaller_than] 229 | 230 | if len(relevant_cluster_indices) != self.num_cluster_centers: 231 | print(f'Narrowed down cluster number from: {self.num_cluster_centers} to:' 232 | f'{len(relevant_cluster_indices)} clusters.') 233 | 234 | print(f"Creating connected components from clusters...") 235 | object_masks = [] 236 | for cluster_index in relevant_cluster_indices: 237 | 238 | # canny_output = cv2.Canny((color_clusters == cluster_index).astype(np.uint8), 0, 1) 239 | # contours, hierarchy = cv2.findContours(canny_output, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 240 | # # Draw contours 241 | # drawing = np.zeros((canny_output.shape[0], canny_output.shape[1], 3), dtype=np.uint8) 242 | # contours = sorted(contours, key=cv2.contourArea, reverse=True) 243 | # for i in range(min(4, len(contours))): 244 | # color = (rng.randint(0, 256), rng.randint(0, 256), rng.randint(0, 256)) 245 | # cv2.drawContours(drawing, contours, i, color, 2, cv2.LINE_8, hierarchy, 0) 246 | 247 | num_labels, components_image = cv2.connectedComponents( 248 | (color_clusters == cluster_index).astype(np.uint8)) 249 | for label in range(num_labels): 250 | cluster_histogram_in_curr_component = Counter( 251 | color_clusters[components_image == label]) 252 | if cluster_index not in cluster_histogram_in_curr_component: 253 | continue 254 | most_common = cluster_histogram_in_curr_component.most_common(1)[0][0] 255 | if cluster_index != most_common: 256 | continue 257 | object_masks.append((components_image == label).astype(np.uint8)) 258 | 259 | print(f"Sorting connected components according to area...") 260 | object_masks = sorted(object_masks, key=lambda x: x.sum(), reverse=True) 261 | 262 | print(f"Ignoring connected components with area < " 263 | f"{self.ignore_shapes_smaller_than * 100.0:.2f} of image size...") 264 | trimmed_object_masks = [mask for mask in object_masks 265 | if mask.sum() > H * W * self.ignore_shapes_smaller_than] 266 | trimmed_object_masks_tensor = [torch.tensor(x).to(self.device) for x in trimmed_object_masks] 267 | return trimmed_object_masks_tensor 268 | 269 | # print(f"Plotting clusters and connected components.") 270 | # from math import floor, sqrt 271 | # plot_rows = 1 + int(floor(sqrt(len(object_masks)))) 272 | # plot_rows = 3 273 | # original_image = Image.open(img_path) 274 | # 275 | # plt.subplot(3, 2, 1) 276 | # plt.title('input image') 277 | # plt.imshow(original_image) 278 | # plt.subplot(3, 2, 2) 279 | # plt.title('color clusters') 280 | # plt.imshow(color_clusters, cmap=plt.colormaps.get('Pastel1')) 281 | # plt.colorbar() 282 | # 283 | # for i in range(6): 284 | # if i < len(object_masks): 285 | # mask = object_masks[i] 286 | # plt.subplot(plot_rows, plot_rows, plot_rows + i+1) 287 | # plt.imshow(object_masks[i], cmap=plt.colormaps.get('Pastel1')) 288 | # 289 | # fig = plt.gcf() 290 | # fig.set_size_inches((8, 8)) 291 | # plt.savefig(os.path.join(RESULTS_DIR, 292 | # f'connected_components_for_{image_name}.png')) 293 | -------------------------------------------------------------------------------- /models/edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GradLayer(nn.Module): 7 | 8 | def __init__(self): 9 | super(GradLayer, self).__init__() 10 | kernel_v = [[0, -1, 0], 11 | [0, 0, 0], 12 | [0, 1, 0]] 13 | kernel_h = [[0, 0, 0], 14 | [-1, 0, 1], 15 | [0, 0, 0]] 16 | kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) 17 | kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) 18 | self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False) 19 | self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False) 20 | 21 | def get_gray(self, x): 22 | ''' 23 | Convert image to its gray one. 24 | ''' 25 | gray_coeffs = [65.738, 129.057, 25.064] 26 | convert = x.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 27 | x_gray = x.mul(convert).sum(dim=1) 28 | return x_gray.unsqueeze(1) 29 | 30 | def forward(self, x): 31 | # x_list = [] 32 | # for i in range(x.shape[1]): 33 | # x_i = x[:, i] 34 | # x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1) 35 | # x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1) 36 | # x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6) 37 | # x_list.append(x_i) 38 | 39 | # x = torch.cat(x_list, dim=1) 40 | if x.shape[1] == 3: 41 | x = self.get_gray(x) 42 | 43 | x_v = F.conv2d(x, self.weight_v, padding=1) 44 | x_h = F.conv2d(x, self.weight_h, padding=1) 45 | x = torch.sqrt(torch.pow(x_v, 2) + torch.pow(x_h, 2) + 1e-6) 46 | 47 | return x 48 | 49 | 50 | class GradLoss(nn.Module): 51 | 52 | def __init__(self): 53 | super(GradLoss, self).__init__() 54 | self.loss = nn.L1Loss() 55 | self.grad_layer = GradLayer() 56 | 57 | def forward(self, output, gt_img, mode): 58 | output_grad = self.grad_layer(output) 59 | gt_grad = self.grad_layer(gt_img) 60 | return self.loss(output_grad, gt_grad) 61 | 62 | 63 | if __name__ == "__main__": 64 | import cv2 65 | import numpy as np 66 | 67 | net = GradLayer() 68 | img = cv2.imread('example.JPEG') 69 | a = img.shape # (256, 256, 3) 70 | 71 | img = (img / 255.0).astype(np.float32) 72 | img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) 73 | img = net(img) # input img: data range [0, 1]; data type torch.float32; data shape [1, 3, 256, 256] 74 | b = img.shape # torch.Size([1, 1, 256, 256]) 75 | img = (img[0, :, :, :].permute(1, 2, 0).numpy() * 255.0).astype(np.uint8) 76 | 77 | c = img.shape # (256, 256, 1) 78 | cv2.imshow('pytorch sobel', img) 79 | cv2.waitKey(0) 80 | -------------------------------------------------------------------------------- /models/histogram.py: -------------------------------------------------------------------------------- 1 | from torch import nn, sigmoid 2 | from PIL import Image 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def phi_k(x, L, W): 9 | return sigmoid((x + (L / 2)) / W) - sigmoid((x - (L / 2)) / W) 10 | 11 | 12 | def compute_pj(x, mu_k, K, L, W): 13 | # we assume that x has only one channel already 14 | # flatten spatial dims 15 | x = x.reshape(x.size(0), 1, -1) 16 | x = x.repeat(1, K, 1) # construct K channels 17 | 18 | # apply activation functions 19 | return phi_k(x - mu_k, L, W) 20 | 21 | 22 | class EarthMoversDistanceLoss(nn.Module): 23 | def __init__(self): 24 | super().__init__() 25 | 26 | def forward(self, x, y): 27 | # input has dims: (Batch x Bins) 28 | bins = x.size(1) 29 | r = torch.arange(bins).cuda() 30 | s, t = torch.meshgrid(r, r) 31 | tt = t >= s 32 | 33 | cdf_x = torch.matmul(x, tt.float()) 34 | cdf_y = torch.matmul(y, tt.float()) 35 | 36 | return torch.sum(torch.square(cdf_x - cdf_y), dim=1) 37 | 38 | 39 | class MutualInformationLoss(nn.Module): 40 | def __init__(self): 41 | super().__init__() 42 | 43 | def forward(self, p1, p2, p12): 44 | # input p12 has dims: (Batch x Bins x Bins) 45 | # input p1 & p2 has dims: (Batch x Bins) 46 | 47 | product_p = torch.matmul(torch.transpose(p1.unsqueeze(1), 1, 2), p2.unsqueeze(1)) + torch.finfo(p1.dtype).eps 48 | mi = torch.sum(p12 * torch.log(p12 / product_p + torch.finfo(p1.dtype).eps), dim=(1, 2)) 49 | h = -torch.sum(p12 * torch.log(p12 + torch.finfo(p1.dtype).eps), dim=(1, 2)) 50 | 51 | return 1 - (mi / h) 52 | 53 | 54 | class HistLayerBase(nn.Module): 55 | def __init__(self): 56 | super().__init__() 57 | 58 | self.K = 256 59 | self.L = 1 / self.K # 2 / K -> if values in [-1,1] (Paper) 60 | self.W = self.L / 2.5 61 | 62 | self.mu_k = (self.L * (torch.arange(self.K) + 0.5)).view(-1, 1).cuda() 63 | 64 | 65 | class SingleDimHistLayer(HistLayerBase): 66 | def __init__(self): 67 | super().__init__() 68 | 69 | def forward(self, x): 70 | N = x.size(1) * x.size(2) 71 | pj = compute_pj(x, self.mu_k, self.K, self.L, self.W) 72 | return pj.sum(dim=2) / N 73 | 74 | 75 | class JointHistLayer(HistLayerBase): 76 | def __init__(self): 77 | super().__init__() 78 | 79 | def forward(self, x, y): 80 | N = x.size(1) * x.size(2) 81 | p1 = compute_pj(x, self.mu_k, self.K, self.L, self.W) 82 | p2 = compute_pj(y, self.mu_k, self.K, self.L, self.W) 83 | return torch.matmul(p1, torch.transpose(p2, 1, 2)) / N 84 | 85 | 86 | """ 87 | ##### Copyright 2021 Mahmoud Afifi. 88 | 89 | If you find this code useful, please cite our paper: 90 | 91 | Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN: 92 | Controlling Colors of GAN-Generated and Real Images via Color Histograms." 93 | In CVPR, 2021. 94 | 95 | @inproceedings{afifi2021histogan, 96 | title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via 97 | Color Histograms}, 98 | author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.}, 99 | booktitle={CVPR}, 100 | year={2021} 101 | } 102 | #### 103 | """ 104 | 105 | EPS = 1e-6 106 | 107 | 108 | class RGBuvHistBlock(nn.Module): 109 | def __init__(self, h=64, insz=150, resizing='interpolation', 110 | method='inverse-quadratic', sigma=0.02, intensity_scale=True, 111 | hist_boundary=None, green_only=False, device='cuda'): 112 | """ Computes the RGB-uv histogram feature of a given image. 113 | Args: 114 | h: histogram dimension size (scalar). The default value is 64. 115 | insz: maximum size of the input image; if it is larger than this size, the 116 | image will be resized (scalar). Default value is 150 (i.e., 150 x 150 117 | pixels). 118 | resizing: resizing method if applicable. Options are: 'interpolation' or 119 | 'sampling'. Default is 'interpolation'. 120 | method: the method used to count the number of pixels for each bin in the 121 | histogram feature. Options are: 'thresholding', 'RBF' (radial basis 122 | function), or 'inverse-quadratic'. Default value is 'inverse-quadratic'. 123 | sigma: if the method value is 'RBF' or 'inverse-quadratic', then this is 124 | the sigma parameter of the kernel function. The default value is 0.02. 125 | intensity_scale: boolean variable to use the intensity scale (I_y in 126 | Equation 2). Default value is True. 127 | hist_boundary: a list of histogram boundary values. Default is [-3, 3]. 128 | green_only: boolean variable to use only the log(g/r), log(g/b) channels. 129 | Default is False. 130 | 131 | Methods: 132 | forward: accepts input image and returns its histogram feature. Note that 133 | unless the method is 'thresholding', this is a differentiable function 134 | and can be easily integrated with the loss function. As mentioned in the 135 | paper, the 'inverse-quadratic' was found more stable than 'RBF' in our 136 | training. 137 | """ 138 | super(RGBuvHistBlock, self).__init__() 139 | self.h = h 140 | self.insz = insz 141 | self.device = device 142 | self.resizing = resizing 143 | self.method = method 144 | self.intensity_scale = intensity_scale 145 | self.green_only = green_only 146 | if hist_boundary is None: 147 | hist_boundary = [-3, 3] 148 | hist_boundary.sort() 149 | self.hist_boundary = hist_boundary 150 | if self.method == 'thresholding': 151 | self.eps = (abs(hist_boundary[0]) + abs(hist_boundary[1])) / h 152 | else: 153 | self.sigma = sigma 154 | 155 | def forward(self, x): 156 | x = torch.clamp(x, 0, 1) 157 | if x.shape[2] > self.insz or x.shape[3] > self.insz: 158 | if self.resizing == 'interpolation': 159 | x_sampled = F.interpolate(x, size=(self.insz, self.insz), 160 | mode='bilinear', align_corners=False) 161 | elif self.resizing == 'sampling': 162 | inds_1 = torch.LongTensor( 163 | np.linspace(0, x.shape[2], self.h, endpoint=False)).to( 164 | device=self.device) 165 | inds_2 = torch.LongTensor( 166 | np.linspace(0, x.shape[3], self.h, endpoint=False)).to( 167 | device=self.device) 168 | x_sampled = x.index_select(2, inds_1) 169 | x_sampled = x_sampled.index_select(3, inds_2) 170 | else: 171 | raise Exception( 172 | f'Wrong resizing method. It should be: interpolation or sampling. ' 173 | f'But the given value is {self.resizing}.') 174 | else: 175 | x_sampled = x 176 | 177 | L = x_sampled.shape[0] # size of mini-batch 178 | if x_sampled.shape[1] > 3: 179 | x_sampled = x_sampled[:, :3, :, :] 180 | X = torch.unbind(x_sampled, dim=0) 181 | hists = torch.zeros((x_sampled.shape[0], 1 + int(not self.green_only) * 2, 182 | self.h, self.h)).to(device=self.device) 183 | for l in range(L): 184 | I = torch.t(torch.reshape(X[l], (3, -1))) 185 | II = torch.pow(I, 2) 186 | if self.intensity_scale: 187 | Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS), 188 | dim=1) 189 | else: 190 | Iy = 1 191 | if not self.green_only: 192 | Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] + 193 | EPS), dim=1) 194 | Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] + 195 | EPS), dim=1) 196 | diff_u0 = abs( 197 | Iu0 - torch.unsqueeze(torch.tensor(np.linspace( 198 | self.hist_boundary[0], self.hist_boundary[1], num=self.h)), 199 | dim=0).to(self.device)) 200 | diff_v0 = abs( 201 | Iv0 - torch.unsqueeze(torch.tensor(np.linspace( 202 | self.hist_boundary[0], self.hist_boundary[1], num=self.h)), 203 | dim=0).to(self.device)) 204 | if self.method == 'thresholding': 205 | diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2 206 | diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2 207 | elif self.method == 'RBF': 208 | diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)), 209 | 2) / self.sigma ** 2 210 | diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)), 211 | 2) / self.sigma ** 2 212 | diff_u0 = torch.exp(-diff_u0) # Radial basis function 213 | diff_v0 = torch.exp(-diff_v0) 214 | elif self.method == 'inverse-quadratic': 215 | diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)), 216 | 2) / self.sigma ** 2 217 | diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)), 218 | 2) / self.sigma ** 2 219 | diff_u0 = 1 / (1 + diff_u0) # Inverse quadratic 220 | diff_v0 = 1 / (1 + diff_v0) 221 | else: 222 | raise Exception( 223 | f'Wrong kernel method. It should be either thresholding, RBF,' 224 | f' inverse-quadratic. But the given value is {self.method}.') 225 | diff_u0 = diff_u0.type(torch.float32) 226 | diff_v0 = diff_v0.type(torch.float32) 227 | a = torch.t(Iy * diff_u0) 228 | hists[l, 0, :, :] = torch.mm(a, diff_v0) 229 | 230 | Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS), 231 | dim=1) 232 | Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS), 233 | dim=1) 234 | diff_u1 = abs( 235 | Iu1 - torch.unsqueeze(torch.tensor(np.linspace( 236 | self.hist_boundary[0], self.hist_boundary[1], num=self.h)), 237 | dim=0).to(self.device)) 238 | diff_v1 = abs( 239 | Iv1 - torch.unsqueeze(torch.tensor(np.linspace( 240 | self.hist_boundary[0], self.hist_boundary[1], num=self.h)), 241 | dim=0).to(self.device)) 242 | 243 | if self.method == 'thresholding': 244 | diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2 245 | diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2 246 | elif self.method == 'RBF': 247 | diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)), 248 | 2) / self.sigma ** 2 249 | diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)), 250 | 2) / self.sigma ** 2 251 | diff_u1 = torch.exp(-diff_u1) # Gaussian 252 | diff_v1 = torch.exp(-diff_v1) 253 | elif self.method == 'inverse-quadratic': 254 | diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)), 255 | 2) / self.sigma ** 2 256 | diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)), 257 | 2) / self.sigma ** 2 258 | diff_u1 = 1 / (1 + diff_u1) # Inverse quadratic 259 | diff_v1 = 1 / (1 + diff_v1) 260 | 261 | diff_u1 = diff_u1.type(torch.float32) 262 | diff_v1 = diff_v1.type(torch.float32) 263 | a = torch.t(Iy * diff_u1) 264 | if not self.green_only: 265 | hists[l, 1, :, :] = torch.mm(a, diff_v1) 266 | else: 267 | hists[l, 0, :, :] = torch.mm(a, diff_v1) 268 | 269 | if not self.green_only: 270 | Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] + 271 | EPS), dim=1) 272 | Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] + 273 | EPS), dim=1) 274 | diff_u2 = abs( 275 | Iu2 - torch.unsqueeze(torch.tensor(np.linspace( 276 | self.hist_boundary[0], self.hist_boundary[1], num=self.h)), 277 | dim=0).to(self.device)) 278 | diff_v2 = abs( 279 | Iv2 - torch.unsqueeze(torch.tensor(np.linspace( 280 | self.hist_boundary[0], self.hist_boundary[1], num=self.h)), 281 | dim=0).to(self.device)) 282 | if self.method == 'thresholding': 283 | diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2 284 | diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2 285 | elif self.method == 'RBF': 286 | diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)), 287 | 2) / self.sigma ** 2 288 | diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)), 289 | 2) / self.sigma ** 2 290 | diff_u2 = torch.exp(-diff_u2) # Gaussian 291 | diff_v2 = torch.exp(-diff_v2) 292 | elif self.method == 'inverse-quadratic': 293 | diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)), 294 | 2) / self.sigma ** 2 295 | diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)), 296 | 2) / self.sigma ** 2 297 | diff_u2 = 1 / (1 + diff_u2) # Inverse quadratic 298 | diff_v2 = 1 / (1 + diff_v2) 299 | diff_u2 = diff_u2.type(torch.float32) 300 | diff_v2 = diff_v2.type(torch.float32) 301 | a = torch.t(Iy * diff_u2) 302 | hists[l, 2, :, :] = torch.mm(a, diff_v2) 303 | 304 | # normalization 305 | hists_normalized = hists / ( 306 | ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS) 307 | 308 | return hists_normalized 309 | 310 | 311 | class HistLoss(torch.nn.Module): 312 | def __init__(self, device='cpu', intensity_scale=True, histogram_size=64, 313 | max_input_size=256, hist_boundary=[-3, 3], method='inverse-quadratic'): 314 | super().__init__() 315 | self.device = device 316 | self.emd = EarthMoversDistanceLoss().to(device) 317 | # create a histogram block 318 | self.hist = RGBuvHistBlock(insz=max_input_size, h=histogram_size, 319 | intensity_scale=intensity_scale, 320 | method=method, hist_boundary=hist_boundary, 321 | device=device) 322 | # self.hist = SingleDimHistLayer().to(device) 323 | 324 | def forward(self, source, target, mode): 325 | input_hist = self.hist(source) 326 | target_hist = self.hist(target) 327 | histogram_loss = (1 / np.sqrt(2.0) * (torch.sqrt(torch.sum( 328 | torch.pow(torch.sqrt(target_hist) - torch.sqrt(input_hist), 2)))) / 329 | input_hist.shape[0]) 330 | return histogram_loss 331 | # return self.emd(hist1, hist2) 332 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import CLIP_.clip as clip 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models, transforms 6 | import numpy as np 7 | 8 | from models.decomp import maskLoss 9 | from models.histogram import HistLoss 10 | from models.structure import SuperPixel 11 | from models.pyramid import MSEPyramidLoss 12 | from models.edge import GradLoss 13 | 14 | 15 | def compute_sine_theta(s1, s2): # s1 and s2 aret two segments to be uswed 16 | # s1, s2 (2, 2) 17 | v1 = s1[1, :] - s1[0, :] 18 | v2 = s2[1, :] - s2[0, :] 19 | # print(v1, v2) 20 | sine_theta = (v1[0] * v2[1] - v1[1] * v2[0]) / (torch.norm(v1) * torch.norm(v2)) 21 | return sine_theta 22 | 23 | 24 | def get_sdf(phi, method='skfmm', **kwargs): 25 | if method == 'skfmm': 26 | import skfmm 27 | phi = (phi - 0.5) * 2 28 | if (phi.max() <= 0) or (phi.min() >= 0): 29 | return np.zeros(phi.shape).astype(np.float32) 30 | sdf = [] 31 | for img in phi: 32 | sd = skfmm.distance(img, dx=1) 33 | 34 | flip_negative = kwargs.get('flip_negative', True) 35 | if flip_negative: 36 | sd = np.abs(sd) 37 | 38 | truncate = kwargs.get('truncate', 10) 39 | sd = np.clip(sd, -truncate, truncate) 40 | # print(f"max sd value is: {sd.max()}") 41 | 42 | zero2max = kwargs.get('zero2max', True) 43 | if zero2max and flip_negative: 44 | sd = sd.max() - sd 45 | elif zero2max: 46 | raise ValueError 47 | 48 | normalize = kwargs.get('normalize', 'sum') 49 | if normalize == 'sum': 50 | sd /= sd.sum() 51 | elif normalize == 'to1': 52 | sd /= sd.max() 53 | sdf.append(sd[None]) 54 | return torch.FloatTensor(np.clip(np.concatenate(sdf, axis=0), 0, 1)) 55 | 56 | 57 | class Loss(nn.Module): 58 | def __init__(self, args): 59 | super(Loss, self).__init__() 60 | self.args = args 61 | self.percep_loss = args.percep_loss 62 | 63 | self.train_with_clip = args.train_with_clip 64 | self.clip_weight = args.clip_weight 65 | self.start_clip = args.start_clip 66 | 67 | self.clip_conv_loss = args.clip_conv_loss 68 | self.xing_loss_weight = args.xing_loss_weight 69 | self.clip_fc_loss_weight = args.clip_fc_loss_weight 70 | self.clip_text_guide = args.clip_text_guide 71 | 72 | self.losses_to_apply = self.get_losses_to_apply() 73 | 74 | self.loss_mapper = \ 75 | { 76 | "clip": CLIPLoss(args), 77 | "clip_conv_loss": CLIPConvLoss(args), 78 | "L2": MSEPyramidLoss(device=args.device), 79 | "Hist": HistLoss(device=args.device), 80 | "SDF": torch.nn.MSELoss(reduction='none'), 81 | "Xing_Loss": self.Xing_Loss, 82 | "Edge": GradLoss().to(args.device), 83 | "Mask": maskLoss 84 | } 85 | self.structure = SuperPixel(args.device, mode='sscolor') 86 | 87 | def distance(self, shapes): 88 | loss = 0 89 | for path1 in shapes: 90 | x1 = path1.points 91 | for path2 in shapes: 92 | x2 = path2.points 93 | loss += ((x1 - x2) ** 2).mean() 94 | return loss 95 | 96 | def Xing_Loss(self, shapes, scale=1): # x[ npoints,2] 97 | loss = 0. 98 | # print(len(x_list)) 99 | for shape in shapes: 100 | x = shape.points 101 | seg_loss = 0. 102 | N = x.size()[0] 103 | x = torch.cat([x, x[0, :].unsqueeze(0)], dim=0) # (N+1,2) 104 | segments = torch.cat([x[:-1, :].unsqueeze(1), x[1:, :].unsqueeze(1)], dim=1) # (N, start/end, 2) 105 | assert N % 3 == 0, 'The segment number is not correct!' 106 | segment_num = int(N / 3) 107 | for i in range(segment_num): 108 | cs1 = segments[i * 3, :, :] # start control segs 109 | cs2 = segments[i * 3 + 1, :, :] # middle control segs 110 | cs3 = segments[i * 3 + 2, :, :] # end control segs 111 | # print('the direction of the vectors:') 112 | # print(compute_sine_theta(cs1, cs2)) 113 | direct = (compute_sine_theta(cs1, cs2) >= 0).float() 114 | opst = 1 - direct # another direction 115 | sina = compute_sine_theta(cs1, cs3) # the angle between cs1 and cs3 116 | seg_loss += direct * torch.relu(- sina) + opst * torch.relu(sina) 117 | # print(direct, opst, sina) 118 | seg_loss /= segment_num 119 | 120 | templ = seg_loss 121 | loss += templ * scale # area_loss * scale 122 | 123 | return loss / (len(shapes)) 124 | 125 | def get_losses_to_apply(self): 126 | losses_to_apply = [] 127 | if self.percep_loss != "none": 128 | losses_to_apply.append(self.percep_loss) 129 | if self.train_with_clip and self.start_clip == 0: 130 | losses_to_apply.append("clip") 131 | if self.clip_conv_loss: 132 | losses_to_apply.append("clip_conv_loss") 133 | if self.clip_text_guide: 134 | losses_to_apply.append("clip_text") 135 | losses_to_apply.append("L2") 136 | 137 | # losses_to_apply.append("Mask") 138 | # losses_to_apply.append("SDF") 139 | # losses_to_apply.append("Sizes") 140 | # losses_to_apply.append("Hist") 141 | # losses_to_apply.append("Edge") 142 | # losses_to_apply.append("Distance") 143 | # losses_to_apply.append("Xing_Loss") 144 | return losses_to_apply 145 | 146 | def update_losses_to_apply(self, epoch): 147 | if "clip" not in self.losses_to_apply: 148 | if self.train_with_clip: 149 | if epoch > self.start_clip: 150 | self.losses_to_apply.append("clip") 151 | 152 | def forward(self, svg_img, targets, masks, epoch, stacked=None, mode="train", points=None, sizes=None, 153 | im_forsdf=None): 154 | loss = 0 155 | self.update_losses_to_apply(epoch) 156 | 157 | losses_dict = dict.fromkeys( 158 | self.losses_to_apply, torch.tensor([0.0]).to(self.args.device)) 159 | loss_coeffs = dict.fromkeys(self.losses_to_apply, 1.0) 160 | loss_coeffs["clip"] = self.clip_weight 161 | loss_coeffs["clip_text"] = self.clip_text_guide 162 | loss_coeffs["Xing_Loss"] = self.xing_loss_weight 163 | loss_coeffs["Mask"] = 1 164 | loss_coeffs["SDF"] = 100 165 | loss_coeffs["Sizes"] = 1e-1 166 | loss_coeffs["Edge"] = 10 167 | 168 | for loss_name in self.losses_to_apply: 169 | if loss_name in ["clip_conv_loss"]: 170 | conv_loss = self.loss_mapper[loss_name]( 171 | svg_img, targets, mode) 172 | for layer in conv_loss.keys(): 173 | losses_dict[layer] = conv_loss[layer] 174 | elif loss_name == "L2": 175 | losses_dict[loss_name] = self.loss_mapper[loss_name]( 176 | svg_img, targets).mean() 177 | elif loss_name == "Xing_Loss": 178 | losses_dict[loss_name] = self.loss_mapper[loss_name]( 179 | points, scale=1) 180 | elif loss_name == "Sizes": 181 | losses_dict[loss_name] = sizes.sum() 182 | elif loss_name == "SDF": 183 | sdf = (im_forsdf[:, 3]).detach().cpu().numpy() 184 | losses_dict[loss_name] = ( 185 | get_sdf(sdf, normalize='to1').to(self.args.device) * self.loss_mapper[loss_name]( 186 | svg_img, targets).sum(1)).mean() 187 | elif loss_name == "Mask": 188 | losses_dict[loss_name] = self.loss_mapper[loss_name](stacked, targets, masks) 189 | else: 190 | losses_dict[loss_name] = (self.loss_mapper[loss_name]( 191 | svg_img, targets, mode)).mean() 192 | 193 | for key in self.losses_to_apply: 194 | if key != "L2": 195 | losses_dict[key] = losses_dict[key] * loss_coeffs[key] 196 | # print(losses_dict) 197 | return losses_dict 198 | 199 | 200 | class CLIPLoss(torch.nn.Module): 201 | def __init__(self, args, text_prompt=None): 202 | super(CLIPLoss, self).__init__() 203 | 204 | self.args = args 205 | self.model, clip_preprocess = clip.load( 206 | 'ViT-B/32', args.device, jit=False) 207 | self.model.eval() 208 | self.preprocess = transforms.Compose( 209 | [clip_preprocess.transforms[0], clip_preprocess.transforms[-1]]) # clip normalisation 210 | # self.preprocess = transforms.Compose([clip_preprocess.transforms[-1]]) # clip normalisation 211 | self.device = args.device 212 | self.NUM_AUGS = args.num_aug_clip 213 | augemntations = [] 214 | if "affine" in args.augemntations: 215 | augemntations.append(transforms.RandomPerspective( 216 | fill=0, p=1.0, distortion_scale=0.5)) 217 | augemntations.append(transforms.RandomResizedCrop( 218 | 224, scale=(0.8, 0.8), ratio=(1.0, 1.0))) 219 | augemntations.append( 220 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) 221 | self.augment_trans = transforms.Compose(augemntations) 222 | 223 | self.calc_target = True 224 | self.include_target_in_aug = args.include_target_in_aug 225 | self.counter = 0 226 | self.augment_both = args.augment_both 227 | self.text_prompt = text_prompt 228 | 229 | def forward(self, sketches, targets, mode="train"): 230 | if self.calc_target: 231 | targets_ = self.preprocess(targets).to(self.device) 232 | self.targets_features = self.model.encode_image(targets_).detach() 233 | self.calc_target = False 234 | if self.text_prompt is not None: 235 | text_input = clip.tokenize([self.text_prompt]).to(self.device) 236 | text_features = self.model.encode_text(text_input).detach() 237 | 238 | if mode == "eval": 239 | # for regular clip distance, no augmentations 240 | with torch.no_grad(): 241 | sketches = self.preprocess(sketches).to(self.device) 242 | sketches_features = self.model.encode_image(sketches) 243 | return 1. - torch.cosine_similarity(sketches_features, self.targets_features) 244 | 245 | loss_clip = 0 246 | sketch_augs = [] 247 | for n in range(self.NUM_AUGS): 248 | augmented_pair = self.augment_trans(torch.cat([sketches, targets])) 249 | sketch_augs.append(augmented_pair[0].unsqueeze(0)) 250 | 251 | sketch_batch = torch.cat(sketch_augs) 252 | sketch_features = self.model.encode_image(sketch_batch) 253 | 254 | for n in range(self.NUM_AUGS): 255 | loss_clip += (1. - torch.cosine_similarity( 256 | sketch_features[n:n + 1], self.targets_features, dim=1)) 257 | if self.text_prompt is not None: 258 | for n in range(self.NUM_AUGS): 259 | loss_clip += (1. - torch.cosine_similarity( 260 | sketch_features[n:n + 1], text_features, dim=1)) 261 | self.counter += 1 262 | return loss_clip 263 | 264 | 265 | class LPIPS(torch.nn.Module): 266 | def __init__(self, pretrained=True, normalize=True, pre_relu=True, device=None): 267 | """ 268 | Args: 269 | pre_relu(bool): if True, selects features **before** reLU activations 270 | """ 271 | super(LPIPS, self).__init__() 272 | # VGG using perceptually-learned weights (LPIPS metric) 273 | self.normalize = normalize 274 | self.pretrained = pretrained 275 | augemntations = [] 276 | augemntations.append(transforms.RandomPerspective( 277 | fill=0, p=1.0, distortion_scale=0.5)) 278 | augemntations.append(transforms.RandomResizedCrop( 279 | 224, scale=(0.8, 0.8), ratio=(1.0, 1.0))) 280 | self.augment_trans = transforms.Compose(augemntations) 281 | self.feature_extractor = LPIPS._FeatureExtractor( 282 | pretrained, pre_relu).to(device) 283 | 284 | def _l2_normalize_features(self, x, eps=1e-10): 285 | nrm = torch.sqrt(torch.sum(x * x, dim=1, keepdim=True)) 286 | return x / (nrm + eps) 287 | 288 | def forward(self, pred, target, mode="train"): 289 | """Compare VGG features of two inputs.""" 290 | 291 | # Get VGG features 292 | 293 | sketch_augs, img_augs = [pred], [target] 294 | if mode == "train": 295 | for n in range(4): 296 | augmented_pair = self.augment_trans(torch.cat([pred, target])) 297 | sketch_augs.append(augmented_pair[0].unsqueeze(0)) 298 | img_augs.append(augmented_pair[1].unsqueeze(0)) 299 | 300 | xs = torch.cat(sketch_augs, dim=0) 301 | ys = torch.cat(img_augs, dim=0) 302 | 303 | pred = self.feature_extractor(xs) 304 | target = self.feature_extractor(ys) 305 | 306 | # L2 normalize features 307 | if self.normalize: 308 | pred = [self._l2_normalize_features(f) for f in pred] 309 | target = [self._l2_normalize_features(f) for f in target] 310 | 311 | # TODO(mgharbi) Apply Richard's linear weights? 312 | 313 | if self.normalize: 314 | diffs = [torch.sum((p - t) ** 2, 1) 315 | for (p, t) in zip(pred, target)] 316 | else: 317 | # mean instead of sum to avoid super high range 318 | diffs = [torch.mean((p - t) ** 2, 1) 319 | for (p, t) in zip(pred, target)] 320 | 321 | # Spatial average 322 | diffs = [diff.mean([1, 2]) for diff in diffs] 323 | 324 | return sum(diffs) 325 | 326 | class _FeatureExtractor(torch.nn.Module): 327 | def __init__(self, pretrained, pre_relu): 328 | super(LPIPS._FeatureExtractor, self).__init__() 329 | vgg_pretrained = models.vgg16(pretrained=pretrained).features 330 | 331 | self.breakpoints = [0, 4, 9, 16, 23, 30] 332 | if pre_relu: 333 | for i, _ in enumerate(self.breakpoints[1:]): 334 | self.breakpoints[i + 1] -= 1 335 | 336 | # Split at the maxpools 337 | for i, b in enumerate(self.breakpoints[:-1]): 338 | ops = torch.nn.Sequential() 339 | for idx in range(b, self.breakpoints[i + 1]): 340 | op = vgg_pretrained[idx] 341 | ops.add_module(str(idx), op) 342 | # print(ops) 343 | self.add_module("group{}".format(i), ops) 344 | 345 | # No gradients 346 | for p in self.parameters(): 347 | p.requires_grad = False 348 | 349 | # Torchvision's normalization: 350 | self.register_buffer("shift", torch.Tensor( 351 | [0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 352 | self.register_buffer("scale", torch.Tensor( 353 | [0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 354 | 355 | def forward(self, x): 356 | feats = [] 357 | x = (x - self.shift) / self.scale 358 | for idx in range(len(self.breakpoints) - 1): 359 | m = getattr(self, "group{}".format(idx)) 360 | x = m(x) 361 | feats.append(x) 362 | return feats 363 | 364 | 365 | class L2_(torch.nn.Module): 366 | def __init__(self): 367 | """ 368 | Args: 369 | pre_relu(bool): if True, selects features **before** reLU activations 370 | """ 371 | super(L2_, self).__init__() 372 | # VGG using perceptually-learned weights (LPIPS metric) 373 | augemntations = [] 374 | augemntations.append(transforms.RandomPerspective( 375 | fill=0, p=1.0, distortion_scale=0.5)) 376 | augemntations.append(transforms.RandomResizedCrop( 377 | 224, scale=(0.8, 0.8), ratio=(1.0, 1.0))) 378 | augemntations.append( 379 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) 380 | self.augment_trans = transforms.Compose(augemntations) 381 | # LOG.warning("LPIPS is untested") 382 | 383 | def forward(self, pred, target, mode="train"): 384 | """Compare VGG features of two inputs.""" 385 | 386 | # Get VGG features 387 | 388 | sketch_augs, img_augs = [pred], [target] 389 | if mode == "train": 390 | for n in range(4): 391 | augmented_pair = self.augment_trans(torch.cat([pred, target])) 392 | sketch_augs.append(augmented_pair[0].unsqueeze(0)) 393 | img_augs.append(augmented_pair[1].unsqueeze(0)) 394 | 395 | pred = torch.cat(sketch_augs, dim=0) 396 | target = torch.cat(img_augs, dim=0) 397 | diffs = [torch.square(p - t).mean() for (p, t) in zip(pred, target)] 398 | return sum(diffs) 399 | 400 | 401 | class CLIPVisualEncoder(nn.Module): 402 | def __init__(self, clip_model, device, mask_cls="none", apply_mask=False, mask_attention=False): 403 | super().__init__() 404 | self.clip_model = clip_model 405 | self.featuremaps = None 406 | self.device = device 407 | self.n_channels = 3 408 | self.kernel_h = 32 409 | self.kernel_w = 32 410 | self.step = 32 411 | self.num_patches = 49 412 | self.mask_cls = mask_cls 413 | self.apply_mask = apply_mask 414 | self.mask_attention = mask_attention 415 | 416 | for i in range(12): # 12 resblocks in VIT visual transformer 417 | self.clip_model.visual.transformer.resblocks[i].register_forward_hook( 418 | self.make_hook(i)) 419 | 420 | def make_hook(self, name): 421 | def hook(module, input, output): 422 | if len(output.shape) == 3: 423 | self.featuremaps[name] = output.permute( 424 | 1, 0, 2) # LND -> NLD bs, smth, 768 425 | else: 426 | self.featuremaps[name] = output 427 | 428 | return hook 429 | 430 | def forward(self, x, masks=None, mode="train"): 431 | masks_flat = torch.ones((x.shape[0], 50, 768)).to(self.device) # without any effect 432 | attn_map = None 433 | if masks is not None and self.apply_mask: 434 | x_copy = x.detach().clone() 435 | 436 | patches_x = x_copy.unfold(2, self.kernel_h, self.step).unfold(3, self.kernel_w, self.step).reshape(-1, 437 | self.n_channels, 438 | self.num_patches, 439 | 32, 32) 440 | # split the masks into patches (the same input patches to the transformer) 441 | # shape is (batch_size, channel, num_patches, patch_size, patch_size) = (5, 3, 49, 32, 32) 442 | patches_mask = masks.unfold(2, self.kernel_h, self.step).unfold(3, self.kernel_w, self.step).reshape(-1, 443 | self.n_channels, 444 | self.num_patches, 445 | 32, 32) 446 | # masks_ is a binary mask (batch_size, 1, 7, ,7) to say which patch should be masked out 447 | masks_ = torch.ones((x.shape[0], 1, 7, 7)).to(self.device) 448 | for i in range(masks.shape[0]): 449 | for j in range(self.num_patches): 450 | # we mask a patch if more than 20% of the patch is masked 451 | zeros = (patches_mask[i, 0, j] == 0).sum() / (self.kernel_w * self.kernel_h) 452 | if zeros > 0.2: 453 | masks_[i, :, j // 7, j % 7] = 0 454 | 455 | if self.mask_attention: 456 | mask2 = masks_[:, 0].reshape(-1, 49).to(self.device) # .to(device) shape (5, 49) 457 | mask2 = torch.cat([torch.ones(mask2.shape[0], 1).to(self.device), mask2], dim=-1) 458 | mask2 = mask2.unsqueeze(1) 459 | attn_map = mask2.repeat(1, 50, 1).to(self.device) # 5, 50, 50 460 | # attn_map = torch.bmm(mask2.permute(0,2,1), mask2) # 5, 50, 50 461 | attn_map[:, 0, 0] = 1 462 | # attn_map[:,:,0] = 1 463 | attn_map = 1 - attn_map 464 | indixes = (attn_map == 0).nonzero() # shape [136, 2] [[aug_im],[index]] 465 | attn_map = attn_map.repeat(12, 1, 1).bool() # [60, 50, 50] 466 | # attn_map = attn_map.repeat(12,1,1).bool() 467 | 468 | # masks_ = torch.nn.functional.interpolate(masks, size=7, mode='nearest') 469 | # masks_[masks_ < 0.5] = 0 470 | # masks_[masks_ >=0.5] = 1 471 | 472 | # masks_flat's shape is (5, 49), for each image in the batch we have 49 flags indicating if to mask the i'th patch or not 473 | masks_flat = masks_[:, 0].reshape(-1, self.num_patches) 474 | # indixes = (masks_flat == 0).nonzero() # shape [136, 2] [[aug_im],[index]] 475 | # for t in indixes: 476 | # b_num, y, x_ = t[0], t[1] // 7, t[1] % 7 477 | # x_copy[b_num, :, 32 * y: 32 * y + 32, 32 * x_: 32 * x_ + 32] = 0 478 | # now we add the cls token mask, it's all ones for now since we want to leave it 479 | # now the shape is (5, 50) where the first number in each of the 5 rows is 1 (meaning - son't mask the cls token) 480 | masks_flat = torch.cat([torch.ones(masks_flat.shape[0], 1).to(self.device), masks_flat], 481 | dim=1) # include cls by default 482 | # now we duplicate this from (5, 50) to (5, 50, 768) to match the tokens dimentions 483 | masks_flat = masks_flat.unsqueeze(2).repeat(1, 1, 768) # shape is (5, 50, 768) 484 | 485 | # masks_flat = masks_[:,0].reshape(-1, 49)#.to(device) shape (5, 49) 486 | 487 | 488 | elif self.mask_cls != "none": 489 | if self.mask_cls == "only_cls": 490 | masks_flat = torch.zeros((5, 50, 768)).to(self.device) 491 | masks_flat[:, 0, :] = 1 492 | elif self.mask_cls == "cls_out": 493 | masks_flat[:, 0, :] = 0 494 | 495 | self.featuremaps = collections.OrderedDict() 496 | fc_features = self.clip_model.encode_image(x).float() 497 | # fc_features = self.clip_model.encode_image(x, attn_map, mode).float() 498 | # Each featuremap is in shape (5,50,768) - 5 is the batchsize(augment), 50 is cls + 49 patches, 768 is the dimension of the features 499 | # for each k (each of the 12 layers) we only take the vectors 500 | # if masks is not None and self.apply_mask: 501 | featuremaps = [self.featuremaps[k] * masks_flat for k in range(12)] 502 | # featuremaps = [self.featuremaps[k][masks_flat == 1] for k in range(12)] 503 | 504 | # else: 505 | # featuremaps = [self.featuremaps[k] for k in range(12)] 506 | 507 | return fc_features, featuremaps 508 | 509 | 510 | def l2_layers(xs_conv_features, ys_conv_features, clip_model_name): 511 | return [torch.square(x_conv - y_conv).mean() for x_conv, y_conv in 512 | zip(xs_conv_features, ys_conv_features)] 513 | 514 | 515 | def l1_layers(xs_conv_features, ys_conv_features, clip_model_name): 516 | return [torch.abs(x_conv - y_conv).mean() for x_conv, y_conv in 517 | zip(xs_conv_features, ys_conv_features)] 518 | 519 | 520 | def l2_layers(xs_conv_features, ys_conv_features, clip_model_name): 521 | return [torch.square(x_conv - y_conv).mean() for x_conv, y_conv in 522 | zip(xs_conv_features, ys_conv_features)] 523 | 524 | 525 | def l1_layers(xs_conv_features, ys_conv_features, clip_model_name): 526 | return [torch.abs(x_conv - y_conv).mean() for x_conv, y_conv in 527 | zip(xs_conv_features, ys_conv_features)] 528 | 529 | 530 | def cos_layers(xs_conv_features, ys_conv_features, clip_model_name): 531 | if "RN" in clip_model_name: 532 | return [torch.square(x_conv, y_conv, dim=1).mean() for x_conv, y_conv in 533 | zip(xs_conv_features, ys_conv_features)] 534 | return [(1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean() for x_conv, y_conv in 535 | zip(xs_conv_features, ys_conv_features)] 536 | 537 | 538 | class CLIPConvLoss(torch.nn.Module): 539 | def __init__(self, args, clip_model_name="ViT-B/32"): 540 | super(CLIPConvLoss, self).__init__() 541 | self.clip_model_name = clip_model_name 542 | # 0 1 2 3 4 5 6 7 8 9 10 11 0 543 | self.clip_conv_layer_weights = [0, 0, 0.35, 0, 0, 0, 0, 0.45, 0, 0, 0.5, 0.9, 0] 544 | assert self.clip_model_name in [ 545 | "RN50", 546 | "RN101", 547 | "RN50x4", 548 | "RN50x16", 549 | "ViT-B/32", 550 | "ViT-B/16", 551 | ] 552 | 553 | self.clip_conv_loss_type = args.clip_conv_loss_type 554 | self.clip_fc_loss_type = "Cos" # args.clip_fc_loss_type 555 | assert self.clip_conv_loss_type in [ 556 | "L2", "Cos", "L1", 557 | ] 558 | assert self.clip_fc_loss_type in [ 559 | "L2", "Cos", "L1", 560 | ] 561 | 562 | self.distance_metrics = \ 563 | { 564 | "L2": l2_layers, 565 | "L1": l1_layers, 566 | "Cos": cos_layers 567 | } 568 | 569 | self.model, clip_preprocess = clip.load( 570 | self.clip_model_name, args.device, jit=False) 571 | 572 | if self.clip_model_name.startswith("ViT"): 573 | self.visual_encoder = CLIPVisualEncoder(self.model, args.device) 574 | 575 | else: 576 | self.visual_model = self.model.visual 577 | layers = list(self.model.visual.children()) 578 | init_layers = torch.nn.Sequential(*layers)[:8] 579 | self.layer1 = layers[8] 580 | self.layer2 = layers[9] 581 | self.layer3 = layers[10] 582 | self.layer4 = layers[11] 583 | self.att_pool2d = layers[12] 584 | 585 | self.args = args 586 | 587 | self.img_size = clip_preprocess.transforms[1].size 588 | self.model.eval() 589 | self.target_transform = transforms.Compose([ 590 | transforms.ToTensor(), 591 | ]) # clip normalisation 592 | self.normalize_transform = transforms.Compose([ 593 | clip_preprocess.transforms[0], # Resize 594 | clip_preprocess.transforms[1], # CenterCrop 595 | clip_preprocess.transforms[-1], # Normalize 596 | ]) 597 | 598 | self.model.eval() 599 | self.device = args.device 600 | self.num_augs = self.args.num_aug_clip 601 | 602 | augemntations = [] 603 | if "affine" in args.augemntations: 604 | augemntations.append(transforms.RandomPerspective( 605 | fill=0, p=1.0, distortion_scale=0.5)) 606 | augemntations.append(transforms.RandomResizedCrop( 607 | 224, scale=(0.8, 0.8), ratio=(1.0, 1.0))) 608 | augemntations.append( 609 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) 610 | self.augment_trans = transforms.Compose(augemntations) 611 | 612 | self.clip_fc_layer_dims = None # self.args.clip_fc_layer_dims 613 | self.clip_conv_layer_dims = None # self.args.clip_conv_layer_dims 614 | self.clip_fc_loss_weight = args.clip_fc_loss_weight 615 | self.counter = 0 616 | 617 | def forward(self, sketch, target, mode="train"): 618 | """ 619 | Parameters 620 | ---------- 621 | sketch: Torch Tensor [1, C, H, W] 622 | target: Torch Tensor [1, C, H, W] 623 | """ 624 | # y = self.target_transform(target).to(self.args.device) 625 | conv_loss_dict = {} 626 | x = sketch.to(self.device) 627 | y = target.to(self.device) 628 | sketch_augs, img_augs = [self.normalize_transform(x)], [ 629 | self.normalize_transform(y)] 630 | if mode == "train": 631 | for n in range(self.num_augs): 632 | augmented_pair = self.augment_trans(torch.cat([x, y])) 633 | sketch_augs.append(augmented_pair[0].unsqueeze(0)) 634 | img_augs.append(augmented_pair[1].unsqueeze(0)) 635 | 636 | xs = torch.cat(sketch_augs, dim=0).to(self.device) 637 | ys = torch.cat(img_augs, dim=0).to(self.device) 638 | 639 | if self.clip_model_name.startswith("RN"): 640 | xs_fc_features, xs_conv_features = self.forward_inspection_clip_resnet( 641 | xs.contiguous()) 642 | ys_fc_features, ys_conv_features = self.forward_inspection_clip_resnet( 643 | ys.detach()) 644 | 645 | else: 646 | xs_fc_features, xs_conv_features = self.visual_encoder(xs) 647 | ys_fc_features, ys_conv_features = self.visual_encoder(ys) 648 | 649 | conv_loss = self.distance_metrics[self.clip_conv_loss_type]( 650 | xs_conv_features, ys_conv_features, self.clip_model_name) 651 | 652 | for layer, w in enumerate(self.clip_conv_layer_weights): 653 | if w: 654 | conv_loss_dict[f"clip_conv_loss_layer{layer}"] = conv_loss[layer] * w 655 | 656 | if self.clip_fc_loss_weight: 657 | # fc distance is always cos 658 | fc_loss = (1 - torch.cosine_similarity(xs_fc_features, 659 | ys_fc_features, dim=1)).mean() 660 | conv_loss_dict["fc"] = fc_loss * self.clip_fc_loss_weight 661 | 662 | self.counter += 1 663 | return conv_loss_dict 664 | 665 | def forward_inspection_clip_resnet(self, x): 666 | def stem(m, x): 667 | for conv, bn in [(m.conv1, m.bn1), (m.conv2, m.bn2), (m.conv3, m.bn3)]: 668 | x = m.relu(bn(conv(x))) 669 | x = m.avgpool(x) 670 | return x 671 | 672 | x = x.type(self.visual_model.conv1.weight.dtype) 673 | x = stem(self.visual_model, x) 674 | x1 = self.layer1(x) 675 | x2 = self.layer2(x1) 676 | x3 = self.layer3(x2) 677 | x4 = self.layer4(x3) 678 | y = self.att_pool2d(x4) 679 | return y, [x, x1, x2, x3, x4] 680 | -------------------------------------------------------------------------------- /models/pyramid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | def gaussian_kernel(size=5, device=torch.device('cpu'), channels=3, sigma=1, dtype=torch.float): 6 | # Create Gaussian Kernel. In Numpy 7 | interval = ( 2 *sigma +1) / (size) 8 | ax = np.linspace(-(size - 1) / 2., (size - 1) / 2., size) 9 | xx, yy = np.meshgrid(ax, ax) 10 | kernel = np.exp(-0.5 * (np.square(xx) + np.square(yy)) / np.square(sigma)) 11 | kernel /= np.sum(kernel) 12 | # Change kernel to PyTorch. reshapes to (channels, 1, size, size) 13 | kernel_tensor = torch.as_tensor(kernel, dtype=dtype) 14 | kernel_tensor = kernel_tensor.repeat(channels, 1, 1, 1) 15 | kernel_tensor.to(device) 16 | return kernel_tensor 17 | 18 | 19 | def gaussian_conv2d(x, g_kernel, dtype=torch.float): 20 | # Assumes input of x is of shape: (minibatch, depth, height, width) 21 | # Infer depth automatically based on the shape 22 | channels = g_kernel.shape[0] 23 | padding = g_kernel.shape[-1] // 2 # Kernel size needs to be odd number 24 | if len(x.shape) != 4: 25 | raise IndexError('Expected input tensor to be of shape: (batch, depth, height, width) but got: ' + str(x.shape)) 26 | y = F.conv2d(x, weight=g_kernel, stride=1, padding=padding, groups=channels) 27 | return y 28 | 29 | 30 | def downsample(x): 31 | # Downsamples along image (H,W). Takes every 2 pixels. output (H, W) = input (H/2, W/2) 32 | return x[:, :, ::2, ::2] 33 | 34 | 35 | def create_laplacian_pyramid(x, kernel, levels): 36 | # upsample = torch.nn.Upsample(scale_factor=2) # Default mode is nearest: [[1 2],[3 4]] -> [[1 1 2 2],[3 3 4 4]] 37 | pyramids = [] 38 | current_x = x 39 | for level in range(0, levels): 40 | gauss_filtered_x = gaussian_conv2d(current_x, kernel) 41 | down = downsample(gauss_filtered_x) 42 | # laplacian = current_x - upsample(down) 43 | pyramids.append(down) 44 | current_x = down 45 | pyramids.append(current_x) 46 | return pyramids 47 | 48 | 49 | class MSEPyramidLoss(torch.nn.Module): 50 | def __init__(self, max_levels=4, channels=3, kernel_size=5, sigma=1, device=torch.device('cpu'), dtype=torch.float): 51 | super(MSEPyramidLoss, self).__init__() 52 | self.max_levels = max_levels 53 | self.kernel = gaussian_kernel(size=kernel_size, channels=channels, sigma=sigma, dtype=dtype).to(device) 54 | 55 | def forward(self, x, target): 56 | input_pyramid = create_laplacian_pyramid(x, self.kernel, self.max_levels) 57 | target_pyramid = create_laplacian_pyramid(target, self.kernel, self.max_levels) 58 | return sum(torch.nn.functional.mse_loss(x, y) for x, y in zip(input_pyramid, target_pyramid)) 59 | -------------------------------------------------------------------------------- /models/structure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage import segmentation 3 | from skimage.feature import local_binary_pattern 4 | from scipy.ndimage import find_objects 5 | from skimage.segmentation import find_boundaries 6 | from skimage.color import rgb2lab 7 | from joblib import Parallel, delayed 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class SuperPixel: 13 | def __init__(self, device: torch.device = 'cpu', mode='simple'): 14 | self.device = device 15 | self.mode = mode 16 | 17 | def process(self, x: torch.Tensor): 18 | # B, C, H, W => B, H, W, C 19 | # Torch => Numpy 20 | skimage_format_tensor = x.permute((0, 2, 3, 1)).cpu().numpy() 21 | 22 | if (self.mode == 'simple'): 23 | skimage_format_tensor = simple_superpixel(skimage_format_tensor) 24 | elif (self.mode == 'sscolor'): 25 | skimage_format_tensor = selective_adacolor(skimage_format_tensor) 26 | 27 | # B, H, W, C => B, C, H, W 28 | # Numpy => Torch 29 | return torch.from_numpy(skimage_format_tensor).to(self.device).permute((0, 3, 1, 2)) 30 | 31 | 32 | # Adaptive Coloring 33 | def label2rgb(label_field, image, kind='mix', bg_label=-1, bg_color=(0, 0, 0)): 34 | out = np.zeros_like(image) 35 | labels = np.unique(label_field) 36 | bg = (labels == bg_label) 37 | if bg.any(): 38 | labels = labels[labels != bg_label] 39 | mask = (label_field == bg_label).nonzero() 40 | out[mask] = bg_color 41 | for label in labels: 42 | mask = (label_field == label).nonzero() 43 | color: np.ndarray = None 44 | if kind == 'avg': 45 | color = image[mask].mean(axis=0) 46 | elif kind == 'median': 47 | color = np.median(image[mask], axis=0) 48 | elif kind == 'mix': 49 | std = np.std(image[mask]) 50 | if std < 20: 51 | color = image[mask].mean(axis=0) 52 | elif 20 < std < 40: 53 | mean = image[mask].mean(axis=0) 54 | median = np.median(image[mask], axis=0) 55 | color = 0.5 * mean + 0.5 * median 56 | elif 40 < std: 57 | color = np.median(image[mask], axis=0) 58 | out[mask] = color 59 | return out 60 | 61 | 62 | # Simple Linear Iterative Clustering 63 | def slic(image, seg_num=200, kind='mix'): 64 | seg_label = segmentation.slic(image, n_segments=seg_num, sigma=1, compactness=10, convert2lab=True) 65 | image = label2rgb(seg_label, image, kind=kind, bg_label=-1) 66 | return image 67 | 68 | 69 | # Apply slic to batches 70 | def simple_superpixel(batch_image, seg_num=200, kind='mix'): 71 | num_job = np.shape(batch_image)[0] 72 | batch_out = Parallel(n_jobs=num_job)(delayed(slic) \ 73 | (image, seg_num, kind) for image in batch_image) 74 | return np.array(batch_out) 75 | 76 | 77 | # Felzenszwalb algorithm + Selective Search 78 | def color_ss_map(image, seg_num=200, power=1.2, k=10, sim_strategy='CTSF'): 79 | img_seg = segmentation.felzenszwalb(image, scale=k, sigma=0.8, min_size=100) 80 | img_cvtcolor = label2rgb(img_seg, image, kind='mix') 81 | 82 | img_cvtcolor = rgb2lab(img_cvtcolor) 83 | S = HierarchicalGrouping(img_cvtcolor, img_seg, sim_strategy) 84 | S.build_regions() 85 | S.build_region_pairs() 86 | 87 | # Start hierarchical grouping 88 | while S.num_regions() > seg_num: 89 | i, j = S.get_highest_similarity() 90 | S.merge_region(i, j) 91 | S.remove_similarities(i, j) 92 | S.calculate_similarity_for_new_region() 93 | 94 | image = label2rgb(S.img_seg, image, kind='mix') 95 | image = (image + 1) / 2 96 | image = image ** power 97 | if (not np.max(image) == 0): 98 | image = image / np.max(image) 99 | image = image * 2 - 1 100 | return image 101 | 102 | 103 | # Apply color_ss_map to batches 104 | def selective_adacolor(batch_image, seg_num=200, power=1.2): 105 | num_job = np.shape(batch_image)[0] 106 | batch_out = Parallel(n_jobs=num_job)(delayed(color_ss_map) \ 107 | (image, seg_num, power) for image in batch_image) 108 | return np.array(batch_out) 109 | 110 | 111 | class HierarchicalGrouping(object): 112 | def __init__(self, img, img_seg, sim_strategy): 113 | self.img = img 114 | self.sim_strategy = sim_strategy 115 | self.img_seg = img_seg.copy() 116 | self.labels = np.unique(self.img_seg).tolist() 117 | 118 | def build_regions(self): 119 | self.regions = {} 120 | lbp_img = generate_lbp_image(self.img) 121 | for label in self.labels: 122 | size = (self.img_seg == 1).sum() 123 | region_slice = find_objects(self.img_seg == label)[0] 124 | box = tuple([region_slice[i].start for i in (1, 0)] + 125 | [region_slice[i].stop for i in (1, 0)]) 126 | 127 | mask = self.img_seg == label 128 | color_hist = calculate_color_hist(mask, self.img) 129 | texture_hist = calculate_texture_hist(mask, lbp_img) 130 | 131 | self.regions[label] = { 132 | 'size': size, 133 | 'box': box, 134 | 'color_hist': color_hist, 135 | 'texture_hist': texture_hist 136 | } 137 | 138 | def build_region_pairs(self): 139 | self.s = {} 140 | for i in self.labels: 141 | neighbors = self._find_neighbors(i) 142 | for j in neighbors: 143 | if i < j: 144 | self.s[(i, j)] = calculate_sim(self.regions[i], 145 | self.regions[j], 146 | self.img.size, 147 | self.sim_strategy) 148 | 149 | def _find_neighbors(self, label): 150 | """ 151 | Parameters 152 | ---------- 153 | label : int 154 | label of the region 155 | Returns 156 | ------- 157 | neighbors : list 158 | list of labels of neighbors 159 | """ 160 | 161 | boundary = find_boundaries(self.img_seg == label, 162 | mode='outer') 163 | neighbors = np.unique(self.img_seg[boundary]).tolist() 164 | 165 | return neighbors 166 | 167 | def get_highest_similarity(self): 168 | return sorted(self.s.items(), key=lambda i: i[1])[-1][0] 169 | 170 | def merge_region(self, i, j): 171 | 172 | # generate a unique label and put in the label list 173 | new_label = max(self.labels) + 1 174 | self.labels.append(new_label) 175 | 176 | # merge blobs and update blob set 177 | ri, rj = self.regions[i], self.regions[j] 178 | 179 | new_size = ri['size'] + rj['size'] 180 | new_box = (min(ri['box'][0], rj['box'][0]), 181 | min(ri['box'][1], rj['box'][1]), 182 | max(ri['box'][2], rj['box'][2]), 183 | max(ri['box'][3], rj['box'][3])) 184 | value = { 185 | 'box': new_box, 186 | 'size': new_size, 187 | 'color_hist': 188 | (ri['color_hist'] * ri['size'] 189 | + rj['color_hist'] * rj['size']) / new_size, 190 | 'texture_hist': 191 | (ri['texture_hist'] * ri['size'] 192 | + rj['texture_hist'] * rj['size']) / new_size, 193 | } 194 | 195 | self.regions[new_label] = value 196 | 197 | # update segmentation mask 198 | self.img_seg[self.img_seg == i] = new_label 199 | self.img_seg[self.img_seg == j] = new_label 200 | 201 | def remove_similarities(self, i, j): 202 | 203 | # mark keys for region pairs to be removed 204 | key_to_delete = [] 205 | for key in self.s.keys(): 206 | if (i in key) or (j in key): 207 | key_to_delete.append(key) 208 | 209 | for key in key_to_delete: 210 | del self.s[key] 211 | 212 | # remove old labels in label list 213 | self.labels.remove(i) 214 | self.labels.remove(j) 215 | 216 | def calculate_similarity_for_new_region(self): 217 | i = max(self.labels) 218 | neighbors = self._find_neighbors(i) 219 | 220 | for j in neighbors: 221 | # i is larger than j, so use (j,i) instead 222 | self.s[(j, i)] = calculate_sim(self.regions[i], 223 | self.regions[j], 224 | self.img.size, 225 | self.sim_strategy) 226 | 227 | def is_empty(self): 228 | return True if not self.s.keys() else False 229 | 230 | def num_regions(self): 231 | return len(self.s.keys()) 232 | 233 | 234 | def calculate_color_hist(mask, img): 235 | """ 236 | Calculate colour histogram for the region. 237 | The output will be an array with n_BINS * n_color_channels. 238 | The number of channel is varied because of different 239 | colour spaces. 240 | """ 241 | 242 | BINS = 25 243 | if len(img.shape) == 2: 244 | img = img.reshape(img.shape[0], img.shape[1], 1) 245 | 246 | channel_nums = img.shape[2] 247 | hist = np.array([]) 248 | 249 | for channel in range(channel_nums): 250 | layer = img[:, :, channel][mask] 251 | hist = np.concatenate([hist] + [np.histogram(layer, BINS)[0]]) 252 | 253 | # L1 normalize 254 | hist = hist / np.sum(hist) 255 | 256 | return hist 257 | 258 | 259 | def generate_lbp_image(img): 260 | if len(img.shape) == 2: 261 | img = img.reshape(img.shape[0], img.shape[1], 1) 262 | channel_nums = img.shape[2] 263 | 264 | lbp_img = np.zeros(img.shape) 265 | for channel in range(channel_nums): 266 | layer = img[:, :, channel] 267 | lbp_img[:, :, channel] = local_binary_pattern(layer, 8, 1) 268 | 269 | return lbp_img 270 | 271 | 272 | def calculate_texture_hist(mask, lbp_img): 273 | """ 274 | Use LBP for now, enlightened by AlpacaDB's implementation. 275 | Plan to switch to Gaussian derivatives as the paper in future 276 | version. 277 | """ 278 | 279 | BINS = 10 280 | channel_nums = lbp_img.shape[2] 281 | hist = np.array([]) 282 | 283 | for channel in range(channel_nums): 284 | layer = lbp_img[:, :, channel][mask] 285 | hist = np.concatenate([hist] + [np.histogram(layer, BINS)[0]]) 286 | 287 | # L1 normalize 288 | hist = hist / np.sum(hist) 289 | 290 | return hist 291 | 292 | 293 | def calculate_sim(ri, rj, imsize, sim_strategy): 294 | """ 295 | Calculate similarity between region ri and rj using diverse 296 | combinations of similarity measures. 297 | C: color, T: texture, S: size, F: fill. 298 | """ 299 | sim = 0 300 | 301 | if 'C' in sim_strategy: 302 | sim += _calculate_color_sim(ri, rj) 303 | if 'T' in sim_strategy: 304 | sim += _calculate_texture_sim(ri, rj) 305 | if 'S' in sim_strategy: 306 | sim += _calculate_size_sim(ri, rj, imsize) 307 | if 'F' in sim_strategy: 308 | sim += _calculate_fill_sim(ri, rj, imsize) 309 | 310 | return sim 311 | 312 | 313 | def _calculate_color_sim(ri, rj): 314 | """ 315 | Calculate color similarity using histogram intersection 316 | """ 317 | return sum([min(a, b) for a, b in zip(ri["color_hist"], rj["color_hist"])]) 318 | 319 | 320 | def _calculate_texture_sim(ri, rj): 321 | """ 322 | Calculate texture similarity using histogram intersection 323 | """ 324 | return sum([min(a, b) for a, b in zip(ri["texture_hist"], rj["texture_hist"])]) 325 | 326 | 327 | def _calculate_size_sim(ri, rj, imsize): 328 | """ 329 | Size similarity boosts joint between small regions, which prevents 330 | a single region from engulfing other blobs one by one. 331 | size (ri, rj) = 1 − [size(ri) + size(rj)] / size(image) 332 | """ 333 | return 1.0 - (ri['size'] + rj['size']) / imsize 334 | 335 | 336 | def _calculate_fill_sim(ri, rj, imsize): 337 | """ 338 | Fill similarity measures how well ri and rj fit into each other. 339 | BBij is the bounding box around ri and rj. 340 | fill(ri, rj) = 1 − [size(BBij) − size(ri) − size(ri)] / size(image) 341 | """ 342 | 343 | bbsize = (max(ri['box'][2], rj['box'][2]) - min(ri['box'][0], rj['box'][0])) * ( 344 | max(ri['box'][3], rj['box'][3]) - min(ri['box'][1], rj['box'][1])) 345 | 346 | return 1.0 - (bbsize - ri['size'] - rj['size']) / imsize 347 | 348 | 349 | if __name__ == "__main__": 350 | super_pixel = SuperPixel(mode='sscolor') 351 | input = torch.randn(5, 3, 256, 256) 352 | result = super_pixel.process(input) 353 | print(result.shape) -------------------------------------------------------------------------------- /reduce_and_optimize.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | import os.path as osp 4 | from numpy.random import choice 5 | import torch 6 | import pydiffvg 7 | import numpy as np 8 | 9 | from tqdm import tqdm 10 | 11 | import custom_parser 12 | from utils import bcolors 13 | from histogram_loss import NonDifferentiableHistogramLoss 14 | 15 | from basic_diffvg import (VanillaDiffVG, 16 | compose_image_with_white_background, 17 | render_based_on_shapes_and_shape_groups) 18 | 19 | 20 | class ReduceAndOptimize(VanillaDiffVG): 21 | def __init__(self, shapes_num_scheduler: list, 22 | ranking_loss_type: str = 'mse', 23 | ranking_l1_extent_in_convex_sum_l1_and_clip: float = 1.0, 24 | ranking_clip_loss_config_file: str = 'test/config_init.npy', 25 | sample_importance=True, 26 | sample_beta=1, 27 | *args, **kwargs): 28 | super(ReduceAndOptimize, self).__init__(*args, **kwargs) 29 | self.sample_importance = sample_importance 30 | self.sample_beta = sample_beta 31 | self.shapes_num_scheduler = shapes_num_scheduler 32 | self.ranking_loss_type = ranking_loss_type 33 | if ranking_loss_type == 'histogram': 34 | self.ranking_loss_fn = NonDifferentiableHistogramLoss( 35 | self.image_for_diffvg) 36 | else: 37 | self.ranking_loss_fn = self.get_reconstruction_loss( 38 | loss_type=ranking_loss_type, 39 | alpha=ranking_l1_extent_in_convex_sum_l1_and_clip, 40 | clip_loss_config_file=ranking_clip_loss_config_file) 41 | 42 | @staticmethod 43 | def create_shape_group(shape_id: int, fill_color: torch.tensor = torch.zeros(4)) -> pydiffvg.ShapeGroup: 44 | return pydiffvg.ShapeGroup(shape_ids=torch.tensor([shape_id]), 45 | fill_color=fill_color) 46 | 47 | def shapes_to_importance(self, shapes, shape_groups): 48 | # create dummy Bezier-curve and dummy shape group 49 | dummy_path = pydiffvg.Path( 50 | num_control_points=shapes[0].num_control_points, 51 | points=torch.zeros_like(shapes[0].points), 52 | stroke_width=torch.tensor(0.0), is_closed=True) 53 | 54 | shapes_importance = [] 55 | target_image_with_white_bg = compose_image_with_white_background(self.image_for_diffvg) 56 | 57 | for i in tqdm(range(len(shapes))): 58 | if i == len(shapes) - 1: 59 | all_shapes_but_one = shapes[:-1] 60 | all_shape_groups_but_one = shape_groups[:-1] 61 | else: 62 | all_shapes_but_one = shapes[:i] + [dummy_path] + shapes[i + 1:] 63 | dummy_path_group = self.create_shape_group(i) 64 | all_shape_groups_but_one = shape_groups[:i] + [dummy_path_group] + shape_groups[i + 1:] 65 | 66 | image_without_one_shape = render_based_on_shapes_and_shape_groups( 67 | all_shapes_but_one, all_shape_groups_but_one, no_grad=True, 68 | canvas_width=self.canvas_width, 69 | canvas_height=self.canvas_height) 70 | image_without_one_shape_white_bg = \ 71 | compose_image_with_white_background(image_without_one_shape) 72 | # NOTE: decide if this needs to be self.image_for_diffvg OR the 73 | # image create from PNG2SVG_using_diffVG(target_PNG). 74 | loss = self.ranking_loss_fn( 75 | target_image_with_white_bg, 76 | image_without_one_shape_white_bg) 77 | shapes_importance.append(loss.item()) 78 | # validate lengths 79 | assert len(shapes) == len(shape_groups) == len(shapes_importance) 80 | if self.sample_importance: 81 | sample_shape_importance = self.sample_according_to_importance(shapes_importance, self.sample_beta) 82 | else: 83 | sample_shape_importance = sorted(range(len(shapes_importance)), key=shapes_importance.__getitem__) 84 | return sample_shape_importance 85 | 86 | @staticmethod 87 | def sample_according_to_importance(shape_importance, sample_beta): 88 | shape_importance_tensor = torch.tensor(shape_importance[1:], dtype=torch.float32) 89 | shapes_weight = torch.nn.functional.softmax(shape_importance_tensor / sample_beta, dim=0) 90 | shapes_weight = (shapes_weight / sum(shapes_weight)).tolist() 91 | if sum(shapes_weight) != 1: 92 | diff = sum(shapes_weight) - 1 93 | shapes_weight[np.argmax(shapes_weight)] -= diff 94 | sample_shapes = choice(range(1, len(shape_importance)), len(shape_importance)-1, p=shapes_weight, replace=False) 95 | # Add BG at the end - so it won't be reduced 96 | return np.append(sample_shapes[::-1], 0) 97 | 98 | @staticmethod 99 | def carve_shapes(shapes, shape_groups, shapes_index_sample, 100 | how_many_to_carve): 101 | shapes_subset = [] 102 | shape_groups_subset = [] 103 | new_shape_id = 0 104 | 105 | shapes_indices_to_remove = shapes_index_sample[:how_many_to_carve + 1] 106 | for pos, (shape, shape_group) in enumerate(zip(shapes, shape_groups)): 107 | if pos not in shapes_indices_to_remove: 108 | shapes_subset.append(shape) 109 | path_group = pydiffvg.ShapeGroup( 110 | shape_ids=torch.tensor([new_shape_id]), 111 | fill_color=shape_group.fill_color) 112 | new_shape_id += 1 113 | shape_groups_subset.append(path_group) 114 | return shapes_subset, shape_groups_subset 115 | 116 | def run(self): 117 | # init: 118 | shapes, shape_groups = self.get_initial_shapes(self.shapes_num_scheduler[0], 119 | self.canvas_width, 120 | self.canvas_height, ) 121 | self.save_intermediate_image(shapes, shape_groups, epoch=-1) 122 | # for loop on reduce steps: 123 | for epoch, (curr_num_shapes, next_num_shapes, num_iters) in enumerate(zip( 124 | self.shapes_num_scheduler[:-1], self.shapes_num_scheduler[1:], self.num_iterations[:-1])): 125 | # optimize: 126 | shapes, shape_groups = self.optimize_shapes( 127 | self.image_for_diffvg, shapes, shape_groups, num_iters, epoch) 128 | self.save_intermediate_image(shapes, shape_groups, 2 * epoch) 129 | self.save_svg_image_by_name(shapes, shape_groups, 130 | f"after_optimization_{len(shapes):04d}") 131 | assert len(shapes) == curr_num_shapes 132 | # rank: 133 | shapes_rank = self.shapes_to_importance(shapes, shape_groups) 134 | assert len(shapes) == curr_num_shapes 135 | # reduce: 136 | how_many_to_carve = curr_num_shapes - next_num_shapes - 1 137 | shapes, shape_groups = self.carve_shapes(shapes, shape_groups, 138 | shapes_rank, 139 | how_many_to_carve) 140 | assert len(shapes) == next_num_shapes 141 | self.save_intermediate_image(shapes, shape_groups, 2 * epoch + 1) 142 | 143 | # lastly, we need to optimize: 144 | total_num_epochs = 2 * len(self.shapes_num_scheduler) 145 | shapes, shape_groups = self.optimize_shapes( 146 | self.image_for_diffvg, shapes, shape_groups, self.num_iterations[-1], 147 | 2 * total_num_epochs) 148 | self.save_final_result(shapes, shape_groups) 149 | 150 | 151 | def main(script_args): 152 | image_name = script_args.target.split("/")[-1].split(".")[0] 153 | scheduler = [int(x) for x in script_args.scheduler] 154 | iterations = [int(x) for x in script_args.num_iter] 155 | num_paths = scheduler[0] 156 | recons_loss_type = script_args.recons_loss_type 157 | geometric_loss_type = script_args.geometric_loss_type 158 | if script_args.experiment_name != '': 159 | experiment_name = script_args.experiment_name 160 | else: 161 | experiment_name = f'reduce_and_optimize_' \ 162 | f'{image_name}_{num_paths}_' \ 163 | f'rec_{recons_loss_type}_' \ 164 | f'geom_{geometric_loss_type}' 165 | if script_args.advanced_logging: 166 | experiment_name += "_advanced_logging" 167 | 168 | root_out_dir = osp.join(script_args.results_dir, experiment_name) 169 | 170 | diffvg_runner = ReduceAndOptimize( 171 | shapes_num_scheduler=scheduler, 172 | ranking_loss_type=script_args.ranking_loss_type, 173 | ranking_l1_extent_in_convex_sum_l1_and_clip=script_args.ranking_l1_and_clip_alpha, 174 | ranking_clip_loss_config_file=script_args.ranking_clip_config_file, 175 | path_to_png_image=script_args.target, 176 | num_paths=num_paths, 177 | num_iterations=iterations, 178 | epochs=1, 179 | root_output_directory=root_out_dir, 180 | canvas_height=script_args.canvas_height, 181 | canvas_width=script_args.canvas_width, 182 | reconstruction_loss_type=recons_loss_type, 183 | l1_extent_in_convex_sum_l1_and_clip=script_args.l1_and_clip_alpha, 184 | lambda_geometric=script_args.lambda_geometric, 185 | geometric_loss_lamda_geometric_punish=script_args.geometric_loss_lamda_geometric_punish, 186 | clip_config_file=script_args.clip_config_file, 187 | geometric_loss_type=geometric_loss_type, 188 | is_advanced_logging=script_args.advanced_logging, 189 | sample_importance=script_args.sample_importance, 190 | text_prompt=script_args.text_prompt, 191 | init_type=script_args.init_type, 192 | init_shape=script_args.init_shape, 193 | ) 194 | diffvg_runner.run() 195 | 196 | 197 | if __name__ == "__main__": 198 | args = parser.parse_arguments() 199 | print(args) 200 | main(args) 201 | -------------------------------------------------------------------------------- /reduce_or_add_and_optimize.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import pydiffvg 7 | from basic_diffvg import compose_image_with_white_background, add_to_file 8 | from custom_parser import parse_arguments 9 | from reduce_and_optimize import ReduceAndOptimize 10 | from utils import bcolors 11 | import yaml 12 | import os 13 | import time 14 | 15 | 16 | class ReduceOrAddAndOptimize(ReduceAndOptimize): 17 | def __init__(self, *args, **kwargs): 18 | super(ReduceOrAddAndOptimize, self).__init__(*args, **kwargs) 19 | 20 | @staticmethod 21 | def interleave_lists(list1, indices, list3): 22 | """Interleave lists and get new indices. 23 | 24 | Interleave the elements of list3 into list1 at the positions 25 | indicated by indices. Returns the interleaved list and an index map 26 | from the original positions of list1 to the new positions of the 27 | returned list. 28 | 29 | Args: 30 | list1: list. The base list of items. 31 | indices: list. The indices in which you want to interleave list3 32 | items in list1. 33 | list3: list. The list of items to put in the indices of list1. 34 | Returns: tuple. The first item is the interleaved list. 35 | The second item is a map from old indices of list1 to the new 36 | indices of list1. 37 | The third item is a map from old indices of list3 to the new 38 | indices of list3. 39 | 40 | >>> list1 = [1, 2, 3, 4, 5] 41 | >>> indices = [1, 3] 42 | >>> list2 = [6, 7] 43 | >>> merged_list, new_locs_list1, new_locs_list2 = merge_lists(list1, indices, list2) 44 | >>> print(merged_list) 45 | >>> [1, 6, 2, 3, 7, 4, 5] 46 | >>> print(new_locs_list1) 47 | >>> {0: 0, 1: 2, 2: 3, 3: 5, 4: 6} 48 | >>> print(new_locs_list2) 49 | >>> {0: 1, 1: 4} 50 | """ 51 | result = [] 52 | i = 0 53 | j = 0 54 | index_map_first_list = {} 55 | index_map_second_list = {} 56 | while i < len(list1): 57 | if i in indices: 58 | result.append(list3[j]) 59 | index_map_second_list[j] = len(result) - 1 60 | j += 1 61 | result.append(list1[i]) 62 | index_map_first_list[i] = len(result) - 1 63 | i += 1 64 | while j < len(list3): 65 | result.append(list3[j]) 66 | index_map_second_list[j] = len(result) - 1 67 | j += 1 68 | return result, index_map_first_list, index_map_second_list 69 | 70 | def add_shapes(self, shapes, shape_groups, how_many_to_add): 71 | current_image = self.render_tensor_image_from_shapes(shapes, shape_groups) 72 | shapes_to_add, shape_groups_to_add = self.get_initial_shapes(how_many_to_add, 73 | self.canvas_width, 74 | self.canvas_height, 75 | current_image) 76 | # NOTE: for now, we scatter the shapes on top of the existing ones. 77 | # Consider adding also to background shapes. 78 | where_to_add = [len(shapes)] * how_many_to_add 79 | new_shapes, old_shapes_ids_to_new_shape_ids, new_shapes_ids = \ 80 | self.interleave_lists(shapes, where_to_add, shapes_to_add) 81 | new_shape_groups = [] 82 | for shg in shape_groups: 83 | new_shape_id = old_shapes_ids_to_new_shape_ids[shg.shape_ids.item()] 84 | new_shape_groups.append(self.create_shape_group( 85 | new_shape_id, fill_color=shg.fill_color)) 86 | for shg in shape_groups_to_add: 87 | new_shape_id = new_shapes_ids[shg.shape_ids.item()] 88 | new_shape_groups.append(self.create_shape_group( 89 | new_shape_id, fill_color=shg.fill_color)) 90 | new_shape_groups.sort(key=lambda x: x.shape_ids.item()) 91 | return new_shapes, new_shape_groups 92 | 93 | def run(self): 94 | # init: 95 | add_to_file({'1_start': time.time()}, self.timing_file) 96 | if self.input_svg is None: 97 | shapes, shape_groups = self.get_initial_shapes(self.shapes_num_scheduler[0], 98 | self.canvas_width, 99 | self.canvas_height, ) 100 | else: 101 | canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(self.input_svg) 102 | self.save_intermediate_image(shapes, shape_groups, epoch=-1) 103 | add_to_file({'1_end': time.time()}, self.timing_file) 104 | # for loop on reduce steps: 105 | for epoch, (curr_num_shapes, next_num_shapes, num_iters) in enumerate(zip( 106 | self.shapes_num_scheduler[:-1], self.shapes_num_scheduler[1:], self.num_iterations[:-1])): 107 | # optimize: 108 | add_to_file({f'{epoch+2}_start': time.time()}, self.timing_file) 109 | shapes, shape_groups = self.optimize_shapes( 110 | self.image_for_diffvg, shapes, shape_groups, 111 | num_iters, epoch, early_stopping=self.early_stopping) 112 | self.save_intermediate_image(shapes, shape_groups, 2 * epoch) 113 | self.save_svg_image_by_name(shapes, shape_groups, 114 | f"after_optimization_{len(shapes):04d}") 115 | assert len(shapes) == curr_num_shapes 116 | # (rank+)reduce or add: 117 | how_many_to_carve = curr_num_shapes - next_num_shapes - 1 118 | if how_many_to_carve > 0: # reduce 119 | # rank: 120 | shapes_importance = self.shapes_to_importance(shapes, 121 | shape_groups) 122 | assert len(shapes) == curr_num_shapes 123 | shapes, shape_groups = self.carve_shapes(shapes, shape_groups, 124 | shapes_importance, 125 | how_many_to_carve) 126 | assert len(shapes) == next_num_shapes 127 | else: # add shapes 128 | how_many_to_add = next_num_shapes - curr_num_shapes 129 | shapes, shape_groups = self.add_shapes(shapes, shape_groups, 130 | how_many_to_add) 131 | assert len(shapes) == next_num_shapes 132 | self.save_intermediate_image(shapes, shape_groups, 2 * epoch + 1) 133 | add_to_file({f'{epoch+2}_end': time.time()}, self.timing_file) 134 | 135 | # lastly, we need to optimize: 136 | add_to_file({f'{len(self.shapes_num_scheduler[:-1]) + 2}_start': time.time()}, self.timing_file) 137 | total_num_epochs = 2 * len(self.shapes_num_scheduler) 138 | shapes, shape_groups = self.optimize_shapes( 139 | self.image_for_diffvg, shapes, shape_groups, self.num_iterations[-1], 140 | 2 * total_num_epochs) 141 | self.save_final_result(shapes, shape_groups) 142 | add_to_file({f'{len(self.shapes_num_scheduler[:-1]) + 2}_total_end': time.time()}, self.timing_file) 143 | 144 | def set_seed(seed): 145 | random.seed(seed) 146 | np.random.seed(seed) 147 | os.environ['PYTHONHASHSEED'] = str(seed) 148 | torch.manual_seed(seed) 149 | torch.cuda.manual_seed(seed) 150 | torch.cuda.manual_seed_all(seed) 151 | 152 | def main(script_args): 153 | # set_seed(0) 154 | image_name = script_args.target.split("/")[-1].split(".")[0] 155 | scheduler = [int(x) for x in script_args.scheduler] 156 | iterations = [int(x) for x in script_args.num_iter] 157 | num_paths = scheduler[0] 158 | recons_loss_type = script_args.recons_loss_type 159 | geometric_loss_type = script_args.geometric_loss_type 160 | if script_args.experiment_name != '': 161 | experiment_name = script_args.experiment_name 162 | else: 163 | experiment_name = f'reduce_or_add_and_optimize_' \ 164 | f'{image_name}_{num_paths}_' \ 165 | f'rec_{recons_loss_type}_' \ 166 | f'geom_{geometric_loss_type}' 167 | if script_args.advanced_logging: 168 | experiment_name += "_advanced_logging" 169 | 170 | root_out_dir = osp.join(script_args.results_dir, experiment_name) 171 | 172 | print(f"{bcolors.OKGREEN}running experiment {experiment_name}... {bcolors.ENDC}") 173 | 174 | print(f'{bcolors.OKCYAN}experiment args: {yaml.dump(vars(script_args))} {bcolors.ENDC}') 175 | config_dir = os.path.join(root_out_dir, 'config') 176 | os.makedirs(config_dir, exist_ok=True) 177 | config_file = os.path.join(config_dir, 'config.yaml') 178 | with open(config_file, 'w') as f: 179 | f.write(yaml.dump(vars(script_args))) 180 | 181 | timing_dir = os.path.join(root_out_dir, 'timing') 182 | os.makedirs(timing_dir, exist_ok=True) 183 | timing_file = os.path.join(timing_dir, 'timing.json') 184 | loss_dir = os.path.join(root_out_dir, 'loss') 185 | os.makedirs(loss_dir, exist_ok=True) 186 | final_loss_path = os.path.join(loss_dir, 'loss.json') 187 | 188 | add_to_file({'0_total_start': time.time()}, timing_file) 189 | 190 | diffvg_runner = ReduceOrAddAndOptimize( 191 | shapes_num_scheduler=scheduler, 192 | ranking_loss_type=script_args.ranking_loss_type, 193 | ranking_l1_extent_in_convex_sum_l1_and_clip=script_args.ranking_l1_and_clip_alpha, 194 | ranking_clip_loss_config_file=script_args.ranking_clip_config_file, 195 | path_to_png_image=script_args.target, 196 | num_paths=num_paths, 197 | num_iterations=iterations, 198 | epochs=1, 199 | root_output_directory=root_out_dir, 200 | canvas_height=script_args.canvas_height, 201 | canvas_width=script_args.canvas_width, 202 | reconstruction_loss_type=recons_loss_type, 203 | l1_extent_in_convex_sum_l1_and_clip=script_args.l1_and_clip_alpha, 204 | lambda_geometric=script_args.lambda_geometric, 205 | geometric_loss_lamda_geometric_punish=script_args.geometric_loss_lamda_geometric_punish, 206 | clip_config_file=script_args.clip_config_file, 207 | geometric_loss_type=geometric_loss_type, 208 | is_advanced_logging=script_args.advanced_logging, 209 | init_type=script_args.init_type, 210 | init_shape=script_args.init_shape, 211 | timing_json_path=timing_file, 212 | final_loss_path=final_loss_path, 213 | early_stopping=script_args.early_stopping, 214 | sample_beta=script_args.sample_beta, 215 | input_svg=script_args.input_svg, 216 | ) 217 | diffvg_runner.run() 218 | 219 | 220 | if __name__ == "__main__": 221 | args = parse_arguments() 222 | print(args) 223 | main(args) 224 | -------------------------------------------------------------------------------- /target_images/083.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajevnisek/optimize-and-reduce/f621b8eac4830d213d889f9c78d1bca1e47b0e7c/target_images/083.png -------------------------------------------------------------------------------- /test/config_init.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajevnisek/optimize-and-reduce/f621b8eac4830d213d889f9c78d1bca1e47b0e7c/test/config_init.npy -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | import torchvision.transforms.functional as F 6 | 7 | 8 | class bcolors: 9 | HEADER = '\033[95m' 10 | OKBLUE = '\033[94m' 11 | OKCYAN = '\033[96m' 12 | OKGREEN = '\033[92m' 13 | WARNING = '\033[93m' 14 | FAIL = '\033[91m' 15 | ENDC = '\033[0m' 16 | BOLD = '\033[1m' 17 | UNDERLINE = '\033[4m' 18 | 19 | 20 | def save_grid_to_path(imgs, path): 21 | if not isinstance(imgs, list): 22 | imgs = [imgs] 23 | fig, axs = plt.subplots(ncols=len(imgs), squeeze=False) 24 | for i, img in enumerate(imgs): 25 | img = img.detach() 26 | img = F.to_pil_image(img) 27 | axs[0, i].imshow(np.asarray(img)) 28 | axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) 29 | plt.savefig(path) 30 | 31 | 32 | def create_fig_from_images(root_dir_images, path_video): 33 | filenames = [os.path.join(root_dir_images, x) for x in os.listdir(root_dir_images) if x.endswith('.png')] 34 | filenames.sort() 35 | import imageio 36 | images = [] 37 | for filename in filenames: 38 | images.append(imageio.imread(filename)) 39 | imageio.mimsave(path_video, images) 40 | 41 | --------------------------------------------------------------------------------