├── LICENSE ├── README.md ├── den.py ├── losses.py ├── models.py └── plots.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Isaac Robinson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DEN 2 | 3 | We present a visualization algorithm based on a novel unsupervised Siameseneural network training regime and loss function, called Differentiating EmbeddingNetworks (DEN). The Siamese neural network finds differentiating or similarfeatures between specific pairs of samples in a dataset, and uses these features toembed the dataset in a lower dimensional space where it can be visualized. Unlikeexisting visualization algorithms such as UMAP ort-SNE, DEN is parametric,meaning it can be interpreted by techniques such as SHAP. To interpret DEN, wecreate an end-to-end parametric clustering algorithm on top of the visualization,and then leverage SHAP scores to determine which features in the sample spaceare important for understanding the structures shown in the visualization basedon the clusters found. We compare DEN visualizations with existing techniqueson a variety of datasets, including image and scRNA-seq data. We then showthat our clustering algorithm performs similarly to the state of the art despite nothaving prior knowledge of the number of clusters, and sets a new state of the arton FashionMNIST. Finally, we demonstrate finding differentiating features of adataset. 4 | 5 | Link to paper: https://arxiv.org/abs/2006.06640 6 | 7 | ## Prereqs 8 | 9 | We highly recommend running this with a GPU. It is VERY slow on CPU. Our package will automatically discover and use any CUDA-enabled GPU discovered. 10 | 11 | To get started, make sure you have PyTorch, Sklearn, Numpy, progressbar2, and matplotlib (with its 3D toolkit) installed. Then, install SHAP (https://github.com/slundberg/shap) with pip: 12 | `pip install shap` 13 | 14 | ## Usage 15 | 16 | DEN.py shows example usage downloading and running the USPS dataset from torchvision. DEN can be run like a standard Sklearn classifier. 17 | 18 | # MORE DOCUMENTATION TO COME SOON 19 | -------------------------------------------------------------------------------- /den.py: -------------------------------------------------------------------------------- 1 | from sklearn.base import BaseEstimator 2 | from sklearn.metrics import normalized_mutual_info_score, mutual_info_score, silhouette_score, davies_bouldin_score, calinski_harabasz_score, v_measure_score, adjusted_mutual_info_score, log_loss 3 | from sklearn.metrics.pairwise import pairwise_distances 4 | from sklearn.feature_extraction.text import TfidfVectorizer 5 | from sklearn.cluster import KMeans 6 | from sklearn.neighbors import KNeighborsClassifier 7 | from sklearn import linear_model, naive_bayes 8 | from sklearn.preprocessing import StandardScaler 9 | from sklearn.model_selection import train_test_split, cross_val_score 10 | from scipy.linalg import eigh 11 | from scipy.optimize import linear_sum_assignment 12 | from scipy.stats import entropy as get_entropy 13 | import numpy as np 14 | import torch 15 | import torch.optim as optim 16 | from torch.utils.data import TensorDataset, DataLoader, ConcatDataset 17 | from torch.nn.utils.rnn import pad_sequence 18 | import progressbar 19 | import copy 20 | import matplotlib.pyplot as plt 21 | import shap 22 | import contextlib 23 | import sys 24 | from collections import defaultdict, Counter 25 | 26 | from models import * 27 | from losses import f_loss, reg_betainc 28 | from plots import plot_2d, plot_3d 29 | 30 | 31 | simple_classifier = linear_model.LogisticRegression(solver = 'lbfgs', n_jobs = -1) 32 | 33 | 34 | def change_cluster_labels_to_sequential(clusters): 35 | labels = np.unique(clusters) 36 | clusters_to_labels = {cluster:i for i, cluster in enumerate(labels)} 37 | seq_clusters = np.array([clusters_to_labels[cluster] for cluster in clusters]) 38 | 39 | return seq_clusters 40 | 41 | def make_cost_matrix(c1, c2): 42 | c1 = change_cluster_labels_to_sequential(c1) 43 | c2 = change_cluster_labels_to_sequential(c2) 44 | 45 | uc1 = np.unique(c1) 46 | uc2 = np.unique(c2) 47 | l1 = uc1.size 48 | l2 = uc2.size 49 | assert(l1 == l2 and np.all(uc1 == uc2)), str(uc1) + " vs " + str(uc2) 50 | 51 | m = np.ones([l1, l2]) 52 | for i in range(l1): 53 | it_i = np.nonzero(c1 == uc1[i])[0] 54 | for j in range(l2): 55 | it_j = np.nonzero(c2 == uc2[j])[0] 56 | m_ij = np.intersect1d(it_j, it_i) 57 | m[i,j] = -m_ij.size 58 | 59 | return m 60 | 61 | def get_accuracy(clusters, labels): 62 | cost = make_cost_matrix(clusters, labels) 63 | row_ind, col_ind = linear_sum_assignment(cost) 64 | to_labels = {i: ind for i, ind in enumerate(col_ind)} 65 | clusters_as_labels = list(map(to_labels.get, clusters)) 66 | acc = np.sum(clusters_as_labels == labels) / labels.shape[0] 67 | 68 | return acc 69 | 70 | def tokens_to_tfidf(x): 71 | list_of_strs = [' '.join(str(token) for token in item if token != 0) for item in x] 72 | out = TfidfVectorizer().fit_transform(list_of_strs) 73 | 74 | return out 75 | 76 | class DummyFile(object): 77 | def write(self, x): pass 78 | def flush(self): pass 79 | 80 | @contextlib.contextmanager 81 | def nostdout(): 82 | save_stdout = sys.stdout 83 | sys.stdout = DummyFile() 84 | yield 85 | sys.stdout = save_stdout 86 | 87 | class DEN(BaseEstimator): 88 | def __init__( 89 | self, 90 | n_components = 2, 91 | model = 'auto', 92 | min_neighbors = 1, 93 | max_neighbors = 10, 94 | snn = True, 95 | batch_size = 256, 96 | ignore = 1, 97 | metric = 'euclidean', 98 | neighbors_preprocess = None, 99 | use_gpu = True, 100 | learning_rate = 1e-3, 101 | optimizer_override = None, 102 | epochs = 10, 103 | verbose_level = 3, 104 | random_seed = 37, 105 | gamma = 1, 106 | semisupervised = False, 107 | cluster_subnet_dropout_p = .3, 108 | is_tokens = False, 109 | cluster_subsample_n = 1000, 110 | initial_zero_cutoff = 1e-2, 111 | minimum_zero_cutoff = 1e-7, 112 | update_zero_cutoff = False, 113 | internal_dim = 128, 114 | cluster_subnet_training_epochs = 50, 115 | semisupervised_weight = None, 116 | l2_penalty = 0, 117 | prune_graph = False, 118 | fine_tune_end_to_end = True, 119 | fine_tune_epochs = 50, 120 | simple_classifier = simple_classifier, 121 | final_training_epochs = 20, 122 | final_dropout_p = .3, 123 | min_p = 0, 124 | max_correlation = 0 125 | ): 126 | self.n_components = n_components 127 | self.model = model 128 | self.min_neighbors = min_neighbors 129 | self.max_neighbors = max_neighbors 130 | self.snn = snn 131 | self.batch_size = batch_size 132 | self.ignore = ignore 133 | self.metric = metric 134 | self.neighbors_preprocess = neighbors_preprocess 135 | self.use_gpu = use_gpu 136 | self.learning_rate = learning_rate 137 | self.optimizer_override = optimizer_override 138 | self.epochs = epochs 139 | self.verbose_level = verbose_level 140 | self.random_seed = random_seed 141 | self.gamma = gamma 142 | self.semisupervised = semisupervised 143 | self.cluster_subnet_dropout_p = cluster_subnet_dropout_p 144 | self.is_tokens = False # forces TF-IDF preprocessing if preprocessing unspecified 145 | self.cluster_subsample_n = cluster_subsample_n 146 | self.initial_zero_cutoff = initial_zero_cutoff 147 | self.minimum_zero_cutoff = minimum_zero_cutoff 148 | self.update_zero_cutoff = update_zero_cutoff 149 | self.internal_dim = internal_dim 150 | self.cluster_subnet_training_epochs = cluster_subnet_training_epochs 151 | self.semisupervised_weight = semisupervised_weight 152 | self.l2_penalty = l2_penalty 153 | self.prune_graph = prune_graph 154 | self.fine_tune_end_to_end = fine_tune_end_to_end 155 | self.fine_tune_epochs = fine_tune_epochs 156 | self.simple_classifier = simple_classifier 157 | self.final_training_epochs = final_training_epochs 158 | # self.final_model = final_model 159 | self.final_dropout_p = final_dropout_p 160 | self.min_p = min_p 161 | self.max_correlation = max_correlation 162 | 163 | self.best_full_net = None 164 | self.best_embedding_net = None 165 | self.optimizer = None 166 | self.semisupervised_model = None 167 | # self.final_model = None 168 | 169 | def find_differentiating_features(self, sample, context, n_context_samples = 400, feature_names = None): 170 | assert self.best_full_net is not None, "have not trained a prediction network yet!" 171 | 172 | if n_context_samples < context.shape[0]: 173 | context_subsample_inds = np.random.choice(context.shape[0], n_context_samples, replace = False) 174 | context_subsample = context[context_subsample_inds] 175 | else: 176 | context_subsample = context 177 | 178 | e = shap.DeepExplainer(self.best_full_net, context_subsample) 179 | 180 | if sample.shape[0] == 1: 181 | # only one sample so assuming need to add batch dimension 182 | sample = sample.unsqueeze(0) 183 | 184 | with nostdout(): 185 | shap_values, indexes = e.shap_values(sample, ranked_outputs = 1) 186 | 187 | if type(sample) is not np.ndarray: 188 | sample = sample.cpu().numpy() 189 | 190 | if len(context.shape) == 4: 191 | # assuming image 192 | shap_values = [np.swapaxes(np.swapaxes(s, 2, 3), 1, -1) for s in shap_values] 193 | if sample.shape[1] == 1: 194 | # need valid image shape for matplotlib 195 | sample = sample.squeeze(1) 196 | shap.image_plot(shap_values, -sample) 197 | else: 198 | # assuming not image 199 | shap.force_plot(e.expected_value[0], shap_values[0], sample, feature_names = feature_names, matplotlib = True) 200 | 201 | def summerize_differentiating_features(self, X, n_samples = 200, n_context_samples = 400): 202 | # split the dataset into clusters and average differentiating features in each cluster 203 | n_samples = min(n_samples, X.shape[0]) 204 | sample_inds = np.random.choice(X.shape[0], n_samples, replace = False) 205 | samples = X[sample_inds] 206 | clusters = self.predict(samples) 207 | 208 | n_context_samples = min(n_context_samples, X.shape[0]) 209 | context_sample_inds = np.random.choice(X.shape[0], n_context_samples, replace = False) 210 | context_samples = X[context_sample_inds] 211 | context_samples = context_samples.to(self.device) 212 | 213 | e = shap.DeepExplainer(self.best_full_net, context_samples) 214 | 215 | summerizations = defaultdict(lambda : np.zeros(X.shape[1:])) 216 | average_samples = defaultdict(lambda : np.zeros(X.shape[1:])) 217 | counts = Counter(clusters) 218 | 219 | self._print_with_verbosity("finding differentiating features across the dataset...", 1) 220 | 221 | for cluster, sample in self._progressbar_with_verbosity(zip(clusters, samples), 1, max_value = n_samples): 222 | with nostdout(): 223 | shap_values, indexes = e.shap_values(sample.unsqueeze(0), ranked_outputs = 1) 224 | summerizations[cluster] += shap_values[0].squeeze(0) / counts[cluster] 225 | average_samples[cluster] += sample.cpu().numpy() / counts[cluster] 226 | 227 | # recall that dictionaries are ordered in Python3 228 | summery = np.array(list(summerizations.values())).squeeze(1) 229 | averages = np.array(list(average_samples.values())).squeeze(1) 230 | 231 | shap.image_plot(summery, -averages) 232 | 233 | 234 | def _print_with_verbosity(self, message, level, strict = False): 235 | if level <= self.verbose_level and (not strict or level == self.verbose_level): 236 | print(message) 237 | 238 | def _progressbar_with_verbosity(self, data, level, max_value = None, strict = False): 239 | if level <= self.verbose_level and (not strict or level == self.verbose_level): 240 | for datum in progressbar.progressbar(data, max_value = max_value): 241 | yield datum 242 | else: 243 | for datum in data: 244 | yield datum 245 | 246 | def _select_model(self, X, n_outputs = None, dropout_p = 0): 247 | if n_outputs is None: 248 | n_outputs = self.n_components 249 | # not sure if allowed to modify model attribute under sklearn rules 250 | if type(X) is tuple or type(X) is list: 251 | self._print_with_verbosity("assuming token-based data, using bag-of-words model", 1) 252 | self.is_tokens = True 253 | vocab = set() 254 | for x in X: 255 | vocab.update(x) 256 | vocab_size = len(vocab) 257 | to_model = BOWNN(n_outputs, vocab_size, internal_dim = self.internal_dim) 258 | else: 259 | n_dims = len(X.shape) 260 | if n_dims == 2: 261 | self._print_with_verbosity("using fully connected neural network", 1) 262 | to_model = FFNN(n_outputs, X.shape[1], internal_dim = self.internal_dim) 263 | elif n_dims == 4: 264 | self._print_with_verbosity("using convolutional neural network", 1) 265 | n_layers = int(np.log2(min(X.shape[2], X.shape[3]))) 266 | to_model = CNN(n_outputs, n_layers, internal_dim = self.internal_dim, p = dropout_p) 267 | # self.model = ClusterNet(X.shape[-1]*X.shape[-2], self.n_components) 268 | else: 269 | assert False, "not sure which neural network to use based off data provided" 270 | 271 | return to_model 272 | 273 | def _get_near_and_far_pairs_mem_efficient_chunks(self, X, block_size = 512, return_sorted = True): 274 | n_neighbors = self.max_neighbors 275 | 276 | closest = [] 277 | furthest = [] 278 | if type(X) is np.ndarray: 279 | splits = np.array_split(X, max(X.shape[0] // block_size, 1)) 280 | max_value = len(splits) 281 | else: 282 | inds = list(range(0, X.shape[0], block_size)) 283 | inds.append(None) 284 | splits = (X[inds[i]:inds[i+1]] for i in range(len(inds) - 1)) 285 | max_value = len(inds) - 1 286 | 287 | self._print_with_verbosity(f"using metric {self.metric} to build nearest neighbors graph", 2) 288 | 289 | for first in self._progressbar_with_verbosity(splits, 2, max_value = max_value): 290 | dists = pairwise_distances(first, X, n_jobs = -1, metric = self.metric) 291 | # dists = cdist(first, X, metric = metric) 292 | this_closest = np.argpartition(dists, n_neighbors + 1)[:, :n_neighbors+1] 293 | if return_sorted: 294 | original_set = set(this_closest[-1]) 295 | relevant = dists[np.arange(this_closest.shape[0])[:, None], this_closest] 296 | sorted_inds = np.argsort(relevant) 297 | this_closest = this_closest[np.arange(sorted_inds.shape[0])[:, None], sorted_inds] 298 | assert set(this_closest[-1]) == original_set, "something went wrong with sorting" 299 | this_closest = this_closest[:, 1:] 300 | closest.append(this_closest) 301 | 302 | probs = dists / np.sum(dists, axis = 1)[:, None] 303 | this_furthest = np.array([np.random.choice(len(probs[i]), n_neighbors, False, probs[i]) for i in range(len(probs))]) 304 | furthest.append(np.array(this_furthest)) 305 | 306 | closest = np.concatenate(closest) 307 | furthest = np.concatenate(furthest) 308 | 309 | return closest, furthest 310 | 311 | def _build_dataset(self, X, y = None): 312 | # returns Dataset object 313 | neighbors_X = X.view(X.shape[0], -1).cpu().numpy() 314 | 315 | if self.is_tokens and self.neighbors_preprocess is None: 316 | self._print_with_verbosity("using tokenized data without neighbors preprocessing so using TF-IDF transform", 2) 317 | self.neighbors_preprocess = tokens_to_tfidf 318 | self.metric = 'cosine' 319 | 320 | if self.neighbors_preprocess is not None: 321 | neighbors_X = self.neighbors_preprocess(neighbors_X) 322 | 323 | closest, furthest = self._get_near_and_far_pairs_mem_efficient_chunks(neighbors_X) 324 | 325 | samples = [] 326 | paired = [] 327 | 328 | # for semisupervised version 329 | # assuming y has positive integer class labels 330 | # and -1 if there is no label 331 | first_label = [] 332 | second_label = [] 333 | 334 | self._print_with_verbosity("building dataset from nearest neighbors graph", 1) 335 | 336 | already_paired = set() 337 | for first, seconds in enumerate(closest): 338 | represented = 0 339 | for ind, second in enumerate(seconds[::-1]): # matters if sorted and min_neighbors so closest are last 340 | if self.snn: 341 | if first not in closest[second]: 342 | n_left = len(seconds) - ind 343 | if n_left > self.min_neighbors - represented: 344 | continue 345 | 346 | if self.semisupervised and self.prune_graph: 347 | if y[first] != y[second] and y[first] != -1 and y[second] != -1: 348 | continue 349 | 350 | represented += 1 351 | 352 | if tuple(sorted([first, second])) not in already_paired and first != second: 353 | first_data = X[first] 354 | second_data = X[second] 355 | stack = torch.stack([first_data, second_data]) 356 | 357 | samples.append(stack) 358 | paired.append(1) 359 | already_paired.add(tuple(sorted([first, second]))) 360 | 361 | if y is not None: 362 | first_label.append(y[first]) 363 | second_label.append(y[second]) 364 | else: 365 | first_label.append(-1) 366 | second_label.append(-1) 367 | 368 | already_paired = set() 369 | for first, seconds in enumerate(furthest): 370 | for second in seconds: 371 | if self.semisupervised and self.prune_graph: 372 | if y[first] == y[second] and y[first] != -1 and y[second] != -1: 373 | continue 374 | 375 | if tuple(sorted([first, second])) not in already_paired and first != second: 376 | first_data = X[first] 377 | second_data = X[second] 378 | stack = torch.stack([first_data, second_data]) 379 | 380 | samples.append(stack) 381 | paired.append(0) 382 | already_paired.add(tuple(sorted([first, second]))) 383 | 384 | if y is not None: 385 | first_label.append(y[first]) 386 | second_label.append(y[second]) 387 | else: 388 | first_label.append(-1) 389 | second_label.append(-1) 390 | 391 | samples = torch.stack(samples) 392 | paired = torch.Tensor(np.array(paired)) 393 | first_label = torch.Tensor(np.array(first_label)) 394 | second_label = torch.Tensor(np.array(second_label)) 395 | 396 | dataset = TensorDataset(samples, paired, first_label.long(), second_label.long()) 397 | 398 | return dataset 399 | 400 | def _orthgonality_regularizer(self, x): 401 | diff = 0 402 | for i in range(x.shape[1]): 403 | for j in range(i+1, x.shape[1]): 404 | diff = diff + torch.abs(F.cosine_similarity(x[:, i], x[:, j], dim = 0)) 405 | diff = diff / ((x.shape[1]*(x.shape[1]-1))/2) 406 | 407 | return diff 408 | 409 | def _train_siamese_one_epoch(self, data_loader): 410 | epoch_loss = 0 411 | self.model.train() 412 | for data, target, first_label, second_label in self._progressbar_with_verbosity(data_loader, 1): 413 | self.optimizer.zero_grad() 414 | 415 | data = data.to(self.device) 416 | target = target.to(self.device) 417 | first_label = first_label.to(self.device) 418 | second_label = second_label.to(self.device) 419 | 420 | output_1 = self.model(data[:, 0]) 421 | output_2 = self.model(data[:, 1]) 422 | 423 | loss = f_loss(output_1, output_2, target, ignore = self.ignore, device = self.device, min_p = self.min_p) 424 | 425 | if self.semisupervised: 426 | which = first_label != -1 427 | first_label_pred = self.semisupervised_model.cluster_net(output_1[which]) 428 | loss = loss + F.cross_entropy(first_label_pred, first_label[which])*self.semisupervised_weight 429 | which = second_label != -1 430 | second_label_pred = self.semisupervised_model.cluster_net(output_2[which]) 431 | loss = loss + F.cross_entropy(second_label_pred, second_label[which])*self.semisupervised_weight 432 | 433 | if self.l2_penalty > 0: 434 | loss = loss + self.l2_penalty*torch.mean((torch.norm(output_1, p = 2, dim = 1) + torch.norm(output_2, p = 2, dim = 1))) 435 | 436 | if self.max_correlation is not None: 437 | diff = (self._orthgonality_regularizer(output_1) + self._orthgonality_regularizer(output_2)) / 2 438 | loss = loss + torch.max(torch.Tensor([self.max_correlation]).to(self.device), diff) - self.max_correlation 439 | 440 | loss.backward() 441 | self.optimizer.step() 442 | 443 | epoch_loss += loss.item() 444 | 445 | del output_1 446 | del output_2 447 | 448 | self._print_with_verbosity(f"training loss: {epoch_loss / len(data_loader)}", 1) 449 | 450 | def _train_one_epoch(self, model, data_loader, optimizer, crit): 451 | model.train() 452 | for data, target in data_loader: 453 | optimizer.zero_grad() 454 | 455 | data = data.to(self.device) 456 | target = target.to(self.device) 457 | 458 | pred = model(data) 459 | 460 | loss = crit(pred, target) 461 | 462 | loss.backward() 463 | optimizer.step() 464 | 465 | def transform(self, X, to_numpy = True, batch_size = 4096, model = None): 466 | if self.is_tokens: 467 | X = pad_sequence(X, padding_value = 0, batch_first = True) 468 | 469 | if model is None: 470 | assert self.best_embedding_net is not None, "no embedding model trained yet!" 471 | model = self.best_embedding_net 472 | # embeds the data 473 | dataset = TensorDataset(X) 474 | embed_loader = DataLoader(dataset, shuffle = False, batch_size = batch_size) 475 | 476 | embeddings = [] 477 | model.eval() 478 | with torch.no_grad(): 479 | for data in embed_loader: 480 | data = data[0].to(self.device) 481 | embedding = model(data).cpu() 482 | if to_numpy: 483 | embedding = embedding.numpy() 484 | embeddings.append(embedding) 485 | 486 | if to_numpy: 487 | embeddings = np.concatenate(embeddings) 488 | embeddings = embeddings.reshape(len(X), -1) 489 | else: 490 | embeddings = torch.cat(embeddings) 491 | embeddings = embeddings.view(len(X), -1) 492 | 493 | return embeddings 494 | 495 | def _get_exp_dist(self, data_loader): 496 | # sets self.exp_dist based off mean of means dist between positive pairs 497 | self.model.eval() 498 | 499 | cumulative_dist = 0 500 | with torch.no_grad(): 501 | for data, target, first_label, second_label in data_loader: 502 | data = data.to(self.device) 503 | target = target.to(self.device) 504 | 505 | should_be_close = target == 1 506 | if torch.sum(should_be_close) == 0: 507 | continue 508 | 509 | output_1 = self.model(data[should_be_close, 0]) 510 | output_2 = self.model(data[should_be_close, 1]) 511 | 512 | d = torch.norm(output_1 - output_2, p = 2, dim = 1) 513 | 514 | # get parameters for f distribution. not sure these are right.. 515 | d1 = torch.Tensor([output_1.shape[-1]]).to(self.device) 516 | d2 = torch.Tensor([1]).to(self.device) 517 | 518 | # compute p-value 519 | p = reg_betainc(d1*d/(d1*d+d2), d1/2, d2/2) 520 | # reject null hypothesis 521 | d = d[p < self.ignore] 522 | # do means 523 | cumulative_dist += torch.mean(d).item() 524 | 525 | del output_1 526 | del output_2 527 | del should_be_close 528 | 529 | avg_dist = cumulative_dist / len(data_loader) 530 | 531 | self.exp_dist = avg_dist 532 | 533 | def _cluster(self, X): 534 | # runs spectral clustering based off self.exp_dist as Gaussian kernel bandwidth 535 | # sets self.n_clusters and returns cluster_assignments 536 | X = X.reshape(X.shape[0], -1) 537 | 538 | n = min(X.shape[0], self.cluster_subsample_n) 539 | 540 | inds = np.random.choice(X.shape[0], n, replace = False) 541 | D = pairwise_distances(X[inds], n_jobs = -1, metric = 'euclidean') 542 | 543 | sigma = (self.exp_dist*self.gamma)**2 544 | A = np.exp(-D**2 / sigma) # known bug: sigma should be larger because subsampling 545 | 546 | sums = A.sum(axis = 1) 547 | D = np.diag(sums) 548 | L = D - A 549 | 550 | vals, vecs = eigh(L, turbo = True)#, eigvals = [0, int(X.shape[0]**.5)]) # assuming only sqrt possible clusters 551 | # print(vals) 552 | 553 | n_zeros = np.sum(vals <= self.zero_cutoff) 554 | self._print_with_verbosity(f"found {n_zeros} candidate clusters", 3) 555 | self._print_with_verbosity(f"running k-means..", 3) 556 | k = KMeans(n_zeros, n_init = 100) 557 | init_clusters = k.fit_predict(vecs[:, :n_zeros]) 558 | # print(np.unique(init_clusters)) 559 | 560 | self._print_with_verbosity(f"applying KNN filter..", 3) 561 | n_neighbors = int(2*np.log2(X.shape[0])) 562 | clusters = KNeighborsClassifier(n_neighbors).fit(X[inds], init_clusters).predict(X) 563 | clusters = change_cluster_labels_to_sequential(clusters) 564 | 565 | self.n_clusters = np.unique(clusters).shape[0] 566 | 567 | if self.update_zero_cutoff: 568 | self._update_zero_cutoff(vals) 569 | 570 | return clusters 571 | 572 | def _update_zero_cutoff(self, eign): 573 | # slowly decrease zero cutoff for spectral clustering calculation 574 | # reduces noise in clustering 575 | # using a separate function because might make this more complex in the future 576 | # this cutoff is just the first eigenvalue that DID NOT correspond to a cluster 577 | # new_zero_cutoff = min(self.zero_cutoff, eign[self.n_clusters]) 578 | # new_zero_cutoff = max(new_zero_cutoff, 1e-8) # don't want negative, that's just numerical error 579 | new_zero_cutoff = max(10*eign[self.n_clusters], self.minimum_zero_cutoff) 580 | 581 | if new_zero_cutoff != self.zero_cutoff: 582 | self._print_with_verbosity(f"updating eigenvalue zero cutoff from {self.zero_cutoff} to {new_zero_cutoff}", 3) 583 | self.zero_cutoff = new_zero_cutoff 584 | 585 | def predict(self, X, model = None, return_embedding = False): 586 | if model is None: 587 | assert self.best_full_net is not None, "have not trained a prediction network yet!" 588 | model = self.best_full_net # if not self.final_model_trained else self.final_model 589 | 590 | dataset = TensorDataset(X) 591 | data_loader = DataLoader(dataset, batch_size = 4096, shuffle = False) 592 | preds = [] 593 | embeddings = [] 594 | 595 | model = model.to(self.device) 596 | model.eval() 597 | with torch.no_grad(): 598 | for data in self._progressbar_with_verbosity(data_loader, 3): 599 | data = data[0].to(self.device) 600 | 601 | if return_embedding: 602 | embedding = model.embed_net(data) 603 | _, pred = torch.max(model.cluster_net(embedding), 1) 604 | embeddings.extend(embedding.cpu().numpy()) 605 | else: 606 | _, pred = torch.max(model(data), 1) 607 | 608 | preds.extend(pred.cpu().numpy()) 609 | 610 | embeddings = np.array(embeddings) 611 | preds = np.array(preds) 612 | 613 | if return_embedding: 614 | return preds, embeddings 615 | else: 616 | return preds 617 | 618 | def _build_cluster_subnet(self, X, transformed, clusters): 619 | # creates clustering subnet and updates best model 620 | # sets self.cluster_subnet 621 | 622 | self._print_with_verbosity("training cluster subnet to predict spectral labels", 2) 623 | 624 | cluster_counts = torch.zeros(self.n_clusters).float().to(self.device) 625 | for cluster_assignment in clusters: 626 | cluster_counts[cluster_assignment] += 1 627 | cluster_weights = len(clusters)/self.n_clusters/cluster_counts 628 | 629 | dataset = TensorDataset(torch.Tensor(transformed), torch.Tensor(clusters).long()) 630 | data_loader = DataLoader(dataset, shuffle = True, batch_size = self.batch_size) 631 | 632 | cluster_subnet = ClusterNet(transformed.shape[-1], self.n_clusters, p = self.cluster_subnet_dropout_p) 633 | cluster_subnet_optimizer = optim.Adam(cluster_subnet.parameters()) 634 | cluster_subnet_crit = nn.CrossEntropyLoss(weight = cluster_weights) 635 | cluster_subnet.train() 636 | cluster_subnet = cluster_subnet.to(self.device) 637 | 638 | for i in self._progressbar_with_verbosity(range(self.cluster_subnet_training_epochs), 2): 639 | self._train_one_epoch(cluster_subnet, data_loader, cluster_subnet_optimizer, cluster_subnet_crit) 640 | 641 | # now fine-tune the whole pipeline 642 | dataset = TensorDataset(X, torch.Tensor(clusters).long()) 643 | data_loader = DataLoader(dataset, shuffle = True, batch_size = self.batch_size) 644 | 645 | full_net = FullNet(copy.deepcopy(self.model), cluster_subnet) 646 | full_net_optimizer = optim.Adam(full_net.parameters(), lr = 1e-4) 647 | full_net_crit = cluster_subnet_crit 648 | full_net.train() 649 | full_net = full_net.to(self.device) 650 | 651 | if self.fine_tune_end_to_end: 652 | self._print_with_verbosity("fine-tuning whole end-to-end network", 2) 653 | 654 | for i in self._progressbar_with_verbosity(range(self.fine_tune_epochs), 2): 655 | self._train_one_epoch(full_net, data_loader, full_net_optimizer, full_net_crit) 656 | 657 | preds = self.predict(X, model = full_net, return_embedding = False) 658 | 659 | # new_transformed = self.transform(X, model = full_net.embed_net) 660 | 661 | # delta_mi = silhouette_score(new_transformed, preds) 662 | # delta_mi = adjusted_mutual_info_score(preds, clusters) 663 | # preds_entropy = get_entropy(list(Counter(preds).values())) 664 | # embedding_score = self._test_label_fit(embedding, preds) # maybe this should use transformed instead? 665 | # sample_score = self._test_label_fit(X, preds) # compensate for random 666 | # delta_mi = preds_entropy * embedding_score * sample_score # average ability to pattern-match times information content of labels 667 | # random_labels_score = self._test_label_fit(X, np.random.randint(0, self.n_clusters, X.shape[0])) 668 | # self._print_with_verbosity(f"this delta mi: {delta_mi}, from embedding: {embedding_score}, from original data: {sample_score}, entropy: {preds_entropy}", 1) 669 | # if delta_mi > self.best_delta_mi: 670 | # self._print_with_verbosity(f"found new best delta mi", 1) 671 | 672 | # we're just going to take the most recent epoch 673 | # this is reasonable because of the new changes to the F-distribution loss 674 | # self.best_delta_mi = delta_mi 675 | self.best_full_net = full_net 676 | self.best_n_clusters = self.n_clusters 677 | self.best_embedding_net = copy.deepcopy(self.model) 678 | 679 | return preds 680 | 681 | def _test_label_fit(self, X, y, test_proportion = .2): 682 | # trains a simple classifier to predict the labels from the dataset 683 | # if the labels are a good fit, this score should go up 684 | if np.unique(y).shape[0] == 1: 685 | # single label dataset has accuracy 1 686 | return 1 687 | 688 | X = X.reshape(X.shape[0], -1) 689 | # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = test_proportion) 690 | # c = naive_bayes.ComplementNB() 691 | # X -= X.min() # get rid of negative values for CNB 692 | c = self.simple_classifier 693 | with nostdout(): 694 | # score = cross_val_score(c, X, y, n_jobs = -1).mean() # pretty sure this uses reproducable random state by default 695 | # since only training internally, only test internally for this 696 | score = np.exp(-log_loss(y, c.fit(X, y).predict_proba(X))) 697 | 698 | return score 699 | 700 | def fit(self, X, y = None, y_for_verification = None, plot = False): 701 | # assert not self.semisupervised, "semisupervised not supported yet" 702 | 703 | self.best_delta_mi = -1 704 | self.best_full_net = None 705 | self.best_embedding_net = None 706 | # self.final_model = None 707 | # self.final_model_trained = False 708 | self.best_n_clusters = 1 709 | self.zero_cutoff = self.initial_zero_cutoff 710 | self.exp_dist = 0 711 | 712 | if self.random_seed is not None: 713 | np.random.seed(self.random_seed) 714 | 715 | use_y_to_verify_performance = y_for_verification is not None 716 | self.semisupervised = self.semisupervised and y is not None 717 | 718 | if self.semisupervised and self.semisupervised_weight is None: 719 | self.semisupervised_weight = np.sum(y != -1) / y.shape[0] 720 | 721 | if self.semisupervised: 722 | n_classes = np.unique(y[y != -1]).shape[0] # because of the -1 723 | 724 | if use_y_to_verify_performance: 725 | verify_n_classes = np.unique(y_for_verification).shape[0] 726 | self._print_with_verbosity(f"number of classes in verification set: {verify_n_classes}", 3) 727 | 728 | if self.model == "auto": 729 | self.model = self._select_model(X) 730 | 731 | if self.is_tokens: 732 | X = pad_sequence(X, padding_value = 0, batch_first = True) 733 | 734 | if type(X) is not torch.Tensor: 735 | X = torch.Tensor(X) 736 | 737 | self.device = torch.device("cuda") if (torch.cuda.is_available() and self.use_gpu) else torch.device("cpu") 738 | if self.device.type == "cpu": 739 | self._print_with_verbosity("WARNING: using CPU, may be very slow", 0, strict = True) 740 | 741 | self._print_with_verbosity(f"using torch device {self.device}", 1) 742 | 743 | self._print_with_verbosity("building dataset", 1) 744 | 745 | dataset = self._build_dataset( 746 | X, 747 | y = y if self.semisupervised else None, 748 | ) 749 | 750 | data_loader = DataLoader(dataset, shuffle = True, batch_size = self.batch_size) 751 | 752 | self.model = self.model.to(self.device) 753 | 754 | if self.optimizer_override is None: 755 | self.optimizer = optim.Adam(self.model.parameters(), lr = self.learning_rate) 756 | else: 757 | self.optimizer = self.optimizer_override(self.model.parameters(), lr = self.learning_rate) 758 | 759 | if self.semisupervised: 760 | label_subnet = ClusterNet(self.n_components, n_classes).to(self.device) 761 | self.semisupervised_model = FullNet(self.model, label_subnet).to(self.device) 762 | self.optimizer = optim.Adam(self.semisupervised_model.parameters(), lr = self.learning_rate) 763 | 764 | self._print_with_verbosity("training", 1) 765 | 766 | for i in self._progressbar_with_verbosity(range(self.epochs), 0, strict = True): 767 | self.model.train() 768 | self._print_with_verbosity(f"this is epoch {i}", 1) 769 | self._train_siamese_one_epoch(data_loader) 770 | self.model.eval() 771 | transformed = self.transform(X, model = self.model) 772 | 773 | self._get_exp_dist(data_loader) 774 | self._print_with_verbosity(f"found expected distance between related points as {self.exp_dist}", 3) 775 | cluster_assignments = self._cluster(transformed) 776 | self._print_with_verbosity(f"found {self.n_clusters} clusters", 1) 777 | 778 | preds = self._build_cluster_subnet(X, transformed, cluster_assignments) 779 | 780 | if use_y_to_verify_performance: 781 | nmi_score = normalized_mutual_info_score(cluster_assignments, y_for_verification, 'geometric') 782 | self._print_with_verbosity(f"NMI of cluster labels with y: {nmi_score}", 2) 783 | 784 | nmi_score = normalized_mutual_info_score(preds, y_for_verification, 'geometric') 785 | self._print_with_verbosity(f"NMI of network predictions with y: {nmi_score}", 1) 786 | 787 | if self.n_clusters == verify_n_classes: 788 | acc_score = get_accuracy(cluster_assignments, y_for_verification) 789 | self._print_with_verbosity(f"accuracy of cluster labels: {acc_score}", 2) 790 | 791 | if np.unique(preds).shape[0] == verify_n_classes: 792 | acc_score = get_accuracy(preds, y_for_verification) 793 | self._print_with_verbosity(f"accuracy of network predictions: {acc_score}", 1) 794 | else: 795 | self._print_with_verbosity(f"number of predicted classes did not match number of clusters so not computing accuracy, correct {verify_n_classes} vs {self.n_clusters}", 2) 796 | 797 | if plot: 798 | if self.n_components == 2: 799 | plot_2d(transformed, cluster_assignments, show = False, no_legend = True) 800 | 801 | if use_y_to_verify_performance: 802 | plot_2d(transformed, y_for_verification, show = False, no_legend = True) 803 | 804 | plt.show() 805 | 806 | elif self.n_components == 3: 807 | plot_3d(transformed, cluster_assignments, show = False) 808 | 809 | if use_y_to_verify_performance: 810 | plot_3d(transformed, y_for_verification, show = False) 811 | 812 | plt.show() 813 | 814 | 815 | if __name__ == "__main__": 816 | from torchvision.datasets import MNIST, USPS, FashionMNIST, CIFAR10 817 | from torchtext.datasets import AG_NEWS 818 | 819 | n = None 820 | # semisupervised_proportion = .2 821 | 822 | e = DEN(n_components = 2, internal_dim = 128) 823 | 824 | USPS_data_train = USPS("./", train = True, download = True) 825 | USPS_data_test = USPS("./", train = False, download = True) 826 | USPS_data = ConcatDataset([USPS_data_test, USPS_data_train]) 827 | X, y = zip(*USPS_data) 828 | 829 | y_numpy = np.array(y[:n]) 830 | X_numpy = np.array([np.asarray(X[i]) for i in range(n if n is not None else len(X))]) 831 | X = torch.Tensor(X_numpy).unsqueeze(1) 832 | 833 | # which = np.random.choice(len(y_numpy), int((1-semisupervised_proportion)*len(y_numpy)), replace = False) 834 | # y_for_verification = copy.deepcopy(y_numpy) 835 | # y_numpy[which] = -1 836 | 837 | # news_train, news_test = AG_NEWS('./', ngrams = 1) 838 | # X, y = zip(*([item[1], item[0]] for item in news_test)) 839 | # X = X[:n] 840 | # y = y[:n] 841 | # y_numpy = np.array(y) 842 | # y_for_verification = copy.deepcopy(y_numpy) 843 | 844 | # X_numpy = np.load("shekhar_data_pca_40.npy")[:n] 845 | # y_numpy_strs = np.load("shekhar_labels.npy", allow_pickle = True)[:n] 846 | # str_to_ind = {name:i for i, name in enumerate(np.unique(y_numpy_strs))} 847 | # y_numpy = np.array([str_to_ind[name] for name in y_numpy_strs]) 848 | # X = torch.Tensor(X_numpy) 849 | # which = y_numpy < 16 # to just focus on interesting stuff 850 | # X = X[which] 851 | # y_numpy = y_numpy[which] 852 | y_for_verification = copy.deepcopy(y_numpy) 853 | 854 | e.fit(X, None, y_for_verification = y_for_verification, plot = True) 855 | # e.save("test_thing.pt") 856 | # e.load("test_thing.pt") 857 | # # e.find_differentiating_features(X[0], X) 858 | # e.summerize_differentiating_features(X) 859 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 4 | 5 | # 2F1 approximation based off https://projecteuclid.org/download/pdf_1/euclid.aos/1031689021 6 | 7 | def F21(a, b, c, x): 8 | y = y_hat(a, b, c, x) 9 | r = r21(a, b, c, x, y) 10 | out = (c**(c-.5))*(r**(-.5)) 11 | out *= (y/a)**a 12 | out *= ((1-y)/(c-a))**(c-a) 13 | out *= (1-x*y)**(-b) 14 | return out 15 | 16 | 17 | def r21(a, b, c, x, y): 18 | out = (y**2)/a 19 | out += ((1-y)**2)/(c-a) 20 | out -= ((b*x**2)*(y**2)*((1-y)**2))/(((1-x*y)**2)*a*(c-a)) 21 | return out 22 | 23 | def y_hat(a, b, c, x): 24 | t = tau(a, b, c, x) 25 | return 2*a/(torch.sqrt(t**2 - 4*a*x*(c - b)) - t) 26 | 27 | def tau(a, b, c, x): 28 | return x*(b-a)-c 29 | 30 | def betainc(x, a, b): 31 | f = F21(a, 1-b, a+1, x) 32 | return (x**a)*f/a 33 | 34 | def beta(a, b): 35 | return torch.exp(torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a+b)) 36 | 37 | def reg_betainc(x, a, b): 38 | return betainc(x, a, b)/beta(a, b) 39 | 40 | def tcdf(t, v, device = device): 41 | t = t.to(device) 42 | v = v.to(device) 43 | x = v/(t**2+v) 44 | half = torch.Tensor([1/2]).to(device) 45 | i = reg_betainc(x, v/2, half) 46 | 47 | p = 1 - i/2 48 | 49 | return p 50 | 51 | def fcdf(x, d1, d2, device = device): 52 | d1 = d1.to(device) 53 | d2 = d2.to(device) 54 | x = x.to(device) 55 | 56 | y = d1*x/(d1*x + d2) 57 | p = reg_betainc(y, d1/2, d2/2) 58 | 59 | return p 60 | 61 | def f_loss(x1, x2, y, ignore = .9, epsilon = 1e-4, device = device, min_p = 0): 62 | if type(x1) is tuple: 63 | v = x1[0].shape[-1] 64 | # print(v) 65 | else: 66 | v = x1.shape[-1] 67 | v = torch.Tensor([v]).to(device) 68 | 69 | paired = y == 1 70 | 71 | dist = torch.norm(x1 - x2, p = 2, dim = 1)**2 72 | 73 | d1 = torch.Tensor([1]).to(device) 74 | d2 = v 75 | 76 | # compute p-value 77 | p = reg_betainc(d1*dist/(d1*dist+d2), d1/2, d2/2) 78 | p = p.to(device) 79 | p[~paired] = 1 - p[~paired] 80 | 81 | if min_p > 0: 82 | p[paired] = torch.max(p[paired], torch.Tensor([min_p]).to(device)) 83 | 84 | # reject hypothesis that this model explains the data if p is significant 85 | usage = p < ignore 86 | # or interpret as 'these values are too extreme for this model to meaningfully optimize' 87 | # ie let discovered features determine their locations 88 | # usage = torch.abs(p - .5) > ignore - .5 89 | 90 | paired_loss = torch.sum(p[paired]*usage[paired]) 91 | assert not torch.isnan(paired_loss), str(p[paired]) 92 | not_paired_loss = torch.sum(p[~paired]*usage[~paired]) 93 | assert not torch.isnan(not_paired_loss), str(p[not_paired]) 94 | 95 | total_loss = torch.sum(usage[paired])*not_paired_loss + torch.sum(usage[~paired])*paired_loss 96 | total_loss = total_loss / len(paired) 97 | 98 | total_loss = total_loss / torch.sum(usage) 99 | 100 | assert not torch.isnan(total_loss), str(t) 101 | 102 | total_loss = total_loss.to(device) 103 | 104 | del p 105 | del usage 106 | 107 | return total_loss -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BOWNN(nn.Module): 6 | def __init__(self, n_components, vocab_size, internal_dim = 64): 7 | super().__init__() 8 | 9 | self.e = nn.EmbeddingBag(vocab_size, internal_dim, mode = 'max') 10 | self.out = nn.Linear(internal_dim, n_components, bias = False) 11 | 12 | def forward(self, x): 13 | x = x.squeeze(1).long() 14 | x = self.e(x) 15 | x = self.out(x) 16 | 17 | return x 18 | 19 | class SimpleCNNBlock(nn.Module): 20 | def __init__(self, in_channels, out_channels): 21 | super().__init__() 22 | 23 | # internal_dim = int((in_channels*out_channels)**.5) # geometric mean 24 | internal_dim = 2*max(in_channels, out_channels) 25 | 26 | self.conv1 = nn.Conv2d(in_channels, internal_dim, 3, 1, 1) 27 | self.bn1 = nn.BatchNorm2d(internal_dim) 28 | self.conv2 = nn.Conv2d(internal_dim, out_channels, 3, 1, 1) 29 | 30 | def forward(self, x): 31 | x = self.bn1(torch.relu(self.conv1(x))) 32 | x = self.conv2(x) 33 | 34 | return x 35 | 36 | class SEBlock(nn.Module): 37 | def __init__(self, cnn_block, in_channels, out_channels, ratio = 16): 38 | super().__init__() 39 | 40 | internal_dim = out_channels // ratio 41 | 42 | self.cnn_block = cnn_block 43 | 44 | self.in_pool = nn.AdaptiveAvgPool2d(1) 45 | self.lin1 = nn.Linear(in_channels, internal_dim) 46 | self.out = nn.Linear(internal_dim, out_channels) 47 | 48 | def forward(self, x): 49 | layers = self.cnn_block(x) 50 | se = self.in_pool(x).view(x.shape[0], -1) 51 | se = torch.relu(self.lin1(se)) 52 | se = torch.sigmoid(self.out(se)) 53 | x = layers * se.view(x.shape[0], -1, 1, 1) 54 | 55 | return x 56 | 57 | class CNN(nn.Module): 58 | def __init__(self, n_components, n_layers, internal_dim = 64, n_channels = 1, p = 0): 59 | super().__init__() 60 | 61 | self.n_layers = n_layers 62 | 63 | self.conv_layers = nn.ModuleList() 64 | self.pool_layers = nn.ModuleList() 65 | self.bn_layers = nn.ModuleList() 66 | 67 | for i in range(n_layers): 68 | in_dim = n_channels if i == 0 else internal_dim*2**(i-1) 69 | out_dim = internal_dim*2**i 70 | this_conv = nn.Conv2d(in_dim, out_dim, 3, 1, 1) # simplest version 71 | # this_conv = SEBlock(nn.Conv2d(in_dim, out_dim, 3, 1, 1), in_dim, out_dim) 72 | # this_conv = SEBlock(SimpleCNNBlock(in_dim, out_dim), in_dim, out_dim) 73 | self.conv_layers.append(this_conv) 74 | 75 | if i != n_layers - 1: 76 | this_pool = nn.MaxPool2d(2) 77 | else: 78 | this_pool = nn.AdaptiveAvgPool2d(1) 79 | self.pool_layers.append(this_pool) 80 | 81 | this_bn = nn.BatchNorm2d(out_dim) 82 | self.bn_layers.append(this_bn) 83 | 84 | dense_dim = internal_dim*2**(n_layers - 1) 85 | self.do1 = nn.Dropout(p) 86 | self.lin1 = nn.Linear(dense_dim, dense_dim) 87 | self.do2 = nn.Dropout(p) 88 | self.out = nn.Linear(dense_dim, n_components) 89 | 90 | def forward(self, x): 91 | for i in range(self.n_layers): 92 | x = self.bn_layers[i](torch.relu(self.pool_layers[i](self.conv_layers[i](x)))) 93 | 94 | x = x.view(x.shape[0], -1) 95 | x = self.do1(x) 96 | x = torch.relu(self.lin1(x)) 97 | x = self.do2(x) 98 | x = self.out(x) 99 | 100 | return x 101 | 102 | class FFNN(nn.Module): 103 | def __init__(self, n_components, input_dim, internal_dim, n_hidden_layers = 2): 104 | super().__init__() 105 | 106 | self.internal_dim = internal_dim 107 | 108 | self.lin_in = nn.Linear(input_dim, self.internal_dim) 109 | 110 | self.layers = nn.ModuleList() 111 | for i in range(n_hidden_layers - 1): 112 | self.layers.append(nn.Linear(self.internal_dim, self.internal_dim)) 113 | 114 | self.out = nn.Linear(self.internal_dim, n_components) 115 | 116 | def forward(self, x): 117 | x = x.view(x.shape[0], -1) 118 | x = torch.relu(self.lin_in(x)) 119 | 120 | for layer in self.layers: 121 | x = torch.relu(layer(x)) 122 | 123 | x = self.out(x) 124 | 125 | return x 126 | 127 | class ClusterNet(nn.Module): 128 | def __init__(self, n_components, n_classes, scale = 4, p = .3): 129 | super().__init__() 130 | 131 | self.n_classes = n_classes 132 | internal_dim = n_classes*scale 133 | 134 | self.lin1 = nn.Linear(n_components, internal_dim) 135 | self.do1 = nn.AlphaDropout(p) 136 | self.lin2 = nn.Linear(internal_dim, internal_dim) 137 | self.do2 = nn.AlphaDropout(p) 138 | self.lin3 = nn.Linear(internal_dim, internal_dim) 139 | self.do3 = nn.AlphaDropout(p) 140 | self.out = nn.Linear(internal_dim, n_classes) 141 | 142 | def forward(self, x): 143 | x = x.view(x.shape[0], -1) 144 | 145 | x = self.do1(F.selu(self.lin1(x))) 146 | x = self.do2(F.selu(self.lin2(x))) 147 | x = self.do3(F.selu(self.lin3(x))) 148 | x = self.out(x) 149 | 150 | return x 151 | 152 | class FullNet(nn.Module): 153 | def __init__(self, embed_net, cluster_net): 154 | super().__init__() 155 | 156 | self.embed_net = embed_net 157 | self.cluster_net = cluster_net 158 | 159 | def forward(self, x): 160 | x = self.embed_net(x) 161 | x = self.cluster_net(x) 162 | 163 | return x -------------------------------------------------------------------------------- /plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.offsetbox import OffsetImage, AnnotationBbox 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import numpy as np 5 | from collections import Counter 6 | np.random.seed(37) 7 | 8 | def plot_2d(reduced, y = None, s = 1, alpha = .2, show = True, no_legend = False, y_names = None, plot_device = None): 9 | if plot_device is None: 10 | plot_device = plt 11 | plot_device.figure() 12 | if y is not None: 13 | s = plt.scatter(reduced[:, 0], reduced[:, 1], s = s, alpha = alpha, c = y) 14 | if not no_legend: 15 | if y_names is None: 16 | y_names = list(np.unique(y)) 17 | plot_device.legend(loc = 'upper right', handles = s.legend_elements()[0], labels = y_names) 18 | else: 19 | plot_device.scatter(reduced[:, 0], reduced[:, 1], s = s, alpha = alpha) 20 | 21 | if show: 22 | plt.show() 23 | 24 | def plot_for_compare(embeddings, y = None, s = 1, alpha = .2, show = True): 25 | fig, axs = plt.subplots(1, len(embeddings)) 26 | 27 | print(axs) 28 | print(axs.shape) 29 | 30 | for i, embedding in enumerate(embeddings): 31 | axs[i].scatter(embedding[:, 0], embedding[:, 1], c = y, alpha = alpha, s = s) 32 | 33 | fig.tight_layout() 34 | 35 | if show: 36 | plt.show() 37 | 38 | 39 | def plot_3d(reduced, y = None, s = 1, alpha = .2, show = True): 40 | fig = plt.figure() 41 | if y is not None: 42 | ax = fig.add_subplot(111, projection='3d') 43 | ax.scatter(reduced[:, 0], reduced[:, 1], reduced[:, 2], s = s, alpha = alpha, c = y) 44 | else: 45 | ax = fig.add_subplot(111, projection='3d') 46 | ax.scatter(reduced[:, 0], reduced[:, 1], reduced[:, 2], s = s, alpha = alpha) 47 | 48 | if show: 49 | plt.show() 50 | 51 | def plot_2d_with_images(reduced, images, show = True, n = 200, zoom = .5): 52 | fig, ax = plt.subplots() 53 | inds = np.random.choice(reduced.shape[0], n, replace = False) 54 | reduced = reduced[inds] 55 | images = images[inds] 56 | ax.scatter(reduced[:, 0], reduced[:, 1], s = 0) 57 | 58 | for point, img in zip(reduced, images): 59 | img = OffsetImage(np.array(img), cmap = 'gray_r', zoom = zoom) 60 | ab = AnnotationBbox(img, point, frameon = False) 61 | ax.add_artist(ab) 62 | 63 | if show: 64 | plt.show() 65 | --------------------------------------------------------------------------------