├── .gitignore ├── LICENSE ├── readme.md ├── setup.py ├── test.py └── umap_pytorch ├── __init__.py ├── data.py ├── main.py ├── model.py └── modules.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | data/ 3 | build/ 4 | __pycache__/ 5 | .ipynb_checkpoints/ 6 | lightning_logs/ 7 | *.png 8 | umap_pytorch.egg-info/ 9 | dist/ 10 | *.pkl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 elyxlz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | Parametric UMAP port for pytorch using pytorch lightning for the training loop. 2 | 3 | ## Install 4 | ```bash 5 | pip install umap-pytorch 6 | ``` 7 | 8 | ## Usage 9 | 10 | ```py 11 | from umap_pytorch import PUMAP 12 | 13 | pumap = PUMAP( 14 | encoder=None, # nn.Module, None for default 15 | decoder=None, # nn.Module, True for default, None for encoder only 16 | n_neighbors=10, 17 | min_dist=0.1, 18 | metric="euclidean", 19 | n_components=2, 20 | beta=1.0, # How much to weigh reconstruction loss for decoder 21 | reconstruction_loss=F.binary_cross_entropy_with_logits, # pass in custom reconstruction loss functions 22 | random_state=None, 23 | lr=1e-3, 24 | epochs=10, 25 | batch_size=64, 26 | num_workers=1, 27 | num_gpus=1, 28 | match_nonparametric_umap=False # Train network to match embeddings from non parametric umap 29 | ) 30 | 31 | data = torch.randn((50000, 512)) 32 | pumap.fit(data) 33 | embedding = pumap.transform(data) # (50000, 2) 34 | 35 | # if decoder enabled 36 | recon = pumap.inverse_transform(embedding) # (50000, 512) 37 | ``` 38 | 39 | ## Saving and Loading 40 | ```py 41 | # Saving 42 | path = 'some/path/hello.pkl' 43 | pumap.save(path) 44 | 45 | # Loading 46 | from umap_pytorch import load_pumap 47 | pumap = load_pumap(path) 48 | ``` 49 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'umap_pytorch', 5 | packages = find_packages(exclude=[]), 6 | include_package_data = True, 7 | version = '0.0.06', 8 | license='MIT', 9 | description = 'Umap port for pytorch', 10 | author = 'Elio Pascarelli', 11 | author_email = 'elio@pascarelli.com', 12 | url = 'https://github.com/elyxlz/umap_pytorch', 13 | long_description_content_type = 'text/markdown', 14 | keywords = [ 15 | 'artificial intelligence', 16 | 'deep learning', 17 | 'dimensionality reduction', 18 | 'UMAP', 19 | ], 20 | install_requires=[ 21 | 'einops>=0.3', 22 | 'pynndescent', 23 | 'llvmlite>=0.34.0', 24 | 'torch>=1.6', 25 | 'scikit-learn', 26 | 'umap-learn', 27 | 'pytorch_lightning', 28 | 'dill', 29 | ], 30 | classifiers=[ 31 | 'Development Status :: 4 - Beta', 32 | 'Intended Audience :: Developers', 33 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 34 | 'License :: OSI Approved :: MIT License', 35 | 'Programming Language :: Python :: 3.9', 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torchvision.transforms import transforms 3 | import matplotlib.pyplot as plt 4 | from umap_pytorch import PUMAP, load_pumap 5 | import seaborn as sns 6 | import torch 7 | from PIL import Image 8 | import numpy as np 9 | import torch.nn.functional as F 10 | 11 | 12 | train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True) 13 | train_tensor = torch.stack([example[0] for example in train_dataset])[:, 0][:, None, ...] 14 | labels = [str(example[1]) for example in train_dataset] 15 | X = train_tensor 16 | 17 | pumap = PUMAP(epochs=5, min_dist=1, n_neighbors=50, num_workers=8, decoder=True, beta = 0.01, match_nonparametric_umap=True) 18 | pumap.fit(X) 19 | pumap.save('yo.pkl') 20 | pumap = load_pumap('yo.pkl') 21 | embedding = pumap.transform(X) 22 | print(embedding.shape, embedding) 23 | sns.scatterplot(x=embedding[:,0], y=embedding[:,1], hue=labels, s=0.4) 24 | plt.savefig('test4.png') 25 | 26 | 27 | def regenerate_and_plot(i=6): 28 | some_points = embedding[np.random.choice(embedding.shape[0], 6)] 29 | regenerated = pumap.inverse_transform(torch.Tensor(some_points)) 30 | 31 | for i in range(6): 32 | img = regenerated[i,0] 33 | img = Image.fromarray(np.uint8(img)) 34 | img.save("image_{}.png".format(i)) 35 | 36 | regenerate_and_plot() 37 | -------------------------------------------------------------------------------- /umap_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import PUMAP, load_pumap -------------------------------------------------------------------------------- /umap_pytorch/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | 5 | def get_graph_elements(graph_, n_epochs): 6 | 7 | graph = graph_.tocoo() 8 | # eliminate duplicate entries by summing them together 9 | graph.sum_duplicates() 10 | # number of vertices in dataset 11 | n_vertices = graph.shape[1] 12 | # get the number of epochs based on the size of the dataset 13 | if n_epochs is None: 14 | # For smaller datasets we can use more epochs 15 | if graph.shape[0] <= 10000: 16 | n_epochs = 500 17 | else: 18 | n_epochs = 200 19 | # remove elements with very low probability 20 | graph.data[graph.data < (graph.data.max() / float(n_epochs))] = 0.0 21 | graph.eliminate_zeros() 22 | # get epochs per sample based upon edge probability 23 | epochs_per_sample = n_epochs * graph.data 24 | 25 | head = graph.row 26 | tail = graph.col 27 | weight = graph.data 28 | 29 | return graph, epochs_per_sample, head, tail, weight, n_vertices 30 | 31 | class UMAPDataset(Dataset): 32 | def __init__(self, data, graph_, n_epochs=200): 33 | graph, epochs_per_sample, head, tail, weight, n_vertices = get_graph_elements(graph_, n_epochs) 34 | 35 | self.edges_to_exp, self.edges_from_exp = ( 36 | np.repeat(head, epochs_per_sample.astype("int")), 37 | np.repeat(tail, epochs_per_sample.astype("int")), 38 | ) 39 | shuffle_mask = np.random.permutation(np.arange(len(self.edges_to_exp))) 40 | self.edges_to_exp = self.edges_to_exp[shuffle_mask].astype(np.int64) 41 | self.edges_from_exp = self.edges_from_exp[shuffle_mask].astype(np.int64) 42 | self.data = torch.Tensor(data) 43 | 44 | def __len__(self): 45 | return int(self.data.shape[0]) 46 | 47 | def __getitem__(self, index): 48 | edges_to_exp = self.data[self.edges_to_exp[index]] 49 | edges_from_exp = self.data[self.edges_from_exp[index]] 50 | return (edges_to_exp, edges_from_exp) 51 | 52 | class MatchDataset(Dataset): 53 | def __init__(self, data, embeddings): 54 | self.embeddings = torch.Tensor(embeddings) 55 | self.data = data 56 | def __len__(self): 57 | return int(self.data.shape[0]) 58 | def __getitem__(self, index): 59 | return self.data[index], self.embeddings[index] -------------------------------------------------------------------------------- /umap_pytorch/main.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from torch.nn.functional import mse_loss 6 | import torch.nn.functional as F 7 | 8 | from umap_pytorch.data import UMAPDataset, MatchDataset 9 | from umap_pytorch.modules import get_umap_graph, umap_loss 10 | from umap_pytorch.model import default_encoder, default_decoder 11 | 12 | from umap.umap_ import find_ab_params 13 | import dill 14 | from umap import UMAP 15 | 16 | """ Model """ 17 | 18 | 19 | class Model(pl.LightningModule): 20 | def __init__( 21 | self, 22 | lr: float, 23 | encoder: nn.Module, 24 | decoder=None, 25 | beta = 1.0, 26 | min_dist=0.1, 27 | reconstruction_loss=F.binary_cross_entropy_with_logits, 28 | match_nonparametric_umap=False, 29 | ): 30 | super().__init__() 31 | self.lr = lr 32 | self.encoder = encoder 33 | self.decoder = decoder 34 | self.beta = beta # weight for reconstruction loss 35 | self.match_nonparametric_umap = match_nonparametric_umap 36 | self.reconstruction_loss = reconstruction_loss 37 | self._a, self._b = find_ab_params(1.0, min_dist) 38 | 39 | def configure_optimizers(self): 40 | return torch.optim.AdamW(self.parameters(), lr=self.lr) 41 | 42 | def training_step(self, batch, batch_idx): 43 | if not self.match_nonparametric_umap: 44 | (edges_to_exp, edges_from_exp) = batch 45 | embedding_to, embedding_from = self.encoder(edges_to_exp), self.encoder(edges_from_exp) 46 | encoder_loss = umap_loss(embedding_to, embedding_from, self._a, self._b, edges_to_exp.shape[0], negative_sample_rate=5) 47 | self.log("umap_loss", encoder_loss, prog_bar=True) 48 | 49 | if self.decoder: 50 | recon = self.decoder(embedding_to) 51 | recon_loss = self.reconstruction_loss(recon, edges_to_exp) 52 | self.log("recon_loss", recon_loss, prog_bar=True) 53 | return encoder_loss + self.beta * recon_loss 54 | else: 55 | return encoder_loss 56 | 57 | else: 58 | data, embedding = batch 59 | embedding_parametric = self.encoder(data) 60 | encoder_loss = mse_loss(embedding_parametric, embedding) 61 | self.log("encoder_loss", encoder_loss, prog_bar=True) 62 | if self.decoder: 63 | recon = self.decoder(embedding_parametric) 64 | recon_loss = self.reconstruction_loss(recon, data) 65 | self.log("recon_loss", recon_loss, prog_bar=True) 66 | return encoder_loss + self.beta * recon_loss 67 | else: 68 | return encoder_loss 69 | 70 | 71 | """ Datamodule """ 72 | 73 | 74 | class Datamodule(pl.LightningDataModule): 75 | def __init__( 76 | self, 77 | dataset, 78 | batch_size, 79 | num_workers, 80 | ): 81 | super().__init__() 82 | self.dataset = dataset 83 | self.batch_size = batch_size 84 | self.num_workers = num_workers 85 | 86 | def train_dataloader(self) -> DataLoader: 87 | return DataLoader( 88 | dataset=self.dataset, 89 | batch_size=self.batch_size, 90 | num_workers=self.num_workers, 91 | shuffle=True, 92 | ) 93 | 94 | class PUMAP(): 95 | def __init__( 96 | self, 97 | encoder=None, 98 | decoder=None, 99 | n_neighbors=10, 100 | min_dist=0.1, 101 | metric="euclidean", 102 | n_components=2, 103 | beta=1.0, 104 | reconstruction_loss=F.binary_cross_entropy_with_logits, 105 | random_state=None, 106 | lr=1e-3, 107 | epochs=10, 108 | batch_size=64, 109 | num_workers=1, 110 | num_gpus=1, 111 | match_nonparametric_umap=False, 112 | ): 113 | self.encoder = encoder 114 | self.decoder = decoder 115 | self.n_neighbors = n_neighbors 116 | self.min_dist = min_dist 117 | self.metric = metric 118 | self.n_components = n_components 119 | self.beta = beta 120 | self.reconstruction_loss = reconstruction_loss 121 | self.random_state = random_state 122 | self.lr = lr 123 | self.epochs = epochs 124 | self.batch_size = batch_size 125 | self.num_workers = num_workers 126 | self.num_gpus = num_gpus 127 | self.match_nonparametric_umap = match_nonparametric_umap 128 | 129 | def fit(self, X): 130 | trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=self.epochs) 131 | encoder = default_encoder(X.shape[1:], self.n_components) if self.encoder is None else self.encoder 132 | 133 | if self.decoder is None or isinstance(self.decoder, nn.Module): 134 | decoder = self.decoder 135 | elif self.decoder == True: 136 | decoder = default_decoder(X.shape[1:], self.n_components) 137 | 138 | 139 | if not self.match_nonparametric_umap: 140 | self.model = Model(self.lr, encoder, decoder, beta=self.beta, min_dist=self.min_dist, reconstruction_loss=self.reconstruction_loss) 141 | graph = get_umap_graph(X, n_neighbors=self.n_neighbors, metric=self.metric, random_state=self.random_state) 142 | trainer.fit( 143 | model=self.model, 144 | datamodule=Datamodule(UMAPDataset(X, graph), self.batch_size, self.num_workers) 145 | ) 146 | else: 147 | print("Fitting Non parametric Umap") 148 | non_parametric_umap = UMAP(n_neighbors=self.n_neighbors, min_dist=self.min_dist, metric=self.metric, n_components=self.n_components, random_state=self.random_state, verbose=True) 149 | non_parametric_embeddings = non_parametric_umap.fit_transform(torch.flatten(X, 1, -1).numpy()) 150 | self.model = Model(self.lr, encoder, decoder, beta=self.beta, reconstruction_loss=self.reconstruction_loss, match_nonparametric_umap=self.match_nonparametric_umap) 151 | print("Training NN to match embeddings") 152 | trainer.fit( 153 | model=self.model, 154 | datamodule=Datamodule(MatchDataset(X, non_parametric_embeddings), self.batch_size, self.num_workers) 155 | ) 156 | 157 | @torch.no_grad() 158 | def transform(self, X): 159 | print(f"Reducing array of shape {X.shape} to ({X.shape[0]}, {self.n_components})") 160 | return self.model.encoder(X).detach().cpu().numpy() 161 | 162 | @torch.no_grad() 163 | def inverse_transform(self, Z): 164 | return self.model.decoder(Z).detach().cpu().numpy() 165 | 166 | def save(self, path): 167 | with open(path, 'wb') as oup: 168 | dill.dump(self, oup) 169 | print(f"Pickled PUMAP object at {path}") 170 | 171 | def load_pumap(path): 172 | print("Loading PUMAP object from pickled file.") 173 | with open(path, 'rb') as inp: return dill.load(inp) 174 | 175 | if __name__== "__main__": 176 | pass -------------------------------------------------------------------------------- /umap_pytorch/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class conv_encoder(nn.Module): 6 | def __init__(self, n_components=2): 7 | super().__init__() 8 | self.encoder = nn.Sequential( 9 | nn.Conv2d( 10 | in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=1, 11 | ), 12 | nn.Conv2d( 13 | in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, 14 | ), 15 | nn.Flatten(), 16 | nn.Linear(6272, 512), 17 | nn.ReLU(), 18 | nn.Linear(512, 512), 19 | nn.ReLU(), 20 | nn.Linear(512, n_components) 21 | ).cuda() 22 | def forward(self, X): 23 | return self.encoder(X) 24 | 25 | class default_encoder(nn.Module): 26 | def __init__(self, dims, n_components=2): 27 | super().__init__() 28 | self.encoder = nn.Sequential( 29 | nn.Flatten(), 30 | nn.Linear(np.product(dims), 200), 31 | nn.ReLU(), 32 | nn.Linear(200,200), 33 | nn.ReLU(), 34 | nn.Linear(200,200), 35 | nn.ReLU(), 36 | nn.Linear(200, n_components), 37 | ).cuda() 38 | 39 | def forward(self, X): 40 | return self.encoder(X) 41 | 42 | class default_decoder(nn.Module): 43 | def __init__(self, dims, n_components): 44 | super().__init__() 45 | self.dims = dims 46 | self.decoder = nn.Sequential( 47 | nn.Linear(n_components, 200), 48 | nn.ReLU(), 49 | nn.Linear(200,200), 50 | nn.ReLU(), 51 | nn.Linear(200,200), 52 | nn.ReLU(), 53 | nn.Linear(200, np.product(dims)), 54 | ).cuda() 55 | def forward(self, X): 56 | return self.decoder(X).view(X.shape[0], *self.dims) 57 | 58 | 59 | if __name__ == "__main__": 60 | model = conv_encoder(2) 61 | print(model.parameters) 62 | print(model(torch.randn((12,1,28,28)).cuda()).shape) -------------------------------------------------------------------------------- /umap_pytorch/modules.py: -------------------------------------------------------------------------------- 1 | from pynndescent import NNDescent 2 | import numpy as np 3 | from sklearn.utils import check_random_state 4 | from umap.umap_ import fuzzy_simplicial_set 5 | import torch 6 | 7 | def convert_distance_to_probability(distances, a=1.0, b=1.0): 8 | return -torch.log1p(a * distances ** (2 * b)) 9 | 10 | def compute_cross_entropy( 11 | probabilities_graph, probabilities_distance, EPS=1e-4, repulsion_strength=1.0 12 | ): 13 | # cross entropy 14 | attraction_term = -probabilities_graph * torch.nn.functional.logsigmoid( 15 | probabilities_distance 16 | ) 17 | repellant_term = ( 18 | -(1.0 - probabilities_graph) 19 | * (torch.nn.functional.logsigmoid(probabilities_distance)-probabilities_distance) 20 | * repulsion_strength) 21 | 22 | # balance the expected losses between atrraction and repel 23 | CE = attraction_term + repellant_term 24 | return attraction_term, repellant_term, CE 25 | 26 | def umap_loss(embedding_to, embedding_from, _a, _b, batch_size, negative_sample_rate=5): 27 | # get negative samples by randomly shuffling the batch 28 | embedding_neg_to = embedding_to.repeat(negative_sample_rate, 1) 29 | repeat_neg = embedding_from.repeat(negative_sample_rate, 1) 30 | embedding_neg_from = repeat_neg[torch.randperm(repeat_neg.shape[0])] 31 | distance_embedding = torch.cat(( 32 | (embedding_to - embedding_from).norm(dim=1), 33 | (embedding_neg_to - embedding_neg_from).norm(dim=1) 34 | ), dim=0) 35 | 36 | # convert probabilities to distances 37 | probabilities_distance = convert_distance_to_probability( 38 | distance_embedding, _a, _b 39 | ) 40 | # set true probabilities based on negative sampling 41 | probabilities_graph = torch.cat( 42 | (torch.ones(batch_size), torch.zeros(batch_size * negative_sample_rate)), dim=0, 43 | ) 44 | 45 | # compute cross entropy 46 | (attraction_loss, repellant_loss, ce_loss) = compute_cross_entropy( 47 | probabilities_graph.cuda(), 48 | probabilities_distance.cuda(), 49 | ) 50 | loss = torch.mean(ce_loss) 51 | return loss 52 | 53 | def get_umap_graph(X, n_neighbors=10, metric="cosine", random_state=None): 54 | random_state = check_random_state(None) if random_state == None else random_state 55 | # number of trees in random projection forest 56 | n_trees = 5 + int(round((X.shape[0]) ** 0.5 / 20.0)) 57 | # max number of nearest neighbor iters to perform 58 | n_iters = max(5, int(round(np.log2(X.shape[0])))) 59 | # distance metric 60 | 61 | # get nearest neighbors 62 | nnd = NNDescent( 63 | X.reshape((len(X), np.product(np.shape(X)[1:]))), 64 | n_neighbors=n_neighbors, 65 | metric=metric, 66 | n_trees=n_trees, 67 | n_iters=n_iters, 68 | max_candidates=60, 69 | verbose=True 70 | ) 71 | # get indices and distances 72 | knn_indices, knn_dists = nnd.neighbor_graph 73 | 74 | # get indices and distances 75 | knn_indices, knn_dists = nnd.neighbor_graph 76 | # build fuzzy_simplicial_set 77 | umap_graph, sigmas, rhos = fuzzy_simplicial_set( 78 | X = X, 79 | n_neighbors = n_neighbors, 80 | metric = metric, 81 | random_state = random_state, 82 | knn_indices= knn_indices, 83 | knn_dists = knn_dists, 84 | ) 85 | 86 | return umap_graph --------------------------------------------------------------------------------