├── .gitignore ├── teaser.png ├── requirements.txt ├── bpe_simple_vocab_16e6.txt.gz ├── hyperbolic.py ├── README.md ├── tokenizer.py ├── model.py ├── LICENSE └── NOTICE /.gitignore: -------------------------------------------------------------------------------- 1 | ckpt.pt 2 | reference.pt 3 | .hypothesis/ 4 | __pycache__/ -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/hype/HEAD/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | Pillow 3 | numpy 4 | requests 5 | ftfy 6 | regex -------------------------------------------------------------------------------- /bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naver-ai/hype/HEAD/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /hyperbolic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import os 4 | 5 | ref_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "reference.pt") 6 | ref = torch.load(ref_path) 7 | img_ref, txt_ref = ref["img"], ref["txt"] 8 | 9 | 10 | @torch.cuda.amp.autocast(enabled=False) 11 | def expm(v, curvature, time_keepdim=False): 12 | v, curvature = v.float(), curvature.float() 13 | x_space_temp = torch.sqrt(curvature) * torch.norm(v, dim=-1, keepdim=True) 14 | x_space = ( 15 | torch.sinh(torch.clamp(x_space_temp, min=1e-8, max=math.asinh(2**15))) * v / torch.clamp(x_space_temp, min=1e-8) 16 | ) 17 | x_time = torch.sqrt(1 / curvature + torch.sum(x_space**2, dim=-1, keepdim=time_keepdim)) # [B, D] 18 | return x_space, x_time 19 | 20 | 21 | @torch.cuda.amp.autocast(enabled=False) 22 | def similarity(x, y, curvature): 23 | x, y = x.float(), y.float() 24 | curvature = curvature.float() 25 | x_space, x_time = expm(x, curvature, time_keepdim=True) 26 | y_space, y_time = expm(y, curvature, time_keepdim=True) 27 | xy_inner = x_space @ y_space.T - x_time * y_time.T 28 | lorentzian_distance = torch.rsqrt(curvature) * torch.acosh(torch.clamp(-curvature * xy_inner, min=1e-8)) 29 | return -lorentzian_distance 30 | 31 | 32 | @torch.no_grad() 33 | def entailment(x, y, curvature): 34 | x_space, x_time = expm(x, curvature, time_keepdim=True) 35 | y_space, y_time = expm(y, curvature, time_keepdim=True) 36 | 37 | K = 0.1 38 | x_euc_norm = torch.norm(x_space, dim=-1, keepdim=True) 39 | denominator = torch.sqrt(curvature) * x_euc_norm + 1e-8 40 | aperture_x = torch.arcsin(torch.clamp(2 * K / denominator, -1 + 1e-8, 1 - 1e-8)) 41 | 42 | xy_inner = x_space @ y_space.T - x_time * y_time.T 43 | denominator = x_euc_norm * torch.sqrt(torch.clamp((curvature * xy_inner) ** 2 - 1, min=1e-8)) + 1e-8 44 | numerator = y_time.T + x_time * curvature * xy_inner 45 | exterior_xy = torch.arccos(torch.clamp(numerator / denominator, -1.0 + 1e-8, 1.0 - 1e-8)) 46 | 47 | return exterior_xy - aperture_x 48 | 49 | 50 | def specificity(image=None, text=None, curv=None): 51 | assert (image is not None) ^ (text is not None), "Either image or text must be provided but not both" 52 | assert curv is not None, "Curvature must be provided" 53 | 54 | global img_ref, txt_ref 55 | 56 | if image is not None: 57 | txt_ref = txt_ref.to(image.device) 58 | ient = entailment(txt_ref, image, curv) 59 | return ient.mean(dim=0) 60 | else: 61 | img_ref = img_ref.to(text.device) 62 | tent = entailment(text, img_ref, curv) 63 | return tent.mean(dim=1) 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # HYPE: Hyperbolic Entailment Filtering for Underspecified Images and Texts 4 | 5 | **[Wonjae Kim](https://wonjae.kim), [Sanghyuk Chun](https://sanghyukchun.github.io/home/), [Taekyung Kim](https://scholar.google.com/citations?user=u-9bdkwAAAAJ&hl=en), [Dongyoon Han](https://sites.google.com/site/dyhan0920/), [Sangdoo Yun](https://sangdooyun.github.io/)**
6 | 7 | [NAVER AI LAB](https://naver-career.gitbook.io/en/teams/clova-cic/ai-lab) 8 | 9 | [![Paper](https://img.shields.io/badge/Paper-arxiv-green)](https://arxiv.org/abs/2404.17507) 10 | [![Paper](https://img.shields.io/badge/Paper-ECCV_2024-blue)](https://www.ecva.net/papers/eccv_2024/papers_ECCV/html/5671_ECCV_2024_paper.php) 11 | 12 | ![teaser](teaser.png) 13 |
14 | 15 | Official PyTorch implementation of "HYPE: Hyperbolic Entailment Filtering for Underspecified Images and Texts" | [arxiv](https://arxiv.org/abs/2404.17507), [ECCV](https://www.ecva.net/papers/eccv_2024/papers_ECCV/html/5671_ECCV_2024_paper.php) 16 | 17 | ### Abstract 18 | 19 | In an era where the volume of data drives the effectiveness of self-supervised learning, the specificity and clarity of data semantics play a crucial role in model training. Addressing this, we introduce HYPerbolic Entailment filtering (HYPE), a novel methodology designed to meticulously extract modality-wise meaningful and well-aligned data from extensive, noisy image-text pair datasets. Our approach leverages hyperbolic embeddings and the concept of entailment cones to evaluate and filter out samples with meaningless or underspecified semantics, focusing on enhancing the specificity of each data sample. HYPE not only demonstrates a significant improvement in filtering efficiency but also sets a new state-of-the-art in the DataComp benchmark when combined with existing filtering techniques. This breakthrough showcases the potential of HYPE to refine the data selection process, thereby contributing to the development of more accurate and efficient self-supervised learning models. Additionally, the image specificity ϵi can be independently applied to induce an image-only dataset from an image-text or image-only data pool for training image-only self-supervised models and showed superior performance when compared to the dataset induced by CLIP score. 20 | 21 | 22 | ## Updates 23 | 24 | - **October 2024**: Released inference code and model weights 25 | - **Jul 16, 2024**: Published paper on arXiv 26 | 27 | ## Prerequisites 28 | 29 | Download the following files to the project root: [hyperbolic CLIP weights](https://drive.google.com/file/d/1VF2g6m0tlHgzYzcMEYncchHXjhw-h5qo/view?usp=share_link) and [reference set](https://drive.google.com/file/d/1pdiFdZzcqoQ1nRtlHIpP0nu-BFRbfuYe/view?usp=share_link). 30 | 31 | - `model.py` : Implementation of Hyperbolic CLIP, which is almost identical to [MERU](https://arxiv.org/abs/2304.09172) but in [OpenCLIP]((https://github.com/mlfoundations/open_clip)) style. 32 | - `tokenizer.py` : Tokenizer copied from [https://github.com/openai/CLIP](https://github.com/openai/CLIP) 33 | - `hyperbolic.py` : Implementation of hyperbolic space operations. 34 | - `hyper_demo.ipynb` : Pedagogical code to show how to calculate negative Lorentizian distance (similarity) and specificity shown in the paper. 35 | 36 | ## How to run 37 | 38 | This repository includes functionality to calculate modality-specificities and the negative Lorentzian distance only. Please refer to the [DataComp](https://github.com/mlfoundations/datacomp) repository to calculate the ImageNet clustering score and CLIP similarity for the complete composite HYPE score. However, as shown in Table 3 of HYPE paper, using only the specificities and negative Lorentzian distance is sufficient to achieve state-of-the-art results. 39 | 40 | ## How to cite 41 | 42 | ``` 43 | @inproceedings{kim2024hype, 44 | title={HYPE: Hyperbolic Entailment Filtering for Underspecified Images and Texts}, 45 | author={Kim, Wonjae and Chun, Sanghyuk and Kim, Taekyung and Han, Dongyoon and Yun, Sangdoo}, 46 | year={2024}, 47 | booktitle={European Conference on Computer Vision (ECCV)}, 48 | } 49 | ``` 50 | 51 | ## License 52 | ``` 53 | HYPE 54 | Copyright (c) 2024-present NAVER Cloud Corp. 55 | CC BY-NC 4.0 (https://creativecommons.org/licenses/by-nc/4.0/) 56 | ``` 57 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import gzip 7 | import html 8 | import os 9 | from functools import lru_cache 10 | from typing import Union, List 11 | 12 | import ftfy 13 | import regex as re 14 | import torch 15 | 16 | # https://stackoverflow.com/q/62691279 17 | import os 18 | 19 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 20 | 21 | 22 | @lru_cache() 23 | def default_bpe(): 24 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 25 | 26 | 27 | @lru_cache() 28 | def bytes_to_unicode(): 29 | """ 30 | Returns list of utf-8 byte and a corresponding list of unicode strings. 31 | The reversible bpe codes work on unicode strings. 32 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 33 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 34 | This is a significant percentage of your normal, say, 32K bpe vocab. 35 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 36 | And avoids mapping to whitespace/control characters the bpe code barfs on. 37 | """ 38 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 39 | cs = bs[:] 40 | n = 0 41 | for b in range(2**8): 42 | if b not in bs: 43 | bs.append(b) 44 | cs.append(2**8 + n) 45 | n += 1 46 | cs = [chr(n) for n in cs] 47 | return dict(zip(bs, cs)) 48 | 49 | 50 | def get_pairs(word): 51 | """Return set of symbol pairs in a word. 52 | Word is represented as tuple of symbols (symbols being variable-length strings). 53 | """ 54 | pairs = set() 55 | prev_char = word[0] 56 | for char in word[1:]: 57 | pairs.add((prev_char, char)) 58 | prev_char = char 59 | return pairs 60 | 61 | 62 | def basic_clean(text): 63 | text = ftfy.fix_text(text) 64 | text = html.unescape(html.unescape(text)) 65 | return text.strip() 66 | 67 | 68 | def whitespace_clean(text): 69 | text = re.sub(r"\s+", " ", text) 70 | text = text.strip() 71 | return text 72 | 73 | 74 | class SimpleTokenizer(object): 75 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 76 | self.byte_encoder = bytes_to_unicode() 77 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 78 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 79 | merges = merges[1 : 49152 - 256 - 2 + 1] 80 | merges = [tuple(merge.split()) for merge in merges] 81 | vocab = list(bytes_to_unicode().values()) 82 | vocab = vocab + [v + "" for v in vocab] 83 | for merge in merges: 84 | vocab.append("".join(merge)) 85 | if not special_tokens: 86 | special_tokens = ["", ""] 87 | else: 88 | special_tokens = ["", ""] + special_tokens 89 | vocab.extend(special_tokens) 90 | self.encoder = dict(zip(vocab, range(len(vocab)))) 91 | self.decoder = {v: k for k, v in self.encoder.items()} 92 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 93 | self.cache = {t: t for t in special_tokens} 94 | special = "|".join(special_tokens) 95 | self.pat = re.compile( 96 | special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE 97 | ) 98 | 99 | self.vocab_size = len(self.encoder) 100 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 101 | 102 | def bpe(self, token): 103 | if token in self.cache: 104 | return self.cache[token] 105 | word = tuple(token[:-1]) + (token[-1] + "",) 106 | pairs = get_pairs(word) 107 | 108 | if not pairs: 109 | return token + "" 110 | 111 | while True: 112 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 113 | if bigram not in self.bpe_ranks: 114 | break 115 | first, second = bigram 116 | new_word = [] 117 | i = 0 118 | while i < len(word): 119 | try: 120 | j = word.index(first, i) 121 | new_word.extend(word[i:j]) 122 | i = j 123 | except: 124 | new_word.extend(word[i:]) 125 | break 126 | 127 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 128 | new_word.append(first + second) 129 | i += 2 130 | else: 131 | new_word.append(word[i]) 132 | i += 1 133 | new_word = tuple(new_word) 134 | word = new_word 135 | if len(word) == 1: 136 | break 137 | else: 138 | pairs = get_pairs(word) 139 | word = " ".join(word) 140 | self.cache[token] = word 141 | return word 142 | 143 | def encode(self, text): 144 | bpe_tokens = [] 145 | text = whitespace_clean(basic_clean(text)).lower() 146 | for token in re.findall(self.pat, text): 147 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 148 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) 149 | return bpe_tokens 150 | 151 | def decode(self, tokens): 152 | text = "".join([self.decoder[token] for token in tokens]) 153 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("", " ") 154 | return text 155 | 156 | 157 | _tokenizer = SimpleTokenizer() 158 | 159 | 160 | def decode(output_ids: torch.Tensor): 161 | output_ids = output_ids.cpu().numpy() 162 | return _tokenizer.decode(output_ids) 163 | 164 | 165 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 166 | """ 167 | Returns the tokenized representation of given input string(s) 168 | 169 | Parameters 170 | ---------- 171 | texts : Union[str, List[str]] 172 | An input string or a list of input strings to tokenize 173 | context_length : int 174 | The context length to use; all CLIP models use 77 as the context length 175 | 176 | Returns 177 | ------- 178 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 179 | """ 180 | if isinstance(texts, str): 181 | texts = [texts] 182 | 183 | sot_token = _tokenizer.encoder[""] 184 | eot_token = _tokenizer.encoder[""] 185 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 186 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 187 | 188 | for i, tokens in enumerate(all_tokens): 189 | if len(tokens) > context_length: 190 | tokens = tokens[:context_length] # Truncate 191 | tokens[-1] = eot_token 192 | result[i, : len(tokens)] = torch.tensor(tokens) 193 | 194 | return result 195 | 196 | 197 | class HFTokenizer: 198 | """HuggingFace tokenizer wrapper""" 199 | 200 | def __init__(self, tokenizer_name: str): 201 | from transformers import AutoTokenizer 202 | 203 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 204 | 205 | def save_pretrained(self, dest): 206 | self.tokenizer.save_pretrained(dest) 207 | 208 | def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: 209 | # same cleaning as for default tokenizer, except lowercasing 210 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 211 | if isinstance(texts, str): 212 | texts = [texts] 213 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 214 | input_ids = self.tokenizer( 215 | texts, 216 | return_tensors="pt", 217 | max_length=context_length, 218 | padding="max_length", 219 | truncation=True, 220 | ).input_ids 221 | return input_ids 222 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from typing import Optional, Callable 7 | from collections import OrderedDict 8 | from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize, CenterCrop 9 | from huggingface_hub import PyTorchModelHubMixin 10 | 11 | 12 | def image_transform(image_size: int): 13 | mean = (0.48145466, 0.4578275, 0.40821073) 14 | std = (0.26862954, 0.26130258, 0.27577711) 15 | normalize = Normalize(mean=mean, std=std) 16 | transforms = [ 17 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 18 | CenterCrop(image_size), 19 | ] 20 | transforms.extend( 21 | [ 22 | lambda x: x.convert("RGB"), 23 | ToTensor(), 24 | normalize, 25 | ] 26 | ) 27 | return Compose(transforms) 28 | 29 | 30 | class LayerNorm(nn.LayerNorm): 31 | def forward(self, x: torch.Tensor): 32 | orig_type = x.dtype 33 | eps = torch.finfo(orig_type).eps 34 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, eps) 35 | return x.to(orig_type) 36 | 37 | 38 | class ResidualAttentionBlock(nn.Module): 39 | def __init__( 40 | self, 41 | d_model: int, 42 | n_head: int, 43 | mlp_ratio: float = 4.0, 44 | act_layer: Callable = nn.GELU, 45 | norm_layer: Callable = LayerNorm, 46 | ): 47 | super().__init__() 48 | 49 | self.ln_1 = norm_layer(d_model) 50 | self.attn = nn.MultiheadAttention(d_model, n_head) 51 | 52 | self.ln_2 = norm_layer(d_model) 53 | mlp_width = int(d_model * mlp_ratio) 54 | self.mlp = nn.Sequential( 55 | OrderedDict( 56 | [ 57 | ("c_fc", nn.Linear(d_model, mlp_width)), 58 | ("gelu", act_layer()), 59 | ("c_proj", nn.Linear(mlp_width, d_model)), 60 | ] 61 | ) 62 | ) 63 | 64 | def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 65 | attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None 66 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 67 | 68 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 69 | x = x + self.attention(x=self.ln_1(x), attn_mask=attn_mask) 70 | x = x + self.mlp(self.ln_2(x)) 71 | return x 72 | 73 | 74 | class Transformer(nn.Module): 75 | def __init__( 76 | self, 77 | width: int, 78 | layers: int, 79 | heads: int, 80 | mlp_ratio: float = 4.0, 81 | act_layer: Callable = nn.GELU, 82 | norm_layer: Callable = LayerNorm, 83 | ): 84 | super().__init__() 85 | self.width = width 86 | self.layers = layers 87 | 88 | self.resblocks = nn.ModuleList( 89 | [ 90 | ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) 91 | for _ in range(layers) 92 | ] 93 | ) 94 | 95 | def get_cast_dtype(self) -> torch.dtype: 96 | return self.resblocks[0].mlp.c_fc.weight.dtype 97 | 98 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 99 | for r in self.resblocks: 100 | x = r(x, attn_mask=attn_mask) 101 | return x 102 | 103 | 104 | class VisionTransformer(nn.Module): 105 | def __init__( 106 | self, 107 | image_size: int, 108 | patch_size: int, 109 | width: int, 110 | layers: int, 111 | heads: int, 112 | mlp_ratio: float, 113 | output_dim: int = 512, 114 | act_layer: Callable = nn.GELU, 115 | norm_layer: Callable = LayerNorm, 116 | ): 117 | super().__init__() 118 | image_height, image_width = self.image_size = (image_size, image_size) 119 | patch_height, patch_width = self.patch_size = (patch_size, patch_size) 120 | self.grid_size = (image_height // patch_height, image_width // patch_width) 121 | self.output_dim = output_dim 122 | 123 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 124 | 125 | scale = width**-0.5 126 | self.scale = scale 127 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 128 | self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) 129 | 130 | self.ln_pre = norm_layer(width) 131 | self.transformer = Transformer(width, layers, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) 132 | 133 | self.ln_post = norm_layer(width) 134 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 135 | 136 | def forward(self, x: torch.Tensor): 137 | x = self.conv1(x) 138 | x = x.reshape(x.shape[0], x.shape[1], -1) 139 | x = x.permute(0, 2, 1) 140 | 141 | x = torch.cat( 142 | [ 143 | self.class_embedding.to(x.dtype) 144 | + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 145 | x, 146 | ], 147 | dim=1, 148 | ) 149 | x = x + self.positional_embedding.to(x.dtype) 150 | x = self.ln_pre(x) 151 | 152 | x = x.permute(1, 0, 2) 153 | x = self.transformer(x) 154 | x = x.permute(1, 0, 2) 155 | 156 | pooled = x[:, 0] 157 | pooled = self.ln_post(pooled) 158 | pooled = pooled @ self.proj 159 | return pooled 160 | 161 | 162 | class TextTransformer(nn.Module): 163 | def __init__( 164 | self, 165 | context_length: int = 77, 166 | vocab_size: int = 49408, 167 | width: int = 512, 168 | heads: int = 8, 169 | layers: int = 12, 170 | output_dim: int = 512, 171 | act_layer: Callable = nn.GELU, 172 | norm_layer: Callable = LayerNorm, 173 | ): 174 | super().__init__() 175 | self.num_pos = self.context_length = context_length 176 | self.vocab_size = vocab_size 177 | self.width = width 178 | self.output_dim = output_dim 179 | self.heads = heads 180 | 181 | self.text_projection = nn.Parameter(torch.empty(width, output_dim)) 182 | self.token_embedding = nn.Embedding(vocab_size, width) 183 | self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) 184 | self.transformer = Transformer( 185 | width=width, layers=layers, heads=heads, act_layer=act_layer, norm_layer=norm_layer 186 | ) 187 | self.ln_final = norm_layer(width) 188 | self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) 189 | 190 | def build_attention_mask(self): 191 | mask = torch.empty(self.num_pos, self.num_pos) 192 | mask.fill_(float("-inf")) 193 | mask.triu_(1) 194 | return mask 195 | 196 | def forward(self, text): 197 | cast_dtype = self.transformer.get_cast_dtype() 198 | seq_len = text.shape[1] 199 | 200 | x = self.token_embedding(text).to(cast_dtype) 201 | x = x + self.positional_embedding[:seq_len].to(cast_dtype) 202 | x = x.permute(1, 0, 2) 203 | x = self.transformer(x, attn_mask=self.attn_mask) 204 | x = x.permute(1, 0, 2) 205 | 206 | x = self.ln_final(x) 207 | pooled = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 208 | return pooled 209 | 210 | 211 | class CLIP(nn.Module): 212 | def __init__(self, embed_dim: int, vision_cfg: dict, text_cfg: dict): 213 | super().__init__() 214 | act_layer = nn.GELU 215 | norm_layer = LayerNorm 216 | 217 | self.visual = VisionTransformer( 218 | image_size=vision_cfg["image_size"], 219 | patch_size=vision_cfg["patch_size"], 220 | width=vision_cfg["width"], 221 | layers=vision_cfg["layers"], 222 | heads=vision_cfg["width"] // 64, 223 | mlp_ratio=4.0, 224 | output_dim=embed_dim, 225 | act_layer=act_layer, 226 | norm_layer=norm_layer, 227 | ) 228 | 229 | text = TextTransformer( 230 | context_length=text_cfg["context_length"], 231 | vocab_size=text_cfg["vocab_size"], 232 | width=text_cfg["width"], 233 | heads=text_cfg["heads"], 234 | layers=text_cfg["layers"], 235 | output_dim=embed_dim, 236 | act_layer=act_layer, 237 | norm_layer=norm_layer, 238 | ) 239 | self.transformer = text.transformer 240 | self.context_length = text.context_length 241 | self.vocab_size = text.vocab_size 242 | self.token_embedding = text.token_embedding 243 | self.positional_embedding = text.positional_embedding 244 | self.ln_final = text.ln_final 245 | self.text_projection = text.text_projection 246 | self.register_buffer("attn_mask", text.attn_mask, persistent=False) 247 | 248 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 249 | self.curvature = nn.Parameter(torch.ones([]) * np.log(1.0)) 250 | self.alpha_img = nn.Parameter(torch.ones([]) * np.log(1 / np.sqrt(embed_dim))) 251 | self.alpha_txt = nn.Parameter(torch.ones([]) * np.log(1 / np.sqrt(embed_dim))) 252 | 253 | def encode_image(self, image): 254 | features = self.visual(image) 255 | return self.alpha_img.exp() * features 256 | 257 | def encode_text(self, text): 258 | cast_dtype = self.transformer.get_cast_dtype() 259 | x = self.token_embedding(text).to(cast_dtype) 260 | x = x + self.positional_embedding.to(cast_dtype) 261 | x = x.permute(1, 0, 2) 262 | x = self.transformer(x, attn_mask=self.attn_mask) 263 | x = x.permute(1, 0, 2) 264 | x = self.ln_final(x) 265 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 266 | return self.alpha_txt.exp() * x 267 | 268 | def forward(self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None): 269 | image_features = self.encode_image(image) if image is not None else None 270 | text_features = self.encode_text(text) if text is not None else None 271 | return (image_features, text_features) 272 | 273 | 274 | def model_init(pretrained: str): 275 | cfg = { 276 | "embed_dim": 768, 277 | "vision_cfg": {"image_size": 224, "layers": 24, "width": 1024, "patch_size": 14}, 278 | "text_cfg": {"context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12}, 279 | } 280 | model = CLIP(**cfg) 281 | 282 | state_dict = torch.load(pretrained) 283 | model.load_state_dict(state_dict, strict=False) 284 | 285 | model.visual.image_mean = (0.48145466, 0.4578275, 0.40821073) 286 | model.visual.image_std = (0.26862954, 0.26130258, 0.27577711) 287 | 288 | preprocess = image_transform(model.visual.image_size) 289 | 290 | return model, preprocess 291 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. 401 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | HYPE 2 | Copyright (c) 2024-present NAVER Cloud Corp. 3 | 4 | Creative Commons Attribution-NonCommercial 4.0 International 5 | 6 | A summary of the CC BY-NC 4.0 license is located here: 7 | https://creativecommons.org/licenses/by-nc/4.0/ 8 | 9 | This project contains subcomponents with separate copyright notices and license terms. 10 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 11 | 12 | ===== 13 | 14 | facebookresearch/meru 15 | https://github.com/facebookresearch/meru 16 | 17 | 18 | Attribution-NonCommercial 4.0 International 19 | 20 | ======================================================================= 21 | 22 | Creative Commons Corporation ("Creative Commons") is not a law firm and 23 | does not provide legal services or legal advice. Distribution of 24 | Creative Commons public licenses does not create a lawyer-client or 25 | other relationship. Creative Commons makes its licenses and related 26 | information available on an "as-is" basis. Creative Commons gives no 27 | warranties regarding its licenses, any material licensed under their 28 | terms and conditions, or any related information. Creative Commons 29 | disclaims all liability for damages resulting from their use to the 30 | fullest extent possible. 31 | 32 | Using Creative Commons Public Licenses 33 | 34 | Creative Commons public licenses provide a standard set of terms and 35 | conditions that creators and other rights holders may use to share 36 | original works of authorship and other material subject to copyright 37 | and certain other rights specified in the public license below. The 38 | following considerations are for informational purposes only, are not 39 | exhaustive, and do not form part of our licenses. 40 | 41 | Considerations for licensors: Our public licenses are 42 | intended for use by those authorized to give the public 43 | permission to use material in ways otherwise restricted by 44 | copyright and certain other rights. Our licenses are 45 | irrevocable. Licensors should read and understand the terms 46 | and conditions of the license they choose before applying it. 47 | Licensors should also secure all rights necessary before 48 | applying our licenses so that the public can reuse the 49 | material as expected. Licensors should clearly mark any 50 | material not subject to the license. This includes other CC- 51 | licensed material, or material used under an exception or 52 | limitation to copyright. More considerations for licensors: 53 | wiki.creativecommons.org/Considerations_for_licensors 54 | 55 | Considerations for the public: By using one of our public 56 | licenses, a licensor grants the public permission to use the 57 | licensed material under specified terms and conditions. If 58 | the licensor's permission is not necessary for any reason--for 59 | example, because of any applicable exception or limitation to 60 | copyright--then that use is not regulated by the license. Our 61 | licenses grant only permissions under copyright and certain 62 | other rights that a licensor has authority to grant. Use of 63 | the licensed material may still be restricted for other 64 | reasons, including because others have copyright or other 65 | rights in the material. A licensor may make special requests, 66 | such as asking that all changes be marked or described. 67 | Although not required by our licenses, you are encouraged to 68 | respect those requests where reasonable. More considerations 69 | for the public: 70 | wiki.creativecommons.org/Considerations_for_licensees 71 | 72 | ======================================================================= 73 | 74 | Creative Commons Attribution-NonCommercial 4.0 International Public 75 | License 76 | 77 | By exercising the Licensed Rights (defined below), You accept and agree 78 | to be bound by the terms and conditions of this Creative Commons 79 | Attribution-NonCommercial 4.0 International Public License ("Public 80 | License"). To the extent this Public License may be interpreted as a 81 | contract, You are granted the Licensed Rights in consideration of Your 82 | acceptance of these terms and conditions, and the Licensor grants You 83 | such rights in consideration of benefits the Licensor receives from 84 | making the Licensed Material available under these terms and 85 | conditions. 86 | 87 | 88 | Section 1 -- Definitions. 89 | 90 | a. Adapted Material means material subject to Copyright and Similar 91 | Rights that is derived from or based upon the Licensed Material 92 | and in which the Licensed Material is translated, altered, 93 | arranged, transformed, or otherwise modified in a manner requiring 94 | permission under the Copyright and Similar Rights held by the 95 | Licensor. For purposes of this Public License, where the Licensed 96 | Material is a musical work, performance, or sound recording, 97 | Adapted Material is always produced where the Licensed Material is 98 | synched in timed relation with a moving image. 99 | 100 | b. Adapter's License means the license You apply to Your Copyright 101 | and Similar Rights in Your contributions to Adapted Material in 102 | accordance with the terms and conditions of this Public License. 103 | 104 | c. Copyright and Similar Rights means copyright and/or similar rights 105 | closely related to copyright including, without limitation, 106 | performance, broadcast, sound recording, and Sui Generis Database 107 | Rights, without regard to how the rights are labeled or 108 | categorized. For purposes of this Public License, the rights 109 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 110 | Rights. 111 | d. Effective Technological Measures means those measures that, in the 112 | absence of proper authority, may not be circumvented under laws 113 | fulfilling obligations under Article 11 of the WIPO Copyright 114 | Treaty adopted on December 20, 1996, and/or similar international 115 | agreements. 116 | 117 | e. Exceptions and Limitations means fair use, fair dealing, and/or 118 | any other exception or limitation to Copyright and Similar Rights 119 | that applies to Your use of the Licensed Material. 120 | 121 | f. Licensed Material means the artistic or literary work, database, 122 | or other material to which the Licensor applied this Public 123 | License. 124 | 125 | g. Licensed Rights means the rights granted to You subject to the 126 | terms and conditions of this Public License, which are limited to 127 | all Copyright and Similar Rights that apply to Your use of the 128 | Licensed Material and that the Licensor has authority to license. 129 | 130 | h. Licensor means the individual(s) or entity(ies) granting rights 131 | under this Public License. 132 | 133 | i. NonCommercial means not primarily intended for or directed towards 134 | commercial advantage or monetary compensation. For purposes of 135 | this Public License, the exchange of the Licensed Material for 136 | other material subject to Copyright and Similar Rights by digital 137 | file-sharing or similar means is NonCommercial provided there is 138 | no payment of monetary compensation in connection with the 139 | exchange. 140 | 141 | j. Share means to provide material to the public by any means or 142 | process that requires permission under the Licensed Rights, such 143 | as reproduction, public display, public performance, distribution, 144 | dissemination, communication, or importation, and to make material 145 | available to the public including in ways that members of the 146 | public may access the material from a place and at a time 147 | individually chosen by them. 148 | 149 | k. Sui Generis Database Rights means rights other than copyright 150 | resulting from Directive 96/9/EC of the European Parliament and of 151 | the Council of 11 March 1996 on the legal protection of databases, 152 | as amended and/or succeeded, as well as other essentially 153 | equivalent rights anywhere in the world. 154 | 155 | l. You means the individual or entity exercising the Licensed Rights 156 | under this Public License. Your has a corresponding meaning. 157 | 158 | 159 | Section 2 -- Scope. 160 | 161 | a. License grant. 162 | 163 | 1. Subject to the terms and conditions of this Public License, 164 | the Licensor hereby grants You a worldwide, royalty-free, 165 | non-sublicensable, non-exclusive, irrevocable license to 166 | exercise the Licensed Rights in the Licensed Material to: 167 | 168 | a. reproduce and Share the Licensed Material, in whole or 169 | in part, for NonCommercial purposes only; and 170 | 171 | b. produce, reproduce, and Share Adapted Material for 172 | NonCommercial purposes only. 173 | 174 | 2. Exceptions and Limitations. For the avoidance of doubt, where 175 | Exceptions and Limitations apply to Your use, this Public 176 | License does not apply, and You do not need to comply with 177 | its terms and conditions. 178 | 179 | 3. Term. The term of this Public License is specified in Section 180 | 6(a). 181 | 182 | 4. Media and formats; technical modifications allowed. The 183 | Licensor authorizes You to exercise the Licensed Rights in 184 | all media and formats whether now known or hereafter created, 185 | and to make technical modifications necessary to do so. The 186 | Licensor waives and/or agrees not to assert any right or 187 | authority to forbid You from making technical modifications 188 | necessary to exercise the Licensed Rights, including 189 | technical modifications necessary to circumvent Effective 190 | Technological Measures. For purposes of this Public License, 191 | simply making modifications authorized by this Section 2(a) 192 | (4) never produces Adapted Material. 193 | 194 | 5. Downstream recipients. 195 | 196 | a. Offer from the Licensor -- Licensed Material. Every 197 | recipient of the Licensed Material automatically 198 | receives an offer from the Licensor to exercise the 199 | Licensed Rights under the terms and conditions of this 200 | Public License. 201 | 202 | b. No downstream restrictions. You may not offer or impose 203 | any additional or different terms or conditions on, or 204 | apply any Effective Technological Measures to, the 205 | Licensed Material if doing so restricts exercise of the 206 | Licensed Rights by any recipient of the Licensed 207 | Material. 208 | 209 | 6. No endorsement. Nothing in this Public License constitutes or 210 | may be construed as permission to assert or imply that You 211 | are, or that Your use of the Licensed Material is, connected 212 | with, or sponsored, endorsed, or granted official status by, 213 | the Licensor or others designated to receive attribution as 214 | provided in Section 3(a)(1)(A)(i). 215 | 216 | b. Other rights. 217 | 218 | 1. Moral rights, such as the right of integrity, are not 219 | licensed under this Public License, nor are publicity, 220 | privacy, and/or other similar personality rights; however, to 221 | the extent possible, the Licensor waives and/or agrees not to 222 | assert any such rights held by the Licensor to the limited 223 | extent necessary to allow You to exercise the Licensed 224 | Rights, but not otherwise. 225 | 226 | 2. Patent and trademark rights are not licensed under this 227 | Public License. 228 | 229 | 3. To the extent possible, the Licensor waives any right to 230 | collect royalties from You for the exercise of the Licensed 231 | Rights, whether directly or through a collecting society 232 | under any voluntary or waivable statutory or compulsory 233 | licensing scheme. In all other cases the Licensor expressly 234 | reserves any right to collect such royalties, including when 235 | the Licensed Material is used other than for NonCommercial 236 | purposes. 237 | 238 | 239 | Section 3 -- License Conditions. 240 | 241 | Your exercise of the Licensed Rights is expressly made subject to the 242 | following conditions. 243 | 244 | a. Attribution. 245 | 246 | 1. If You Share the Licensed Material (including in modified 247 | form), You must: 248 | 249 | a. retain the following if it is supplied by the Licensor 250 | with the Licensed Material: 251 | 252 | i. identification of the creator(s) of the Licensed 253 | Material and any others designated to receive 254 | attribution, in any reasonable manner requested by 255 | the Licensor (including by pseudonym if 256 | designated); 257 | 258 | ii. a copyright notice; 259 | 260 | iii. a notice that refers to this Public License; 261 | 262 | iv. a notice that refers to the disclaimer of 263 | warranties; 264 | 265 | v. a URI or hyperlink to the Licensed Material to the 266 | extent reasonably practicable; 267 | 268 | b. indicate if You modified the Licensed Material and 269 | retain an indication of any previous modifications; and 270 | 271 | c. indicate the Licensed Material is licensed under this 272 | Public License, and include the text of, or the URI or 273 | hyperlink to, this Public License. 274 | 275 | 2. You may satisfy the conditions in Section 3(a)(1) in any 276 | reasonable manner based on the medium, means, and context in 277 | which You Share the Licensed Material. For example, it may be 278 | reasonable to satisfy the conditions by providing a URI or 279 | hyperlink to a resource that includes the required 280 | information. 281 | 282 | 3. If requested by the Licensor, You must remove any of the 283 | information required by Section 3(a)(1)(A) to the extent 284 | reasonably practicable. 285 | 286 | 4. If You Share Adapted Material You produce, the Adapter's 287 | License You apply must not prevent recipients of the Adapted 288 | Material from complying with this Public License. 289 | 290 | 291 | Section 4 -- Sui Generis Database Rights. 292 | 293 | Where the Licensed Rights include Sui Generis Database Rights that 294 | apply to Your use of the Licensed Material: 295 | 296 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 297 | to extract, reuse, reproduce, and Share all or a substantial 298 | portion of the contents of the database for NonCommercial purposes 299 | only; 300 | 301 | b. if You include all or a substantial portion of the database 302 | contents in a database in which You have Sui Generis Database 303 | Rights, then the database in which You have Sui Generis Database 304 | Rights (but not its individual contents) is Adapted Material; and 305 | 306 | c. You must comply with the conditions in Section 3(a) if You Share 307 | all or a substantial portion of the contents of the database. 308 | 309 | For the avoidance of doubt, this Section 4 supplements and does not 310 | replace Your obligations under this Public License where the Licensed 311 | Rights include other Copyright and Similar Rights. 312 | 313 | 314 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 315 | 316 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 317 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 318 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 319 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 320 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 321 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 322 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 323 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 324 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 325 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 326 | 327 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 328 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 329 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 330 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 331 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 332 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 333 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 334 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 335 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 336 | 337 | c. The disclaimer of warranties and limitation of liability provided 338 | above shall be interpreted in a manner that, to the extent 339 | possible, most closely approximates an absolute disclaimer and 340 | waiver of all liability. 341 | 342 | 343 | Section 6 -- Term and Termination. 344 | 345 | a. This Public License applies for the term of the Copyright and 346 | Similar Rights licensed here. However, if You fail to comply with 347 | this Public License, then Your rights under this Public License 348 | terminate automatically. 349 | 350 | b. Where Your right to use the Licensed Material has terminated under 351 | Section 6(a), it reinstates: 352 | 353 | 1. automatically as of the date the violation is cured, provided 354 | it is cured within 30 days of Your discovery of the 355 | violation; or 356 | 357 | 2. upon express reinstatement by the Licensor. 358 | 359 | For the avoidance of doubt, this Section 6(b) does not affect any 360 | right the Licensor may have to seek remedies for Your violations 361 | of this Public License. 362 | 363 | c. For the avoidance of doubt, the Licensor may also offer the 364 | Licensed Material under separate terms or conditions or stop 365 | distributing the Licensed Material at any time; however, doing so 366 | will not terminate this Public License. 367 | 368 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 369 | License. 370 | 371 | 372 | Section 7 -- Other Terms and Conditions. 373 | 374 | a. The Licensor shall not be bound by any additional or different 375 | terms or conditions communicated by You unless expressly agreed. 376 | 377 | b. Any arrangements, understandings, or agreements regarding the 378 | Licensed Material not stated herein are separate from and 379 | independent of the terms and conditions of this Public License. 380 | 381 | 382 | Section 8 -- Interpretation. 383 | 384 | a. For the avoidance of doubt, this Public License does not, and 385 | shall not be interpreted to, reduce, limit, restrict, or impose 386 | conditions on any use of the Licensed Material that could lawfully 387 | be made without permission under this Public License. 388 | 389 | b. To the extent possible, if any provision of this Public License is 390 | deemed unenforceable, it shall be automatically reformed to the 391 | minimum extent necessary to make it enforceable. If the provision 392 | cannot be reformed, it shall be severed from this Public License 393 | without affecting the enforceability of the remaining terms and 394 | conditions. 395 | 396 | c. No term or condition of this Public License will be waived and no 397 | failure to comply consented to unless expressly agreed to by the 398 | Licensor. 399 | 400 | d. Nothing in this Public License constitutes or may be interpreted 401 | as a limitation upon, or waiver of, any privileges and immunities 402 | that apply to the Licensor or You, including from the legal 403 | processes of any jurisdiction or authority. 404 | 405 | ======================================================================= 406 | 407 | Creative Commons is not a party to its public 408 | licenses. Notwithstanding, Creative Commons may elect to apply one of 409 | its public licenses to material it publishes and in those instances 410 | will be considered the “Licensor.” The text of the Creative Commons 411 | public licenses is dedicated to the public domain under the CC0 Public 412 | Domain Dedication. Except for the limited purpose of indicating that 413 | material is shared under a Creative Commons public license or as 414 | otherwise permitted by the Creative Commons policies published at 415 | creativecommons.org/policies, Creative Commons does not authorize the 416 | use of the trademark "Creative Commons" or any other trademark or logo 417 | of Creative Commons without its prior written consent including, 418 | without limitation, in connection with any unauthorized modifications 419 | to any of its public licenses or any other arrangements, 420 | understandings, or agreements concerning use of licensed material. For 421 | the avoidance of doubt, this paragraph does not form part of the 422 | public licenses. 423 | 424 | Creative Commons may be contacted at creativecommons.org. 425 | 426 | ===== 427 | 428 | mlfoundations/open_clip 429 | https://github.com/mlfoundations/open_clip 430 | 431 | 432 | Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, 433 | Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, 434 | John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, 435 | Ludwig Schmidt 436 | 437 | Permission is hereby granted, free of charge, to any person obtaining 438 | a copy of this software and associated documentation files (the 439 | "Software"), to deal in the Software without restriction, including 440 | without limitation the rights to use, copy, modify, merge, publish, 441 | distribute, sublicense, and/or sell copies of the Software, and to 442 | permit persons to whom the Software is furnished to do so, subject to 443 | the following conditions: 444 | 445 | The above copyright notice and this permission notice shall be 446 | included in all copies or substantial portions of the Software. 447 | 448 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 449 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 450 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 451 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 452 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 453 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 454 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 455 | 456 | ===== 457 | --------------------------------------------------------------------------------