├── tldr ├── __init__.py ├── loss.py ├── optimizer.py ├── utils.py └── tldr.py ├── scripts └── dummy_example.py ├── setup.py ├── LICENSE_Barlow_Twins ├── LICENSE ├── .gitignore └── README.md /tldr/__init__.py: -------------------------------------------------------------------------------- 1 | from .tldr import TLDR # noqa 2 | -------------------------------------------------------------------------------- /tldr/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def off_diagonal(x): 5 | # return a flattened view of the off-diagonal elements of a square matrix 6 | n, m = x.shape 7 | assert n == m 8 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() 9 | 10 | 11 | def BarlowTwinsLoss(z1, z2, batch_size, scale_loss=1.0 / 32, lambd=3.9e-3): 12 | """ 13 | Zbontar et al., Barlow Twins: Self-Supervised Learning via Redundancy Reduction 14 | https://arxiv.org/abs/2103.03230 15 | 16 | Implementation from https://github.com/facebookresearch/barlowtwins 17 | Copyright (c) Facebook, Inc. and its affiliates. 18 | """ 19 | # empirical cross-correlation matrix 20 | c = z1.T @ z2 21 | 22 | # sum the cross-correlation matrix between all gpus 23 | c.div_(batch_size) 24 | # torch.distributed.all_reduce(c) 25 | 26 | # use --scale-loss to multiply the loss by a constant factor 27 | on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(scale_loss) 28 | off_diag = off_diagonal(c).pow_(2).sum().mul(scale_loss) 29 | loss = on_diag + lambd * off_diag 30 | 31 | return loss 32 | -------------------------------------------------------------------------------- /scripts/dummy_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | from tldr import TLDR 6 | 7 | parser = argparse.ArgumentParser("Dummy TLDR example") 8 | parser.add_argument("--device", default="cuda", type=str) 9 | args = parser.parse_args() 10 | 11 | # Training 12 | X = np.random.rand(100000, 2048) # replace with training NxD array 13 | 14 | tldr = TLDR( 15 | n_components=32, 16 | n_neighbors=5, 17 | encoder="linear", 18 | projector="mlp-1-2048", 19 | device=args.device, 20 | verbose=2, 21 | knn_approximation="medium", 22 | ) 23 | tldr.fit(X, epochs=20, warmup_epochs=5, batch_size=1024, output_folder="data/", print_every=100) 24 | Z = tldr.transform(X, l2_norm=True) # Returns Nxn_components array 25 | 26 | tldr.save("data/inference_model.pth") 27 | tldr.save_knn("data/knn.npy") # We can save the pre-computed KNN for future trainings with this data 28 | 29 | # Inference 30 | X = np.random.rand(5000, 2048) # replace with test NxD matrix 31 | tldr = TLDR() 32 | tldr.load("data/inference_model.pth", init=True) # With init=True Loads both model parameters and weights 33 | Z = tldr.transform(X, l2_norm=True) # Returns a Nxn_components array 34 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | def readme(): 5 | try: 6 | with open("README.md", encoding="UTF-8") as readme_file: 7 | return readme_file.read() 8 | except TypeError: 9 | # Python 2.7 doesn't support encoding argument in builtin open 10 | import io 11 | 12 | with io.open("README.md", encoding="UTF-8") as readme_file: 13 | return readme_file.read() 14 | 15 | 16 | setup( 17 | name="TLDR", 18 | version="0.1.1", 19 | description="Twin Learning for Dimensionality Reduction", 20 | url="https://github.com/naver/tldr", 21 | long_description=readme(), 22 | long_description_content_type="text/x-rst", 23 | author="Jon Almazan", 24 | author_email="jon.almazan@naverlabs.com", 25 | license="CC BY-NCA-SA 4.0", 26 | packages=["tldr"], 27 | install_requires=[ 28 | "rich", 29 | "numpy", 30 | "faiss>=1.7.0", 31 | "torch>=1.8.0", 32 | ], 33 | classifiers=[ 34 | "Development Status :: 3 - Beta", 35 | "Intended Audience :: Science/Research", 36 | "License :: OSI Approved :: Common Public License", 37 | "Environment :: GPU :: NVIDIA CUDA :: 11.2", 38 | "Programming Language :: Python :: 3.6", 39 | ], 40 | ) 41 | -------------------------------------------------------------------------------- /LICENSE_Barlow_Twins: -------------------------------------------------------------------------------- 1 | License and Copyright with respect to content in the files tldr/loss.py and tldr/optimizer.py: 2 | 3 | This software is being redistributed in its original form and license that is available here: 4 | https://github.com/facebookresearch/barlowtwins/ 5 | 6 | ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW: 7 | https://github.com/facebookresearch/barlowtwins/blob/main/LICENSE 8 | 9 | MIT License 10 | 11 | Copyright (c) Facebook, Inc. and its affiliates. 12 | 13 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Twin Learning for Dimensionality Reduction, Copyright (c) 2021 Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0. 2 | 3 | A summary of the CC BY-NC-SA 4.0 license is located here: 4 | https://creativecommons.org/licenses/by-nc-sa/4.0/ 5 | 6 | The CC BY-NC-SA 4.0 license is located here: 7 | https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 8 | 9 | 10 | ATTRIBUTIONS: 11 | 12 | The class TLDRNetwork in the file tldr/tldr.py is derived from the class BarlowTwins in the file main.py available here: 13 | https://github.com/facebookresearch/barlowtwins/, which was made available under the MIT License available here: 14 | https://github.com/facebookresearch/barlowtwins/blob/main/LICENSE, 15 | which is reproduced below: 16 | 17 | MIT License 18 | 19 | Copyright (c) Facebook, Inc. and its affiliates. 20 | 21 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 24 | 25 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /tldr/optimizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | 7 | def adjust_learning_rate(epochs, optimizer, n_data, step, learning_rate, batch_size, warmup_epochs): 8 | max_steps = epochs * n_data // batch_size 9 | warmup_steps = int(warmup_epochs * n_data // batch_size) 10 | base_lr = learning_rate * batch_size / 256 11 | if step < warmup_steps: 12 | lr = base_lr * step / warmup_steps 13 | else: 14 | step -= warmup_steps 15 | max_steps -= warmup_steps 16 | q = 0.5 * (1 + math.cos(math.pi * step / max_steps)) 17 | end_lr = base_lr * 0.001 18 | lr = base_lr * q + end_lr * (1 - q) 19 | for param_group in optimizer.param_groups: 20 | param_group["lr"] = lr 21 | return lr 22 | 23 | 24 | class LARS(optim.Optimizer): 25 | """ 26 | Large Batch Training of Convolutional Networks 27 | https://arxiv.org/abs/1708.03888 28 | 29 | Implementation from: https://github.com/facebookresearch/barlowtwins 30 | Copyright (c) Facebook, Inc. and its affiliates. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | params, 36 | lr, 37 | weight_decay=0, 38 | momentum=0.9, 39 | eta=0.001, 40 | weight_decay_filter=None, 41 | lars_adaptation_filter=None, 42 | ): 43 | defaults = dict( 44 | lr=lr, 45 | weight_decay=weight_decay, 46 | momentum=momentum, 47 | eta=eta, 48 | weight_decay_filter=weight_decay_filter, 49 | lars_adaptation_filter=lars_adaptation_filter, 50 | ) 51 | super().__init__(params, defaults) 52 | 53 | @torch.no_grad() 54 | def step(self): 55 | for g in self.param_groups: 56 | for p in g["params"]: 57 | dp = p.grad 58 | 59 | if dp is None: 60 | continue 61 | 62 | if g["weight_decay_filter"] is None or not g["weight_decay_filter"](p): 63 | dp = dp.add(p, alpha=g["weight_decay"]) 64 | 65 | if g["lars_adaptation_filter"] is None or not g["lars_adaptation_filter"](p): 66 | param_norm = torch.norm(p) 67 | update_norm = torch.norm(dp) 68 | one = torch.ones_like(param_norm) 69 | q = torch.where( 70 | param_norm > 0.0, 71 | torch.where(update_norm > 0, (g["eta"] * param_norm / update_norm), one), 72 | one, 73 | ) 74 | dp = dp.mul(q) 75 | 76 | param_state = self.state[p] 77 | if "mu" not in param_state: 78 | param_state["mu"] = torch.zeros_like(p) 79 | mu = param_state["mu"] 80 | mu.mul_(g["momentum"]).add_(dp) 81 | 82 | p.add_(mu, alpha=-g["lr"]) 83 | -------------------------------------------------------------------------------- /tldr/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NAVER and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from time import time 8 | 9 | import faiss 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeRemainingColumn 14 | 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | 19 | def __init__(self, name, fmt=":f"): 20 | self.name = name 21 | self.fmt = fmt 22 | self.reset() 23 | 24 | def reset(self): 25 | self.val = 0 26 | self.avg = 0 27 | self.sum = 0 28 | self.count = 0 29 | 30 | def update(self, val, n=1): 31 | self.val = val 32 | self.sum += val * n 33 | self.count += n 34 | self.avg = self.sum / self.count 35 | 36 | def __str__(self): 37 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 38 | return fmtstr.format(**self.__dict__) 39 | 40 | 41 | def get_knn_graph( 42 | X, n_neighbors, l2_norm_graph=False, device="cuda", metric="IP", verbose=0, knn_approximation=None 43 | ): 44 | """Computes and returns the k nearest neighbours of each sample 45 | Parameters 46 | ---------- 47 | X : ndarray 48 | N x D array containing N samples of dimension D 49 | n_neighbors : int 50 | Number of nearest neighbors 51 | l2_norm_graph : bool 52 | L2 normalize samples 53 | device : str 54 | Selects the device [cpu, gpu] 55 | metric : str 56 | Selects the similarity metric [IP, L2] 57 | verbose : int 58 | Selects verbosity level [0, 1, 2] 59 | knn_approximation : str 60 | Enables nearest neighbor approximation [None, low, medium, high] 61 | Returns 62 | ------- 63 | knn_graph : ndarray 64 | Array containing the indices of the k nearest neighbors of each sample 65 | """ 66 | knn_graph = None 67 | metric = metric.upper() 68 | if metric not in ["IP", "L2"]: 69 | raise ValueError(f"similarity metric {metric} not supported. Metrics supported are 'L2' and 'IP'") 70 | 71 | if verbose > 1: 72 | print(f" - Creating {n_neighbors}-NN graph for training data") 73 | if isinstance(X, torch.Tensor): 74 | X = X.cpu().numpy() 75 | X = X.astype(np.float32) 76 | X = np.ascontiguousarray(X) 77 | if l2_norm_graph: 78 | X = torch.nn.functional.normalize(torch.tensor(X), dim=1, p=2).cpu().numpy() 79 | 80 | split_train = np.array_split(X, 100) 81 | all_neighbors = list() 82 | all_dists = list() 83 | dimensions = X.shape[1] 84 | used_dimensions = dimensions 85 | faiss_type = "Flat" 86 | if metric == "IP": 87 | faiss_metric = faiss.METRIC_INNER_PRODUCT 88 | elif metric == "L2": 89 | faiss_metric = faiss.METRIC_L2 90 | if knn_approximation is not None: 91 | if knn_approximation not in ["low", "medium", "high"]: 92 | raise ValueError( 93 | f"knn_approximation should be one of None, low, medium or high and it was {knn_approximation}" 94 | ) 95 | approx_mapping = {"low": 32, "medium": 16, "high": 8} 96 | used_dimensions = approx_mapping[knn_approximation] 97 | if (dimensions % used_dimensions) != 0: 98 | raise ValueError( 99 | f"Number of dimensions of training data must be divisible by {used_dimensions} to allow {knn_approximation} knn_approximation" 100 | ) 101 | faiss_type = f"IVF1,PQ{used_dimensions}" 102 | faiss_index = faiss.index_factory(dimensions, faiss_type, faiss_metric) 103 | if "cuda" in device and faiss.get_num_gpus() > 0: 104 | faiss_index = faiss.index_cpu_to_all_gpus(faiss_index) 105 | 106 | t0 = time() 107 | if knn_approximation is not None: 108 | if verbose > 1: 109 | print( 110 | "Training product quantization for knn approximation with 10% of total data. Note that for small datasets training + computing neighbors could be slower than brute force computation" 111 | ) 112 | faiss_index.train(X[: 1 + X.shape[0] // 10]) 113 | if verbose > 1: 114 | print("Finished training product quantization") 115 | faiss_index.add(X) # add vectors to the index 116 | with get_progress_bar() as progress: 117 | if verbose > 0: 118 | task = progress.add_task(description="[green]Computing KNN", total=len(split_train), info="-") 119 | for splitted in split_train: 120 | D, Idx = faiss_index.search( 121 | splitted, k=n_neighbors + 1 122 | ) # n_neighbors+1 because the first one is always yourself... 123 | all_neighbors.append(Idx[:, 1:]) 124 | all_dists.append(D[:, 1:]) 125 | if verbose > 0: 126 | progress.update(task, advance=1) 127 | t1 = time() 128 | if verbose > 1: 129 | print(" - KNN computation took %.2g sec" % (t1 - t0)) 130 | knn_graph = np.concatenate(all_neighbors, axis=0) 131 | 132 | knn_graph = knn_graph[:, :n_neighbors] 133 | 134 | return knn_graph 135 | 136 | 137 | def tonumpy(x): 138 | """Converts a tensor to numpy array""" 139 | if type(x).__module__ == torch.__name__: 140 | return x.cpu().numpy() 141 | else: 142 | return x 143 | 144 | 145 | def whiten(X, fudge=1e-18): 146 | """Applies whitening to an NxD array of N samples with dimensionality D""" 147 | # the matrix X should be observations-by-components 148 | 149 | # get the covariance matrix 150 | Xcov = np.dot(X.T, X) 151 | # eigenvalue decomposition of the covariance matrix 152 | d, V = np.linalg.eigh(Xcov) 153 | 154 | # a fudge factor can be used so that eigenvectors associated with 155 | # small eigenvalues do not get overamplified. 156 | D = np.diag(1.0 / np.sqrt(d + fudge)) 157 | 158 | # whitening matrix 159 | W = np.dot(np.dot(V, D), V.T) 160 | 161 | # multiply by the whitening matrix 162 | X_white = np.dot(X, W) 163 | 164 | return X_white, W 165 | 166 | 167 | def l2_normalize(x, axis=-1): 168 | """L2 normalizes an NxD array of N samples with dimensionality D""" 169 | x = F.normalize(x, p=2, dim=axis) 170 | return x 171 | 172 | 173 | def parse_net_config(net_config: str): 174 | """Parses an architecture configuration string and returns the corresponding network type, number of hidden layers and their dimensionality""" 175 | config = net_config.split("-") 176 | 177 | net_type = config[0].lower() 178 | if net_type not in ["linear", "flinear", "mlp"]: 179 | raise ValueError( 180 | f"Incorrect network configuration format '{net_config}': incorrect network type '{net_type}', currently supported types are 'linear', 'flinear', and 'mlp'" 181 | ) 182 | 183 | if len(config) == 1: 184 | if net_type not in ["linear"]: 185 | raise ValueError( 186 | f"Incorrect network configuration format '{net_config}': you need to specify the number of layers and dimensionality of each layer `{net_type}-[NUM_HLAYERS]-[HDIMS]`" 187 | ) 188 | return net_type, 0, [] 189 | 190 | num_hidden_layers = int(config[1]) 191 | if len(config) == 2: 192 | if num_hidden_layers == 0: 193 | return net_type, num_hidden_layers, [] 194 | raise ValueError( 195 | f"Incorrect network configuration format '{net_config}': you need to specify the dimensionality of each hidden layer using `{net_type}-{num_hidden_layers}-[HDIMS]`" 196 | ) 197 | 198 | hidden_layers_dim = [int(e) for e in config[2:]] 199 | return net_type, num_hidden_layers, hidden_layers_dim 200 | 201 | 202 | def get_progress_bar(): 203 | """Returns a progress bar using the rich library""" 204 | return Progress( 205 | "[progress.description]{task.description}", 206 | SpinnerColumn(), 207 | BarColumn(), 208 | TextColumn("[bold blue]{task.fields[info]}", justify="right"), 209 | TimeRemainingColumn(), 210 | ) 211 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TLDR: Twin Learning for Dimensionality Reduction 2 | 3 | [TLDR](https://openreview.net/forum?id=86fhqdBUbx) (Twin Learning for Dimensionality Reduction) is an unsupervised dimensionality reduction method that combines neighborhood embedding learning with the simplicity and effectiveness of recent self-supervised learning losses. 4 | 5 | Inspired by manifold learning, TLDR uses nearest neighbors as a way to build pairs from a training set and a redundancy reduction loss to learn an encoder that produces representations invariant across such pairs. Similar to other neighborhood embeddings, TLDR effectively and unsupervisedly learns low-dimensional spaces where local neighborhoods of the input space are preserved; unlike other manifold learning methods, it simply consists of an offline nearest neighbor computation step and a straightforward learning process that does not require mining negative samples to contrast, eigendecompositions, or cumbersome optimization solvers. 6 | 7 | More details and evaluation can be found in [our TMLR paper](https://openreview.net/forum?id=86fhqdBUbx). 8 | 9 | ![diagram](https://user-images.githubusercontent.com/228798/137484016-7cf1c255-0182-46c6-849b-76281fadb251.png) 10 |
11 | ***Overview of TLDR**: Given a set of feature vectors in a generic input space, we use nearest neighbors to define a set of feature pairs whose proximity we want to preserve. We then learn a dimensionality-reduction function (theencoder) by encouraging neighbors in the input space to havesimilar representations. We learn it jointly with an auxiliary projector that produces high dimensional representations, where we compute the [Barlow Twins](https://arxiv.org/abs/2103.03230) loss over the (d′ × d′) cross-correlation matrix averaged over the batch.* 12 | 13 | 14 | **Contents**: 15 | - [Installing the TLDR library](#installing-the-tldr-library) 16 | - [Using the TLDR library](#using-the-tldr-library) 17 | - [Documentation](#documentation) 18 | - [Citation](#citation) 19 | - [Contributors](#contributors) 20 | 21 | ## Installing the TLDR library 22 | 23 | Requirements: 24 | - Python 3.6 or greater 25 | - PyTorch 1.8 or greater 26 | - numpy 27 | - [FAISS](https://github.com/facebookresearch/faiss) 28 | - [rich](https://github.com/willmcgugan/rich) 29 | 30 | In order to install the TLDR library, one should first make sure that [FAISS](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md) and [Pytorch](https://pytorch.org/get-started/locally/) are installed. We recommend using a new [conda](https://www.anaconda.com/products/individual) environment: 31 | 32 | ```bash 33 | conda create --name ENV_NAME python=3.6.8 34 | conda activate ENV_NAME 35 | conda install -c pytorch faiss-gpu cudatoolkit=10.2 36 | conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch 37 | ``` 38 | 39 | After ensuring that you have installed both FAISS and numpy, you can install TLDR by using the two commands below: 40 | 41 | ```bash 42 | git clone git@github.com:naver/tldr.git 43 | python3 -m pip install -e tldr 44 | ``` 45 | 46 | ## Using the TLDR library 47 | 48 | The `TLDR` library can be used to learn dimensionality reduction models using an API and functionality that mimics similar methods in the [scikit-learn library](https://scikit-learn.org/stable/modules/unsupervised_reduction.html), _i.e._ you can learn a dimensionality reduction on your training data using `fit()` and you can project new data using `transform()`. 49 | 50 | To illustrate the different functionalities we present a dummy example on randomly generated data. Let's import the library and generate some random training data (we will use 100K training examples with a dimensionality of 2048), _i.e._: 51 | 52 | ```python 53 | import numpy as np 54 | from tldr import TLDR 55 | 56 | # Generate random data 57 | X = np.random.rand(100000, 2048) # replace with training (N x D) array 58 | ``` 59 | 60 | ### Instantiating a TLDR model 61 | 62 | When instantiating a `TLDR` model one has to specify the output dimension (`n_components`), the number of nearest neighbors to use (`n_neighbors`) as well as the encoder and projector architectures that are specified as strings. 63 | 64 | For this example we will learn a dimensionality reduction to 32 components, we will use the 10 nearest neighbors to sample positive pairs, and we will use a linear encoder and a multi-layer perceptron with one hidden layer of 2048 dimensions as a projector: 65 | ```python 66 | tldr = TLDR(n_components=32, n_neighbors=10, encoder='linear', projector='mlp-1-2048', device='cuda', verbose=2) 67 | ``` 68 | For a more detailed list of optional arguments please refer to the function [documentation](#documentation) below; architecture specification string formatting guide is described in [this section](#architecture-specification-strings) below. 69 | 70 | ### Learning and applying the TLDR model 71 | 72 | We learn the parameters of the dimensionality reduction model by using the `fit()` method: 73 | 74 | ```python 75 | tldr.fit(X, epochs=100, batch_size=1024, output_folder='data/', print_every=50) 76 | ``` 77 | 78 | By default, `fit()` first collects the `k` nearest neighbors for each training data point using [FAISS](https://github.com/facebookresearch/faiss) and then optimizes the Barlow Twin loss using the batch size and number of epochs provided. Note that, apart from the dimensionality reduction function (the _encoder_), a _projector_ function that is part of the training process is also learned (see also the Figure above); the projector is by default discarded after training. 79 | 80 | Once the model has been trained we can use `transform()` to project the training data to the new learned space: 81 | 82 | ```python 83 | Z = tldr.transform(X, l2_norm=True) # Returns (N x n_components) matrix 84 | ``` 85 | 86 | The optional `l2_norm=True` argument of `transform()` further applies L2 normalization to all features after projection. 87 | 88 | Again, we refer the user to the functions' [documentation](#documentation) below for argument details. 89 | 90 | 91 | ### Saving/loading the model 92 | 93 | The TLDR model and the array of nearest neighbors per training datapoint can be saved using the `save()` and `save_knn()` functions, repsectively: 94 | ```python 95 | tldr.save("data/inference_model.pth") 96 | tldr.save_knn("data/knn.npy") 97 | ``` 98 | 99 | Note that by default the projector weights will _not_ be saved. To also save the projector (_e.g._ for subsequent fine-tuning of the model) one must set the `retain_projector=True` argument when calling `fit()`. 100 | 101 | One can use the `load()` method to load a pre-trained model from disk. Using the `init=True` argument when loading also loads the hyper-parameters of the model: 102 | 103 | ```python 104 | X = np.random.rand(5000, 2048) 105 | tldr = TLDR() 106 | tldr.load("data/inference_model.pth", init=True) # Loads both model parameters and weights 107 | Z = tldr.transform(X, l2_norm=True) # Returns (N x n_components) matrix 108 | ``` 109 | 110 | You can find this full example in [scripts/dummy_example.py](scripts/dummy_example.py). 111 | 112 | ## Documentation 113 | 114 | #### TLDR(n_components, encoder, projector, n_neighbors=5, device='cpu', pin_memory=False) 115 | 116 | Description of selected arguments (see code for full list): 117 | * `n_components`: output dimension 118 | * `encoder`: encoder network architecture specification string--[see formatting guide](#architecture-specification-strings) (Default: `'linear'`). 119 | * `projector`: projector network architecture specification string--[see formatting guide](#architecture-specification-strings) (Default: `'mlp-1-2048'`). 120 | * `n_neighbors`: number of nearest neighbors used to sample training pairs (Default: `5`). 121 | * `device`: selects the device ['cpu', 'cuda'] (Default: `cpu`). 122 | * `pin_memory`: pin all data to the memory of the device (Default: `False`). 123 | * `random_state`: sets the random seed (Default: `None`). 124 | * `knn_approximation`: Amount of approximation to use during the knn computation; accepted values are [None, "low", "medium" and "high"] (Default: `None`). No approximation will calculate exact neighbors while setting the approximation to either low, medium or high will use product quantization and create the FAISS index using the index_factory with an `"IVF1,PQ[X]"` string, where X={32,16,8} for {"low","med","high"}. The PQ parameters are learned using 10% of the training data. 125 | 126 | ```python 127 | from tldr import TLDR 128 | 129 | tlrd = TLDR(n_components=128, encoder='linear', projector='mlp-2-2048', n_neighbors=3, device='cuda') 130 | ``` 131 | 132 | #### fit(X, epochs=100, batch_size=1024, knn_graph=None, output_folder=None, snapshot_freq=None) 133 | Parameters: 134 | * `X`: NxD training data array containing N training samples of dimension D. 135 | * `epochs`: number of training epochs (Default: `100`). 136 | * `batch_size`: size of the training mini batch (Default: `1024`). 137 | * `knn_graph`: `N`x`n_neighbors` array containing the indices of nearest neighbors of each sample; if None it will be computed (Default: `None`). 138 | * `output_folder`: folder where the final model (and also the snapshots if snapshot_freq > 1) will be saved (Default: `None`). 139 | * `snapshot_freq`: number of epochs to save a new snapshot (Default: `None`). 140 | * `print_every`: prints useful training information every given number of steps (Default: `0`). 141 | * `retain_projector`: flag so that the projector parameters are retained after training (Default: `False`). 142 | 143 | ```python 144 | from tldr import TLDR 145 | import numpy as np 146 | 147 | tldr = TLDR(n_components=32, encoder='linear', projector='mlp-2-2048') 148 | X = np.random.rand(10000, 2048) 149 | tldr.fit(X, epochs=50, batch_size=512, output_folder='data/', snapshot_freq=5, print_every=50) 150 | ``` 151 | 152 | #### transform(X, l2_norm=False) 153 | 154 | Parameters: 155 | * `X`: NxD array containing N samples of dimension D. 156 | * `l2_norm`: l2 normalizes the features after projection. Default False. 157 | 158 | Output: 159 | * Z: Nxn_components array 160 | 161 | ```python 162 | tldr.fit(X, epochs=100) 163 | Z = tldr.transform(X, l2_norm=True) 164 | ``` 165 | 166 | #### save(path) and load(path) 167 | 168 | * `save()` saves to disk both model parameters and weights. 169 | * `load()` loads the weights of the model. If `init=True` it initializes the model with the hyper-parameters found in the file. 170 | 171 | ```python 172 | tldr = TLDR(n_components=32, encoder='linear', projector='mlp-2-2048') 173 | tldr.fit(X, epochs=50, batch_size=512) 174 | tldr.save("data/model.pth") # Saves weights and params 175 | 176 | tldr = TLDR() 177 | tldr.load("data/model.pth", init=True) # Initialize model with params in file and loads the weights 178 | ``` 179 | 180 | #### remove_projector() 181 | 182 | Removes the projector head from the model. Useful for reducing the size of the model before saving it to disk. Note that you'll need the projection head if you want to resume training. 183 | 184 | #### compute_knn(), save_knn() and load_knn() 185 | 186 | ```python 187 | tldr = TLDR(n_components=128, encoder='linear', projector='mlp-2-2048') 188 | tldr.compute_knn(X) 189 | tldr.fit(X, epochs=100) 190 | tldr.save_knn("knn.npy") 191 | ``` 192 | 193 | ```python 194 | tldr = TLDR(n_components=128, encoder='linear', projector='mlp-2-2048') 195 | tldr.load_knn("knn.npy") 196 | tldr.fit(X, epochs=100) 197 | ``` 198 | 199 | ### Architecture Specification Strings 200 | 201 | You can specify the network configuration using a string with the following format: 202 | 203 | ```'[NETWORK_TYPE]-[NUM_HIDDEN_LAYERS]-[NUM_DIMENSIONS_PER_LAYER]'``` 204 | 205 | - `NETWORK_TYPE`: three network types currently available: 206 | - `linear`: a linear function parametrized by a weight matrix W of size `input_dim X num_components`. 207 | - `flinear`: a factorized linear model in a sequence of linear layers, each composed of a linear layer followed by a batch normalization layer. 208 | - `mlp`: a multi-layer perceptron (MLP) with batch normalization and rectified linear units (ReLUs) as non-linearities. 209 | - `NUM_HIDDEN_LAYERS`: selects the number of hidden (ie. intermediate) layers for the factorized linear model and the MLP 210 | - `NUM_DIMENSIONS_PER_LAYER`: selects the dimensionality of the hidden layers. 211 | 212 | For example, `linear` will use a single linear layer; `flinear-1-512` will use a factorized linear layer with one hidden layer of 512 dimensions; and `mlp-2-4096` will select a MLP composed of two hidden layers of 4096 dimensions each. 213 | 214 | ## Citation 215 | 216 | Please consider citing the following paper in your publications if this helps your research. 217 | 218 | ``` 219 | @article{kalantidis2022tldr, 220 | title = {TLDR: Twin Learning for Dimensionality Reduction}, 221 | author = {Kalantidis, Y. and Lassance, C. and Almaz\'an, J. and Larlus, D.}, 222 | journal={Transactions of Machine Learning Research}, 223 | year={2022}, 224 | url={https://openreview.net/forum?id=86fhqdBUbx}, 225 | } 226 | ``` 227 | 228 | ## Contributors 229 | 230 | This code has been developed by Jon Almazan, Carlos Lassance, Yannis Kalantidis and Diane Larlus at [NAVER Labs Europe](https://europe.naverlabs.com). 231 | -------------------------------------------------------------------------------- /tldr/tldr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NAVER and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import random 9 | import pathlib 10 | from pathlib import Path 11 | from time import time 12 | from typing import Optional, Union 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | from torch.utils.data import BatchSampler, RandomSampler 18 | 19 | from tldr.loss import BarlowTwinsLoss 20 | from tldr.optimizer import LARS, adjust_learning_rate 21 | from tldr.utils import AverageMeter, get_knn_graph, get_progress_bar, parse_net_config 22 | 23 | 24 | class TLDR_Module(nn.Module): 25 | def __init__( 26 | self, 27 | inputdim: int, 28 | n_components: int, 29 | encoder: str = "linear", 30 | projector: str = "mlp-2-2048", 31 | batch_size: int = 1024, 32 | scale_loss: float = 1.0 / 32, 33 | lambd: float = 3.9e-3, 34 | norm_layer: str = "BN", 35 | loss: str = "BT", 36 | ): 37 | super().__init__() 38 | 39 | self.batch_size = batch_size 40 | self.scale_loss = scale_loss 41 | self.lambd = lambd 42 | self.loss = loss 43 | 44 | if norm_layer == "BN": 45 | self.norm_layer = nn.BatchNorm1d 46 | elif norm_layer == "LN": 47 | self.norm_layer = nn.LayerNorm 48 | 49 | # Encoder 50 | encoder_type, num_hlayers_encoder, hdims_encoder = parse_net_config(encoder) 51 | hdims = hdims_encoder * num_hlayers_encoder 52 | hdims = [inputdim] + hdims + [n_components] 53 | layers = [] 54 | if encoder_type in ["linear", "flinear"]: 55 | for i in range(len(hdims) - 2): 56 | layers.append(nn.Linear(hdims[i], hdims[i + 1])) 57 | layers.append(self.norm_layer(hdims[i + 1])) 58 | layers.append(nn.Linear(hdims[-2], hdims[-1])) 59 | elif encoder_type == "mlp": 60 | for i in range(len(hdims) - 2): 61 | layers.append(nn.Linear(hdims[i], hdims[i + 1], bias=False)) 62 | layers.append(self.norm_layer(hdims[i + 1])) 63 | layers.append(nn.ReLU(inplace=True)) 64 | layers.append(nn.Linear(hdims[-2], hdims[-1], bias=False)) 65 | else: 66 | raise ValueError(f"Incorrect network type {encoder_type}") 67 | self.encoder = nn.Sequential(*layers) 68 | 69 | # Projector 70 | if projector is not None: 71 | projector_type, num_hlayers_projector, hdims_projector = parse_net_config(projector) 72 | sizes = [n_components] + hdims_projector * (num_hlayers_projector + 1) 73 | layers = [] 74 | if projector_type in ["linear", "flinear"]: 75 | for i in range(len(sizes) - 2): 76 | layers.append(nn.Linear(sizes[i], sizes[i + 1])) 77 | layers.append(self.norm_layer(sizes[i + 1])) 78 | layers.append(nn.Linear(sizes[-2], sizes[-1])) 79 | elif projector_type == "mlp": 80 | for i in range(len(sizes) - 2): 81 | layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False)) 82 | layers.append(self.norm_layer(sizes[i + 1])) 83 | layers.append(nn.ReLU(inplace=True)) 84 | layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) 85 | else: 86 | raise ValueError(f"Incorrect network type {projector_type}") 87 | self.projector = nn.Sequential(*layers) 88 | bn_size = sizes[-1] 89 | else: 90 | bn_size = n_components 91 | self.projector = None 92 | 93 | # normalization layer for the representations z1 and z2 94 | self.bn = nn.BatchNorm1d(bn_size, affine=False) 95 | 96 | def forward(self, X: torch.Tensor): 97 | """Performs a forward pass over the encoder, projecting the input features to the learnt space 98 | 99 | Parameters 100 | ---------- 101 | X : tensor 102 | N x D input tensor containing N samples of dimension D 103 | 104 | Returns 105 | ------- 106 | Z : tensor 107 | Projected output tensor of size N x n_components 108 | """ 109 | return self.encoder(X) 110 | 111 | def match(self, x1: torch.Tensor, x2: torch.Tensor): 112 | """Computes the matching loss over two sets of sample pairs 113 | 114 | Parameters 115 | ---------- 116 | x1, x2 : tensor 117 | Two tensors of size N x D where each row represents a matching pair 118 | 119 | Returns 120 | ------- 121 | loss : tensor 122 | Aggregated matching loss over all training pairs 123 | """ 124 | z1 = self.encoder(x1) 125 | z2 = self.encoder(x2) 126 | if self.projector is not None: 127 | z1 = self.projector(z1) 128 | z2 = self.projector(z2) 129 | 130 | if self.loss in ["BT", "BarlowTwins"]: 131 | loss = BarlowTwinsLoss(self.bn(z1), self.bn(z2), self.batch_size, self.scale_loss, self.lambd) 132 | elif self.loss in ["MSE", "MeanSquaredError"]: 133 | loss = nn.MSELoss(reduction="mean")(torch.vstack([x1, x2]), torch.vstack([z1, z2])).mul(self.scale_loss) 134 | elif self.loss == "Contrastive": 135 | raise ValueError("Contrastive loss temporary removed :_( (WIP)") 136 | return loss.unsqueeze(0) 137 | 138 | def set_device(self, device: torch.device): 139 | """Selects the device""" 140 | self.encoder.to(device) 141 | self.projector.to(device) 142 | 143 | 144 | class TLDR: 145 | def __init__( 146 | self, 147 | n_components: int = 32, 148 | encoder: str = "linear", 149 | projector: str = "mlp-2-2048", 150 | n_neighbors: int = 5, 151 | pin_memory: bool = False, 152 | knn_approximation: Optional["str"] = None, 153 | knn_graph: Optional[np.ndarray] = None, 154 | inputdim: Optional[int] = None, 155 | batch_size: int = 1024, 156 | scale_loss: float = 1.0 / 32, 157 | lambd: float = 3.9e-3, 158 | epochs: int = 100, 159 | learning_rate: float = 0.2, 160 | warmup_epochs: int = 10, 161 | norm_layer: str = "BN", 162 | loss: str = "BT", 163 | gaussian: bool = False, 164 | output_folder: Optional[str] = None, 165 | snapshot_freq: int = 0, 166 | resume: bool = False, 167 | save_best: bool = False, 168 | verbose: int = 0, 169 | random_state: Optional[int] = None, 170 | device: Union[str, torch.device] = "cpu", 171 | writer=None, 172 | ): 173 | """ Constructor method of the TLDR class 174 | 175 | Parameters 176 | ---------- 177 | n_components : int 178 | Output dimension 179 | encoder : str 180 | Encoder network architecture specification string (see README) 181 | projector : str 182 | Projector network architecture specification string (see README) 183 | n_neighbors : int 184 | number of nearest neighbors used to sample training pairs 185 | knn_approximation : str (optional) 186 | Amount of approximation to use during the knn computation [None, low, medium, high] 187 | pin_memory : bool 188 | Pin all data to the memory of the device 189 | knn_graph : np.ndarray (optional) 190 | Array containing the indices of nearest neighbors of each sample 191 | inputdim : int (optional) 192 | Input dimension 193 | batch_size : int 194 | Batch size 195 | scale_loss : float 196 | Loss scaling parameter of the LARS optimizer 197 | lambd : float 198 | Lambda parameter of the BarlowTwins loss 199 | epochs : int 200 | Number of training epoch 201 | learning_rate : float 202 | Learning rate 203 | warmup_epochs : int 204 | Waming-up epochs 205 | norm_layer : str 206 | Type of normalization layer used [BN, LN] 207 | loss : str 208 | Training loss [BarlowTwins, MeanSquaredError, Contrastive] 209 | gaussian : bool 210 | Uses uniform random noise to generate training pairs 211 | output_folder : str (optional) 212 | Local folder where the snapshots and final model will be saved 213 | snapshot_freq : int 214 | Number of epochs to save a new snapshot 215 | resume : bool 216 | Enables auto-resuming using snapshots in `output_folder` 217 | save_best : bool 218 | Saves the best intermediate model 219 | verbose : int 220 | Verbosity level [0, 1, 2] 221 | random_state : int (optional) 222 | Fixes the random seed 223 | device : str, torch.device 224 | Selects the device [cpu, gpu] 225 | writer : TBWriter (optional) 226 | TensorBoard writer 227 | """ 228 | self.architecture = { 229 | "inputdim": inputdim, 230 | "n_components": n_components, 231 | "encoder": encoder, 232 | "projector": projector, 233 | "batch_size": batch_size, 234 | "scale_loss": scale_loss, 235 | "lambd": lambd, 236 | "norm_layer": norm_layer, 237 | "loss": loss, 238 | } 239 | self.model = None 240 | self.device = torch.device(device) if type(device) == str else device 241 | self.batch_size = batch_size 242 | self.epochs = epochs 243 | self.start_epoch = 0 244 | self.n_neighbors = n_neighbors 245 | self.learning_rate = learning_rate 246 | self.knn_graph = knn_graph 247 | self.warmup_epochs = warmup_epochs 248 | self.snapshot_freq = snapshot_freq 249 | self.output_folder = output_folder 250 | self.resume = resume 251 | self.pin_memory = pin_memory 252 | self.gaussian = gaussian 253 | self.writer = writer 254 | self.save_best = save_best 255 | self.verbose = verbose 256 | self.random_state = random_state 257 | self.knn_approximation = knn_approximation 258 | if knn_approximation not in [None, "low", "medium", "high"]: 259 | raise ValueError( 260 | f"knn_approximation should be either None or low, medium, or high and it was {knn_approximation}" 261 | ) 262 | if self.random_state is not None: 263 | torch.manual_seed(self.random_state) 264 | np.random.seed(self.random_state) 265 | random.seed(self.random_state) 266 | 267 | def initialize_model(self): 268 | """Initializes the TLDR module using the hyper-parameters in self.architecture""" 269 | self.model = TLDR_Module(**self.architecture) 270 | self.model.to(self.device) 271 | 272 | def parameters(self): 273 | """Returns the parameters of the model""" 274 | if self.model is None: 275 | raise RuntimeError("model not initialized") 276 | 277 | def concat_generators(*args): 278 | for gen in args: 279 | yield from gen 280 | 281 | if self.model.projector is not None: 282 | return concat_generators(self.model.encoder.parameters(), self.model.projector.parameters()) 283 | else: 284 | return self.model.encoder.parameters() 285 | 286 | def fit( 287 | self, 288 | X: Union[torch.tensor, np.ndarray], 289 | epochs: Optional[int] = None, 290 | warmup_epochs: Optional[int] = None, 291 | batch_size: Optional[int] = None, 292 | knn_graph: Optional[np.ndarray] = None, 293 | output_folder: Optional[str] = None, 294 | snapshot_freq: Optional[int] = None, 295 | print_every: Optional[int] = None, 296 | retain_projector: bool = False, 297 | dataset_val=None, 298 | l2_norm_eval: Optional[bool] = False, 299 | eval_every: Optional[int] = None, 300 | ): 301 | """Trains a model on the input data 302 | 303 | Parameters 304 | ---------- 305 | X : torch.tensor, np.ndarray 306 | N x D input array containing N samples of dimension D 307 | epochs : int 308 | Number of training epoch 309 | warmup_epochs : int 310 | Waming-up epochs 311 | batch_size : int 312 | Batch size 313 | knn_graph : np.ndarray (optional) 314 | Array containing the indices of nearest neighbors of each sample 315 | output_folder : str (optional) 316 | Local folder where the snapshots and final model will be saved 317 | snapshot_freq : int 318 | Number of epochs to save a new snapshot 319 | print_every : int 320 | Prints useful training information every given number of steps 321 | retain_projector : bool 322 | Flag so that the projector parameters are retained after training 323 | dataset_val : torch.data.Dataset (optional) 324 | A dataset class containing evaluation data and code 325 | l2_norm_eval : bool 326 | Enables L2 normalization before evaluation (optional 327 | eval_every : int (optional) 328 | Runs evaluation every given number of epochs 329 | """ 330 | self.architecture["inputdim"] = X.shape[1] 331 | if epochs is not None: 332 | self.epochs = epochs 333 | if warmup_epochs is not None: 334 | self.warmup_epochs = warmup_epochs 335 | if batch_size is not None: 336 | self.batch_size = batch_size 337 | self.architecture["batch_size"] = batch_size 338 | if output_folder is not None: 339 | self.output_folder = Path(output_folder) 340 | if self.output_folder is not None: 341 | self.output_folder.mkdir(parents=True, exist_ok=True) 342 | if snapshot_freq is not None: 343 | self.snapshot_freq = snapshot_freq 344 | 345 | self.initialize_model() 346 | if self.model is None: 347 | raise RuntimeError("model not initialized") 348 | 349 | if self.verbose > 1: 350 | if "cuda" in self.device.type: 351 | print(" - Using GPU") 352 | else: 353 | print(" - Using CPU") 354 | n_data = X.shape[0] 355 | 356 | if knn_graph is not None: 357 | self.knn_graph = knn_graph 358 | elif self.knn_graph is None: 359 | self.compute_knn(X) 360 | 361 | # Resuming options 362 | if self.resume: 363 | path = self.output_folder / "final_model.pth" 364 | if path.is_file(): 365 | self.load(path) 366 | print(" * Final model found. Skipping training.") 367 | return 368 | path = self.output_folder / "latest_snapshot.pth" 369 | if path.is_file(): 370 | self.load(path) 371 | self.model.train() 372 | 373 | if isinstance(X, np.ndarray): # if data is not a tensor convert it 374 | X = torch.Tensor(X) 375 | X = X.float() 376 | if self.pin_memory: 377 | X = X.to(self.device) 378 | 379 | def exclude_bias_and_norm(p): 380 | return p.ndim == 1 381 | 382 | optimizer = LARS( 383 | self.model.parameters(), 384 | lr=0, 385 | weight_decay=1e-6, 386 | weight_decay_filter=exclude_bias_and_norm, 387 | lars_adaptation_filter=exclude_bias_and_norm, 388 | ) 389 | 390 | losses = AverageMeter("Loss", ":.4e") 391 | batch_sampler = BatchSampler(RandomSampler(range(n_data)), batch_size=self.batch_size, drop_last=True) 392 | step = self.start_epoch * len(batch_sampler) 393 | best_eval = 0 394 | t0 = time() 395 | with get_progress_bar() as progress: 396 | task = ( 397 | progress.add_task( 398 | description="[green]Training TLDR", total=(len(batch_sampler) * self.epochs), info="-" 399 | ) 400 | if self.verbose > 0 401 | else None 402 | ) 403 | for epoch in range(self.start_epoch, self.epochs): 404 | if self.verbose > 0: 405 | progress.update(task, info=f"epoch {epoch+1} (of {self.epochs})") 406 | for i, ind in enumerate(batch_sampler): 407 | step += 1 408 | if type(self.knn_graph) == dict: # Oracle 409 | ind_nn = [] 410 | for j in ind: 411 | ind_nn.append(random.choices(self.knn_graph[j])[0]) 412 | y1 = X[ind, :] 413 | y2 = X[ind_nn, :] 414 | else: 415 | if self.gaussian: # Synthetic neighbors 416 | y1 = X[ind, :] 417 | y2 = y1 + (torch.std(y1) ** 0.5) * torch.randn(y1.shape).to(self.device) * 0.1 418 | else: # Randomly select m neighbors as training pair(s) 419 | y1 = X[ind, :] 420 | ind_nn = np.random.randint(self.n_neighbors, size=self.batch_size) 421 | y2 = X[self.knn_graph[ind, ind_nn], :] 422 | 423 | if not self.pin_memory: 424 | y1 = y1.to(self.device) 425 | y2 = y2.to(self.device) 426 | 427 | lr = adjust_learning_rate( 428 | self.epochs, 429 | optimizer, 430 | n_data, 431 | step, 432 | self.learning_rate, 433 | self.batch_size, 434 | self.warmup_epochs, 435 | ) 436 | optimizer.zero_grad() 437 | loss = self.model.match(y1, y2).mean() 438 | losses.update(loss.item(), y1.size(0)) 439 | loss.mean().backward() 440 | optimizer.step() 441 | if print_every and step % print_every == 0: 442 | if self.verbose > 1: 443 | progress.console.print(f" * {losses}, LR = {lr:.5f}") 444 | if self.writer: 445 | self.writer.add_scalar( 446 | f'n{self.architecture["n_components"]}/train/loss', 447 | losses.val, 448 | epoch + (i / len(batch_sampler)), 449 | ) 450 | if self.verbose > 0: 451 | progress.update(task, advance=1) 452 | checkpoint = { 453 | "epoch": epoch + 1, 454 | "state_dict": self._get_state_dict(), 455 | "architecture": self.architecture, 456 | } 457 | 458 | if dataset_val is not None and (epoch + 1) % eval_every == 0: 459 | res = self.evaluate(dataset_val, l2_norm_eval) 460 | checkpoint["val"] = res 461 | if self.writer: 462 | self.writer.add_scalar( 463 | f'n{self.architecture["n_components"]}/val/acc', 464 | res, 465 | epoch + 1, 466 | ) 467 | if res > best_eval and self.output_folder and self.save_best: 468 | torch.save(checkpoint, self.output_folder / "best.pth") 469 | best_eval = res 470 | self.model.train() 471 | 472 | if self.output_folder: 473 | if self.snapshot_freq and (epoch + 1) % self.snapshot_freq == 0: 474 | torch.save(checkpoint, self.output_folder / f"snapshot_{epoch+1}.pth") 475 | torch.save(checkpoint, self.output_folder / "latest_snapshot.pth") 476 | 477 | if self.output_folder: 478 | torch.save(checkpoint, self.output_folder / "final_model.pth") 479 | t1 = time() 480 | if self.verbose > 1: 481 | print(" - Fit took %.2g sec" % (t1 - t0)) 482 | if not retain_projector: 483 | self.remove_projector() 484 | 485 | def transform( 486 | self, 487 | X: Union[torch.tensor, np.ndarray], 488 | l2_norm: bool = False, 489 | batching_threshold: int = 10000, 490 | amount_batches: int = 1000, 491 | ): 492 | """Projects the input data to the learnt space 493 | 494 | Parameters 495 | ---------- 496 | X : torch.tensor, np.ndarray 497 | N x D input array containing N samples of dimension D 498 | l2_norm : bool 499 | L2 normalizes the output representation after projection 500 | batching_threshold : int 501 | Applies batching for large input matrices 502 | amount_batches : int 503 | Number of batches in which the input data is splitted before projection 504 | 505 | Returns 506 | ------- 507 | Z : torch.tensor, np.ndarray 508 | Projected output tensor of size N x n_components 509 | """ 510 | if self.model is None: 511 | raise RuntimeError("model not initialized") 512 | 513 | with torch.no_grad(): # Avoid computing gradients when we do not need them 514 | self.model.eval() 515 | self.model.to(self.device) 516 | if isinstance(X, np.ndarray): # If data is not a tensor convert it 517 | X = torch.Tensor(X) 518 | to_numpy = True # Output type same as input 519 | input_device = "cpu" 520 | elif isinstance(X, torch.Tensor): 521 | to_numpy = False # Output type same as input 522 | input_device = X.device.type 523 | else: 524 | raise ValueError(f"unknow input type {type(X)}. Input must be numpy array or torch tensor") 525 | 526 | if ( 527 | X.shape[0] > batching_threshold 528 | ): # If there are more than batching_threshold samples do batched transformation 529 | splitted_dataset = torch.split(X, amount_batches) 530 | all_data = list() 531 | for batch in splitted_dataset: 532 | all_data.append(self.forward(batch, l2_norm=l2_norm)) 533 | Z = torch.vstack(all_data) 534 | else: 535 | Z = self.forward(X, l2_norm=l2_norm) 536 | if input_device == "cpu": 537 | Z = Z.cpu() 538 | if to_numpy: 539 | Z = Z.detach().numpy() 540 | return Z 541 | 542 | def fit_transform(self, X: Union[torch.tensor, np.ndarray], **kwargs): 543 | """ 544 | See documentation of methods fit() and transform() 545 | """ 546 | l2_norm = kwargs.pop("l2_norm", False) 547 | batching_threshold = kwargs.pop("batching_threshold", 10000) 548 | amount_batches = kwargs.pop("amount_batches", 1000) 549 | self.fit(X, **kwargs) 550 | return self.transform(X, l2_norm=l2_norm, batching_threshold=batching_threshold, amount_batches=amount_batches) 551 | 552 | def forward(self, X: Union[torch.tensor, np.ndarray], l2_norm: bool = False): 553 | """Performs a forward pass over the encoder, projecting the input features to the learnt space 554 | 555 | Parameters 556 | ---------- 557 | X : tensor 558 | N x D input tensor containing N samples of dimension D 559 | l2_norm : bool 560 | L2 normalizes the output representation after projection 561 | 562 | Returns 563 | ------- 564 | Z : tensor 565 | Projected output tensor of size N x n_components 566 | """ 567 | Z = self.model.forward(X.float().to(self.device)) 568 | if l2_norm: 569 | Z = Z / torch.linalg.norm(Z, 2, axis=1, keepdims=True) 570 | return Z 571 | 572 | def save(self, path: Union[pathlib.PosixPath, str]): 573 | """Saves both the weights and hyper-parameters of the model to disk""" 574 | path = Path(path) 575 | if not path.parent.is_dir(): 576 | path.parent.mkdir(parents=True, exist_ok=True) 577 | 578 | architecture = self.architecture 579 | if self.model.projector is None: 580 | architecture["projector"] = None 581 | checkpoint = { 582 | "state_dict": self._get_state_dict(), 583 | "architecture": architecture, 584 | } 585 | torch.save(checkpoint, path) 586 | 587 | def load(self, path: Union[pathlib.PosixPath, str], init: bool = False, strict: bool = False): 588 | """Loads a model from disk 589 | 590 | Parameters 591 | ---------- 592 | path: Path, str 593 | Location of the model in disk 594 | init: bool 595 | Forces the model to initialize with the hyper-parameters found in the file 596 | strict: bool 597 | If set to False it ignores non-matching keys 598 | """ 599 | 600 | checkpoint = torch.load(path, map_location=torch.device("cpu")) 601 | architecture = checkpoint.pop("architecture", None) 602 | if init: 603 | self.architecture = architecture 604 | self.initialize_model() 605 | else: 606 | if architecture != self.architecture: 607 | raise ValueError( 608 | f"Parameters in {path} do not match. Use load(path, init=True) to intialize the model with the parameters in this file." 609 | ) 610 | 611 | self.model.load_state_dict(checkpoint["state_dict"], strict=strict) 612 | if "epoch" in checkpoint: 613 | self.start_epoch = checkpoint["epoch"] 614 | self.model.to(self.device) 615 | 616 | def compute_knn(self, X: Union[torch.tensor, np.ndarray]): 617 | """Computes the k nearest neighbors of each sample 618 | Parameters 619 | ---------- 620 | X : ndarray 621 | N x D array of size containing N samples of dimension D 622 | """ 623 | self.knn_graph = get_knn_graph( 624 | X, self.n_neighbors, device=self.device.type, verbose=self.verbose, knn_approximation=self.knn_approximation 625 | ) 626 | 627 | def get_knn(self): 628 | """Returns the graph of K nearest neighbors""" 629 | return self.knn_graph 630 | 631 | def save_knn(self, path: Union[pathlib.PosixPath, str]): 632 | """Saves the K nearest neighbors graph to disk""" 633 | path = Path(path) 634 | if not path.parent.is_dir(): 635 | path.parent.mkdir(parents=True, exist_ok=True) 636 | np.save(path, self.knn_graph) 637 | 638 | def load_knn(self, path: Union[pathlib.PosixPath, str]): 639 | """Loads the K nearest neihgbors graph from disk""" 640 | self.knn_graph = np.load(path) 641 | 642 | def _get_state_dict(self): 643 | """Returns the model parameters""" 644 | return self.model.state_dict() 645 | 646 | def to(self, device: Union[str, torch.device]): 647 | """Moves computation to device""" 648 | self.device = torch.device(device) 649 | self.model.to(self.device) 650 | 651 | def remove_projector(self): 652 | """Removes the projector head from the model""" 653 | self.model.projector = None 654 | self.model.bn = None 655 | 656 | def evaluate(self, dataset, l2_norm_eval: bool = True, whiten_eval: bool = False, metric: str = "mAP-medium"): 657 | self.model.eval() 658 | dataset.transform(self) 659 | return dataset.evaluate(l2_norm=l2_norm_eval, whiten=whiten_eval, metric=metric) 660 | 661 | def __repr__(self): 662 | return self.__str__() 663 | 664 | def __str__(self): 665 | s = f"{self.__class__.__name__}\n" 666 | s += f"{json.dumps(self.architecture, indent=2)}\n\n{self.model}" 667 | return s 668 | --------------------------------------------------------------------------------