├── LICENSE ├── README.md ├── app.py ├── configs ├── CDR.yaml └── ICDR.yaml ├── data └── H5 Data │ ├── texture.h5 │ ├── usps.h5 │ └── wifi.h5 ├── dataset ├── datasets.py ├── samplers.py ├── transforms.py └── warppers.py ├── experiments ├── icdr_trainer.py └── trainer.py ├── model ├── baseline_encoder.py ├── cdr.py ├── icdr.py ├── nce_loss.py └── nx_cdr.py ├── model_weights └── usps.pth.tar ├── prototype.png ├── requirements.txt ├── results └── CDR │ └── n15 │ └── usps_demo │ ├── config.yaml │ ├── embeddings_1000.npy │ ├── loss_1000.jpg │ ├── usps_vis_1000.jpg │ └── usps_vis_500.jpg ├── static ├── css │ ├── common │ │ ├── bootstrap.min.css │ │ └── eu_index.css │ ├── container.css │ ├── control.css │ ├── lasso.css │ ├── link_view.css │ ├── myCss.css │ └── project_view.css ├── icon │ ├── cannotlink.png │ ├── cannotlink2.png │ └── mustlink.png └── js │ ├── common │ ├── boostrap.min.js │ ├── d3-lasso.min.js │ ├── d3.v5.min.js │ ├── eu_index.js │ ├── jquery.min.js │ └── vue.js │ ├── main.js │ ├── models │ ├── ContourModel.js │ ├── LinkModel.js │ ├── NewScatterModel.js │ ├── ParallelModel.js │ └── StateMachine.js │ └── utils.js ├── teaser.png ├── templates └── index.html ├── train.py ├── utils ├── common_utils.py ├── constant_pool.py ├── link_utils.py ├── logger.py ├── math_utils.py ├── nn_utils.py └── umap_utils.py ├── vis.jpg └── vis.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 DRLib 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # CDR - Interactive Visual Cluster Analysis by Contrastive Dimensionality Reduction 3 | 4 | ![teaser](teaser.png) 5 | 6 | ## Environment setup 7 | 8 | This project was based on `python 3.6 and pytorch 1.7.0`. See `requirements.txt` for all prerequisites, and you can also install them using the following command. 9 | 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Datasets 15 | 16 | | | Size | Dimensionality | Clusters | Type | Link | 17 | | :-----------: | :---: | :------------: | :------: | :-----: | :----------------------------------------------------------: | 18 | | Animals | 10000 | 512 | 10 | image | [Kaggle](https://www.kaggle.com/datasets/alessiocorrado99/animals10) | 19 | | Anuran calls | 7195 | 22 | 8 | tabular | [UCI](https://archive.ics.uci.edu/ml/datasets/Anuran+Calls+%28MFCCs%29) | 20 | | Banknote | 1097 | 4 | 2 | text | [UCI](https://archive.ics.uci.edu/ml/datasets/banknote+authentication) | 21 | | Cifar10 | 10000 | 512 | 10 | image | [Alex Krizhevsky](https://www.cs.toronto.edu/~kriz/cifar.html) | 22 | | Cnae9 | 864 | 856 | 9 | text | [UCI](https://archive.ics.uci.edu/ml/datasets/cnae-9) | 23 | | Cats-vs-Dogs | 10000 | 512 | 2 | image | [Kaggle](https://www.kaggle.com/datasets/shaunthesheep/microsoft-catsvsdogs-dataset) | 24 | | Fish | 9000 | 512 | 9 | image | [Kaggle](https://www.kaggle.com/datasets/crowww/a-large-scale-fish-dataset) | 25 | | Food | 3585 | 512 | 11 | image | [Kaggle](https://www.kaggle.com/datasets/anshulmehtakaggl/themassiveindianfooddataset) | 26 | | Har | 8240 | 561 | 6 | tabular | [UCI](https://archive.ics.uci.edu/ml/datasets/human+activity+recognition+using+smartphones) | 27 | | Isolet | 1920 | 617 | 8 | text | [UCI](https://archive.ics.uci.edu/ml/datasets/isolet) | 28 | | ML binary | 1000 | 10 | 2 | tabular | [Kaggle](https://www.kaggle.com/datasets/rhythmcam/ml-binary-classification-study-data) | 29 | | MNIST | 10000 | 784 | 10 | image | [Yann LeCun](http://yann.lecun.com/exdb/mnist/) | 30 | | Pendigits | 8794 | 16 | 10 | tabular | [UCI](https://archive.ics.uci.edu/ml/datasets/pen-based+recognition+of+handwritten+digits) | 31 | | Retina | 10000 | 50 | 12 | tabular | [Paper](https://www.cell.com/fulltext/S0092-8674(15)00549-8) | 32 | | Satimage | 5148 | 36 | 6 | image | [UCI](https://archive.ics.uci.edu/ml/datasets/Statlog+(Landsat+Satellite)) | 33 | | Stanford Dogs | 1384 | 512 | 7 | image | [Stanford University](http://vision.stanford.edu/aditya86/ImageNetDogs/) | 34 | | Texture | 4400 | 40 | 11 | text | [KEEL](https://sci2s.ugr.es/keel/dataset.php?cod=72) | 35 | | USPS | 7440 | 256 | 10 | image | [Kaggle](https://www.kaggle.com/bistaumanga/usps-dataset) | 36 | | Weathers | 900 | 512 | 4 | image | [Kaggle](https://www.kaggle.com/datasets/vijaygiitk/multiclass-weather-dataset) | 37 | | WiFi | 1600 | 7 | 4 | tabular | [UCI](https://archive.ics.uci.edu/ml/datasets/Wireless+Indoor+Localization) | 38 | 39 | For image dataset such as Animals, Cifar10, Cats-vs-Dogs, Fish, Food, Stanford Dogs and Weathers, we use [SimCLR](https://github.com/sthalles/SimCLR) to get their 512 dimensional representations. 40 | 41 | All the datasets are supported with **H5 format** (e.g. usps.h5), and we need all the dataset to be stored at **`data/H5 Data`.** For image data sets, place all images as `0.jpg,1.jpg,...,n-1.jpg` format and put it in the `static/images/(dataset name)`(e.g. static/images/usps) directory. 42 | 43 | ## Pre-trained model weights 44 | 45 | The pre-training model weights on all the above data sets can be found in [Google Drive](https://drive.google.com/drive/folders/19WYgUcOI6cOYSUPK_w1eICSr0ceRK9Zb?usp=sharing). 46 | 47 | ## Training 48 | 49 | To train the model on USPS with a single GPU, check the configuration file `configs/CDR.yaml`, and try the following command: 50 | 51 | ```bash 52 | python train.py --configs configs/CDR.yaml 53 | ``` 54 | 55 | ## Config File 56 | 57 | The configuration files can be found under the folder `./configs`, and we provide two config files with the format `.yaml`. We give the guidance of several key parameters in this paper below. 58 | 59 | - **n_neighbors(K):** It determines **the granularity of the local structure** to be maintained in low-dimensional space. A too small value will cause one cluster in the high-dimensional space be projected into two low-dimensional clusters, while too large value will aggravate the problem of clustering overlap. The default setting is **K = 15**. 60 | - **batch_size(B):** It determines the number of negative samples. A larger value is better, but it also depends on the data size. We recommend to use **`B = n/10`**, where `n` is the number of instances. 61 | - **temperature(t):** It determines the ability of the model upon neighborhood preservation. The smaller the value is, the more strict the model is to maintain the neighborhood, but it also keeps more error neighbors. The default setting is **t = 0.15**. 62 | - **separate_upper(μ):** It determines the intensity of cluster separation. The larger the value is, the higher the cluster separation degree is. The default setting is **μ = 0.11**. 63 | 64 | ## Load pre-trained model for visualization 65 | 66 | To use our pre-trained model, try the following command: 67 | 68 | ```bash 69 | # python vis.py --configs 'configuration file path' --ckpt 'model weights path' 70 | 71 | # Example on USPS dataset 72 | python vis.py --configs configs/CDR.yaml --ckpt_path model_weights/usps.pth.tar 73 | ``` 74 | 75 | ## Prototype interface 76 | 77 | Using our prototype interface for interactive visual clustering analysis, try the following command. 78 | 79 | ```bash 80 | python app.py --config configs/ICDR.yaml 81 | ``` 82 | 83 | After that, the prototype interface can be found in [http://127.0.0.1:5000](http://127.0.0.1:5000) . 84 | 85 | 86 | 87 | ![frontend_07](prototype.png) 88 | 89 | ## Cite 90 | ```bash 91 | @article{xia2022interactive, 92 | title={Interactive visual cluster analysis by contrastive dimensionality reduction}, 93 | author={Xia, Jiazhi and Huang, Linquan and Lin, Weixing and Zhao, Xin and Wu, Jing and Chen, Yang and Zhao, Ying and Chen, Wei}, 94 | journal={IEEE Transactions on Visualization and Computer Graphics}, 95 | volume={29}, 96 | number={1}, 97 | pages={734--744}, 98 | year={2022}, 99 | publisher={IEEE} 100 | } 101 | ``` 102 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import argparse 4 | import os 5 | from datetime import timedelta 6 | 7 | import h5py 8 | from flask import Flask, render_template, request 9 | from experiments.icdr_trainer import ICLPTrainer 10 | from model.cdr import CDRModel 11 | from model.icdr import ICDRModel 12 | from utils.constant_pool import * 13 | from utils.common_utils import get_principle_components, get_config 14 | from utils.link_utils import LinkInfo 15 | import numpy as np 16 | 17 | 18 | app = Flask(__name__) 19 | experimenter: ICLPTrainer 20 | app.config['SEND_FILE_MAX_AGE_DEFAULT'] = timedelta(seconds=1) 21 | 22 | 23 | def wrap_results(embeddings, principle_comps=None, attr_names=None): 24 | ret_dict = {} 25 | ret_dict["embeddings"] = embeddings.tolist() 26 | ret_dict["label"] = experimenter.get_label() 27 | if principle_comps is not None: 28 | ret_dict["low_data"] = principle_comps.tolist() 29 | ret_dict["attrs"] = attr_names 30 | return ret_dict 31 | 32 | 33 | def build_link_info(embeddings, min_dist): 34 | links = request.form.get("links") 35 | link_spreads = request.form.get("link_spreads") 36 | finetune_epochs = request.form.get("finetune_epochs", type=int) 37 | 38 | links = np.array(eval(links)) 39 | print(links) 40 | link_spreads = np.array(eval(link_spreads)) 41 | 42 | if links.shape[0] == 0: 43 | experimenter.link_info = None 44 | return experimenter.link_info 45 | 46 | if experimenter.link_info is None: 47 | experimenter.link_info = LinkInfo(links, link_spreads, finetune_epochs, embeddings, min_dist) 48 | else: 49 | experimenter.link_info.process_cur_links(links, link_spreads, embeddings) 50 | 51 | return experimenter.link_info 52 | 53 | 54 | def update_config(): 55 | global configs 56 | ds_name = request.form.get("dataset", type=str) 57 | configs.exp_params.dataset = ds_name 58 | configs.exp_params.n_neighbors = request.form.get("n_neighbors", type=int) 59 | configs.training_params.epoch_nums = request.form.get("epoch_nums", type=int) 60 | configs.exp_params.input_dims = request.form.get("input_dims", type=int) 61 | configs.exp_params.split_upper = request.form.get("split_upper", type=float) 62 | configs.exp_params.batch_size = int(request.form.get("n_samples", type=int) / 10) 63 | 64 | 65 | def load_experiment(cfg): 66 | method_name = CDR_METHOD if cfg.exp_params.gradient_redefine else NX_CDR_METHOD 67 | result_save_dir = ConfigInfo.RESULT_SAVE_DIR.format(method_name, cfg.exp_params.n_neighbors) 68 | # 创建CLP模型 69 | clr_model = ICDRModel(cfg, device=device) 70 | global experimenter 71 | experimenter = ICLPTrainer(clr_model, cfg.exp_params.dataset, cfg, result_save_dir, None, device=device) 72 | 73 | 74 | @app.route("/") 75 | def index(): 76 | return render_template("index.html") 77 | 78 | 79 | @app.route("/load_dataset_list") 80 | def load_dataset_list(): 81 | data = [] 82 | for item in ConfigInfo.AVAILABLE_DATASETS: 83 | data_obj = {} 84 | for i, k in enumerate(ConfigInfo.DATASETS_META): 85 | data_obj[k] = item[i] 86 | data.append(data_obj) 87 | 88 | return {"data": data} 89 | 90 | 91 | @app.route("/train_for_vis", methods=["POST"]) 92 | def train_for_vis(): 93 | update_config() 94 | load_experiment(configs) 95 | 96 | embeddings = experimenter.train_for_visualize() 97 | principle_comps, attr_names = get_principle_components(experimenter.dataset.data, attr_names=None) 98 | ret_dict = wrap_results(embeddings, principle_comps, attr_names) 99 | return ret_dict 100 | 101 | 102 | @app.route("/constraint_resume", methods=["POST"]) 103 | def constraint_resume(): 104 | update_config() 105 | link_info = build_link_info(experimenter.pre_embeddings, experimenter.configs.exp_params.min_dist) 106 | ft_epoch = request.form.get("finetune_epochs", type=int) 107 | 108 | ml_strength = request.form.get("ml_strength", type=float) 109 | cl_strength = request.form.get("cl_strength", type=float) 110 | experimenter.update_link_stat(link_info, is_finetune=True, finetune_epoch=ft_epoch) 111 | 112 | if link_info is not None: 113 | experimenter.model.link_stat_update(ft_epoch, experimenter.steady_epoch, ml_strength, cl_strength) 114 | 115 | embeddings = experimenter.resume_train(ft_epoch) 116 | return wrap_results(embeddings) 117 | 118 | 119 | def parse_args(): 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("--configs", type=str, default="configs/ICDR.yaml", help="configuration file path") 122 | parser.add_argument("--device", type=str, default="cpu") 123 | return parser.parse_args() 124 | 125 | 126 | def load_available_data(): 127 | for item in os.listdir(ConfigInfo.DATASET_CACHE_DIR): 128 | ds = item.split(".")[0] 129 | n_samples, dims = np.array(h5py.File(os.path.join(ConfigInfo.DATASET_CACHE_DIR, item), "r")['x']).shape 130 | ds_type = "image" if os.path.exists(os.path.join(ConfigInfo.IMAGE_DIR, ds)) else "tabular" 131 | ConfigInfo.AVAILABLE_DATASETS.append([ds, n_samples, dims, ds_type]) 132 | 133 | 134 | if __name__ == '__main__': 135 | app.jinja_env.variable_start_string = '[[' 136 | app.jinja_env.variable_end_string = ']]' 137 | 138 | args = parse_args() 139 | device = args.device 140 | config_path = args.configs 141 | configs = get_config() 142 | configs.merge_from_file(config_path) 143 | load_available_data() 144 | load_experiment(configs) 145 | app.run(debug=False) 146 | -------------------------------------------------------------------------------- /configs/CDR.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | dataset: "usps" 3 | input_dims: 256 # (28, 28, 1) 4 | LR: 0.001 5 | batch_size: 512 6 | n_neighbors: 15 7 | optimizer: "adam" # adam or sgd 8 | scheduler: "multi_step" # cosine or multi_step or on_plateau 9 | temperature: 0.15 10 | gradient_redefine: True 11 | separate_upper: 0.1 12 | separation_begin_ratio: 0.25 13 | steady_begin_ratio: 0.875 14 | 15 | training_params: 16 | epoch_nums: 1000 17 | epoch_print_inter_ratio: 0.1 18 | val_inter_ratio: 1 19 | ckp_inter_ratio: 1 -------------------------------------------------------------------------------- /configs/ICDR.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | dataset: "wifi" 3 | input_dims: 7 # (28, 28, 1) 4 | LR: 0.001 5 | batch_size: 128 6 | n_neighbors: 15 7 | optimizer: "adam" # adam or sgd 8 | scheduler: "multi_step" # cosine or multi_step or on_plateau 9 | temperature: 0.15 10 | min_dist: 0.1 11 | separate_upper: 0.11 12 | gradient_redefine: True 13 | separation_begin_ratio: 0.25 14 | steady_begin_ratio: 0.875 15 | 16 | training_params: 17 | epoch_nums: 1000 18 | epoch_print_inter_ratio: 0.1 19 | val_inter_ratio: 0.5 20 | ckp_inter_ratio: 1 -------------------------------------------------------------------------------- /data/H5 Data/texture.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/data/H5 Data/texture.h5 -------------------------------------------------------------------------------- /data/H5 Data/usps.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/data/H5 Data/usps.h5 -------------------------------------------------------------------------------- /data/H5 Data/wifi.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/data/H5 Data/wifi.h5 -------------------------------------------------------------------------------- /dataset/datasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import os 4 | import random 5 | 6 | import h5py 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from scipy.sparse import csr_matrix 11 | from sklearn.decomposition import PCA 12 | from torch.utils.data import Dataset 13 | from torchvision import transforms as transforms 14 | 15 | from utils.constant_pool import ConfigInfo 16 | from utils.logger import InfoLogger 17 | from utils.nn_utils import compute_knn_graph, compute_accurate_knn, cal_snn_similarity 18 | from utils.umap_utils import fuzzy_simplicial_set, construct_edge_dataset, compute_local_membership 19 | 20 | MACHINE_EPSILON = np.finfo(np.double).eps 21 | 22 | 23 | class MyTextDataset(Dataset): 24 | def __init__(self, dataset_name, root_dir): 25 | self.dataset_name = dataset_name 26 | self.root_dir = root_dir 27 | self.data_file_path = os.path.join(root_dir, dataset_name + ".h5") 28 | self.data = None 29 | self.target = None 30 | self.data_num = 0 31 | self.min_neighbor_num = 0 32 | self.symmetry_knn_indices = None 33 | self.symmetry_knn_weights = None 34 | self.symmetry_knn_dists = None 35 | self.transform = None 36 | self.__load_data() 37 | 38 | def __len__(self): 39 | return self.data.shape[0] 40 | 41 | def __load_data(self): 42 | if not self._check_exists(): 43 | raise RuntimeError('Dataset not found.' + 44 | ' You can use download=True to download it') 45 | 46 | train_data, train_labels = \ 47 | load_local_h5_by_path(self.data_file_path, ['x', 'y']) 48 | self.data = train_data 49 | self.targets = train_labels 50 | self.data_num = self.data.shape[0] 51 | 52 | def __getitem__(self, index): 53 | text, target = self.data[index], int(self.targets[index]) 54 | text = torch.tensor(text, dtype=torch.float) 55 | 56 | return text, target 57 | 58 | def _check_exists(self): 59 | return os.path.exists(self.data_file_path) 60 | 61 | def update_transform(self, new_transform): 62 | self.transform = new_transform 63 | 64 | def get_data(self, index): 65 | res = self.data[index] 66 | return torch.tensor(res, dtype=torch.float) 67 | 68 | def get_label(self, index): 69 | return int(self.targets[index]) 70 | 71 | def get_dims(self): 72 | return int(self.data.shape[1]) 73 | 74 | def get_all_data(self, data_num=-1): 75 | if data_num == -1: 76 | return self.data 77 | else: 78 | return self.data[torch.randperm(self.data_num)[:data_num], :] 79 | 80 | def get_data_shape(self): 81 | return self.data[0].shape 82 | 83 | 84 | class MyImageDataset(MyTextDataset): 85 | def __init__(self, dataset_name, root_dir, transform=None): 86 | MyTextDataset.__init__(self, dataset_name, root_dir) 87 | self.transform = transform 88 | 89 | def __getitem__(self, index): 90 | img, target = self.data[index], int(self.targets[index]) 91 | 92 | img = np.squeeze(img) 93 | mode = 'RGB' if len(img.shape) == 3 else 'L' 94 | if mode == 'RGB': 95 | img = Image.fromarray(img, mode=mode) 96 | if self.transform is not None: 97 | img = self.transform(img) 98 | 99 | return img, target 100 | 101 | def get_data(self, index): 102 | res = self.data[index] 103 | res = res.astype(np.uint8) 104 | return res 105 | 106 | def get_all_data(self, data_num=-1): 107 | if data_num == -1: 108 | return np.transpose(self.data, (0, 3, 1, 2)) 109 | else: 110 | return np.transpose(self.data[torch.randperm(self.data_num)[:data_num], :, :, :], (0, 3, 1, 2)) 111 | 112 | 113 | class UMAPTextDataset(MyTextDataset): 114 | def __init__(self, dataset_name, root_dir, repeat=1): 115 | MyTextDataset.__init__(self, dataset_name, root_dir) 116 | self.repeat = repeat 117 | self.edge_data = None 118 | self.edge_num = None 119 | self.edge_weight = None 120 | self.raw_knn_weights = None 121 | 122 | def build_fuzzy_simplicial_set(self, knn_cache_path, pairwise_cache_path, n_neighbors): 123 | 124 | knn_indices, knn_distances = compute_knn_graph(self.data, knn_cache_path, n_neighbors, pairwise_cache_path) 125 | umap_graph, sigmas, rhos, self.raw_knn_weights = fuzzy_simplicial_set( 126 | X=self.data, 127 | n_neighbors=n_neighbors, 128 | knn_indices=knn_indices, 129 | knn_dists=knn_distances) 130 | return umap_graph, sigmas, rhos 131 | 132 | def umap_process(self, knn_cache_path, pairwise_cache_path, n_neighbors, embedding_epoch): 133 | umap_graph, sigmas, rhos = self.build_fuzzy_simplicial_set(knn_cache_path, pairwise_cache_path, n_neighbors) 134 | self.edge_data, self.edge_num, self.edge_weight = construct_edge_dataset( 135 | self.data, umap_graph, embedding_epoch) 136 | 137 | return self.edge_data, self.edge_num 138 | 139 | def __getitem__(self, index): 140 | to_data, from_data = self.edge_data[0][index], self.edge_data[1][index] 141 | return torch.tensor(to_data, dtype=torch.float), torch.tensor(from_data, dtype=torch.float) 142 | 143 | def __len__(self): 144 | return self.edge_num 145 | 146 | 147 | class UMAPImageDataset(MyImageDataset, UMAPTextDataset): 148 | def __init__(self, dataset_name, root_dir, transform=None, repeat=1): 149 | MyImageDataset.__init__(self, dataset_name, root_dir, transform) 150 | UMAPTextDataset.__init__(self, dataset_name, root_dir, repeat) 151 | self.transform = transform 152 | 153 | def __getitem__(self, index): 154 | to_data, from_data = self.edge_data[0][index], self.edge_data[1][index] 155 | if self.transform is not None: 156 | to_data = self.transform(to_data) 157 | from_data = self.transform(from_data) 158 | 159 | return to_data, from_data 160 | 161 | 162 | class CDRTextDataset(MyTextDataset): 163 | def __init__(self, dataset_name, root_dir): 164 | MyTextDataset.__init__(self, dataset_name, root_dir) 165 | 166 | def __getitem__(self, index): 167 | text, target = self.data[index], int(self.targets[index]) 168 | x, x_sim, idx, sim_idx = self.transform(text, index) 169 | if not isinstance(x, torch.Tensor): 170 | x = torch.tensor(x, dtype=torch.float) 171 | x_sim = torch.tensor(x_sim, dtype=torch.float) 172 | return [x, x_sim, idx, sim_idx], target 173 | 174 | def sample_data(self, indices): 175 | x = self.data[indices] 176 | if not isinstance(x, torch.Tensor): 177 | x = torch.tensor(x, dtype=torch.float) 178 | return x 179 | 180 | 181 | class CDRImageDataset(MyImageDataset): 182 | def __init__(self, dataset_name, root_dir, transform=None): 183 | MyImageDataset.__init__(self, dataset_name, root_dir, transform) 184 | self.transform = transform 185 | 186 | def __getitem__(self, index): 187 | img, target = self.data[index], int(self.targets[index]) 188 | 189 | img = np.squeeze(img) 190 | mode = 'RGB' if len(img.shape) == 3 else 'L' 191 | img = Image.fromarray(img, mode=mode) 192 | if self.transform is not None: 193 | img = self.transform(img, index) 194 | 195 | return img, target 196 | 197 | def sample_data(self, indices): 198 | num = len(indices) 199 | first_data = self.data[indices[0]] 200 | ret_data = torch.empty((num, first_data.shape[2], first_data.shape[0], first_data.shape[1])) 201 | count = 0 202 | transform = transforms.ToTensor() 203 | for index in indices: 204 | img = np.squeeze(self.data[index]) 205 | mode = 'RGB' if len(img.shape) == 3 else 'L' 206 | img = Image.fromarray(img, mode=mode) 207 | img = transform(img) 208 | ret_data[count, :, :, :] = img.unsqueeze(0) 209 | count += 1 210 | return ret_data 211 | 212 | 213 | class UMAPCDRTextDataset(CDRTextDataset): 214 | def __init__(self, dataset_name, root_dir): 215 | CDRTextDataset.__init__(self, dataset_name, root_dir) 216 | self.umap_graph = None 217 | self.raw_knn_weights = None 218 | self.sym_no_norm_weights = None 219 | self.min_neighbor_num = None 220 | self.knn_dist = None 221 | self.knn_indices = None 222 | 223 | def build_fuzzy_simplicial_set(self, knn_indices, knn_distances, n_neighbors, symmetric): 224 | self.umap_graph, sigmas, rhos, self.raw_knn_weights, knn_dist = fuzzy_simplicial_set( 225 | X=self.data, 226 | n_neighbors=n_neighbors, knn_indices=knn_indices, 227 | knn_dists=knn_distances, return_dists=True, symmetric=symmetric) 228 | self.symmetry_knn_dists = knn_dist.tocoo() 229 | 230 | def umap_process(self, knn_indices, knn_distances, n_neighbors, symmetric): 231 | self.build_fuzzy_simplicial_set(knn_indices, knn_distances, n_neighbors, symmetric) 232 | 233 | self.data_num = knn_indices.shape[0] 234 | n_samples = self.data_num 235 | 236 | nn_indices, nn_weights, self.min_neighbor_num, raw_weights, nn_dists \ 237 | = get_kw_from_coo(self.umap_graph, n_neighbors, n_samples, self.symmetry_knn_dists) 238 | 239 | self.symmetry_knn_indices = np.array(nn_indices, dtype=object) 240 | self.symmetry_knn_weights = np.array(nn_weights, dtype=object) 241 | self.symmetry_knn_dists = np.array(nn_dists, dtype=object) 242 | self.sym_no_norm_weights = np.array(raw_weights, dtype=object) 243 | 244 | 245 | class UMAPCDRImageDataset(CDRImageDataset, UMAPCDRTextDataset): 246 | def __init__(self, dataset_name, root_dir, transform=None): 247 | CDRImageDataset.__init__(self, dataset_name, root_dir, transform) 248 | UMAPCDRTextDataset.__init__(self, dataset_name, root_dir) 249 | 250 | 251 | def get_kw_from_coo(csr_graph, n_neighbors, n_samples, dist_csr=None): 252 | nn_indices = [] 253 | nn_weights = [] 254 | raw_weights = [] 255 | nn_dists = [] 256 | 257 | tmp_min_neighbor_num = n_neighbors 258 | for i in range(1, n_samples + 1): 259 | pre = csr_graph.indptr[i-1] 260 | idx = csr_graph.indptr[i] 261 | cur_indices = csr_graph.indices[pre:idx] 262 | if dist_csr is not None: 263 | nn_dists.append(dist_csr.data[pre:idx]) 264 | tmp_min_neighbor_num = min(tmp_min_neighbor_num, idx - pre) 265 | cur_weights = csr_graph.data[pre:idx] 266 | 267 | nn_indices.append(cur_indices) 268 | cur_sum = np.sum(cur_weights) 269 | nn_weights.append(cur_weights / cur_sum) 270 | raw_weights.append(cur_weights) 271 | return nn_indices, nn_weights, tmp_min_neighbor_num, raw_weights, nn_dists 272 | 273 | 274 | def load_local_h5_by_path(dataset_path, keys): 275 | f = h5py.File(dataset_path, "r") 276 | res = [] 277 | for key in keys: 278 | res.append(f[key][:]) 279 | f.close() 280 | return res 281 | -------------------------------------------------------------------------------- /dataset/samplers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import random 4 | import time 5 | from typing import Iterator 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Sampler 10 | from torch.utils.data.sampler import T_co 11 | 12 | 13 | class CustomSampler(Sampler): 14 | def __init__(self, train_indices): 15 | Sampler.__init__(self, None) 16 | self.indices = train_indices 17 | self.random = True 18 | 19 | def update_indices(self, new_indices, is_random): 20 | self.indices = new_indices 21 | self.random = is_random 22 | 23 | def __iter__(self): 24 | if self.random: 25 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 26 | else: 27 | return (self.indices[i] for i in range(len(self.indices))) 28 | 29 | def __len__(self): 30 | return len(self.indices) 31 | 32 | -------------------------------------------------------------------------------- /dataset/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import math 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | from torchvision import transforms as transforms 9 | 10 | 11 | def assign_weighted_neighbor_samples(symmetry_knn_indices, symmetry_knn_weights, n_neighbors, epoch_num): 12 | 13 | n_samples = symmetry_knn_indices.shape[0] 14 | repo_size = epoch_num 15 | sample_num_per_neighbor = symmetry_knn_weights * repo_size 16 | neighbor_sample_repo = np.empty((n_samples, repo_size), dtype=np.int) 17 | 18 | for i in range(n_samples): 19 | sample_num_per_neighbor[i] = np.ceil(sample_num_per_neighbor[i]).astype(np.int) 20 | if np.sum(sample_num_per_neighbor[i]) < epoch_num: 21 | sample_num_per_neighbor[i] = \ 22 | np.ones_like(sample_num_per_neighbor[i]) * np.ceil(epoch_num / n_neighbors) 23 | 24 | tmp_num = len(symmetry_knn_indices[i]) 25 | tmp_repo = np.repeat(symmetry_knn_indices[i].astype(np.int), sample_num_per_neighbor[i][:tmp_num].astype(np.int).squeeze()) 26 | num = min(repo_size, len(tmp_repo)) 27 | np.random.shuffle(tmp_repo) 28 | neighbor_sample_repo[i, :num] = tmp_repo[:num].astype(np.int) 29 | 30 | return neighbor_sample_repo 31 | 32 | 33 | class SimCLRDataTransform(object): 34 | def __init__(self, epoch_num, train_dataset, is_image, transform, n_neighbors, norm_nn_indices, 35 | norm_nn_weights): 36 | 37 | self.epoch_num = epoch_num 38 | self.transform = transform 39 | self.train_dataset = train_dataset 40 | self.n_samples = norm_nn_indices.shape[0] 41 | 42 | self.neighbor_sample_repo = None 43 | self.neighbor_sample_index = None 44 | self.init_norm_nn_indices = norm_nn_indices 45 | self.init_norm_nn_weights = norm_nn_weights 46 | 47 | self.build_neighbor_repo(epoch_num, n_neighbors, norm_nn_indices, norm_nn_weights) 48 | 49 | self.is_image = is_image 50 | if self.is_image: 51 | self.transform = transforms.ToTensor() 52 | 53 | def build_neighbor_repo(self, epoch_num, n_neighbors, norm_nn_indices=None, norm_nn_weights=None): 54 | if norm_nn_indices is None: 55 | norm_nn_indices = self.init_norm_nn_indices 56 | if norm_nn_weights is None: 57 | norm_nn_weights = self.init_norm_nn_weights 58 | self.neighbor_sample_repo = assign_weighted_neighbor_samples(norm_nn_indices, 59 | norm_nn_weights, n_neighbors, 60 | epoch_num) 61 | self.neighbor_sample_index = np.zeros(self.n_samples, dtype=np.int) 62 | 63 | def _neighbor_index_fixed(self, index): 64 | sim_index = self.neighbor_sample_repo[index][self.neighbor_sample_index[index]] 65 | self.neighbor_sample_index[index] += 1 66 | return sim_index 67 | 68 | def __call__(self, sample, index): 69 | x = sample 70 | if self.transform is not None: 71 | x = self.transform(sample) 72 | else: 73 | x = torch.tensor(x, dtype=torch.float) 74 | sim_index = self._neighbor_index_fixed(index) 75 | x_sim = self.train_dataset.get_data(sim_index) 76 | if self.transform is not None: 77 | x_sim = self.transform(x_sim) 78 | return x.float(), x_sim.float(), index, sim_index 79 | 80 | -------------------------------------------------------------------------------- /dataset/warppers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import math 4 | 5 | from torch.utils.data import DataLoader 6 | 7 | from dataset.samplers import CustomSampler 8 | from dataset.transforms import SimCLRDataTransform 9 | from dataset.datasets import * 10 | from utils.nn_utils import compute_knn_graph 11 | 12 | 13 | def build_dataset(dataset_name, is_image, root_dir): 14 | data_augment = transforms.Compose([ 15 | transforms.ToTensor() 16 | ]) 17 | if is_image: 18 | train_dataset = UMAPCDRImageDataset(dataset_name, root_dir, data_augment) 19 | else: 20 | train_dataset = UMAPCDRTextDataset(dataset_name, root_dir) 21 | data_augment = None 22 | return data_augment, train_dataset 23 | 24 | 25 | class DataSetWrapper(object): 26 | def __init__(self, batch_size): 27 | self.batch_size = batch_size 28 | self.batch_num = 0 29 | self.test_batch_num = 0 30 | self.knn_indices = None 31 | self.knn_distances = None 32 | self.symmetric_nn_indices = None 33 | self.symmetric_nn_weights = None 34 | self.symmetric_nn_dists = None 35 | self.sym_no_norm_weights = None 36 | self.n_neighbor = 0 37 | self.shifted_data = None 38 | 39 | def get_data_loaders(self, epoch_num, dataset_name, root_dir, n_neighbors, knn_cache_path, pairwise_cache_path, 40 | is_image=True, symmetric="UMAP"): 41 | self.n_neighbor = n_neighbors 42 | data_augment, train_dataset = build_dataset(dataset_name, is_image, root_dir) 43 | 44 | self.knn_indices, self.knn_distances = compute_knn_graph(train_dataset.data, knn_cache_path, n_neighbors, 45 | pairwise_cache_path, accelerate=True) 46 | 47 | self.distance2prob(train_dataset, symmetric) 48 | 49 | train_indices, train_num = self.update_transform(data_augment, epoch_num, is_image, train_dataset) 50 | 51 | train_loader = self._get_train_validation_data_loaders(train_dataset, train_indices) 52 | 53 | return train_loader, train_num 54 | 55 | def update_transform(self, data_augment, epoch_num, is_image, train_dataset): 56 | 57 | train_dataset.update_transform(SimCLRDataTransform(epoch_num, train_dataset, is_image, data_augment, 58 | self.n_neighbor, self.symmetric_nn_indices, 59 | self.symmetric_nn_weights)) 60 | train_num = train_dataset.data_num 61 | 62 | train_indices = list(range(train_num)) 63 | self.batch_num = math.floor(train_num / self.batch_size) 64 | 65 | return train_indices, train_num 66 | 67 | def distance2prob(self, train_dataset, symmetric): 68 | 69 | train_dataset.umap_process(self.knn_indices, self.knn_distances, self.n_neighbor, symmetric) 70 | self.symmetric_nn_indices = train_dataset.symmetry_knn_indices 71 | self.symmetric_nn_weights = train_dataset.symmetry_knn_weights 72 | self.symmetric_nn_dists = train_dataset.symmetry_knn_dists 73 | self.sym_no_norm_weights = train_dataset.sym_no_norm_weights 74 | 75 | def _get_train_validation_data_loaders(self, train_dataset, train_indices): 76 | np.random.shuffle(train_indices) 77 | train_sampler = CustomSampler(train_indices) 78 | train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler, 79 | drop_last=True, shuffle=False) 80 | 81 | return train_loader 82 | -------------------------------------------------------------------------------- /experiments/icdr_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import random 4 | import torch 5 | from experiments.trainer import CDRTrainer 6 | from utils.constant_pool import * 7 | from utils.link_utils import CANNOT_LINK, MUST_LINK, LinkInfo, UN_SPREAD 8 | 9 | 10 | class ICLPTrainer(CDRTrainer): 11 | def __init__(self, clr_model, dataset_name, configs, result_save_dir, config_path, device='cuda', 12 | log_path="log.txt"): 13 | CDRTrainer.__init__(self, clr_model, dataset_name, configs, result_save_dir, config_path, device, log_path) 14 | self.link_info = None 15 | self.dataset = None 16 | self.finetune_epoch = 0 17 | self.finetune_data_ratio = 0.5 18 | self.minimum_finetune_data_num = 400 19 | self.finetune_data_num = 0 20 | self.steady_epoch = 2 21 | self.is_finetune = False 22 | self.has_link = False 23 | 24 | def get_label(self): 25 | return self.train_loader.dataset.targets.tolist() 26 | 27 | def update_link_stat(self, link_info, is_finetune, finetune_epoch): 28 | self.message_queue = Queue() 29 | self.link_info = link_info 30 | self.has_link = link_info is not None 31 | self.finetune_epoch = finetune_epoch 32 | self.is_finetune = is_finetune 33 | 34 | self._update_finetune_dataset() 35 | 36 | self._update_symm_knn_by_cl() 37 | 38 | def _step_prepare(self, *args): 39 | if self.dataset is None: 40 | self.dataset = self.train_loader.dataset 41 | return super()._step_prepare(*args) 42 | 43 | def _train_step(self, *args): 44 | x, x_sim, epoch, indices, sim_indices = args 45 | self.optimizer.zero_grad() 46 | 47 | resp_data = None 48 | link_embeddings = None 49 | final_link_types = [] 50 | final_x_ranks = [] 51 | final_link_weights = [] 52 | if self.has_link: 53 | all_indices = np.concatenate([indices, sim_indices]) 54 | all_link_indices = self.link_info.crs_indices[:, [1, 2]] 55 | all_link_types = self.link_info.crs_indices[:, 3] 56 | all_link_weights = self.link_info.crs_indices[:, 4] 57 | flattened_all_link_indices = np.ravel(all_link_indices) 58 | 59 | link_related_indices, x_ranks, y_ranks = np.intersect1d(all_indices, flattened_all_link_indices, 60 | return_indices=True) 61 | if len(link_related_indices) > 0: 62 | per_link_indices = (y_ranks // 2).astype(np.int) 63 | resp_indices = all_link_indices[per_link_indices, ((y_ranks + 1) % 2).astype(np.int)] 64 | resp_link_types = all_link_types[per_link_indices] 65 | resp_link_weights = all_link_weights[per_link_indices] 66 | final_indices = [] 67 | 68 | for i in range(len(resp_indices)): 69 | if resp_indices[i] not in link_related_indices: 70 | final_indices.append(resp_indices[i]) 71 | final_link_types.append(resp_link_types[i]) 72 | final_x_ranks.append(x_ranks[i]) 73 | final_link_weights.append(resp_link_weights[i]) 74 | 75 | final_indices = np.array(final_indices, dtype=np.int) 76 | final_link_types = np.array(final_link_types, dtype=np.int) 77 | final_x_ranks = np.array(final_x_ranks, dtype=np.int) 78 | final_link_weights = np.array(final_link_weights, dtype=np.float) 79 | 80 | if len(final_indices) > 0: 81 | resp_data = self.dataset.data[final_indices] 82 | resp_data = torch.tensor(resp_data, dtype=torch.float).to(self.device, non_blocking=True) 83 | if self.is_image: 84 | resp_data /= 255. 85 | 86 | if resp_data is not None: 87 | x_and_resp = torch.cat([x, resp_data], dim=0) 88 | resp_num = resp_data.shape[0] 89 | x_and_resp_embeddings = self.encode(x_and_resp)[1] 90 | 91 | link_embeddings = x_and_resp_embeddings[-resp_num:] 92 | x_embeddings = x_and_resp_embeddings[:-resp_num] 93 | x_sim_embeddings = self.encode(x_sim)[1] 94 | else: 95 | x_embeddings = self.encode(x)[1] 96 | x_sim_embeddings = self.encode(x_sim)[1] 97 | 98 | train_loss = self.model.compute_loss(x_embeddings, x_sim_embeddings, epoch, link_embeddings, final_link_types, 99 | final_x_ranks, final_link_weights) 100 | 101 | train_loss.backward() 102 | self.optimizer.step() 103 | return train_loss 104 | 105 | def _update_symm_knn_by_cl(self): 106 | 107 | def delete_neighbor(self_idx, other_indices): 108 | inter_knn_indices = np.setdiff1d(symm_knn_indices[self_idx], other_indices) 109 | if len(inter_knn_indices) != len(symm_knn_indices[self_idx]): 110 | indices = [] 111 | for item in inter_knn_indices: 112 | indices.append(np.argwhere(symm_knn_indices[self_idx] == item)[0][0]) 113 | symm_knn_indices[self_idx] = inter_knn_indices 114 | symm_no_norm_weights[self_idx] = symm_no_norm_weights[self_idx][indices] 115 | 116 | link_weight = 1 117 | symm_knn_indices = self.dataset.symmetry_knn_indices 118 | symm_knn_weights = self.dataset.symmetry_knn_weights 119 | symm_no_norm_weights = self.dataset.sym_no_norm_weights 120 | 121 | link_num = self.link_info.new_link_num if self.link_info is not None else 0 122 | if link_num == 0: 123 | self.dataset.transform.build_neighbor_repo(self.finetune_epoch, self.n_neighbors, symm_knn_indices, 124 | symm_knn_weights) 125 | return 126 | 127 | link_crs = self.link_info.new_crs_indices 128 | link_sims = self.link_info.new_crs_sims 129 | repeat_num = self.link_info.new_link_spreads + ~(self.link_info.new_link_spreads.astype(np.bool)) 130 | 131 | link_spreads = np.repeat(self.link_info.new_link_spreads, repeat_num) 132 | 133 | total_link_num = link_crs.shape[0] 134 | 135 | for i in range(total_link_num): 136 | uuid, h_idx, t_idx, link_type, _ = link_crs[i] 137 | h_sim, t_sim = link_sims[i][[1, 2]] 138 | h_idx = int(h_idx) 139 | t_idx = int(t_idx) 140 | 141 | if link_type == MUST_LINK: 142 | if link_spreads[i] == UN_SPREAD: 143 | symm_knn_indices[h_idx] = np.array([t_idx], dtype=np.int) 144 | symm_knn_indices[t_idx] = np.array([h_idx], dtype=np.int) 145 | symm_no_norm_weights[h_idx] = np.array([1], dtype=np.float) 146 | symm_no_norm_weights[t_idx] = np.array([1], dtype=np.float) 147 | else: 148 | symm_knn_indices[h_idx] = np.append(symm_knn_indices[h_idx], t_idx) 149 | symm_knn_indices[t_idx] = np.append(symm_knn_indices[t_idx], h_idx) 150 | 151 | h_weight, t_weight = h_sim, t_sim 152 | w = h_weight * t_weight * link_weight 153 | 154 | symm_no_norm_weights[h_idx] = np.append(symm_no_norm_weights[h_idx], w) 155 | symm_no_norm_weights[t_idx] = np.append(symm_no_norm_weights[t_idx], w) 156 | else: 157 | delete_neighbor(h_idx, t_idx) 158 | delete_neighbor(t_idx, h_idx) 159 | 160 | symm_knn_weights[h_idx] = symm_no_norm_weights[h_idx] / np.sum(symm_no_norm_weights[h_idx]) 161 | symm_knn_weights[t_idx] = symm_no_norm_weights[t_idx] / np.sum(symm_no_norm_weights[t_idx]) 162 | 163 | self.dataset.transform.build_neighbor_repo(self.finetune_epoch, self.n_neighbors, symm_knn_indices, 164 | symm_knn_weights) 165 | 166 | def _after_epoch(self, ckp_save_inter, epoch, training_loss, training_loss_history, val_inter): 167 | ret_val = super()._after_epoch(ckp_save_inter, epoch, training_loss, training_loss_history, val_inter) 168 | if epoch % 10 == 0 and self.has_link: 169 | self._update_finetune_dataset() 170 | pass 171 | return ret_val 172 | 173 | def _update_finetune_dataset(self): 174 | sampled_num = max(int(self.n_samples * self.finetune_data_ratio), self.minimum_finetune_data_num) 175 | sampled_indices = random.sample(list(np.arange(0, self.n_samples, 1)), sampled_num) 176 | if self.link_info is not None: 177 | sampled_indices = np.union1d(sampled_indices, 178 | np.ravel(self.link_info.crs_indices[:, [1, 2]]).astype(np.int)) 179 | self.train_loader.sampler.update_indices(sampled_indices, True) 180 | self.finetune_data_num = len(sampled_indices) 181 | -------------------------------------------------------------------------------- /experiments/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import math 4 | 5 | import torch 6 | 7 | from dataset.warppers import DataSetWrapper 8 | from utils.common_utils import check_path_exists, time_stamp_to_date_time_adjoin 9 | from utils.math_utils import * 10 | import matplotlib 11 | matplotlib.use('Agg') 12 | import matplotlib.pyplot as plt 13 | import os 14 | from torch.optim.lr_scheduler import MultiStepLR 15 | import shutil 16 | from utils.constant_pool import ConfigInfo 17 | from multiprocessing import Queue 18 | from utils.logger import InfoLogger, LogWriter 19 | import seaborn as sns 20 | import time 21 | 22 | 23 | def draw_loss(training_loss, idx, save_path=None): 24 | plt.figure() 25 | plt.plot(idx, training_loss, color="blue", label="training loss") 26 | plt.legend() 27 | plt.xlabel("epochs") 28 | plt.ylabel("loss") 29 | if save_path is not None: 30 | plt.savefig(save_path) 31 | # plt.show() 32 | 33 | 34 | def draw_projections(embeddings, labels, vis_save_path): 35 | x = embeddings[:, 0] 36 | y = embeddings[:, 1] 37 | 38 | plt.figure(figsize=(8, 8)) 39 | if labels is None: 40 | sns.scatterplot(x=x, y=y, s=8, legend=False, alpha=1.0) 41 | else: 42 | classes = np.unique(labels) 43 | num_classes = classes.shape[0] 44 | palette = "tab10" if num_classes <= 10 else "tab20" 45 | sns.scatterplot(x=x, y=y, hue=labels, s=8, palette=palette, legend=False, alpha=0.8) 46 | plt.xticks([]) 47 | plt.yticks([]) 48 | 49 | if vis_save_path is not None: 50 | plt.savefig(vis_save_path, dpi=600, bbox_inches='tight', pad_inches=0.1) 51 | # plt.show() 52 | 53 | 54 | class CDRTrainer: 55 | def __init__(self, model, dataset_name, configs, result_save_dir, config_path, device='cuda', 56 | log_path="log.txt"): 57 | self.model = model 58 | self.config_path = config_path 59 | self.configs = configs 60 | self.device = device 61 | self.result_save_dir = result_save_dir 62 | self.dataset_name = dataset_name 63 | self.batch_size = configs.exp_params.batch_size 64 | self.epoch_num = configs.training_params.epoch_nums 65 | self.n_neighbors = configs.exp_params.n_neighbors 66 | self.print_iter = int(self.configs.training_params.epoch_print_inter_ratio * self.epoch_num) 67 | self.is_image = not isinstance(self.configs.exp_params.input_dims, int) 68 | self.lr = configs.exp_params.LR 69 | self.ckp_save_dir = self.result_save_dir 70 | 71 | self.batch_num = 0 72 | self.val_inter = 0 73 | self.start_epoch = 0 74 | self.train_loader = None 75 | self.launch_date_time = None 76 | self.optimizer = None 77 | self.scheduler = None 78 | 79 | self.tmp_log_path = log_path 80 | self.log_process = None 81 | self.log_path = None 82 | self.message_queue = Queue() 83 | self.pre_embeddings = None 84 | self.fixed_k = 15 85 | 86 | self.clr_dataset = None 87 | self.resume_epochs = 0 88 | self.model.to(self.device) 89 | self.steps = 0 90 | self.resume_start_epoch = self.resume_epochs if self.resume_epochs > 0 else self.epoch_num 91 | self.gradient_redefine = configs.exp_params.gradient_redefine 92 | self.warmup_epochs = 0 93 | self.separation_epochs = 0 94 | if self.gradient_redefine: 95 | self.warmup_epochs = int(self.epoch_num * configs.exp_params.separation_begin_ratio) 96 | self.separation_epochs = int(self.epoch_num * configs.exp_params.steady_begin_ratio) 97 | 98 | def update_configs(self, configs): 99 | self.configs = configs 100 | self.dataset_name = configs.exp_params.dataset 101 | self.epoch_num = configs.training_params.epoch_nums 102 | 103 | def encode(self, x): 104 | return self.model.encode(x) 105 | 106 | def _train_begin(self, launch_time_stamp=None): 107 | self.sta_time = time.time() if launch_time_stamp is None else launch_time_stamp 108 | 109 | InfoLogger.info("Start Training for {} Epochs".format(self.epoch_num - self.start_epoch)) 110 | 111 | param_template = "Experiment Configurations: \nDataset: %s Epochs: %d Batch Size: %d \n" \ 112 | "Learning rate: %4f Optimizer: %s\n" 113 | 114 | param_str = param_template % (self.dataset_name, self.epoch_num, self.batch_size, 115 | self.lr, self.configs.exp_params.optimizer) 116 | 117 | InfoLogger.info(param_str) 118 | self.message_queue.put(param_str) 119 | 120 | InfoLogger.info("Start Training for {} Epochs".format(self.epoch_num)) 121 | if self.launch_date_time is None: 122 | if launch_time_stamp is None: 123 | launch_time_stamp = int(time.time()) 124 | self.launch_date_time = time_stamp_to_date_time_adjoin(launch_time_stamp) 125 | 126 | self.result_save_dir = os.path.join(self.result_save_dir, 127 | "{}_{}".format(self.dataset_name, self.launch_date_time)) 128 | 129 | self.log_path = os.path.join(self.result_save_dir, "log.txt") 130 | self.ckp_save_dir = self.result_save_dir 131 | 132 | if self.optimizer is None: 133 | self.init_optimizer() 134 | self.init_scheduler(cur_epochs=self.epoch_num) 135 | 136 | val_inter = math.ceil(self.epoch_num * self.configs.training_params.val_inter_ratio) 137 | ckp_save_inter = math.ceil(self.epoch_num * self.configs.training_params.ckp_inter_ratio) 138 | 139 | return val_inter, ckp_save_inter 140 | 141 | def init_optimizer(self): 142 | if self.configs.exp_params.optimizer == "adam": 143 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=0.0001) 144 | elif self.configs.exp_params.optimizer == "sgd": 145 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9, 146 | weight_decay=0.0001) 147 | else: 148 | raise RuntimeError("Unsupported optimizer! Please check the configuration and ensure the param " 149 | "name is one of 'adam/sgd'") 150 | 151 | def init_scheduler(self, cur_epochs, base=0, gamma=0.1, milestones=None): 152 | if milestones is None: 153 | milestones = [0.8] 154 | if self.configs.exp_params.scheduler == "multi_step": 155 | self.scheduler = MultiStepLR(self.optimizer, milestones=[int(base + p * cur_epochs) for p in milestones], 156 | gamma=gamma, last_epoch=-1) 157 | elif self.configs.exp_params.scheduler == "cosine": 158 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=len(self.train_loader), 159 | eta_min=0.00001, last_epoch=-1) 160 | else: 161 | raise RuntimeError("Unsupported learning scheduler! Please check the configuration and ensure the param " 162 | "name is one of 'multi_step/cosine'") 163 | 164 | def _before_epoch(self, epoch): 165 | self.model = self.model.to(self.device) 166 | if self.gradient_redefine: 167 | if epoch == self.warmup_epochs: 168 | self.train_loader.dataset.transform.build_neighbor_repo(self.separation_epochs - self.warmup_epochs, 169 | self.n_neighbors) 170 | elif epoch == self.separation_epochs: 171 | self.train_loader.dataset.transform.build_neighbor_repo(self.epoch_num - self.separation_epochs, 172 | self.n_neighbors) 173 | 174 | train_iterator = iter(self.train_loader) 175 | return train_iterator, 0 176 | 177 | def _step_prepare(self, *args): 178 | data, epoch = args 179 | x, x_sim, indices, sim_indices = data[0] 180 | 181 | x = x.to(self.device, non_blocking=True) 182 | x_sim = x_sim.to(self.device, non_blocking=True) 183 | return x, x_sim, epoch, indices, sim_indices 184 | 185 | def _train_step(self, *args): 186 | x, x_sim, epoch, indices, sim_indices = args 187 | 188 | self.optimizer.zero_grad() 189 | _, x_embeddings, _, x_sim_embeddings = self.forward(x, x_sim) 190 | 191 | train_loss = self.model.compute_loss(x_embeddings, x_sim_embeddings, epoch) 192 | 193 | train_loss.backward() 194 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=2.0, norm_type=2) 195 | self.optimizer.step() 196 | return train_loss 197 | 198 | def model_prepare(self): 199 | self.model.preprocess() 200 | 201 | def _after_epoch(self, ckp_save_inter, epoch, training_loss, training_loss_history, val_inter): 202 | 203 | if self.configs.exp_params.scheduler == "cosine" and epoch >= 10: 204 | self.scheduler.step() 205 | elif self.configs.exp_params.scheduler == "multi_step": 206 | self.scheduler.step() 207 | 208 | train_loss = training_loss / self.batch_num 209 | if epoch % self.print_iter == 0: 210 | epoch_template = 'Epoch %d/%d, Train Loss: %.5f, ' 211 | epoch_output = epoch_template % (epoch, self.epoch_num, train_loss) 212 | InfoLogger.info(epoch_output) 213 | self.message_queue.put(epoch_output) 214 | 215 | training_loss_history.append(train_loss) 216 | embeddings = self.post_epoch(ckp_save_inter, epoch, val_inter) 217 | 218 | return embeddings 219 | 220 | def _train_end(self, training_loss_history, embeddings): 221 | np.save(os.path.join(self.result_save_dir, "embeddings_{}.npy".format(self.epoch_num)), embeddings) 222 | self.message_queue.put("end") 223 | self.save_weights(self.epoch_num) 224 | 225 | x_idx = np.linspace(self.start_epoch, self.epoch_num, self.epoch_num - self.start_epoch) 226 | save_path = os.path.join(self.result_save_dir, 227 | "loss_{}.jpg".format(self.epoch_num)) 228 | draw_loss(training_loss_history, x_idx, save_path) 229 | self.log_process.join(timeout=5) 230 | shutil.copyfile(self.tmp_log_path, self.log_path) 231 | InfoLogger.info("Training process logging to {}".format(self.log_path)) 232 | 233 | def train(self, launch_time_stamp=None): 234 | self.val_inter, ckp_save_inter = self._train_begin(launch_time_stamp) 235 | embeddings = None 236 | net = self.model 237 | net.batch_num = self.batch_num 238 | training_loss_history = [] 239 | 240 | for epoch in range(self.start_epoch, self.epoch_num): 241 | train_iterator, training_loss = self._before_epoch(epoch) 242 | for idx, data in enumerate(train_iterator): 243 | self.steps += 1 244 | train_data = self._step_prepare(data, epoch) 245 | loss = self._train_step(*train_data) 246 | training_loss += loss 247 | 248 | embeddings = self._after_epoch(ckp_save_inter, epoch + 1, training_loss, training_loss_history, 249 | self.val_inter) 250 | 251 | self._train_end(training_loss_history, embeddings) 252 | return embeddings 253 | 254 | def resume_train(self, resume_epoch): 255 | self.resume_start_epoch = self.epoch_num 256 | self.start_epoch = self.epoch_num 257 | self.epoch_num = self.resume_start_epoch + resume_epoch 258 | self.optimizer.param_groups[0]['lr'] = self.lr 259 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=len(self.train_loader), 260 | eta_min=0.00001, last_epoch=-1) 261 | return self.train() 262 | 263 | def save_weights(self, epoch, prefix_name=None): 264 | if prefix_name is None: 265 | prefix_name = epoch 266 | if not os.path.exists(self.ckp_save_dir): 267 | os.mkdir(self.ckp_save_dir) 268 | weight_save_path = os.path.join(self.ckp_save_dir, "{}.pth.tar". 269 | format(prefix_name)) 270 | torch.save({'epoch': epoch, 'state_dict': self.model.state_dict(), 271 | 'optimizer': self.optimizer.state_dict(), 272 | 'lr': self.lr, 'launch_time': self.launch_date_time}, weight_save_path) 273 | InfoLogger.info("model weights successfully saved to {}".format(weight_save_path)) 274 | 275 | def forward(self, x, x_sim): 276 | return self.model.forward(x, x_sim) 277 | 278 | def load_weights(self, checkpoint_path, train=True): 279 | self.preprocess(train) 280 | model_ckpt = torch.load(checkpoint_path, map_location=torch.device(self.device)) 281 | self.model.load_state_dict(model_ckpt['state_dict']) 282 | self.init_optimizer() 283 | self.optimizer.load_state_dict(model_ckpt['optimizer']) 284 | self.optimizer.param_groups[0]['lr'] = self.lr 285 | for state in self.optimizer.state.values(): 286 | for k, v in state.items(): 287 | if torch.is_tensor(v): 288 | state[k] = v.to(self.device) 289 | return model_ckpt 290 | 291 | def load_weights_train(self, checkpoint_path): 292 | model_ckpt = self.load_weights(checkpoint_path) 293 | self.start_epoch = model_ckpt["epoch"] 294 | self.launch_date_time = model_ckpt["launch_time"] 295 | self.train() 296 | 297 | def load_weights_visualization(self, checkpoint_path, vis_save_path, device='cuda'): 298 | self.load_weights(checkpoint_path, train=False) 299 | embeddings = self.visualize(vis_save_path, device=device) 300 | return embeddings 301 | 302 | def train_for_visualize(self): 303 | InfoLogger.info("Start train for Visualize") 304 | launch_time_stamp = int(time.time()) 305 | self.preprocess() 306 | self.pre_embeddings = self.train(launch_time_stamp) 307 | return self.pre_embeddings 308 | 309 | def cal_lower_embeddings(self, data): 310 | if self.is_image: 311 | data = data / 255. 312 | embeddings = self.acquire_latent_code_allin(data) 313 | return embeddings 314 | 315 | def visualize(self, vis_save_path=None, device="cuda"): 316 | self.model.to(device) 317 | data = torch.tensor(self.train_loader.dataset.get_all_data()).to(device).float() 318 | embeddings = self.cal_lower_embeddings(data) 319 | 320 | draw_projections(embeddings, self.train_loader.dataset.targets, vis_save_path) 321 | 322 | return embeddings 323 | 324 | def acquire_latent_code(self, inputs): 325 | return self.model.acquire_latent_code(inputs) 326 | 327 | def acquire_latent_code_allin(self, data): 328 | with torch.no_grad(): 329 | self.model.eval() 330 | embeddings = self.model.acquire_latent_code(data).cpu().numpy() 331 | self.model.train() 332 | return embeddings 333 | 334 | def preprocess(self, train=True): 335 | self.build_dataset() 336 | if train: 337 | self.log_process = LogWriter(self.tmp_log_path, self.log_path, self.message_queue) 338 | self.log_process.start() 339 | self.model_prepare() 340 | 341 | def build_dataset(self): 342 | knn_cache_path = os.path.join(ConfigInfo.NEIGHBORS_CACHE_DIR, 343 | "{}_k{}.npy".format(self.dataset_name, self.n_neighbors)) 344 | pairwise_cache_path = os.path.join(ConfigInfo.PAIRWISE_DISTANCE_DIR, "{}.npy".format(self.dataset_name)) 345 | check_path_exists(ConfigInfo.NEIGHBORS_CACHE_DIR) 346 | check_path_exists(ConfigInfo.PAIRWISE_DISTANCE_DIR) 347 | 348 | cdr_dataset = DataSetWrapper(self.batch_size) 349 | resume_start_epoch = self.resume_start_epoch 350 | if self.gradient_redefine: 351 | resume_start_epoch = self.warmup_epochs 352 | 353 | self.train_loader, self.n_samples = cdr_dataset.get_data_loaders( 354 | resume_start_epoch, self.dataset_name, ConfigInfo.DATASET_CACHE_DIR, self.n_neighbors, knn_cache_path, 355 | pairwise_cache_path, self.is_image) 356 | 357 | self.batch_num = cdr_dataset.batch_num 358 | self.model.batch_num = self.batch_num 359 | 360 | def post_epoch(self, ckp_save_inter, epoch, val_inter): 361 | embeddings = None 362 | vis_save_path = os.path.join(self.result_save_dir, '{}_vis_{}.jpg'.format(self.dataset_name, epoch)) 363 | 364 | if epoch % val_inter == 0: 365 | if not os.path.exists(self.result_save_dir): 366 | os.makedirs(self.result_save_dir) 367 | if self.config_path is not None: 368 | shutil.copyfile(self.config_path, os.path.join(self.result_save_dir, "config.yaml")) 369 | 370 | embeddings = self.visualize(vis_save_path, device=self.device) 371 | 372 | # save model 373 | if epoch % ckp_save_inter == 0: 374 | if not os.path.exists(self.ckp_save_dir): 375 | os.makedirs(self.ckp_save_dir) 376 | self.save_weights(epoch) 377 | 378 | return embeddings 379 | 380 | -------------------------------------------------------------------------------- /model/baseline_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | 6 | 7 | def get_encoder(encoder_name, input_size, input_dims, input_channels): 8 | try: 9 | if encoder_name == "CBR": 10 | encoder = Encoder(input_size, in_channels=input_channels) 11 | encoder_out_dims = encoder.output_dims 12 | else: 13 | encoder = FCEncoder(input_dims) 14 | encoder_out_dims = encoder.hidden_dims[-1] 15 | return encoder, encoder_out_dims 16 | except: 17 | raise Exception("Invalid model name. Check the config file and pass one of: resnet18 or resnet50 or CBR or " 18 | "FC") 19 | 20 | 21 | class Encoder(nn.Module): 22 | def __init__(self, input_size, in_channels=1, hidden_dims=None): 23 | super(Encoder, self).__init__() 24 | if hidden_dims is None: 25 | hidden_dims = [64, 128, 256, 512] 26 | self.hidden_dims = hidden_dims 27 | self.input_size = input_size 28 | self.in_channels = in_channels 29 | modules = [] 30 | for h_dim in self.hidden_dims: 31 | modules.append(nn.Sequential( 32 | nn.Conv2d(in_channels, h_dim, kernel_size=3, stride=2, padding=1), 33 | nn.BatchNorm2d(h_dim), 34 | nn.ReLU() 35 | )) 36 | in_channels = h_dim 37 | modules.append(nn.Flatten()) 38 | 39 | self.encoder = nn.Sequential(*modules) 40 | tmp = torch.zeros((2, self.in_channels, self.input_size, self.input_size)) 41 | self.output_dims = self.encoder.forward(tmp).shape[1] 42 | 43 | def forward(self, x): 44 | h = self.encoder(x) 45 | return h 46 | 47 | 48 | class FCEncoder(nn.Module): 49 | def __init__(self, in_features, hidden_dims=None): 50 | nn.Module.__init__(self) 51 | 52 | if hidden_dims is None: 53 | hidden_dims = [128, 256, 256, 512] 54 | self.hidden_dims = hidden_dims 55 | modules = [] 56 | 57 | in_dim = in_features 58 | for dim in hidden_dims: 59 | modules.append(nn.Sequential( 60 | nn.Linear(in_dim, dim), 61 | nn.BatchNorm1d(dim), 62 | nn.ReLU() 63 | )) 64 | in_dim = dim 65 | self.encoder = nn.Sequential(*modules) 66 | 67 | def forward(self, x): 68 | return self.encoder(x) 69 | -------------------------------------------------------------------------------- /model/cdr.py: -------------------------------------------------------------------------------- 1 | from model.nce_loss import NT_Xent, Mixture_NT_Xent 2 | from model.nx_cdr import NX_CDRModel 3 | import torch 4 | 5 | 6 | class CDRModel(NX_CDRModel): 7 | def __init__(self, cfg, device='cuda'): 8 | NX_CDRModel.__init__(self, cfg, device) 9 | 10 | self.a = torch.tensor(-40) 11 | self.miu = torch.tensor(cfg.exp_params.separate_upper) 12 | self.lower_thresh = torch.tensor(0.015) 13 | self.scale = torch.tensor(0.13) 14 | self.alpha = torch.tensor(5) 15 | 16 | self.separate_epoch = int(self.epoch_num * cfg.exp_params.separation_begin_ratio) 17 | self.steady_epoch = int(self.epoch_num * cfg.exp_params.steady_begin_ratio) 18 | 19 | def preprocess(self): 20 | self.build_model() 21 | self.criterion = NT_Xent.apply 22 | 23 | def _post_loss(self, logits, x_embeddings, epoch, item_weights, *args): 24 | if self.separate_epoch <= epoch <= self.steady_epoch: 25 | epoch_ratio = torch.tensor((epoch - self.separate_epoch) / (self.steady_epoch - self.separate_epoch)) 26 | cur_lower_thresh = 0.001 + (self.lower_thresh - 0.001) * epoch_ratio 27 | loss = Mixture_NT_Xent.apply(logits, torch.tensor(self.temperature), self.alpha, self.a, self.miu, 28 | cur_lower_thresh, self.scale, item_weights) 29 | else: 30 | loss = self.criterion(logits, torch.tensor(self.temperature), item_weights) 31 | return loss 32 | -------------------------------------------------------------------------------- /model/icdr.py: -------------------------------------------------------------------------------- 1 | from model.cdr import CDRModel 2 | from utils.link_utils import MUST_LINK, CANNOT_LINK 3 | import torch 4 | 5 | 6 | class ICDRModel(CDRModel): 7 | def __init__(self, cfg, device='cuda'): 8 | CDRModel.__init__(self, cfg, device) 9 | self.cur_ml_num = 0 10 | self.cur_cl_num = 0 11 | self.cur_ml_indices = None 12 | self.cur_cl_indices = None 13 | self.max_prob_thresh = 0.95 14 | self.ml_strength = 0 15 | self.cl_strength = 0 16 | self.gather_weight = 0.2 17 | 18 | def link_stat_update(self, finetune_epochs, steady_epoch, ml_strength, cl_strength): 19 | self.separate_epoch = self.epoch_num 20 | self.steady_epoch = self.separate_epoch + finetune_epochs - steady_epoch 21 | self.epoch_num = self.separate_epoch + finetune_epochs 22 | self.ml_strength = ml_strength 23 | self.cl_strength = cl_strength 24 | 25 | def batch_logits(self, x_embeddings, x_sim_embeddings, *args): 26 | logits = super().batch_logits(x_embeddings, x_sim_embeddings, *args) 27 | 28 | link_embeddings, cur_link_types, related_indices = args[-4:-1] 29 | cur_link_num = len(related_indices) 30 | if cur_link_num == 0: 31 | self.cur_ml_num = 0 32 | self.cur_cl_num = 0 33 | return logits 34 | 35 | all_embeddings = torch.cat([x_embeddings, x_sim_embeddings], dim=0) 36 | cl_indices = torch.where(torch.tensor(cur_link_types) == CANNOT_LINK)[0] 37 | ml_indices = torch.where(torch.tensor(cur_link_types) == MUST_LINK)[0] 38 | self.cur_cl_indices = cl_indices 39 | self.cur_ml_indices = ml_indices 40 | self.cur_ml_num = len(ml_indices) 41 | self.cur_cl_num = len(cl_indices) 42 | 43 | if self.cur_ml_num > 0: 44 | h_embeddings = all_embeddings[related_indices[ml_indices]] 45 | t_embeddings = link_embeddings[ml_indices] 46 | 47 | positives = self.similarity_func(h_embeddings, t_embeddings, self.min_dist)[0].view(-1, 1) 48 | negatives = logits[related_indices[ml_indices], 1:].view(self.cur_ml_num, -1) 49 | ml_logits = torch.cat([positives, negatives.detach()], dim=1) 50 | logits = torch.cat((logits, ml_logits), dim=0) 51 | 52 | if self.cur_cl_num > 0: 53 | h_embeddings = all_embeddings[related_indices[cl_indices]] 54 | t_embeddings = link_embeddings[cl_indices] 55 | 56 | negative = self.similarity_func(h_embeddings, t_embeddings, self.min_dist)[0].view(-1, 1) 57 | other_negatives = logits[related_indices[cl_indices], 1:-1].view(self.cur_cl_num, -1) 58 | positive = logits[related_indices[cl_indices], 0].view(self.cur_cl_num, 1).clone() 59 | 60 | indices = torch.where(positive > self.max_prob_thresh)[0] 61 | positive[indices] = positive[indices].detach() 62 | 63 | cl_logits = torch.cat([positive, other_negatives.detach(), negative], dim=1) 64 | logits = torch.cat((logits, cl_logits), dim=0) 65 | 66 | return logits 67 | 68 | def compute_loss(self, x_embeddings, x_sim_embeddings, *args): 69 | epoch = args[0] 70 | total_link_weights = args[-1] 71 | logits = self.batch_logits(x_embeddings, x_sim_embeddings, *args) 72 | loss = self._post_loss(logits, x_embeddings, epoch, total_link_weights, *args) 73 | return loss 74 | 75 | def _post_loss(self, logits, x_embeddings, epoch, total_link_weights, *args): 76 | total_link_num = self.cur_ml_num + self.cur_cl_num 77 | normal_num = self.batch_size * 2 78 | normal_loss = super()._post_loss(logits[:normal_num], x_embeddings[:normal_num], epoch, None, *args) 79 | if total_link_num == 0: 80 | return normal_loss 81 | else: 82 | total_link_weights = torch.tensor(total_link_weights, dtype=torch.float).to(self.device) 83 | t = self.temperature 84 | ml_loss, cl_loss = 0, 0 85 | link_logits = logits[-total_link_num:] 86 | if self.cur_ml_num > 0: 87 | ml_logits = link_logits[:self.cur_ml_num] 88 | ml_link_weight = total_link_weights[self.cur_ml_indices] 89 | 90 | indices = torch.where(ml_logits[:, 0] < self.max_prob_thresh)[0] 91 | if len(indices) > 0: 92 | ml_loss = super()._post_loss(ml_logits[indices], None, epoch, ml_link_weight[indices], *args) 93 | 94 | if self.cur_cl_num > 0: 95 | cl_logits = link_logits[-self.cur_cl_num:] 96 | cl_link_weight = total_link_weights[self.cur_cl_indices] 97 | pos_cl_logits = torch.cat([cl_logits[:, 0].unsqueeze(1), cl_logits[:, 1:].detach()], dim=1) 98 | neg_cl_logits = torch.cat([1 - cl_logits[:, -1].unsqueeze(1), cl_logits[:, 1:].detach()], dim=1) 99 | 100 | indices = torch.where(cl_logits[:, -1] > self.lower_thresh)[0] 101 | 102 | cl_pos_loss = self.criterion(pos_cl_logits[indices], torch.tensor(t), cl_link_weight[indices]) 103 | cl_neg_loss = self.criterion(neg_cl_logits[indices], torch.tensor(t), cl_link_weight[indices]) 104 | 105 | cl_loss = self.gather_weight * cl_pos_loss + self.cl_strength * cl_neg_loss 106 | 107 | link_loss = (self.ml_strength * ml_loss + cl_loss) 108 | loss = (normal_loss + link_loss) 109 | return loss 110 | -------------------------------------------------------------------------------- /model/nce_loss.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | import torch 3 | 4 | import math 5 | 6 | normal_obj = torch.distributions.normal.Normal(0, 1) 7 | 8 | 9 | def torch_norm_pdf(data): 10 | return torch.exp(-torch.square(data) / 2) / math.sqrt(2 * math.pi) 11 | 12 | 13 | def torch_norm_cdf(data): 14 | global normal_obj 15 | return normal_obj.cdf(data) 16 | 17 | 18 | def torch_skewnorm_pdf(data, a, loc, scale): 19 | y = (data - loc) / scale 20 | output = 2 * torch_norm_pdf(y) * torch_norm_cdf(a * y) / scale 21 | return output 22 | 23 | 24 | def torch_app_skewnorm_func(data, r, a=-40, loc=0.11, scale=0.13): 25 | y = torch_skewnorm_pdf(data, a, loc, scale) 26 | y = y * r 27 | return y 28 | 29 | 30 | class NT_Xent(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, probabilities, t, item_weights): 34 | exp_prob = torch.exp(probabilities / t) 35 | 36 | similarities = exp_prob / torch.sum(exp_prob, dim=1).unsqueeze(1) 37 | 38 | ctx.save_for_backward(similarities, t, item_weights) 39 | 40 | pos_loss = -torch.log(similarities[:, 0]).mean() 41 | 42 | return pos_loss 43 | 44 | @staticmethod 45 | def backward(ctx, grad_output): 46 | similarities, t, item_weights = ctx.saved_tensors 47 | pos_grad_coeff = -((torch.sum(similarities, dim=1) - similarities[:, 0]) / t).unsqueeze(1) 48 | neg_grad_coeff = similarities[:, 1:] / t 49 | grad = torch.cat([pos_grad_coeff, neg_grad_coeff], dim=1) * grad_output / similarities.shape[0] 50 | if item_weights is not None: 51 | grad *= item_weights.view(-1, 1) 52 | return grad, None, None 53 | 54 | 55 | class Mixture_NT_Xent(Function): 56 | 57 | @staticmethod 58 | def forward(ctx, probabilities, t, alpha, a, loc, lower_thresh, scale, item_weight): 59 | 60 | def nt_xent_grad(data, tau): 61 | exp_prob = torch.exp(data / tau) 62 | norm_exp_prob = exp_prob / torch.sum(exp_prob, dim=1).unsqueeze(1) 63 | gradients = norm_exp_prob[:, 1:] / tau 64 | return norm_exp_prob, gradients 65 | 66 | similarities, exp_neg_grad_coeff = nt_xent_grad(probabilities, t) 67 | 68 | skewnorm_prob = torch_skewnorm_pdf(probabilities[:, 1:], a, loc, scale) 69 | skewnorm_similarities = skewnorm_prob / torch.sum(skewnorm_prob, dim=1).unsqueeze(1) 70 | sn_max_val_indices = torch.argmax(skewnorm_similarities, dim=1) 71 | rows = torch.arange(0, skewnorm_similarities.shape[0], 1) 72 | skewnorm_max_value = skewnorm_similarities[rows, sn_max_val_indices].unsqueeze(1) 73 | ref_exp_value = exp_neg_grad_coeff[rows, sn_max_val_indices].unsqueeze(1) 74 | 75 | raw_alpha = ref_exp_value / skewnorm_max_value 76 | 77 | ctx.save_for_backward(probabilities, similarities, t, skewnorm_similarities, loc, lower_thresh, alpha, 78 | raw_alpha, item_weight) 79 | 80 | pos_loss = -torch.log(similarities[:, 0]).mean() 81 | 82 | return pos_loss 83 | 84 | @staticmethod 85 | def backward(ctx, grad_output): 86 | prob, exp_sims, t, sn_sims, loc, lower_thresh, alpha, raw_alpha, item_weights = ctx.saved_tensors 87 | 88 | pos_grad_coeff = -((torch.sum(exp_sims, dim=1) - exp_sims[:, 0]) / t).unsqueeze(1) 89 | high_thresh = loc 90 | sn_sims[prob[:, 1:] < lower_thresh] = 0 91 | sn_sims[prob[:, 1:] >= high_thresh] = 0 92 | exp_sims[:, 1:][prob[:, 1:] < lower_thresh] = 0 93 | 94 | neg_grad_coeff = exp_sims[:, 1:] / t + alpha * sn_sims * raw_alpha 95 | grad = torch.cat([pos_grad_coeff, neg_grad_coeff], dim=1) * grad_output / exp_sims.shape[0] 96 | if item_weights is not None: 97 | grad *= item_weights.view(-1, 1) 98 | return grad, None, None, None, None, None, None, None 99 | 100 | -------------------------------------------------------------------------------- /model/nx_cdr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.nn import Module 3 | from model.nce_loss import torch_app_skewnorm_func 4 | from utils.math_utils import get_similarity_function, get_correlated_mask 5 | from utils.umap_utils import find_ab_params 6 | from model.baseline_encoder import * 7 | 8 | 9 | def exp_ce(data_matrix, t, data_labels, accumulation="MEAN"): 10 | exp_data = torch.exp(data_matrix / t) 11 | return ce(exp_data, data_labels, accumulation) 12 | 13 | 14 | def skewnorm_ce(data_matrix, ratio, data_labels, accumulation="MEAN"): 15 | sn_data = torch_app_skewnorm_func(data_matrix, ratio) 16 | return ce(sn_data, data_labels, accumulation) 17 | 18 | 19 | def ce(data_matrix, data_labels, accumulation="MEAN"): 20 | softmax_data = data_matrix / torch.sum(data_matrix, dim=1).unsqueeze(1) 21 | loss = -torch.log(softmax_data[torch.arange(0, data_matrix.shape[0]), data_labels]) 22 | if accumulation == "MEAN": 23 | return torch.mean(loss) 24 | elif accumulation == "SUM": 25 | return torch.sum(loss) 26 | else: 27 | return loss 28 | 29 | 30 | class NX_CDRModel(Module): 31 | def __init__(self, cfg, device='cuda'): 32 | Module.__init__(self) 33 | self.device = device 34 | self.config = cfg 35 | self.input_dims = cfg.exp_params.input_dims 36 | self.encoder_name = "Conv" if isinstance(self.input_dims, int) else "FC" 37 | self.in_channels = 1 if isinstance(self.input_dims, int) else self.input_dims[-1] 38 | 39 | self.input_size = int(np.sqrt(self.input_dims / self.in_channels)) 40 | self.latent_dim = 2 41 | self.batch_size = cfg.exp_params.batch_size 42 | self.similarity_method = "umap" 43 | self.temperature = cfg.exp_params.temperature 44 | 45 | self.batch_num = 0 46 | self.max_neighbors = 0 47 | self.encoder = None 48 | self.pro_head = None 49 | 50 | self.criterion = None 51 | self.correlated_mask = get_correlated_mask(2 * self.batch_size) 52 | self.min_dist = 0.1 53 | 54 | self._a, self._b = find_ab_params(1, self.min_dist) 55 | self.similarity_func = get_similarity_function(self.similarity_method) 56 | 57 | self.reduction = "mean" 58 | self.epoch_num = self.config.training_params.epoch_nums 59 | self.batch_count = 0 60 | 61 | def build_model(self): 62 | encoder, encoder_out_dims = get_encoder(self.encoder_name, self.input_size, self.input_dims, self.in_channels) 63 | self.encoder = encoder 64 | pro_dim = 512 65 | self.pro_head = nn.Sequential( 66 | nn.Linear(encoder_out_dims, pro_dim), 67 | nn.ReLU(), 68 | nn.Linear(pro_dim, self.latent_dim) 69 | ) 70 | 71 | def preprocess(self): 72 | self.build_model() 73 | self.criterion = nn.CrossEntropyLoss(reduction=self.reduction) 74 | 75 | def encode(self, x): 76 | if x is None: 77 | return None, None 78 | reps = self.encoder(x) 79 | reps = reps.squeeze() 80 | 81 | embeddings = self.pro_head(reps) 82 | return reps, embeddings 83 | 84 | def forward(self, x, x_sim): 85 | # get the representations and the projections 86 | x_reps, x_embeddings = self.encode(x) # [N,C] 87 | 88 | # get the representations and the projections 89 | x_sim_reps, x_sim_embeddings = self.encode(x_sim) # [N,C] 90 | 91 | return x_reps, x_embeddings, x_sim_reps, x_sim_embeddings 92 | 93 | def acquire_latent_code(self, inputs): 94 | reps, embeddings = self.encode(inputs) 95 | return embeddings 96 | 97 | def compute_loss(self, x_embeddings, x_sim_embeddings, *args): 98 | epoch = args[0] 99 | logits = self.batch_logits(x_embeddings, x_sim_embeddings, *args) 100 | loss = self._post_loss(logits, x_embeddings, epoch, None, *args) 101 | return loss 102 | 103 | def _post_loss(self, logits, x_embeddings, epoch, item_weights, *args): 104 | labels = torch.zeros(logits.shape[0]).to(self.device).long() 105 | loss = self.criterion(logits / self.temperature, labels) 106 | return loss 107 | 108 | def batch_logits(self, x_embeddings, x_sim_embeddings, *args): 109 | self.batch_count += 1 110 | all_embeddings = torch.cat([x_embeddings, x_sim_embeddings], dim=0) 111 | representations = all_embeddings.unsqueeze(0).repeat(all_embeddings.shape[0], 1, 1) 112 | similarity_matrix, pairwise_dist = self.similarity_func(representations.transpose(0, 1), representations, 113 | self.min_dist) 114 | 115 | l_pos = torch.diag(similarity_matrix, self.batch_size) 116 | r_pos = torch.diag(similarity_matrix, -self.batch_size) 117 | 118 | positives = torch.cat([l_pos, r_pos]).view(all_embeddings.shape[0], 1) 119 | negatives = similarity_matrix[self.correlated_mask].view(all_embeddings.shape[0], -1) 120 | 121 | logits = torch.cat((positives, negatives), dim=1) 122 | return logits 123 | 124 | 125 | -------------------------------------------------------------------------------- /model_weights/usps.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/model_weights/usps.pth.tar -------------------------------------------------------------------------------- /prototype.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/prototype.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.10.0 2 | antlr4-python3-runtime==4.8 3 | anykeystore==0.2 4 | appdirs==1.4.4 5 | astor==0.8.1 6 | astunparse==1.6.3 7 | atomicwrites==1.4.0 8 | attrs==21.2.0 9 | audeer==1.16.0 10 | audiofile==1.0.0 11 | audioread==2.1.9 12 | audtorch==0.6.4 13 | cachetools==4.1.1 14 | certifi==2020.12.5 15 | cffi==1.14.6 16 | chardet==3.0.4 17 | click==7.1.2 18 | colorama==0.4.4 19 | cxxfilt==0.3.0 20 | cycler==0.10.0 21 | dataclasses==0.8 22 | decorator==4.4.2 23 | defusedxml==0.7.1 24 | docopt==0.6.2 25 | easydict==1.9 26 | et-xmlfile==1.0.1 27 | Flask==1.1.2 28 | future==0.18.2 29 | gast==0.2.2 30 | google-auth==1.22.1 31 | google-auth-oauthlib==0.4.1 32 | google-pasta==0.2.0 33 | greenlet==1.1.1 34 | grpcio==1.33.1 35 | h5py==2.10.0 36 | hupper==1.10.3 37 | hydra-core==1.1.1 38 | idna==2.10 39 | importlib-metadata==2.0.0 40 | importlib-resources==5.2.2 41 | iniconfig==1.1.1 42 | itsdangerous==1.1.0 43 | jdcal==1.4.1 44 | Jinja2==2.11.3 45 | joblib==0.17.0 46 | Keras-Applications==1.0.8 47 | Keras-Preprocessing==1.1.2 48 | kiwisolver==1.2.0 49 | librosa==0.8.1 50 | llvmlite==0.34.0 51 | Markdown==3.3.2 52 | MarkupSafe==1.1.1 53 | matplotlib==3.3.2 54 | networkx==2.5 55 | numba==0.51.2 56 | numpy==1.19.5 57 | oauthlib==3.1.0 58 | olefile==0.46 59 | omegaconf==2.1.1 60 | opencv-python==4.5.4.60 61 | openpyxl==3.0.6 62 | opt-einsum==3.3.0 63 | packaging==21.0 64 | pandas==1.1.5 65 | PasteDeploy==2.1.1 66 | pbkdf2==1.3 67 | Pillow==8.0.1 68 | pipreqs==0.4.11 69 | plaster==1.0 70 | plaster-pastedeploy==0.7 71 | plotly==5.5.0 72 | pluggy==1.0.0 73 | pooch==1.5.1 74 | protobuf==3.13.0 75 | psutil==5.8.0 76 | py==1.10.0 77 | pyasn1==0.4.8 78 | pyasn1-modules==0.2.8 79 | pycparser==2.20 80 | pynndescent==0.5.1 81 | pyparsing==2.4.7 82 | pyramid==2.0 83 | pyramid-mailer==0.15.1 84 | pytest==6.2.5 85 | python-dateutil==2.8.1 86 | python3-openid==3.2.0 87 | pytz==2020.4 88 | PyYAML==5.3.1 89 | repoze.sendmail==4.4.1 90 | requests==2.24.0 91 | requests-oauthlib==1.3.0 92 | resampy==0.2.2 93 | rsa==4.6 94 | scikit-learn==0.23.2 95 | scipy==1.4.1 96 | seaborn==0.11.2 97 | six==1.15.0 98 | SoundFile==0.10.3.post1 99 | sox==1.4.1 100 | SQLAlchemy==1.4.25 101 | tabulate==0.8.9 102 | tb-nightly==1.14.0a20190301 103 | tenacity==8.0.1 104 | termcolor==1.1.0 105 | tf-estimator-nightly==1.14.0.dev2019030115 106 | threadpoolctl==2.1.0 107 | toml==0.10.2 108 | torch==1.7.0 109 | torchvision==0.8.1 110 | tqdm==4.50.2 111 | transaction==3.0.1 112 | translationstring==1.4 113 | typing-extensions==3.7.4.3 114 | umap-learn==0.4.6 115 | urllib3==1.25.11 116 | velruse==1.1.1 117 | venusian==3.0.0 118 | WebOb==1.8.7 119 | Werkzeug==1.0.1 120 | wincertstore==0.2 121 | wrapt==1.12.1 122 | WTForms==2.3.3 123 | wtforms-recaptcha==0.3.2 124 | xlrd==2.0.1 125 | yarg==0.1.9 126 | zipp==3.3.1 127 | zope.deprecation==4.4.0 128 | zope.interface==5.4.0 129 | zope.sqlalchemy==1.6 130 | -------------------------------------------------------------------------------- /results/CDR/n15/usps_demo/config.yaml: -------------------------------------------------------------------------------- 1 | exp_params: 2 | dataset: "usps" 3 | input_dims: 256 # (28, 28, 1) 4 | LR: 0.001 5 | batch_size: 512 6 | n_neighbors: 15 7 | optimizer: "adam" # adam or sgd 8 | scheduler: "multi_step" # cosine or multi_step or on_plateau 9 | temperature: 0.15 10 | gradient_redefine: True 11 | split_upper: 0.1 12 | separation_begin_ratio: 0.25 13 | steady_begin_ratio: 0.875 14 | 15 | training_params: 16 | epoch_nums: 1000 17 | epoch_print_inter_ratio: 0.1 18 | val_inter_ratio: 0.5 19 | ckp_inter_ratio: 1 -------------------------------------------------------------------------------- /results/CDR/n15/usps_demo/embeddings_1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/results/CDR/n15/usps_demo/embeddings_1000.npy -------------------------------------------------------------------------------- /results/CDR/n15/usps_demo/loss_1000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/results/CDR/n15/usps_demo/loss_1000.jpg -------------------------------------------------------------------------------- /results/CDR/n15/usps_demo/usps_vis_1000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/results/CDR/n15/usps_demo/usps_vis_1000.jpg -------------------------------------------------------------------------------- /results/CDR/n15/usps_demo/usps_vis_500.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/results/CDR/n15/usps_demo/usps_vis_500.jpg -------------------------------------------------------------------------------- /static/css/container.css: -------------------------------------------------------------------------------- 1 | #container{ 2 | padding: 0; 3 | margin: 5px 10px 10px 10px; 4 | height: 1000px; 5 | } 6 | .container_up{ 7 | height: 830px; 8 | padding: 2px; 9 | margin: 0; 10 | border: 0; 11 | } 12 | .container_bottom{ 13 | height: 230px; 14 | padding: 0; 15 | } 16 | .row{ 17 | padding: 0; 18 | margin: 0; 19 | } 20 | 21 | .panel-title{ 22 | font-size: 8px; 23 | color: #297bcc; 24 | } 25 | .panel-heading{ 26 | padding: 5px 8px; 27 | } 28 | .sub_head { 29 | background-color: #e1f0fa;!important; 30 | } 31 | .panel-body{ 32 | padding: 3px; 33 | } 34 | .panel-info{ 35 | border: lightgray 5px; 36 | } 37 | .view{ 38 | width: 100%; 39 | height: 97%; 40 | } 41 | .title{ 42 | background-color: #b6d6eb; 43 | /*background-color: #cee8ef;*/ 44 | color: #1b243a; 45 | /*color: #81bedc;*/ 46 | margin: 0; 47 | padding-left: 10px; 48 | margin-bottom: 4px; 49 | font-size: 36px; 50 | height: 50px; 51 | line-height: 50px; 52 | /*color: #8bb7d0;*/ 53 | } 54 | -------------------------------------------------------------------------------- /static/css/control.css: -------------------------------------------------------------------------------- 1 | #control_div_content{ 2 | height: 95%; 3 | } 4 | .control_div{ 5 | margin: 0px; 6 | margin-bottom: 8px; 7 | } 8 | .ds_info{ 9 | padding: 5px; 10 | font-size: 22px; 11 | color: rgb(42, 42, 42); 12 | margin-bottom: 5px; 13 | } 14 | .para_item{ 15 | height: 55px; 16 | margin: 21px 0px 21px 0px; 17 | } 18 | .control_btn{ 19 | width: 100px; 20 | font-size: 16px; 21 | height: 30px; 22 | margin-left: 10px; 23 | line-height: 30px; 24 | padding:0; 25 | } 26 | .param_span, .param_input{ 27 | height: 43.5px; 28 | font-size: 10px; 29 | padding: 2px 4px; 30 | } 31 | .param_span{ 32 | width: 200px; 33 | font-size: 24px; 34 | background-color: #fdfdfd; 35 | color: #333; 36 | } 37 | .param_input{ 38 | width: 90px; 39 | display: inline; 40 | padding-left: 10px; 41 | } 42 | .param_lock{ 43 | background-color: rgba(239, 239, 239, 0.3); 44 | color:#bcbcc0 45 | } 46 | .dataset_ul{ 47 | width: 210px; 48 | } 49 | .panel_title_info{ 50 | font-size: 30px; 51 | } 52 | .panel_heading_info{ 53 | height: 40px; 54 | } 55 | .el-input__inner{ 56 | font-size: 20px;!important; 57 | height: 50px; 58 | } 59 | -------------------------------------------------------------------------------- /static/css/lasso.css: -------------------------------------------------------------------------------- 1 | circle { 2 | fill-opacity: 0.5; 3 | } 4 | 5 | .lasso path { 6 | stroke: rgb(80,80,80); 7 | stroke-width:2px; 8 | } 9 | 10 | .lasso .drawn { 11 | fill-opacity:.05 ; 12 | } 13 | 14 | .lasso .loop_close { 15 | fill:none; 16 | stroke-dasharray: 4,4; 17 | } 18 | 19 | .lasso .origin { 20 | fill:#3399FF; 21 | fill-opacity:.5; 22 | } 23 | 24 | .point_not_possible { 25 | fill: rgb(200,200,200); 26 | } 27 | 28 | .point_possible { 29 | fill: #FFBF00; 30 | } 31 | /* 和node_link中yellow保持一致 */ 32 | .point_selected { 33 | fill: #FFBF00; 34 | } 35 | 36 | .image_selected{ 37 | opacity: 1; 38 | } 39 | 40 | .image_possible{ 41 | opacity: 0.75; 42 | } 43 | .image_not_possible{ 44 | opacity: 0.5; 45 | } -------------------------------------------------------------------------------- /static/css/link_view.css: -------------------------------------------------------------------------------- 1 | .link_btn{ 2 | width: 50px; 3 | } 4 | 5 | .link_list{ 6 | height: 48.2%; 7 | /* overflow:auto; */ 8 | margin-bottom: 8px; 9 | } 10 | .inter_icon{ 11 | height: 20px; 12 | margin-right: 5px; 13 | } 14 | .link_div_column{ 15 | width: 90px; 16 | height: 100%; 17 | display:inline-block; 18 | font-size: 20px; 19 | /* border-right-width: 2px; 20 | border-right-style: solid; 21 | border-right-color:#e9e3e3 */ 22 | } 23 | .link_div_item{ 24 | width: 150px; 25 | height: 100%; 26 | display:inline-block; 27 | text-align:center; 28 | font-size: 20px; 29 | padding: 3px; 30 | border-right-width: 2px; 31 | border-right-style: solid; 32 | border-right-color:#e9e3e3 33 | } 34 | .link_div_small{ 35 | height: 44px; 36 | text-align:center; 37 | line-height:44px; 38 | padding-right: 5px; 39 | color: black;!important; 40 | /* color: rgb(144, 147, 153); */ 41 | } 42 | .link_div_large{ 43 | height: 130px; 44 | text-align:center; 45 | line-height:130px; 46 | padding-right: 5px; 47 | color: black;!important; 48 | /* color: rgb(144, 147, 153); */ 49 | 50 | } -------------------------------------------------------------------------------- /static/css/myCss.css: -------------------------------------------------------------------------------- 1 | .panel{ 2 | border: 0; 3 | } 4 | .panel_content{ 5 | background-color: #fdfdfd; 6 | } 7 | .panel-heading{ 8 | border-bottom: 0; 9 | padding-left: 10px; 10 | height: 45px; 11 | /* background-color: ; */ 12 | } 13 | 14 | .input-group-addon{ 15 | border: 0; 16 | } 17 | .el-switch__label *{ 18 | font-size: 16px; 19 | } 20 | #project_view { 21 | background-color: white; 22 | } 23 | 24 | -------------------------------------------------------------------------------- /static/css/project_view.css: -------------------------------------------------------------------------------- 1 | #project_view{ 2 | position: absolute; 3 | height: 91%; 4 | width: 100%; 5 | } 6 | 7 | .project_component_area{ 8 | height: 45px; 9 | padding: 7px 10px 7px 10px; 10 | background-color: #f6f6f6; 11 | } 12 | .project_button{ 13 | color: #409EFF; 14 | background: #ecf5ff; 15 | font-size:16px; 16 | height: 30px; 17 | font-weight:bold; 18 | border-radius:3px; 19 | padding: 0px 8px 0px 5px; 20 | display: inline-block; 21 | float: left; 22 | white-space: nowrap; 23 | cursor: pointer; 24 | border: 1px solid #b3d8ff; 25 | -webkit-appearance: none; 26 | text-align: center; 27 | box-sizing: border-box; 28 | outline: 0; 29 | margin-left: 5px; 30 | transition: .1s; 31 | font-weight: 500; 32 | } 33 | .button_selected{ 34 | background: #409EFF; 35 | border-color: #409EFF; 36 | color: #FFF; 37 | } 38 | .button_not_selected{ 39 | color: #409EFF; 40 | background: #ecf5ff; 41 | border-color: #b3d8ff; 42 | } 43 | .tooltip{ 44 | font-size:10px; 45 | height:auto; 46 | position:absolute; 47 | text-align:center; 48 | border-style:solid; 49 | border-width:1px; 50 | background-color:white; 51 | border-radius:5px; 52 | padding-left:2px; 53 | padding-right:2px; 54 | } 55 | 56 | .thick-green-border { 57 | width: 100px; 58 | margin-top: 40px; 59 | margin-left: 40px; 60 | margin-bottom: -10px; 61 | margin-right: 20px; 62 | padding-top: 20px; 63 | border-style: solid; 64 | border-width: 10px; 65 | border-color: green; 66 | } -------------------------------------------------------------------------------- /static/icon/cannotlink.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/static/icon/cannotlink.png -------------------------------------------------------------------------------- /static/icon/cannotlink2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/static/icon/cannotlink2.png -------------------------------------------------------------------------------- /static/icon/mustlink.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/static/icon/mustlink.png -------------------------------------------------------------------------------- /static/js/common/d3-lasso.min.js: -------------------------------------------------------------------------------- 1 | !function(t,n){"object"==typeof exports&&"undefined"!=typeof module?n(exports,require("d3-selection"),require("d3-drag")):"function"==typeof define&&define.amd?define(["exports","d3-selection","d3-drag"],n):n(t.d3=t.d3||{},t.d3,t.d3)}(this,function(t,n,r){"use strict";function e(t,n){return n={exports:{}},t(n,n.exports),n.exports}function o(){function t(t){function u(){p=[],h="",_.attr("d",null),m.attr("d",null),r.nodes().forEach(function(t){t.__lasso.possible=!1,t.__lasso.selected=!1,t.__lasso.hoverSelect=!1,t.__lasso.loopSelect=!1;var n=t.getBoundingClientRect();t.__lasso.lassoPoint=[Math.round(n.left+n.width/2),Math.round(n.top+n.height/2)]}),s&&r.on("mouseover.lasso",function(){this.__lasso.hoverSelect=!0}),i.start()}function l(){var t,n;"touchmove"===d3.event.sourceEvent.type?(t=d3.event.sourceEvent.touches[0].clientX,n=d3.event.sourceEvent.touches[0].clientY):(t=d3.event.sourceEvent.clientX,n=d3.event.sourceEvent.clientY);var s=d3.mouse(this)[0],u=d3.mouse(this)[1];""===h?(h=h+"M "+s+" "+u,v=[t,n],d=[s,u],b.attr("cx",s).attr("cy",u).attr("r",7).attr("display",null)):h=h+" L "+s+" "+u,p.push([t,n]);var l=Math.sqrt(Math.pow(t-v[0],2)+Math.pow(n-v[1],2)),f="M "+s+" "+u+" L "+d[0]+" "+d[1];_.attr("d",h),m.attr("d",f),a=l<=e,a&&o?m.attr("display",null):m.attr("display","none"),r.nodes().forEach(function(t){t.__lasso.loopSelect=!(!a||!o)&&c(p,t.__lasso.lassoPoint)<1,t.__lasso.possible=t.__lasso.hoverSelect||t.__lasso.loopSelect}),i.draw()}function f(){r.on("mouseover.lasso",null),r.nodes().forEach(function(t){t.__lasso.selected=t.__lasso.possible,t.__lasso.possible=!1}),_.attr("d",null),m.attr("d",null),b.attr("display","none"),i.end()}var h,v,d,p,g=t.append("g").attr("class","lasso"),_=g.append("path").attr("class","drawn"),m=g.append("path").attr("class","loop_close"),b=g.append("circle").attr("class","origin"),M=d3.drag().on("start",u).on("drag",l).on("end",f);n.call(M)}var n,r=[],e=75,o=!0,a=!1,s=!0,i={start:function(){},draw:function(){},end:function(){}};return t.items=function(n){if(!arguments.length)return r;r=n;var e=r.nodes();return e.forEach(function(t){t.__lasso={possible:!1,selected:!1}}),t},t.possibleItems=function(){return r.filter(function(){return this.__lasso.possible})},t.selectedItems=function(){return r.filter(function(){return this.__lasso.selected})},t.notPossibleItems=function(){return r.filter(function(){return!this.__lasso.possible})},t.notSelectedItems=function(){return r.filter(function(){return!this.__lasso.selected})},t.closePathDistance=function(n){return arguments.length?(e=n,t):e},t.closePathSelect=function(n){return arguments.length?(o=n===!0,t):o},t.isPathClosed=function(n){return arguments.length?(a=n===!0,t):a},t.hoverSelect=function(n){return arguments.length?(s=n===!0,t):s},t.on=function(n,r){if(!arguments.length)return i;if(1===arguments.length)return i[n];var e=["start","draw","end"];return e.indexOf(n)>-1&&(i[n]=r),t},t.targetArea=function(r){return arguments.length?(n=r,t):n},t}var a=e(function(t){function n(t,n,e){var o=t*n,a=r*t,s=a-t,i=a-s,u=t-i,l=r*n,f=l-n,c=l-f,h=n-c,v=o-i*c,d=v-u*c,p=d-i*h,g=u*h-p;return e?(e[0]=g,e[1]=o,e):[g,o]}t.exports=n;var r=+(Math.pow(2,27)+1)}),s=e(function(t){function n(t,n){var r=t+n,e=r-t,o=r-e,a=n-e,s=t-o,i=s+a;return i?[i,r]:[r]}function r(t,r){var e=0|t.length,o=0|r.length;if(1===e&&1===o)return n(t[0],r[0]);var a,s,i=e+o,u=new Array(i),l=0,f=0,c=0,h=Math.abs,v=t[f],d=h(v),p=r[c],g=h(p);d=o?(a=v,f+=1,f=o?(a=v,f+=1,f>1;return["sum(",o(t.slice(0,n)),",",o(t.slice(n)),")"].join("")}function i(t){if(2===t.length)return[["sum(prod(",t[0][0],",",t[1][1],"),prod(-",t[0][1],",",t[1][0],"))"].join("")];for(var r=[],a=0;a0){if(a<=0)return s;e=o+a}else{if(!(o<0))return s;if(a>=0)return s;e=-(o+a)}var i=b*e;return s>=i||s<=-i?s:y(t,n,r)},function(t,n,r,e){var o=t[0]-e[0],a=n[0]-e[0],s=r[0]-e[0],i=t[1]-e[1],u=n[1]-e[1],l=r[1]-e[1],f=t[2]-e[2],c=n[2]-e[2],h=r[2]-e[2],v=a*l,d=s*u,p=s*i,g=o*l,_=o*u,m=a*i,b=f*(v-d)+c*(p-g)+h*(_-m),y=(Math.abs(v)+Math.abs(d))*Math.abs(f)+(Math.abs(p)+Math.abs(g))*Math.abs(c)+(Math.abs(_)+Math.abs(m))*Math.abs(h),x=M*y;return b>x||-b>x?b:w(t,n,r,e)}];h()}),c=e(function(t){function n(t,n){for(var e=n[0],o=n[1],a=t.length,s=1,i=a,u=0,l=a-1;u0;){var b=(l+a-1)%a,M=t[b];if(M[1]!==o)break;var y=M[0];_=Math.min(_,y),m=Math.max(m,y),l=b}if(0===l)return _<=e&&e<=m?0:1;i=l+1}for(var w=t[(l+a-1)%a][1];u+1 d[0]); 20 | max.y = d3.max(embeddings, d => d[1]); 21 | min.x = d3.min(embeddings, d => d[0]); 22 | min.y = d3.min(embeddings, d => d[1]); 23 | xScale = d3 24 | .scaleLinear() 25 | .domain([min.x, max.x]) 26 | .range([margin.left, width - margin.right]); 27 | yScale = d3 28 | .scaleLinear() 29 | .domain([-max.y, -min.y]) 30 | .range([margin.top, height - margin.bottom]); 31 | // prepare data 32 | let contourFunc = d3 33 | .contourDensity() 34 | .x(function (d) { 35 | return xScale(d[0]); 36 | }) 37 | .y(function (d) { 38 | return yScale(-d[1]); 39 | }) 40 | .bandwidth(5) 41 | .thresholds(10) 42 | 43 | let contourMapData = contourFunc(embeddings); 44 | 45 | function ticks(start, end, count) { 46 | let result = [], 47 | increment = (end - start) / count; 48 | for (let i = 0; i <= count; i++) { 49 | result.push(start + i * increment); 50 | } 51 | return result; 52 | } 53 | 54 | let colorBlue = d3 55 | .scalePow() 56 | .domain(ticks(0, d3.max(contourMapData.map((d) => d.value)), 3)) 57 | .range(["#ffffff", "#b3cddd"]); 58 | 59 | var data = contourMapData.map((d) => { 60 | return { 61 | path: d3.geoPath()(d), 62 | color: colorBlue(d.value), 63 | }; 64 | }); 65 | //draw 66 | contour_canvas 67 | .selectAll("path") 68 | .data(data) 69 | .enter() 70 | .append("path") 71 | 72 | contour_canvas.exit().remove(); 73 | 74 | contour_canvas 75 | .selectAll("path") 76 | .attr("stroke", "steelblue") 77 | .attr("stroke-width",0.25) 78 | .attr("d", (d) => d.path) 79 | .attr("fill", (d) => d.color); 80 | 81 | let x1 = embeddings[select_node_list[0]][0]; 82 | let y1 = embeddings[select_node_list[0]][1]; 83 | let x2 = embeddings[select_node_list[1]][0]; 84 | let y2 = embeddings[select_node_list[1]][1]; 85 | 86 | contour_node_canvas.append("circle") 87 | .attr("cx", d => xScale(x1)) 88 | .attr("cy", d => yScale(-y1)) 89 | .attr("r", 2) 90 | .attr("fill", 'black') 91 | .attr("stroke-width", "1px"); 92 | 93 | contour_node_canvas.append("circle") 94 | .attr("cx", d => xScale(x2)) 95 | .attr("cy", d => yScale(-y2)) 96 | .attr("r", 2) 97 | .attr("fill", 'black') 98 | .attr("stroke-width", "1px"); 99 | 100 | contour_link_canvas.append("line") 101 | .attr("x1", xScale(x1)) 102 | .attr("y1", yScale(-y1)) 103 | .attr("x2", xScale(x2)) 104 | .attr("y2", yScale(-y2)) 105 | // .attr('link_type',link_type) 106 | .attr("stroke", function (d) { 107 | if (link_type == "M") return link_colors['must']; 108 | else if (link_type === 'C') return link_colors['cannot']; 109 | }) 110 | .attr("stroke-width", "1px"); 111 | } 112 | } -------------------------------------------------------------------------------- /static/js/models/LinkModel.js: -------------------------------------------------------------------------------- 1 | function LinkModel() { 2 | 3 | const svg = d3.select("#project_view_scatter"); 4 | this.listen_link = function(){ 5 | const svg = d3.select("#project_view_scatter"); 6 | let select_node_list = []; 7 | let model = this 8 | function add_link_data(item){ 9 | svg.selectAll(item) 10 | .on('click',function (d,i) { 11 | console.log(d) 12 | if(project_object.mustlink_state || project_object.cannotlink_state){ 13 | if (select_node_list.length<2) { 14 | if (select_node_list.length === 1 && select_node_list[0] === i) 15 | return; 16 | select_node_list.push(i); 17 | if (select_node_list.length === 2) { 18 | model.send_node_list(select_node_list); 19 | select_node_list = []; 20 | } 21 | } 22 | } 23 | }) 24 | } 25 | if(project_object.picture_state==0){ 26 | console.log("circle listen") 27 | add_link_data('circle') 28 | }else{ 29 | console.log("image listen") 30 | add_link_data('image') 31 | } 32 | } 33 | 34 | this.send_node_list = function(select_node_list){ 35 | if (project_object.mustlink_state) { 36 | link_object.addMustLink(select_node_list) 37 | } else if (project_object.cannotlink_state) { 38 | link_object.addCannotLink(select_node_list) 39 | } 40 | } 41 | 42 | this.draw_link_pipeline = function(select_node_list,link_type){ 43 | if(project_object.picture_state==0){ 44 | let arr = this.get_points_position(select_node_list,link_type) 45 | this.draw_link(arr,select_node_list,link_type) 46 | }else if(project_object.picture_state==1){ 47 | let arr = this.get_picture_position(select_node_list,link_type) 48 | this.draw_link(arr,select_node_list,link_type) 49 | } 50 | } 51 | this.get_points_position = function(select_node_list){ 52 | let nodes = svg.selectAll('circle'); 53 | let source = nodes._groups[0][select_node_list[0]]; 54 | let target = nodes._groups[0][select_node_list[1]]; 55 | let x1 = source.cx.animVal['value']; 56 | let y1 = source.cy.animVal['value']; 57 | let x2 = target.cx.animVal['value']; 58 | let y2 = target.cy.animVal['value']; 59 | return [x1,y1,x2,y2] 60 | } 61 | this.get_picture_position = function(select_node_list){ 62 | let nodes = svg.selectAll('image'); 63 | let source = nodes._groups[0][select_node_list[0]]; 64 | let target = nodes._groups[0][select_node_list[1]]; 65 | let x1 = source.x.animVal['value']+project_object.picture_width/2; 66 | let y1 = source.y.animVal['value']+project_object.picture_height/2; 67 | let x2 = target.x.animVal['value']+project_object.picture_width/2; 68 | let y2 = target.y.animVal['value']+project_object.picture_height/2;; 69 | return [x1,y1,x2,y2] 70 | } 71 | this.draw_node_border = function(){ 72 | let mustlink_id =link_object.get_mustlink_id() 73 | let cannotlink_id = link_object.get_cannotlink_id() 74 | d3.selectAll("circle").attr("stroke","") 75 | mustlink_id.forEach(id => { 76 | d3.select("#circle_"+id).attr("stroke",link_colors['must']) 77 | }); 78 | cannotlink_id.forEach(id => { 79 | d3.select("#circle_"+id).attr("stroke",link_colors['cannot']) 80 | }); 81 | } 82 | this.draw_link = function(arr,select_node_list,link_type) { 83 | x1 = arr[0],y1 = arr[1],x2 = arr[2],y2 = arr[3] 84 | let links_canvas = d3.select('#links_canvas'); 85 | this.draw_node_border() 86 | 87 | links_canvas.append("line") 88 | .attr("x1", x1) 89 | .attr("y1", y1) 90 | .attr("x2", x2) 91 | .attr("y2", y2) 92 | .attr("id","line_"+select_node_list[0]+"_"+select_node_list[1]) 93 | .attr('source_',select_node_list[0]) 94 | .attr('target_',select_node_list[1]) 95 | .attr('link_type',link_type) 96 | .attr("stroke", function (d) { 97 | if(link_type==='must')return link_colors['must']; 98 | else if(link_type==='cannot') return link_colors['cannot']; 99 | }) 100 | .attr("stroke-width", "2px") 101 | .attr("opacity",function(d,i){ 102 | if(select_node_list[2]==true)return 0.5; 103 | else if(select_node_list[2]==false)return 0; 104 | }); 105 | }; 106 | 107 | this.delete_link = function(link) { 108 | const svg = d3.select("#project_view_scatter"); 109 | this.draw_node_border() 110 | svg.select("#line_"+link.head+"_"+link.tail).remove(); 111 | } 112 | 113 | this.active_link = function(link){ 114 | const svg = d3.select("#project_view_scatter"); 115 | this.draw_node_border() 116 | console.log(link); 117 | let type = link.index[0] === "M"? "must":"cannot"; 118 | this.draw_link_pipeline([link.head, link.tail], type); 119 | svg.select("#line_"+link.head+"_"+link.tail).attr("opacity", 0.5); 120 | } 121 | 122 | this.not_active_link = function(link){ 123 | const svg = d3.select("#project_view_scatter"); 124 | this.draw_node_border() 125 | svg.select("#line_"+link.head+"_"+link.tail).attr("opacity", 0); 126 | } 127 | 128 | this.highlight_link = function(link){ 129 | const svg = d3.select("#project_view_scatter"); 130 | svg.select("#line_"+link.head+"_"+link.tail).attr("stroke-width", "3px"); 131 | } 132 | 133 | this.not_highlight_link = function(link){ 134 | const svg = d3.select("#project_view_scatter"); 135 | svg.select("#line_"+link.head+"_"+link.tail).attr("stroke-width", "2px"); 136 | } 137 | } -------------------------------------------------------------------------------- /static/js/models/NewScatterModel.js: -------------------------------------------------------------------------------- 1 | Array.prototype.unique = function (a) { 2 | return function () { 3 | return this.filter(a) 4 | } 5 | }(function (a, b, c) { 6 | return c.indexOf(a, b + 1) < 0 7 | }); 8 | 9 | function NewScatterModel(svgId) { 10 | this.svg = d3.select(svgId); 11 | this.width = this.svg.node().parentNode.clientWidth; 12 | this.height = this.svg.node().parentNode.clientHeight; 13 | this.svg.attr("width", this.width).attr("height", this.height); 14 | this.margin = {top: 20, right: 20, bottom: 20, left: 20}; 15 | 16 | this.canvas = this.svg.append('g'); 17 | this.links_canvas = this.canvas.append('g').attr("id", "links_canvas"); 18 | this.points_canvas = this.canvas.append('g'); 19 | this.pictures_canvas = this.canvas.append('g'); 20 | 21 | var tooltip = d3.select("body") 22 | .append("div") 23 | .attr("opacity", 0) 24 | .attr("class", "tooltip"); 25 | var text_tooltip = tooltip.append('div') 26 | var image_tooltip = tooltip.append('img') 27 | 28 | this.max_sample_num = 30 29 | this.sample_num = 10 30 | this.img_dir = "static/images/" 31 | this.lasso_ids = [] 32 | 33 | this.draw_pipeline = function (result) { 34 | this.result = result 35 | this.process(result) 36 | this.pictures_canvas.selectAll('image').remove() 37 | this.points_canvas.selectAll('circle').remove() 38 | this.links_canvas.selectAll("line").remove() 39 | if (project_object.picture_state == 0) { 40 | this.draw_points() 41 | } else if (project_object.picture_state == 1) { 42 | this.draw_pictures() 43 | } else { 44 | console.log('error') 45 | } 46 | linkModel.listen_link() 47 | this.call_mode() 48 | } 49 | 50 | this.process = function (result) { 51 | this.data = result.embeddings 52 | this.label = result.label 53 | this.label_map() 54 | const max = {}; 55 | const min = {}; 56 | max.x = d3.max(this.data, d => d[0]); 57 | max.y = d3.max(this.data, d => d[1]); 58 | min.x = d3.min(this.data, d => d[0]); 59 | min.y = d3.min(this.data, d => d[1]); 60 | this.xScale = d3 61 | .scaleLinear() 62 | .domain([min.x, max.x]) 63 | .range([this.margin.left, this.width - this.margin.right]); 64 | this.yScale = d3 65 | .scaleLinear() 66 | .domain([-max.y, -min.y]) 67 | .range([this.margin.top, this.height - this.margin.bottom]); 68 | } 69 | 70 | this.draw_points = function () { 71 | let lasso_ids = this.lasso_ids 72 | let self_obj = this; 73 | this.items = this.points_canvas.selectAll("circle") 74 | .data(this.data) 75 | .enter() 76 | .append("circle") 77 | .attr("cx", d => this.xScale(d[0])) 78 | .attr("cy", d => this.yScale(-d[1])) 79 | .attr("r", 4) 80 | .attr("id", (d, i) => "circle_" + i) 81 | .attr("cls", (d, i) => this.label[i]) 82 | .attr("fill", (d, i) => node_colors['normal']) 83 | .on('mouseover', function (d, i) { 84 | 85 | tooltip 86 | .style("display", "block") 87 | .style("left", (d3.event.pageX + 5) + "px") 88 | .style("top", (d3.event.pageY - 35) + "px") 89 | .style("opacity", 1); 90 | 91 | text_tooltip.html("ID:" + i + " ") 92 | 93 | if (para_object.selected_dataset_type == 'image') { 94 | image_tooltip 95 | .attr('src', self_obj.generate_path(i)) 96 | .attr("width", "100") 97 | .attr("height", "100") 98 | .style("display", "block") 99 | } else { 100 | image_tooltip.style("display", "none") 101 | } 102 | 103 | if (lasso_ids.indexOf(i) > -1) return true; 104 | d3.select(this) 105 | .attr("fill", node_colors['hover']) 106 | .attr("r", 7) 107 | paraModel.highlightSingleLine(i); 108 | }) 109 | .on('mouseout', function (d, i) { 110 | 111 | tooltip.style("display", "none") 112 | 113 | if (lasso_ids.indexOf(i) > -1) return true; 114 | d3.select(this) 115 | .attr("fill", node_colors['normal']) 116 | .attr("r", 4) 117 | 118 | paraModel.notHighlightSingleLine(i); 119 | 120 | }) 121 | 122 | if (lasso_ids.length > 0) { 123 | this.points_canvas.selectAll("circle") 124 | .filter(function (d, i) { 125 | if (lasso_ids.indexOf(i) > -1) return true; 126 | else return false 127 | }) 128 | .attr("fill", node_colors['lasso']) 129 | paraModel.highlightLines(lasso_ids) 130 | } 131 | } 132 | 133 | this.draw_pictures = function () { 134 | let xScale = this.xScale 135 | let yScale = this.yScale 136 | let lasso_ids = this.lasso_ids 137 | link_object.getLinkID() 138 | 139 | let picture_to_draw_ids = sample_images(this.label, this.sample_num).concat(lasso_ids).concat(link_object.link_ids) 140 | let self_obj = this; 141 | this.items = this.pictures_canvas.selectAll("image") 142 | .data(this.data) 143 | .enter() 144 | .append("image") 145 | .attr("x", d => this.xScale(d[0]) - project_object.picture_width / 2) 146 | .attr("y", d => this.yScale(-d[1]) - project_object.picture_height / 2) 147 | .attr("id", (d, i) => "images_" + i) 148 | .attr('opacity', 0.8) 149 | .style('display', function (d, i) { 150 | 151 | if (picture_to_draw_ids.indexOf(i) > -1) return "block"; 152 | else return "none"; 153 | }) 154 | .attr('xlink:href', function (d, i) { 155 | if (picture_to_draw_ids.indexOf(i) > -1) return self_obj.generate_path(i); 156 | }) 157 | .attr("width", project_object.picture_width) 158 | .attr("height", project_object.picture_height) 159 | .on('mouseover', function (d, i) { 160 | tooltip 161 | .style("display", "block") 162 | .style("left", (d3.event.pageX + 5) + "px") 163 | .style("top", (d3.event.pageY - 35) + "px") 164 | .style("opacity", 1.0); 165 | text_tooltip.html("ID:" + i + " ") 166 | image_tooltip 167 | .attr('src', self_obj.generate_path(i)) 168 | .attr("width", "120") 169 | .attr("height", "120") 170 | .style("display", "block") 171 | if (lasso_ids.indexOf(i) > -1) return true; 172 | let x = xScale(d[0]) - project_object.picture_width / 2 - 1 173 | let y = yScale(-d[1]) - project_object.picture_height / 2 - 1 174 | d3.select(this) 175 | .attr("x", x) 176 | .attr("y", y) 177 | .attr("width", project_object.picture_width + 2) 178 | .attr("height", project_object.picture_height + 2) 179 | 180 | paraModel.highlightSingleLine(i); 181 | }) 182 | .on('mouseout', function (d, i) { 183 | tooltip.style("display", "none") 184 | if (lasso_ids.indexOf(i) > -1) return true; 185 | let x = xScale(d[0]) - project_object.picture_width / 2 186 | let y = yScale(-d[1]) - project_object.picture_height / 2 187 | d3.select(this) 188 | .attr("x", x) 189 | .attr("y", y) 190 | .attr("width", project_object.picture_width) 191 | .attr("height", project_object.picture_height) 192 | paraModel.notHighlightSingleLine(i); 193 | 194 | }) 195 | 196 | if (lasso_ids.length > 0) { 197 | this.pictures_canvas.selectAll("image") 198 | .filter(function (d, i) { 199 | if (lasso_ids.indexOf(i) > -1) return true; 200 | else return false 201 | }) 202 | .attr("opacity", 1) 203 | paraModel.highlightLines(lasso_ids) 204 | } 205 | } 206 | 207 | this.call_mode = function () { 208 | if (project_object.mustlink_state || project_object.cannotlink_state) { 209 | console.log("link") 210 | this.call_link_mode() 211 | } else { 212 | console.log("lasso") 213 | this.call_lasso_mode() 214 | } 215 | } 216 | 217 | this.call_link_mode = function () { 218 | 219 | this.remove_lasso() 220 | this.svg.on(".drag", null); 221 | } 222 | 223 | this.call_lasso_mode = function () { 224 | let svg = this.svg 225 | svg.on(".zoom", null) 226 | 227 | var lasso = d3.lasso() 228 | .closePathSelect(true) 229 | .closePathDistance(100) 230 | .targetArea(svg) 231 | 232 | if (project_object.picture_state == 0) { 233 | lasso = this.call_points_lasso_mode(svg, lasso) 234 | } else { 235 | lasso = this.call_pictures_lasso_mode(svg, lasso) 236 | } 237 | svg.call(lasso) 238 | } 239 | 240 | this.call_points_lasso_mode = function (svg, lasso) { 241 | let lasso_ids = this.lasso_ids 242 | var lasso_start = function () { 243 | lasso.items() 244 | .attr("r", 4) // reset size 245 | .attr("fill", node_colors['not_lasso']) 246 | 247 | lasso_ids.splice(0, lasso_ids.length) 248 | paraModel.notHighlightLines(); 249 | }; 250 | 251 | var lasso_draw = function () { 252 | // Style the possible dots 253 | lasso.possibleItems() 254 | .attr("fill", node_colors['lasso']) 255 | 256 | // Style the not possible dot 257 | lasso.notPossibleItems() 258 | .attr("fill", node_colors['not_lasso']) 259 | }; 260 | 261 | var lasso_end = function () { 262 | // Reset the color of all dots 263 | lasso.items() 264 | .attr("fill", node_colors['normal']) 265 | 266 | // Style the selected dots 267 | lasso.selectedItems() 268 | .attr("fill", node_colors['lasso']) 269 | 270 | let selected = lasso.selectedItems()._groups[0] 271 | for (let i = 0; i < selected.length; i++) 272 | lasso_ids.push(parseInt(selected[i].id.slice(7))) 273 | console.log(lasso_ids); 274 | // Reset the style of the not selected dots 275 | paraModel.highlightLines(lasso_ids) 276 | }; 277 | 278 | 279 | lasso.items(svg.selectAll("circle")) 280 | .on("start", lasso_start) 281 | .on("draw", lasso_draw) 282 | .on("end", lasso_end); 283 | 284 | return lasso 285 | } 286 | 287 | this.call_pictures_lasso_mode = function (svg, lasso) { 288 | let lasso_ids = this.lasso_ids 289 | var lasso_start = function () { 290 | lasso.items() 291 | .classed("image_not_possible", true) 292 | .classed("image_possible", false) 293 | console.log(lasso_ids) 294 | 295 | lasso_ids.splice(0, lasso_ids.length) 296 | paraModel.notHighlightLines(); 297 | }; 298 | 299 | var lasso_draw = function () { 300 | // Style the possible dots 301 | lasso.possibleItems() 302 | .classed("image_not_possible", false) 303 | .classed("image_possible", true) 304 | 305 | // Style the not possible dot 306 | lasso.notPossibleItems() 307 | .classed("image_not_possible", true) 308 | .classed("image_possible", false) 309 | }; 310 | 311 | var lasso_end = function () { 312 | // Style the selected dots 313 | lasso.selectedItems() 314 | .classed("image_selected", true) 315 | .classed("image_possible", false) 316 | 317 | let selected = lasso.selectedItems()._groups[0] 318 | for (let i = 0; i < selected.length; i++) 319 | lasso_ids.push(parseInt(selected[i].id.slice(7))) 320 | // Reset the style of the not selected dots 321 | lasso.notSelectedItems() 322 | .classed("image_selected", false) 323 | .classed("image_possible", false) 324 | console.log(lasso_ids) 325 | paraModel.highlightLines(lasso_ids) 326 | }; 327 | 328 | lasso.items(svg.selectAll("image")) 329 | .on("start", lasso_start) 330 | .on("draw", lasso_draw) 331 | .on("end", lasso_end); 332 | 333 | return lasso 334 | } 335 | 336 | this.remove_lasso = function () { 337 | let svg = this.svg 338 | if (project_object.picture_state == 0) { 339 | svg.selectAll('circle') 340 | .attr("r", 4) 341 | .classed("point_not_possible", false) 342 | .classed("point_possible", false) 343 | .classed("point_selected", false) 344 | } else { 345 | svg.selectAll('image') 346 | .classed("image_selected", false) 347 | .classed("image_possible", false) 348 | .classed("image_not_possible", true) 349 | } 350 | } 351 | 352 | this.generate_path = function (idx) { 353 | let src_dir = this.img_dir + para_object.selected_dataset_name; 354 | return src_dir + "/" + idx + ".jpg" 355 | 356 | } 357 | 358 | this.label_map = function () { 359 | let unique_label = this.label.unique(); 360 | this.label_dict = {}; 361 | for (let i = 0; i < unique_label.length; i++) { 362 | this.label_dict[unique_label[i]] = i; 363 | } 364 | } 365 | } 366 | 367 | function sample_images(labels, sample_num) { 368 | let n_samples = labels.length; 369 | let sampled_indices = []; 370 | let label_count = {}; 371 | let unique_label = labels.unique(); 372 | for (let i = 0; i < unique_label.length; i++) 373 | label_count[unique_label[i]] = 0 374 | 375 | let indices = generateArray(0, n_samples-1); 376 | indices = shuffleSelf(indices, n_samples); 377 | // console.log(indices) 378 | 379 | for (let i = 0; i < n_samples; i++) { 380 | let idx = indices[i] 381 | let cur_label = labels[idx]; 382 | if (label_count[cur_label] <= sample_num) { 383 | label_count[cur_label] += 1 384 | sampled_indices.push(idx) 385 | } 386 | } 387 | return sampled_indices 388 | } 389 | 390 | function generateArray(start, end) { 391 | return Array.from(new Array(end + 1).keys()).slice(start) 392 | } 393 | 394 | function shuffleSelf(array, size) { 395 | var index = -1, 396 | length = array.length, 397 | lastIndex = length - 1; 398 | 399 | size = size === undefined ? length : size; 400 | while (++index < size) { 401 | var rand = index + Math.floor(Math.random() * (lastIndex - index + 1)) 402 | value = array[rand]; 403 | 404 | array[rand] = array[index]; 405 | 406 | array[index] = value; 407 | } 408 | array.length = size; 409 | return array; 410 | } -------------------------------------------------------------------------------- /static/js/models/ParallelModel.js: -------------------------------------------------------------------------------- 1 | function ParallelModel(svgId) { 2 | const parallelSVG = d3.select(svgId); 3 | const width = parallelSVG.node().parentNode.clientWidth; 4 | const height = parallelSVG.node().parentNode.clientHeight; 5 | parallelSVG.attr("width", width).attr("height", height); 6 | let margin = {top: 30, right: 40, bottom: 15, left: 40}; 7 | 8 | this.drawParallelPlot = function(result){ 9 | parallelSVG.selectAll('g').remove(); 10 | this.line_g1 = parallelSVG.append('g'); 11 | this.line_g2 = parallelSVG.append('g'); 12 | this.line_g3 = parallelSVG.append('g'); 13 | this.axis_g = parallelSVG.append('g'); 14 | this.text_g = parallelSVG.append('g'); 15 | let attr = result.attr 16 | let data = result.data 17 | this.xScale=d3.scaleLinear() 18 | .domain([0, attr.length-1]) 19 | .range([margin.left, width - margin.right]); 20 | 21 | if (para_object.selected_dataset_name === "Wifi") { 22 | this.yScales = d3.zip(...(data.map((item) => d3.permute(item, attr)))).map((subject) => { 23 | return d3.scaleLinear() 24 | .domain([-100, -10]) 25 | .range([margin.top, height - margin.bottom]); 26 | }); 27 | } else { 28 | this.yScales = d3.zip(...(data.map((item) => d3.permute(item, attr)))).map((subject) => { 29 | return d3.scaleLinear() 30 | .domain([d3.min(subject), d3.max(subject)]) 31 | .range([margin.top, height - margin.bottom]); 32 | }); 33 | } 34 | 35 | // this.yScales = d3.zip(...(data.map((item) => d3.permute(item, attr)))).map((subject) => { 36 | // return d3.scaleLinear() 37 | // .domain([d3.min(subject), d3.max(subject)]) 38 | // .range([margin.top, height - margin.bottom]); 39 | // }); 40 | 41 | this.renderLines(data,attr); 42 | this.renderAxis(); 43 | this.renderText(attr) 44 | } 45 | 46 | this.renderAxis = function(){ 47 | let axis_g = this.axis_g; 48 | xScale = this.xScale 49 | yScales = this.yScales 50 | yScales.forEach((scale, index) => { 51 | axis_g 52 | .append('g') 53 | .attr('transform', 'translate(' + xScale(index) + ',0)' ) 54 | .call(d3.axisLeft(scale).ticks(5)); 55 | }); 56 | axis_g.selectAll("text") 57 | .attr("font-size",12) 58 | 59 | axis_g.selectAll("text") 60 | .clone(true) 61 | .lower() 62 | .attr('fill', 'none') 63 | .attr('stroke-width', 5) 64 | .attr('stroke-linejoin', 'round') 65 | .attr('stroke', 'white') 66 | } 67 | 68 | this.renderLines = function(data,attr){ 69 | let xScale = this.xScale 70 | let yScales = this.yScales 71 | function generatePoints(d) { 72 | return d3.permute(d, attr).map((item, index) => { 73 | return [ 74 | xScale(index), 75 | yScales[index](item) 76 | ]; 77 | }); 78 | } 79 | 80 | let line_g1 = this.line_g1 81 | .selectAll('.line') 82 | .data(data); 83 | 84 | const linesEnter1 = line_g1.enter() 85 | .append('g') 86 | 87 | linesEnter1.append('path') 88 | .attr('stroke', node_colors['lasso_line']) 89 | .attr('opacity',0.1) 90 | .attr('stroke-width', 2) 91 | .attr('fill', 'none') 92 | .attr('d', (d) => d3.line()(generatePoints(d))) 93 | 94 | let line_g2 = this.line_g2 95 | .selectAll('.line') 96 | .data(data); 97 | 98 | const linesEnter2 = line_g2.enter() 99 | .append('g') 100 | 101 | linesEnter2.append('path') 102 | .attr('stroke', node_colors['lasso_line']) 103 | .attr('opacity',0) 104 | .attr('stroke-width', 2) 105 | .attr('fill', 'none') 106 | .attr('d', (d) => d3.line()(generatePoints(d))) 107 | 108 | let line_g3 = this.line_g3 109 | .selectAll('.line') 110 | .data(data); 111 | 112 | const linesEnter3 = line_g3.enter() 113 | .append('g') 114 | 115 | linesEnter3.append('path') 116 | .attr('stroke', node_colors['lasso_line']) 117 | .attr('opacity',0) 118 | .attr('stroke-width', 2) 119 | .attr('fill', 'none') 120 | .attr('d', (d) => d3.line()(generatePoints(d))) 121 | } 122 | 123 | this.renderText = function(attr){ 124 | let text_g = this.text_g 125 | .selectAll('text') 126 | .data(attr) 127 | .enter() 128 | .append('g') 129 | .append('text') 130 | .attr('dx','1em') 131 | .attr('transform', (d,i)=>'translate(' + (xScale(i)-20) + ',10)' ) 132 | .style("text-anchor", "middle") 133 | .text((d,i)=>d) 134 | } 135 | 136 | this.highlightSingleLine = function(id){ 137 | this.line_g3.selectAll('path') 138 | .filter(function(d,i){ 139 | if(id==i)return true; 140 | else return false; 141 | }) 142 | .attr('stroke', node_colors['hover']) 143 | .attr('stroke-width', 3) 144 | .attr('opacity',0.5) 145 | } 146 | 147 | this.notHighlightSingleLine = function(id){ 148 | this.line_g3.selectAll('path') 149 | .filter(function(d,i){ 150 | if(id==i)return true; 151 | else return false; 152 | }) 153 | .attr('stroke', node_colors['lasso_line']) 154 | .attr('stroke-width', 2) 155 | .attr('opacity', 0) 156 | } 157 | 158 | this.highlightLines = function(ids){ 159 | 160 | if(ids.length==0) return 161 | this.line_g2.selectAll('path').attr('opacity',0) 162 | this.line_g2.selectAll('path') 163 | .filter(function(d,i){ 164 | if(ids.indexOf(i)>-1)return true; 165 | else return false; 166 | }) 167 | .attr('stroke', node_colors['lasso']) 168 | .attr('stroke-width', 2) 169 | .attr('opacity',0.5) 170 | } 171 | 172 | this.notHighlightLines = function(){ 173 | this.line_g2.selectAll('path').attr('opacity',0) 174 | this.line_g2.selectAll('path') 175 | .attr('stroke', node_colors['lasso_line']) 176 | .attr('stroke-width', 2) 177 | .attr('opacity', 0) 178 | } 179 | 180 | } -------------------------------------------------------------------------------- /static/js/models/StateMachine.js: -------------------------------------------------------------------------------- 1 | MUST_LINK = 1 2 | CANNOT_LINK = 0 3 | SPREAD = 1 4 | UN_SPREAD = 0 5 | ACTIVE = 1 6 | INACTIVE = 0 7 | 8 | function StateMachine() { 9 | this.state='empty'; 10 | let css_state_machine = new CssStateMachine(); 11 | let interact_state_machine = new InteractStateMachine(); 12 | this.change_state = function (new_state) { 13 | if(new_state === this.state){ 14 | new_state = 'empty'; 15 | } 16 | this.state=new_state; 17 | css_state_machine.change_state(new_state); 18 | interact_state_machine.change_state(new_state); 19 | }; 20 | 21 | this.init_links = function (){ 22 | interact_state_machine.init_links(); 23 | } 24 | 25 | this.all_state_init = function () { 26 | this.state='empty'; 27 | css_state_machine.all_state_init(); 28 | interact_state_machine.all_state_init(); 29 | }; 30 | 31 | this.get_all_link = function () { 32 | return interact_state_machine.get_all_link(); 33 | }; 34 | 35 | this.get_link_spreads = function () { 36 | return interact_state_machine.get_link_spreads(); 37 | }; 38 | 39 | this.get_must_link = function () { 40 | return interact_state_machine.get_must_link(); 41 | }; 42 | 43 | this.get_cannot_link = function () { 44 | return interact_state_machine.get_cannot_link(); 45 | }; 46 | 47 | } 48 | 49 | 50 | function CssStateMachine() { 51 | let past_state='empty'; 52 | let state2button={'add_must':'#must_button','add_cannot':'#cannot_button','delete_link':'#delete_button'}; 53 | this.change_state = function (new_state) { 54 | if (new_state==='empty'){ 55 | d3.select(".inter_button_selected").attr('class','inter_button'); 56 | past_state = new_state; 57 | } 58 | else{ 59 | past_state = new_state; 60 | d3.select(".inter_button_selected").attr('class','inter_button'); 61 | d3.select(state2button[new_state]).attr('class','inter_button_selected'); 62 | } 63 | console.log('css state',past_state) 64 | }; 65 | this.all_state_init = function () { 66 | 67 | past_state='empty'; 68 | d3.select(".inter_button_selected").attr('class','inter_button'); 69 | } 70 | } 71 | 72 | 73 | function InteractStateMachine() { 74 | let past_state='empty'; 75 | let select_node_list = []; 76 | let link_spread = []; 77 | let link_list = []; 78 | let must_link_list = []; 79 | let cannot_link_list = []; 80 | 81 | this.change_state = function (new_state) { 82 | let snapshot2DSVG = d3.select('#origin_scatter'); 83 | // snapshot2DSVG.selectAll('circle').style("stroke",'white'); 84 | // snapshot2DSVG = d3.select('#constrained_scatter'); 85 | // snapshot2DSVG.selectAll('circle').style("stroke",'white'); 86 | select_node_list = []; 87 | if(new_state === 'empty'){ 88 | past_state = 'empty'; 89 | action_query_data('#origin_scatter'); 90 | action_query_data('#constrained_scatter'); 91 | } 92 | else if(new_state==='add_must' || new_state==='add_cannot'){ 93 | past_state = new_state; 94 | select_node('#origin_scatter'); 95 | select_node('#constrained_scatter'); 96 | } 97 | else if('delete_link' === new_state){ 98 | past_state=new_state; 99 | delete_link('#origin_scatter'); 100 | delete_link('#constrained_scatter'); 101 | } 102 | console.log('interact state',past_state) 103 | }; 104 | 105 | 106 | function select_node(svgId) { 107 | const snapshot2DSVG = d3.select(svgId); 108 | 109 | snapshot2DSVG 110 | .selectAll('circle') 111 | .on('click',function (d,i) { 112 | if(past_state==='add_must'||past_state==='add_cannot'){ 113 | if (select_node_list.length<2) { 114 | if (select_node_list.length === 1 && select_node_list[0] === i) 115 | return; 116 | select_node_list.push(i); 117 | if (select_node_list.length === 1){ 118 | } 119 | else if (select_node_list.length === 2) { 120 | send_node_list(svgId); 121 | select_node_list = []; 122 | } 123 | } 124 | } 125 | }) 126 | } 127 | 128 | function action_query_data(svgId) { 129 | 130 | const snapshot2DSVG = d3.select(svgId); 131 | snapshot2DSVG.selectAll('circle') 132 | .on('click',function (d, i) { 133 | if(past_state === 'empty') 134 | query_data(i); 135 | }) 136 | } 137 | 138 | function send_node_list(svgId) { 139 | if (past_state === 'add_must') { 140 | linkModel.drawLink(svgId, select_node_list, 'must'); 141 | select_node_list.push(MUST_LINK); 142 | must_link_list.push(select_node_list); 143 | } else if (past_state === 'add_cannot') { 144 | linkModel.drawLink(svgId,select_node_list, 'cannot'); 145 | select_node_list.push(CANNOT_LINK); 146 | cannot_link_list.push(select_node_list); 147 | } 148 | link_list.push(select_node_list); 149 | // link_spread.push(SPREAD); 150 | link_spread.push(UN_SPREAD); 151 | } 152 | 153 | function delete_link(svgId) { 154 | 155 | let state = past_state; 156 | function search(lists, list) { 157 | for (let i = 0; i < lists.length; i++) { 158 | if (lists[i] === list[0]) { 159 | if (lists[i][1] === list[1]) { 160 | return i; 161 | } 162 | } 163 | } 164 | } 165 | 166 | const snapshot2DSVG = d3.select(svgId); 167 | snapshot2DSVG 168 | .selectAll('line') 169 | .on('click', function (d, i) { 170 | if (state === 'delete_link') { 171 | let line = d3.select(this); 172 | let source = line.attr('source_'); 173 | let target = line.attr('target_'); 174 | let link_type = line.attr('link_type'); 175 | let index = -1; 176 | if (link_type === 'must') { 177 | index = search(must_link_list, [source, target]); 178 | must_link_list.splice(index, 1) 179 | } else { 180 | index = search(cannot_link_list, [source, target]); 181 | cannot_link_list.splice(index, 1) 182 | } 183 | index = search(link_list, [source, target]); 184 | link_list.splice(index, 1); 185 | d3.select(this).remove(); 186 | } 187 | }) 188 | } 189 | this.all_state_init = function () { 190 | 191 | console.log("init state!"); 192 | past_state='empty'; 193 | select_node_list = []; 194 | link_list = []; 195 | link_spread = []; 196 | must_link_list = []; 197 | cannot_link_list = []; 198 | }; 199 | 200 | this.init_links = function () { 201 | link_list = []; 202 | link_spread = []; 203 | must_link_list = []; 204 | cannot_link_list = []; 205 | select_node_list = []; 206 | } 207 | 208 | this.get_link_spreads = function () { 209 | return link_spread; 210 | }; 211 | 212 | this.get_all_link = function () { 213 | return link_list; 214 | }; 215 | 216 | this.get_must_link = function () { 217 | return must_link_list; 218 | }; 219 | 220 | this.get_cannot_link = function () { 221 | return cannot_link_list; 222 | } 223 | } 224 | 225 | -------------------------------------------------------------------------------- /static/js/utils.js: -------------------------------------------------------------------------------- 1 | 2 | function ajax_for_data(url, para_list, type) { 3 | let data = -1; 4 | $.ajaxSettings.async = false; 5 | $.ajax({ 6 | url: url, 7 | data: para_list, 8 | type: type, 9 | success: function (result) { 10 | data_object.data = result.data 11 | data_object.label = result.label 12 | data_object.low_data = result.low_data 13 | data_object.attrs = result.attrs 14 | } 15 | }); 16 | return data; 17 | } 18 | 19 | 20 | function ajax_for_get_projection(url, para_list) { 21 | $.ajax({ 22 | url: url, 23 | data: para_list, 24 | type: 'POST', 25 | dataType: 'json', 26 | success: function (result) { 27 | 28 | let scatter_result = {embeddings: result.embeddings, label: result.label}; 29 | 30 | let parr_data = Array(); 31 | let attr_names = result.attrs; 32 | for (let i = 0; i < result.low_data.length; i++) { 33 | let single_obj = {} 34 | for (let j = 0; j < attr_names.length; j++) 35 | single_obj[attr_names[j]] = result.low_data[i][j]; 36 | parr_data.push(single_obj); 37 | } 38 | 39 | let parr_result = {'data': parr_data, 'attr': attr_names}; 40 | 41 | scatterModel1.draw_pipeline(scatter_result); 42 | paraModel.drawParallelPlot(parr_result); 43 | 44 | 45 | } 46 | }); 47 | 48 | } 49 | 50 | function ajax_for_get_projection2(url, para_list) { 51 | 52 | $.ajax({ 53 | url: url, 54 | data: para_list, 55 | type: 'POST', 56 | dataType: 'json', 57 | success: function (result) { 58 | 59 | let scatter_result = {embeddings: result.embeddings, label: result.label}; 60 | scatterModel1.draw_pipeline(scatter_result); 61 | 62 | d3.select("#project_view_scatter").selectAll('line').remove(); 63 | // d3.select("#project_view_scatter").selectAll('line').attr("opacity", 0); 64 | 65 | link_object.restoreMustLink() 66 | link_object.restoreCannotLink() 67 | } 68 | }); 69 | } -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/teaser.png -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Interactive CDR 6 | 7 | 8 | 9 | 11 | 12 | 13 | {# #} 14 | {# #} 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 |
31 |
32 |
33 |
CONTROL BOARD
34 |
35 |
36 |
37 |

Dataset Selection

38 |
39 | 48 |
num: 49 | 50 | {{ selected_dataset_num }} 51 | 52 |
53 |
dim: 54 | 55 | {{ selected_dataset_dim }} 56 | 57 |
58 |
type: 59 | 60 | {{ selected_dataset_type }} 61 | 62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |

Training Parameters

70 |
71 |
77 |
78 |
79 |
80 |
81 |
82 |

Projecting Parameters

83 |
84 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
PROJECTION VIEW
95 |
96 |
97 | 107 | 108 | Projection 110 | 111 |
112 |
113 | 114 | 121 | 128 |
129 |
130 |
131 |
132 | 133 |
134 |
135 |
136 | 184 |
185 |
186 |
187 |
PARALLEL COORDINATES
188 |
189 | 190 |
191 |
192 |
193 |
194 | 195 | 196 | 590 | 591 | 592 | 593 | 594 | 595 | 596 | 597 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | from model.cdr import CDRModel 4 | import torch 5 | 6 | from utils.common_utils import get_config 7 | from utils.constant_pool import * 8 | import argparse 9 | from experiments.trainer import CDRTrainer 10 | import os 11 | 12 | log_path = "log.txt" 13 | 14 | 15 | def cdr_pipeline(config_path): 16 | 17 | cfg.merge_from_file(config_path) 18 | method_name = CDR_METHOD if cfg.exp_params.gradient_redefine else NX_CDR_METHOD 19 | result_save_dir = ConfigInfo.RESULT_SAVE_DIR.format(method_name, cfg.exp_params.n_neighbors) 20 | if not os.path.exists(result_save_dir): 21 | os.makedirs(result_save_dir) 22 | 23 | clr_model = CDRModel(cfg, device=device) 24 | trainer = CDRTrainer(clr_model, cfg.exp_params.dataset, cfg, result_save_dir, config_path, 25 | device=device, log_path=log_path) 26 | trainer.train_for_visualize() 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--configs", type=str, default="configs/CDR.yaml", help="configuration file path") 32 | parser.add_argument("--device", type=str, default="cpu") 33 | return parser.parse_args() 34 | 35 | 36 | if __name__ == '__main__': 37 | args = parse_args() 38 | cfg = get_config() 39 | device = args.device 40 | cdr_pipeline(args.configs) 41 | -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import os 4 | import os.path 5 | import time 6 | from multiprocessing import Queue 7 | 8 | import numpy as np 9 | import yaml 10 | from easydict import EasyDict as edict 11 | from sklearn.decomposition import PCA 12 | from yaml import FullLoader 13 | 14 | from dataset.datasets import MyImageDataset, MyTextDataset 15 | 16 | DATE_TIME_ADJOIN_FORMAT = "%Y%m%d_%Hh%Mm%Ss" 17 | 18 | 19 | def check_path_exists(path): 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | 23 | 24 | def get_dataset(dataset_name, root_dir, train, transform=None, is_image=True): 25 | if is_image: 26 | dataset = MyImageDataset(dataset_name, root_dir, train, transform) 27 | else: 28 | dataset = MyTextDataset(dataset_name, root_dir, train) 29 | return dataset 30 | 31 | 32 | def get_principle_components(data, attr_names=None, target_components=16): 33 | n_samples = data.shape[0] 34 | flattened_data = np.reshape(data, (n_samples, -1)) 35 | 36 | if flattened_data.shape[1] <= target_components: 37 | low_data = flattened_data 38 | else: 39 | pca = PCA() 40 | z = pca.fit_transform(flattened_data) 41 | low_data = z[:, :target_components] 42 | 43 | if attr_names is None: 44 | attr_names = ["A{}".format(i) for i in range(low_data.shape[1])] 45 | return low_data, attr_names 46 | 47 | 48 | def time_stamp_to_date_time_adjoin(time_stamp): 49 | time_array = time.localtime(time_stamp) 50 | return time.strftime(DATE_TIME_ADJOIN_FORMAT, time_array) 51 | 52 | 53 | class QueueSet: 54 | def __init__(self): 55 | self.eval_data_queue = Queue() 56 | self.eval_result_queue = Queue() 57 | 58 | self.test_eval_data_queue = Queue() 59 | self.test_eval_result_queue = Queue() 60 | 61 | 62 | class YamlParser(edict): 63 | """ 64 | This is yaml parser based on EasyDict. 65 | """ 66 | 67 | def __init__(self, cfg_dict=None, config_file=None): 68 | if cfg_dict is None: 69 | cfg_dict = {} 70 | 71 | if config_file is not None: 72 | assert (os.path.isfile(config_file)) 73 | with open(config_file, 'r') as fo: 74 | cfg_dict.update(yaml.load(fo.read(), Loader=FullLoader)) 75 | 76 | super(YamlParser, self).__init__(cfg_dict) 77 | 78 | def merge_from_file(self, config_file): 79 | with open(config_file, 'r') as fo: 80 | self.update(yaml.load(fo.read(), Loader=FullLoader)) 81 | 82 | def merge_from_dict(self, config_dict): 83 | self.update(config_dict) 84 | 85 | 86 | def get_config(config_file=None): 87 | return YamlParser(config_file=config_file) -------------------------------------------------------------------------------- /utils/constant_pool.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import numpy as np 4 | from multiprocessing import Queue 5 | 6 | CDR_METHOD = "CDR" 7 | NX_CDR_METHOD = "NX_CDR" 8 | 9 | 10 | class ProjectSettings: 11 | LABEL_COLORS = {0: 'blue', 1: 'orange', 2: 'green', 3: 'red', 4: 'blueviolet', 5: 'maroon', 6: 'deeppink', 12 | 7: 'greenyellow', 8: 'olive', 9: 'cyan', 10: 'yellow', 11: 'purple'} 13 | 14 | 15 | class ConfigInfo: 16 | # method_name, dataset_name, method_name+dataset_name.ckpt 17 | MODEL_CONFIG_PATH = "./configs/" 18 | RESULT_SAVE_DIR = r"results\{}\n{}" 19 | NEIGHBORS_CACHE_DIR = r"data\neighbors" 20 | PAIRWISE_DISTANCE_DIR = r"data\pair_distance" 21 | DATASET_CACHE_DIR = r"data\H5 Data" 22 | IMAGE_DIR = r"static/images" 23 | 24 | DATASETS_META = ["name", "num", "dim", "type"] 25 | AVAILABLE_DATASETS = [] 26 | -------------------------------------------------------------------------------- /utils/link_utils.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import random 3 | 4 | import numpy as np 5 | 6 | from utils.umap_utils import find_ab_params, convert_distance_to_probability 7 | 8 | MUST_LINK = 1 9 | CANNOT_LINK = 0 10 | SPREAD = 1 11 | UN_SPREAD = 0 12 | ACTIVE = 1 13 | INACTIVE = 0 14 | 15 | 16 | class LinkInfo: 17 | def __init__(self, links, link_spreads, finetune_epochs, pre_embeddings, min_dist): 18 | self.links = links 19 | self.link_uuid = links[:, 0] 20 | self.link_indices = links[:, 1:3] 21 | self.link_types = links[:, 3] 22 | self.link_activate = links[:, 4] 23 | self.link_spreads = link_spreads 24 | self.link_num = links.shape[0] 25 | self.new_link_num = self.link_num 26 | 27 | self.min_dist = min_dist 28 | self._a, self._b = find_ab_params(1.0, min_dist) 29 | 30 | self.finetune_epochs = finetune_epochs 31 | 32 | self.all_link_indices = None 33 | 34 | self.crs_indices = None 35 | self.crs_sims = None 36 | self.new_crs_indices = None 37 | self.new_crs_sims = None 38 | self.new_link_spreads = None 39 | 40 | self.iter_count = 0 41 | self.weaken_rate = 0.5 42 | 43 | self.construct_csr(pre_embeddings, links, link_spreads) 44 | 45 | def process_cur_links(self, links, link_spreads, pre_embeddings): 46 | link_uuid = links[:, 0] 47 | link_activates = links[:, 4] 48 | pre_link_uuid = self.link_uuid 49 | 50 | added_link_indices = [] 51 | preserve_link_indices = [] 52 | preserved_new_link_activate = [] 53 | list_link_uuid = list(link_uuid) 54 | for i, uuid in enumerate(pre_link_uuid): 55 | idx = list_link_uuid.index(uuid) if uuid in list_link_uuid else -1 56 | if idx >= 0: 57 | preserve_link_indices.append(i) 58 | preserved_new_link_activate.append(link_activates[idx]) 59 | 60 | for i, uuid in enumerate(link_uuid): 61 | if uuid not in pre_link_uuid: 62 | added_link_indices.append(i) 63 | 64 | preserve_link_uuid = pre_link_uuid 65 | 66 | added_links = links[added_link_indices] if len(added_link_indices) > 0 else [] 67 | added_link_spreads = link_spreads[added_link_indices] if len(added_link_indices) > 0 else [] 68 | 69 | self._preserve_old_links(preserve_link_uuid, preserve_link_indices) 70 | self._link_activate_change(preserved_new_link_activate) 71 | self.add_new_links(added_links, added_link_spreads, pre_embeddings) 72 | 73 | def add_new_links(self, added_links, added_link_spreads, pre_embeddings): 74 | self.new_link_num = len(added_link_spreads) 75 | if self.new_link_num <= 0: 76 | return 77 | self.link_num += self.new_link_num 78 | self.links = np.concatenate([self.links, added_links], axis=0) 79 | self.link_uuid = np.concatenate([self.link_uuid, added_links[:, 0]], axis=0) 80 | self.link_indices = np.concatenate([self.link_indices, added_links[:, 1:3]], axis=0) 81 | self.link_types = np.concatenate([self.link_types, added_links[:, 3]], axis=0) 82 | self.link_activate = np.concatenate([self.link_activate, added_links[:, 4]], axis=0) 83 | self.link_spreads = np.concatenate([self.link_spreads, added_link_spreads], axis=0) 84 | self.construct_csr(pre_embeddings, added_links, added_link_spreads) 85 | 86 | def _preserve_old_links(self, preserve_link_uuid, preserve_link_indices): 87 | self.link_num = len(preserve_link_indices) 88 | self.link_uuid = self.link_uuid[preserve_link_indices] 89 | self.link_indices = self.link_indices[preserve_link_indices] 90 | self.link_spreads = self.link_spreads[preserve_link_indices] 91 | self.link_types = self.link_types[preserve_link_indices] 92 | self.link_activate = self.link_activate[preserve_link_indices] 93 | 94 | reserved_crs_indices = [] 95 | for i in range(len(self.crs_indices)): 96 | if self.crs_indices[i, 0] in preserve_link_uuid: 97 | reserved_crs_indices.append(i) 98 | 99 | self.crs_indices = self.crs_indices[reserved_crs_indices] 100 | self.crs_sims = self.crs_sims[reserved_crs_indices] 101 | 102 | def _link_activate_change(self, new_link_activate): 103 | 104 | weaken_indices = [] 105 | zero_set_indices = [] 106 | one_set_indices = [] 107 | for i in range(self.link_num): 108 | pre_stat = self.link_activate[i] 109 | cur_stat = new_link_activate[i] 110 | if pre_stat == cur_stat: 111 | if pre_stat == ACTIVE: 112 | weaken_indices.append(i) 113 | else: 114 | zero_set_indices.append(i) 115 | else: 116 | if pre_stat == ACTIVE: 117 | zero_set_indices.append(i) 118 | else: 119 | one_set_indices.append(i) 120 | 121 | weaken_link_ids = self.link_uuid[weaken_indices] if len(weaken_indices) > 0 else [] 122 | zero_set_link_ids = self.link_uuid[zero_set_indices] if len(zero_set_indices) > 0 else [] 123 | one_set_link_ids = self.link_uuid[one_set_indices] if len(one_set_indices) > 0 else [] 124 | self.weaken_old_links(weaken_link_ids, zero_set_link_ids, one_set_link_ids) 125 | 126 | def weaken_old_links(self, weaken_link_ids, zero_set_link_ids, one_set_link_ids): 127 | for i in range(self.crs_indices.shape[0]): 128 | link_uuid = self.crs_indices[i][0] 129 | 130 | if link_uuid in weaken_link_ids: 131 | pre_link_weight = self.crs_indices[i][-1] 132 | new_link_weight = pre_link_weight * self.weaken_rate 133 | if new_link_weight <= 0.05: 134 | pass 135 | elif link_uuid in zero_set_link_ids: 136 | new_link_weight = 0 137 | else: 138 | new_link_weight = 1 139 | 140 | self.crs_indices[i][-1] = new_link_weight 141 | 142 | def construct_csr(self, embeddings, cur_new_links, cur_new_link_spreads): 143 | 144 | link_uuid = cur_new_links[:, 0] 145 | link_indices = cur_new_links[:, 1:3] 146 | link_types = cur_new_links[:, 3] 147 | 148 | self.new_link_spreads = np.array(cur_new_link_spreads, dtype=np.int) 149 | link_num = len(cur_new_link_spreads) 150 | cur_total_link_num = np.sum(cur_new_link_spreads) + np.sum(~np.array(cur_new_link_spreads, dtype=np.bool)) 151 | if link_num <= 0: 152 | return 153 | 154 | crs_indices = np.ones((cur_total_link_num, 5), dtype=np.float) 155 | crs_sims = np.ones((cur_total_link_num, 5)) 156 | 157 | count = 0 158 | for i in range(link_num): 159 | h_idx, t_idx = link_indices[i] 160 | uuid = link_uuid[i] 161 | cur_type = link_types[i] 162 | crs_indices[count] = [uuid, h_idx, t_idx, cur_type, 2] 163 | crs_sims[count] = [uuid, 1, 1, cur_type, 2] 164 | count += 1 165 | 166 | self.new_crs_indices = crs_indices 167 | self.new_crs_sims = crs_sims 168 | if self.crs_indices is None: 169 | self.crs_indices = crs_indices 170 | self.crs_sims = crs_sims 171 | else: 172 | self.crs_indices = np.concatenate([self.crs_indices, crs_indices], axis=0) 173 | self.crs_sims = np.concatenate([self.crs_sims, crs_sims], axis=0) 174 | 175 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import logging 4 | import shutil 5 | from multiprocessing import Process 6 | 7 | 8 | class InfoLogger: 9 | LOG_LEVEL = logging.INFO 10 | LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" 11 | logging.basicConfig(level=LOG_LEVEL, format=LOG_FORMAT) 12 | 13 | @staticmethod 14 | def change_level(new_level): 15 | InfoLogger.LOG_LEVEL = new_level 16 | logging.basicConfig(level=new_level, format=InfoLogger.LOG_FORMAT) 17 | 18 | @staticmethod 19 | def change_format(new_format): 20 | InfoLogger.LOG_FORMAT = new_format 21 | logging.basicConfig(level=InfoLogger.LOG_LEVEL, format=new_format) 22 | 23 | @staticmethod 24 | def info(message): 25 | logging.info(message) 26 | 27 | @staticmethod 28 | def warn(message): 29 | logging.warning(message) 30 | 31 | @staticmethod 32 | def debug(message): 33 | logging.debug(message) 34 | 35 | @staticmethod 36 | def error(message): 37 | logging.error(message) 38 | 39 | 40 | class LogWriter(Process): 41 | def __init__(self, file_path, save_path, message_queue): 42 | self.name = "logging process" 43 | Process.__init__(self, name=self.name) 44 | self.file_path = file_path 45 | self.save_path = save_path 46 | self.file = None 47 | self.message_queue = message_queue 48 | 49 | def run(self) -> None: 50 | self.file = open(self.file_path, "a") 51 | self.file.truncate(0) 52 | while True: 53 | message = self.message_queue.get() 54 | if message == "end": 55 | break 56 | self.file.write(message + "\n") 57 | 58 | self.file.close() 59 | -------------------------------------------------------------------------------- /utils/math_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from pynndescent import NNDescent 7 | 8 | from utils.umap_utils import find_ab_params, convert_distance_to_probability 9 | import networkx as nx 10 | import bisect 11 | from scipy.spatial.distance import pdist, squareform 12 | 13 | MACHINE_EPSILON = np.finfo(np.double).eps 14 | 15 | 16 | def _student_t_similarity(rep1, rep2, *args): 17 | pairwise_matrix = torch.norm(rep1 - rep2, dim=-1) 18 | similarity_matrix = 1 / (1 + pairwise_matrix ** 2) 19 | return similarity_matrix, pairwise_matrix 20 | 21 | 22 | def _exp_similarity(rep1, rep2, *args): 23 | pairwise_matrix = torch.norm(rep1 - rep2, dim=-1) 24 | similarity_matrix = torch.exp(-pairwise_matrix ** 2) 25 | return similarity_matrix, pairwise_matrix 26 | 27 | 28 | def _cosine_similarity(rep1, rep2, *args): 29 | x = rep2[0] 30 | x = F.normalize(x, dim=1) 31 | similarity_matrix = torch.matmul(x, x.T).clamp(min=1e-7) 32 | pairwise_matrix = torch.norm(rep1 - rep2, dim=-1) 33 | return similarity_matrix, pairwise_matrix 34 | 35 | 36 | a = None 37 | b = None 38 | pre_min_dist = -1 39 | 40 | 41 | def _umap_similarity(rep1, rep2, min_dist=0.1): 42 | pairwise_matrix = torch.norm(rep1 - rep2, dim=-1) 43 | global a, b, pre_min_dist 44 | if a is None or pre_min_dist != min_dist: 45 | pre_min_dist = min_dist 46 | a, b = find_ab_params(1.0, min_dist) 47 | 48 | similarity_matrix = convert_distance_to_probability(pairwise_matrix, a, b) 49 | return similarity_matrix, pairwise_matrix 50 | 51 | 52 | def get_similarity_function(similarity_method): 53 | if similarity_method == "umap": 54 | return _umap_similarity 55 | elif similarity_method == "tsne": 56 | return _student_t_similarity 57 | elif similarity_method == "exp": 58 | return _exp_similarity 59 | elif similarity_method == "cosine": 60 | return _cosine_similarity 61 | 62 | 63 | def get_correlated_mask(batch_size): 64 | diag = np.eye(batch_size) 65 | l1 = np.eye(batch_size, batch_size, k=int(-batch_size / 2)) 66 | l2 = np.eye(batch_size, batch_size, k=int(batch_size / 2)) 67 | mask = torch.from_numpy((diag + l1 + l2)) 68 | mask = (1 - mask).type(torch.bool) 69 | return mask 70 | -------------------------------------------------------------------------------- /utils/nn_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from pynndescent import NNDescent 5 | from sklearn.metrics import pairwise_distances 6 | 7 | from utils.logger import InfoLogger 8 | 9 | 10 | def cal_snn_similarity(knn, cache_path=None): 11 | if cache_path is not None and os.path.exists(cache_path): 12 | _, snn_sim = np.load(cache_path) 13 | InfoLogger.info("directly load accurate neighbor_graph from {}".format(cache_path)) 14 | return knn, snn_sim 15 | 16 | snn_sim = np.zeros_like(knn) 17 | n_samples, n_neighbors = knn.shape 18 | for i in range(n_samples): 19 | sample_nn = knn[i] 20 | for j, neighbor_idx in enumerate(sample_nn): 21 | neighbor_nn = knn[int(neighbor_idx)] 22 | snn_num = len(np.intersect1d(sample_nn, neighbor_nn)) 23 | snn_sim[i][j] = snn_num / n_neighbors 24 | if cache_path is not None and not os.path.exists(cache_path): 25 | np.save(cache_path, [knn, snn_sim]) 26 | InfoLogger.info("successfully compute snn similarity and save to {}".format(cache_path)) 27 | return knn, snn_sim 28 | 29 | 30 | def compute_accurate_knn(flattened_data, k, neighbors_cache_path=None, pairwise_cache_path=None, metric="euclidean"): 31 | cur_path = None 32 | if neighbors_cache_path is not None: 33 | cur_path = neighbors_cache_path.replace(".npy", "_ac.npy") 34 | 35 | if cur_path is not None and os.path.exists(cur_path): 36 | knn_indices, knn_distances = np.load(cur_path) 37 | InfoLogger.info("directly load accurate neighbor_graph from {}".format(cur_path)) 38 | else: 39 | preload = flattened_data.shape[0] <= 30000 40 | 41 | pairwise_distance = get_pairwise_distance(flattened_data, metric, pairwise_cache_path, preload=preload) 42 | sorted_indices = np.argsort(pairwise_distance, axis=1) 43 | knn_indices = sorted_indices[:, 1:k+1] 44 | knn_distances = [] 45 | for i in range(knn_indices.shape[0]): 46 | knn_distances.append(pairwise_distance[i, knn_indices[i]]) 47 | knn_distances = np.array(knn_distances) 48 | if cur_path is not None: 49 | np.save(cur_path, [knn_indices, knn_distances]) 50 | InfoLogger.info("successfully compute accurate neighbor_graph and save to {}".format(cur_path)) 51 | return knn_indices, knn_distances 52 | 53 | 54 | def compute_knn_graph(all_data, neighbors_cache_path, k, pairwise_cache_path, 55 | metric="euclidean", max_candidates=60, accelerate=False): 56 | flattened_data = all_data.reshape((len(all_data), np.product(all_data.shape[1:]))) 57 | 58 | if not accelerate: 59 | knn_indices, knn_distances = compute_accurate_knn(flattened_data, k, neighbors_cache_path, pairwise_cache_path) 60 | return knn_indices, knn_distances 61 | 62 | if neighbors_cache_path is not None and os.path.exists(neighbors_cache_path): 63 | neighbor_graph = np.load(neighbors_cache_path) 64 | knn_indices, knn_distances = neighbor_graph 65 | InfoLogger.info("directly load approximate neighbor_graph from {}".format(neighbors_cache_path)) 66 | else: 67 | # number of trees in random projection forest 68 | n_trees = 5 + int(round((all_data.shape[0]) ** 0.5 / 20.0)) 69 | # max number of nearest neighbor iters to perform 70 | n_iters = max(5, int(round(np.log2(all_data.shape[0])))) 71 | nnd = NNDescent( 72 | flattened_data, 73 | n_neighbors=k+1, 74 | metric=metric, 75 | n_trees=n_trees, 76 | n_iters=n_iters, 77 | max_candidates=max_candidates, 78 | verbose=False 79 | ) 80 | 81 | knn_indices, knn_distances = nnd.neighbor_graph 82 | knn_indices = knn_indices[:, 1:] 83 | knn_distances = knn_distances[:, 1:] 84 | 85 | if neighbors_cache_path is not None: 86 | np.save(neighbors_cache_path, [knn_indices, knn_distances]) 87 | InfoLogger.info("successfully compute approximate neighbor_graph and save to {}".format(neighbors_cache_path)) 88 | return knn_indices, knn_distances 89 | 90 | 91 | def get_pairwise_distance(flattened_data, metric, pairwise_distance_cache_path=None, preload=False): 92 | if pairwise_distance_cache_path is not None and preload and os.path.exists(pairwise_distance_cache_path): 93 | pairwise_distance = np.load(pairwise_distance_cache_path) 94 | InfoLogger.info("directly load pairwise distance from {}".format(pairwise_distance_cache_path)) 95 | else: 96 | pairwise_distance = pairwise_distances(flattened_data, metric=metric, squared=False) 97 | pairwise_distance[pairwise_distance < 1e-12] = 0.0 98 | if preload and pairwise_distance_cache_path is not None: 99 | np.save(pairwise_distance_cache_path, pairwise_distance) 100 | InfoLogger.info("successfully compute pairwise distance and save to {}".format(pairwise_distance_cache_path)) 101 | return pairwise_distance 102 | -------------------------------------------------------------------------------- /utils/umap_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | import numpy 4 | import numpy as np 5 | import torch 6 | import scipy 7 | from scipy.sparse import coo_matrix, csr_matrix 8 | from scipy.optimize import curve_fit 9 | import matplotlib.pyplot as plt 10 | 11 | from utils.nn_utils import cal_snn_similarity, compute_accurate_knn 12 | 13 | INT32_MIN = np.iinfo(np.int32).min + 1 14 | INT32_MAX = np.iinfo(np.int32).max - 1 15 | 16 | SMOOTH_K_TOLERANCE = 1e-5 17 | MIN_K_DIST_SCALE = 1e-3 18 | NPY_INFINITY = np.inf 19 | 20 | 21 | def get_graph_elements(graph_, n_epochs): 22 | # CSR -> COO 23 | graph = graph_.tocoo() 24 | # eliminate duplicate entries by summing them together 25 | graph.sum_duplicates() 26 | # number of vertices in dataset 27 | n_vertices = graph.shape[1] 28 | # get the number of epochs based on the size of the dataset 29 | if n_epochs is None: 30 | # For smaller datasets we can use more epochs 31 | if graph.shape[0] <= 10000: 32 | n_epochs = 500 33 | else: 34 | n_epochs = 200 35 | # remove elements with very low probability 36 | graph.data[graph.data < (graph.data.max() / float(n_epochs))] = 0.0 37 | graph.eliminate_zeros() 38 | epochs_per_sample = make_epochs_per_sample(graph.data, n_epochs) 39 | 40 | head = graph.row 41 | tail = graph.col 42 | weight = graph.data 43 | 44 | return graph, epochs_per_sample, head, tail, weight, n_vertices 45 | 46 | 47 | def make_epochs_per_sample(weights, n_epochs): 48 | result = -1.0 * np.ones(weights.shape[0], dtype=np.float64) 49 | n_samples = n_epochs * (weights / weights.max()) 50 | result[n_samples > 0] = float(n_epochs) / n_samples[n_samples > 0] 51 | return result 52 | 53 | 54 | def construct_edge_dataset( 55 | X, 56 | graph_, 57 | n_epochs 58 | ): 59 | graph, epochs_per_sample, head, tail, weight, n_vertices = get_graph_elements( 60 | graph_, n_epochs 61 | ) 62 | 63 | edges_to_exp, edges_from_exp = ( 64 | np.repeat(head, epochs_per_sample.astype("int")), 65 | np.repeat(tail, epochs_per_sample.astype("int")), 66 | ) 67 | 68 | shuffle_mask = np.random.permutation(range(len(edges_to_exp))) 69 | edges_to_exp = edges_to_exp[shuffle_mask] 70 | edges_from_exp = edges_from_exp[shuffle_mask] 71 | 72 | embedding_to_from_indices = np.array([edges_to_exp, edges_from_exp]) 73 | embedding_to_from_indices_re = np.repeat(embedding_to_from_indices, 2, 1) 74 | np.random.shuffle(embedding_to_from_indices_re) 75 | embedding_to_from_data = X[embedding_to_from_indices[0, :]], X[embedding_to_from_indices[1, :]] 76 | 77 | return embedding_to_from_data, len(edges_to_exp), weight 78 | 79 | 80 | def fuzzy_simplicial_set( 81 | X, 82 | n_neighbors, 83 | knn_indices=None, 84 | knn_dists=None, 85 | set_op_mix_ratio=1.0, 86 | local_connectivity=1.0, 87 | apply_set_operations=True, 88 | return_dists=None, 89 | symmetric="TSNE", 90 | ): 91 | if knn_indices is None or knn_dists is None: 92 | pass 93 | 94 | knn_dists = knn_dists.astype(np.float32) 95 | sigmas, rhos = smooth_knn_dist( 96 | knn_dists, 97 | float(n_neighbors), 98 | local_connectivity=float(local_connectivity), 99 | ) 100 | 101 | rows, cols, vals, dists = compute_membership_strengths( 102 | knn_indices, knn_dists, sigmas, rhos, return_dists 103 | ) 104 | 105 | origin_knn_weights = vals.reshape(knn_indices.shape) 106 | result = scipy.sparse.coo_matrix( 107 | (vals, (rows, cols)), shape=(X.shape[0], X.shape[0]) 108 | ) 109 | result.eliminate_zeros() 110 | 111 | if apply_set_operations: 112 | transpose = result.transpose() 113 | if symmetric == "UMAP": 114 | prod_matrix = result.multiply(transpose) 115 | result = ( 116 | set_op_mix_ratio * (result + transpose - prod_matrix) 117 | + (1.0 - set_op_mix_ratio) * prod_matrix 118 | ) 119 | elif symmetric == "TSNE": 120 | result = (result + transpose) / 2 121 | else: 122 | raise RuntimeError("Unsupported symmetric way! Please ensure the param " 123 | "name is one of 'UMAP/TSNE'") 124 | 125 | result.eliminate_zeros() 126 | 127 | if return_dists is None: 128 | return result, sigmas, rhos, origin_knn_weights 129 | else: 130 | if return_dists: 131 | dmat = coo_matrix( 132 | (dists, (rows, cols)), shape=(X.shape[0], X.shape[0]) 133 | ) 134 | 135 | dists = dmat.maximum(dmat.transpose()).todok() 136 | else: 137 | dists = None 138 | 139 | return result, sigmas, rhos, origin_knn_weights, dists 140 | 141 | 142 | def smooth_knn_dist(distances, k, n_iter=64, local_connectivity=1.0, bandwidth=1.0): 143 | target = np.log2(k) * bandwidth 144 | rho = np.zeros(distances.shape[0], dtype=np.float32) 145 | result = np.zeros(distances.shape[0], dtype=np.float32) 146 | 147 | mean_distances = np.mean(distances) 148 | 149 | for i in range(distances.shape[0]): 150 | lo = 0.0 151 | hi = NPY_INFINITY 152 | mid = 1.0 153 | 154 | ith_distances = distances[i] 155 | non_zero_dists = ith_distances[ith_distances > 0.0] 156 | if non_zero_dists.shape[0] >= local_connectivity: 157 | index = int(np.floor(local_connectivity)) 158 | interpolation = local_connectivity - index 159 | if index > 0: 160 | rho[i] = non_zero_dists[index - 1] 161 | if interpolation > SMOOTH_K_TOLERANCE: 162 | rho[i] += interpolation * ( 163 | non_zero_dists[index] - non_zero_dists[index - 1] 164 | ) 165 | else: 166 | rho[i] = interpolation * non_zero_dists[0] 167 | elif non_zero_dists.shape[0] > 0: 168 | rho[i] = np.max(non_zero_dists) 169 | 170 | for n in range(n_iter): 171 | psum = 0.0 172 | for j in range(1, distances.shape[1]): 173 | d = distances[i, j] - rho[i] 174 | if d > 0: 175 | psum += np.exp(-(d / mid)) 176 | else: 177 | psum += 1.0 178 | 179 | if np.fabs(psum - target) < SMOOTH_K_TOLERANCE: 180 | break 181 | if psum > target: 182 | hi = mid 183 | mid = (lo + hi) / 2.0 184 | else: 185 | lo = mid 186 | if hi == NPY_INFINITY: 187 | mid *= 2 188 | else: 189 | mid = (lo + hi) / 2.0 190 | result[i] = mid 191 | 192 | if rho[i] > 0.0: 193 | mean_ith_distances = np.mean(ith_distances) 194 | if result[i] < MIN_K_DIST_SCALE * mean_ith_distances: 195 | result[i] = MIN_K_DIST_SCALE * mean_ith_distances 196 | else: 197 | if result[i] < MIN_K_DIST_SCALE * mean_distances: 198 | result[i] = MIN_K_DIST_SCALE * mean_distances 199 | 200 | return result, rho 201 | 202 | 203 | def compute_membership_strengths( 204 | knn_indices, knn_dists, sigmas, rhos, return_dists=False 205 | ): 206 | n_samples = knn_indices.shape[0] 207 | n_neighbors = knn_indices.shape[1] 208 | 209 | rows = np.zeros(knn_indices.size, dtype=np.int32) 210 | cols = np.zeros(knn_indices.size, dtype=np.int32) 211 | vals = np.zeros(knn_indices.size, dtype=np.float32) 212 | if return_dists: 213 | dists = np.zeros(knn_indices.size, dtype=np.float32) 214 | else: 215 | dists = None 216 | 217 | for i in range(n_samples): 218 | for j in range(n_neighbors): 219 | if knn_indices[i, j] == -1: 220 | continue 221 | 222 | if knn_indices[i, j] == i: 223 | val = 0.0 224 | 225 | elif knn_dists[i, j] - rhos[i] <= 0.0 or sigmas[i] == 0.0: 226 | val = 1.0 227 | else: 228 | 229 | val = np.exp(-((knn_dists[i, j] - rhos[i]) / (sigmas[i]))) 230 | 231 | rows[i * n_neighbors + j] = i 232 | cols[i * n_neighbors + j] = knn_indices[i, j] 233 | vals[i * n_neighbors + j] = val 234 | if return_dists: 235 | dists[i * n_neighbors + j] = knn_dists[i, j] 236 | 237 | return rows, cols, vals, dists 238 | 239 | 240 | def find_ab_params(spread, min_dist): 241 | def curve(x, a, b): 242 | return 1.0 / (1.0 + a * x ** (2 * b)) 243 | xv = np.linspace(0, spread * 3, 300) 244 | yv = np.zeros(xv.shape) 245 | yv[xv < min_dist] = 1.0 246 | yv[xv >= min_dist] = np.exp(-(xv[xv >= min_dist] - min_dist) / spread) 247 | params, covar = curve_fit(curve, xv, yv) 248 | return params[0], params[1] 249 | 250 | 251 | def convert_distance_to_probability(distances, a=1.0, b=1.0): 252 | return 1.0 / (1.0 + a * distances ** (2 * b)) 253 | 254 | 255 | def compute_local_membership(knn_dist, knn_indices, local_connectivity=1): 256 | knn_dist = knn_dist.astype(np.float32) 257 | sigmas, rhos = smooth_knn_dist( 258 | knn_dist, 259 | float(knn_indices.shape[1]), 260 | local_connectivity=float(local_connectivity), 261 | ) 262 | rows, cols, vals, dists = compute_membership_strengths( 263 | knn_indices+1, knn_dist, sigmas, rhos, False 264 | ) 265 | return vals 266 | -------------------------------------------------------------------------------- /vis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DRLib/CDR/bf83fbd80263b75db1b821c4d8e4b43f307249ac/vis.jpg -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | 4 | from model.cdr import CDRModel 5 | import torch 6 | 7 | from utils.common_utils import get_config 8 | from utils.constant_pool import * 9 | import argparse 10 | from experiments.trainer import CDRTrainer 11 | import os 12 | 13 | log_path = "log.txt" 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--configs", type=str, default="configs/CDR.yaml", help="configuration file path") 19 | parser.add_argument("--ckpt", type=str, default="model_weights/usps.pth.tar") 20 | parser.add_argument("--device", type=str, default="cpu") 21 | return parser.parse_args() 22 | 23 | 24 | if __name__ == '__main__': 25 | args = parse_args() 26 | cfg = get_config() 27 | cfg.merge_from_file(args.configs) 28 | device = args.device 29 | clr_model = CDRModel(cfg, device=device) 30 | trainer = CDRTrainer(clr_model, cfg.exp_params.dataset, cfg, None, args.configs, 31 | device=device, log_path=log_path) 32 | trainer.load_weights_visualization(args.ckpt, vis_save_path="vis.jpg", device=device) 33 | --------------------------------------------------------------------------------