├── environment.yml ├── run_galaxc.sh ├── run.sh ├── LICENSE ├── requirements.txt ├── README.md ├── data.py ├── predict_main.py ├── utils.py ├── train_main.py └── network.py /environment.yml: -------------------------------------------------------------------------------- 1 | name: galaxc 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.6.9 6 | - anaconda 7 | - pip 8 | - pip: 9 | - -r file:requirements.txt 10 | -------------------------------------------------------------------------------- /run_galaxc.sh: -------------------------------------------------------------------------------- 1 | conda env create -n environment.yml 2 | conda init bash 3 | source ~/.bashrc 4 | conda activate environment.yml 5 | pip install hnswlib 6 | sudo apt --yes install bc git 7 | git clone https://github.com/kunaldahiya/pyxclib.git 8 | cd pyxclib 9 | python setup.py install --user 10 | echo "the pwd is : $PWD" 11 | cd .. 12 | bash run.sh /mnt/my_storage/DLTSEastUSBackup/GraphXML/data/$1/ $2 $3 $4 $5 $6 $7 $8 $9 ${10} ${11} ${12} ${13} ${14} ${15} | tee /mnt/my_storage/DLTSEastUSBackup/GraphXML/Logs/$1/`cat /dev/urandom | tr -cd 'a-f0-9' | head -c 32`.txt 13 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python -u train_main.py \ 2 | --dataset $1 \ 3 | --save-model 0 \ 4 | --devices ${16} \ 5 | --num-epochs $2 \ 6 | --num-HN-epochs $3 \ 7 | --batch-size $4 \ 8 | --lr $5 \ 9 | --attention-lr $6 \ 10 | --adjust-lr $7 \ 11 | --dlr-factor 0.5 \ 12 | --mpt $8 \ 13 | --restrict-edges-num $9 \ 14 | --restrict-edges-head-threshold 20 \ 15 | --num-random-samples ${10} \ 16 | --random-shuffle-nbrs ${11} \ 17 | --fanouts ${12} \ 18 | --num-HN-shortlist ${13} \ 19 | --embedding-type DX \ 20 | --run-type NR \ 21 | --num-validation 25000 \ 22 | --validation-freq -1 \ 23 | --num-shortlist 500 \ 24 | --predict-ova 0 \ 25 | --A ${14} \ 26 | --B ${15} \ 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Extreme-classification 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | adal==1.2.5 3 | argon2-cffi==20.1.0 4 | astor==0.8.1 5 | async-generator==1.10 6 | attrs==20.3.0 7 | azure-datalake-store==0.0.51 8 | backcall==0.2.0 9 | bleach==3.3.0 10 | boto3==1.15.16 11 | botocore==1.18.18 12 | certifi==2020.6.20 13 | cffi==1.14.3 14 | chardet==3.0.4 15 | click==7.1.2 16 | cloudpickle==1.6.0 17 | cryptography==3.2 18 | cycler==0.10.0 19 | Cython==0.29.21 20 | dataclasses==0.7 21 | decorator==4.4.2 22 | defusedxml==0.6.0 23 | entrypoints==0.3 24 | filelock==3.0.12 25 | future==0.18.2 26 | gast==0.2.2 27 | gensim==3.8.3 28 | google-pasta==0.2.0 29 | grpcio==1.33.1 30 | h5py==2.10.0 31 | hnswlib==0.4.0 32 | idna==2.10 33 | importlib-metadata==2.0.0 34 | ipykernel==5.4.3 35 | ipython==7.16.1 36 | ipython-genutils==0.2.0 37 | ipywidgets==7.6.3 38 | jedi==0.17.2 39 | Jinja2==2.11.3 40 | jmespath==0.10.0 41 | joblib==0.16.0 42 | jsonschema==3.2.0 43 | jupyter==1.0.0 44 | jupyter-client==6.1.11 45 | jupyter-console==6.2.0 46 | jupyter-core==4.7.1 47 | jupyterlab-pygments==0.1.2 48 | jupyterlab-widgets==1.0.0 49 | Keras-Applications==1.0.8 50 | Keras-Preprocessing==1.1.2 51 | kiwisolver==1.2.0 52 | llvmlite==0.34.0 53 | Markdown==3.3.3 54 | MarkupSafe==1.1.1 55 | matplotlib==3.3.2 56 | mistune==0.8.4 57 | nbclient==0.5.1 58 | nbconvert==6.0.7 59 | nbformat==5.1.2 60 | nest-asyncio==1.5.1 61 | networkx==2.5 62 | nltk==3.5 63 | nmslib==2.0.6 64 | notebook==6.2.0 65 | numba==0.51.1 66 | numpy==1.19.3 67 | opt-einsum==3.3.0 68 | packaging==20.4 69 | pandas==1.1.3 70 | pandocfilters==1.4.3 71 | parso==0.7.1 72 | pexpect==4.8.0 73 | pickleshare==0.7.5 74 | Pillow==8.0.1 75 | plotly==4.11.0 76 | pluggy==0.13.1 77 | prometheus-client==0.9.0 78 | prompt-toolkit==3.0.14 79 | protobuf==3.13.0 80 | psutil==5.7.2 81 | ptyprocess==0.7.0 82 | pybind11==2.5.0 83 | pycparser==2.20 84 | Pygments==2.7.4 85 | PyJWT==1.7.1 86 | pyparsing==2.4.7 87 | pyrsistent==0.17.3 88 | python-dateutil==2.8.1 89 | python-jsonrpc-server==0.4.0 90 | python-language-server==0.36.2 91 | pytz==2020.1 92 | PyYAML==5.3.1 93 | pyzmq==22.0.2 94 | qtconsole==5.0.2 95 | QtPy==1.9.0 96 | randomgen==1.19.3 97 | regex==2020.7.14 98 | requests==2.24.0 99 | retrying==1.3.3 100 | s3transfer==0.3.3 101 | sacremoses==0.0.43 102 | scikit-learn==0.23.2 103 | scipy==1.5.4 104 | Send2Trash==1.5.0 105 | sentencepiece==0.1.94 106 | six==1.15.0 107 | sklearn==0.0 108 | smart-open==4.1.0 109 | tensorboard==1.15.0 110 | tensorflow-estimator==1.15.1 111 | tensorflow-gpu==1.15.0 112 | termcolor==1.1.0 113 | terminado==0.9.2 114 | testpath==0.4.4 115 | threadpoolctl==2.1.0 116 | tokenizers==0.5.2 117 | torch==1.6.0 118 | tornado==6.1 119 | tqdm==4.48.2 120 | traitlets==4.3.3 121 | transformers==2.8.0 122 | typing==3.7.4.3 123 | ujson==4.0.2 124 | urllib3==1.25.11 125 | wcwidth==0.2.5 126 | webencodings==0.5.1 127 | Werkzeug==1.0.1 128 | widgetsnbextension==3.5.1 129 | wrapt==1.12.1 130 | zipp==3.4.0 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GalaXC 2 | ## [GalaXC: Graph Neural Networks with Labelwise Attention for Extreme Classification](http://manikvarma.org/pubs/saini21.pdf) 3 | ```bib 4 | @InProceedings{Saini21, 5 | author = {Saini, D. and Jain, A.K. and Dave, K. and Jiao, J. and Singh, A. and Zhang, R. and Varma, M.}, 6 | title = {GalaXC: Graph Neural Networks with Labelwise Attention for Extreme Classification}, 7 | booktitle = {Proceedings of The Web Conference}, 8 | month = "April", 9 | year = "2021", 10 | } 11 | ``` 12 | 13 | #### Setup GalaXC 14 | ```bash 15 | git clone https://github.com/Extreme-classification/GalaXC.git 16 | conda env create -f GalaXC/environment.yml 17 | conda activate galaxc 18 | pip install hnswlib 19 | git clone https://github.com/kunaldahiya/pyxclib.git 20 | cd pyxclib 21 | python setup.py install 22 | cd ../GalaXC 23 | ``` 24 | 25 | #### Dataset Structure 26 | Your dataset should have the following structure: 27 | ``` 28 | DatasetName (e.g. LF-AmazonTitles-131K) 29 | │ trn_X.txt (text for trn documents, one text in each line) 30 | | tst_X.tst (text for tst documents, one text in each line) 31 | | Y.txt (text for labels, one text in each line) 32 | │ trn_X_Y.txt (trn labels in spmat format) 33 | | tst_X_Y.txt (tst labels in spmat format) 34 | | filter_labels_test.txt (filter labels where label and test documents are same) 35 | │ 36 | └───XXCondensedData (embeddings for tst, trn documents and labels, for benchmark datasets, XX=DX[Astec]) 37 | │ trn_point_embs.npy (2D numpy matrix for trn document embeddings) 38 | │ tst_point_embs.npy (2D numpy matrix for tst document embeddings) 39 | | label_embs.npy (2D numpy matrix for label embeddings) 40 | 41 | ``` 42 | 43 | We have provided the DX(embeddings from Module 1 of Astec) embeddings for public benchmark datasets for ease of use. Got better(higher recall) embeddings from somewhere? Just plug the new ones and GalaXC will have better preformance, no need to make any code change! These files for LF-AmazonTitles-131K, LF-WikiSeeAlsoTitles-320K and LF-AmazonTitles-1.3M can be found [here](https://drive.google.com/drive/folders/1PamOpzMV6NlgvBEOwpxPZ4dahnun-dtN?usp=sharing). Except the files in DXCondensedData, all other files are copy of the datasets from [The Extreme Classification Repository](http://manikvarma.org/downloads/XC/XMLRepository.html). 44 | 45 | 46 | #### Sample Runs 47 | To reproduce the numbers on public benchmark datasets reported in the paper, the sample runs are 48 | 49 | **LF-AmazonTitles-131K** 50 | ```bash 51 | python -u -W ignore train_main.py --dataset /your/path/to/data/LF-AmazonTitles-131K --save-model 0 --devices cuda:0 --num-epochs 30 --num-HN-epochs 0 --batch-size 256 --lr 0.001 --attention-lr 0.001 --adjust-lr 5,10,15,20,25,28 --dlr-factor 0.5 --mpt 0 --restrict-edges-num -1 --restrict-edges-head-threshold 20 --num-random-samples 30000 --random-shuffle-nbrs 0 --fanouts 4,3,2 --num-HN-shortlist 500 --embedding-type DX --run-type NR --num-validation 25000 --validation-freq -1 --num-shortlist 500 --predict-ova 0 --A 0.6 --B 2.6 52 | ``` 53 | 54 | **LF-WikiSeeAlsoTitles-320K** 55 | ```bash 56 | python -u -W ignore train_main.py --dataset /your/path/to/data/LF-WikiSeeAlsoTitles-320K --save-model 0 --devices cuda:0 --num-epochs 30 --num-HN-epochs 0 --batch-size 256 --lr 0.001 --attention-lr 0.05 --adjust-lr 5,10,15,20,25,28 --dlr-factor 0.5 --mpt 0 --restrict-edges-num -1 --restrict-edges-head-threshold 20 --num-random-samples 32000 --random-shuffle-nbrs 0 --fanouts 4,3,2 --num-HN-shortlist 500 --repo 1 --embedding-type DX --run-type NR --num-validation 25000 --validation-freq -1 --num-shortlist 500 --predict-ova 0 --A 0.55 --B 1.5 57 | ``` 58 | 59 | **LF-AmazonTitles-1.3M** 60 | ```bash 61 | python -u -W ignore train_main.py --dataset /your/path/to/data/LF-AmazonTitles-1.3M --save-model 0 --devices cuda:0 --num-epochs 24 --num-HN-epochs 15 --batch-size 512 --lr 0.001 --attention-lr 0.05 --adjust-lr 4,8,12,16,18,20,22 --dlr-factor 0.5 --mpt 0 --restrict-edges-num 5 --restrict-edges-head-threshold 20 --num-random-samples 100000 --random-shuffle-nbrs 1 --fanouts 3,3,3 --num-HN-shortlist 500 --embedding-type DX --run-type NR --num-validation 25000 --validation-freq -1 --num-shortlist 500 --predict-ova 0 --A 0.6 --B 2.6 62 | ``` 63 | 64 | ### YOU MAY ALSO LIKE 65 | - [DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents](https://github.com/Extreme-classification/deepxml) 66 | - [DECAF: Deep Extreme Classification with Label Features](https://github.com/Extreme-classification/DECAF) 67 | - [ECLARE: Extreme Classification with Label Graph Correlations](https://github.com/Extreme-classification/ECLARE) 68 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | import torch.nn.functional as F 5 | import torch.utils.data 6 | 7 | import numpy as np 8 | import math 9 | import time 10 | import os 11 | import pickle 12 | import random 13 | import nmslib 14 | import sys 15 | from scipy.sparse import csr_matrix, lil_matrix, load_npz, hstack, vstack 16 | 17 | from torch.utils.data import IterableDataset, DataLoader 18 | 19 | 20 | class Graph(): 21 | def __init__(self, feat_data, adj_lists, random_shuffle_nbrs): 22 | self.feat_data = feat_data 23 | self.adj_lists = adj_lists 24 | self.random_shuffle_nbrs = random_shuffle_nbrs 25 | 26 | def sample_neighbors( 27 | self, 28 | nodes: np.array, 29 | count: int = 10, 30 | default_node: int = -1, 31 | default_weight: float = 0.0, 32 | default_node_type: int = -1, 33 | ) -> (np.array, np.array, np.array, np.array): 34 | res = np.empty((len(nodes), count), dtype=np.int64) 35 | for i in range(len(nodes)): 36 | universe = np.array(self.adj_lists[nodes[i]], dtype=np.int64) 37 | 38 | if(self.random_shuffle_nbrs == 1): 39 | np.random.shuffle(universe) 40 | 41 | # If there are no neighbors, fill results with a dummy value. 42 | if len(universe) == 0: 43 | res[i] = np.full(count, -1, dtype=np.int64) 44 | else: 45 | repetitions = int(count / len(universe)) + 1 46 | res[i] = np.resize(np.tile(universe, repetitions), count) 47 | 48 | return ( 49 | res, 50 | np.full((len(nodes), count), 0.0, dtype=np.float32), 51 | np.full((len(nodes), count), -1, dtype=np.int32), 52 | np.full((len(nodes)), 0, dtype=np.int32), 53 | ) 54 | 55 | def node_features(self, nodes: np.array) -> np.array: 56 | return torch.Tensor(self.feat_data[nodes]) 57 | 58 | 59 | class DatasetGraph(torch.utils.data.Dataset): 60 | def __init__(self, X_Y, hard_negs): 61 | self.X_Y = X_Y 62 | self.res_dict = [self.X_Y.indices[self.X_Y.indptr[i]: self.X_Y.indptr[i + 1]] 63 | for i in range(len(self.X_Y.indptr) - 1)] 64 | self.hard_negs = [list(set(hard_negs[i]) - set(self.res_dict[i])) 65 | for i in range(self.X_Y.shape[0])] 66 | 67 | print( 68 | "Shape of X_Y = ", self.X_Y.shape, len( 69 | self.res_dict), len( 70 | hard_negs[1]), len( 71 | self.hard_negs[1]), self.res_dict[1]) 72 | 73 | def __getitem__(self, index): 74 | return (index, self.res_dict[index], self.hard_negs[index]) 75 | 76 | def update_hard_negs(self, hard_negs): 77 | self.hard_negs = [list(set(hard_negs[i]) - set(self.res_dict[i])) 78 | for i in range(self.X_Y.shape[0])] 79 | 80 | def __len__(self): 81 | return self.X_Y.shape[0] 82 | 83 | 84 | class GraphCollator(): 85 | def __init__(self, model, num_labels, num_random=0, 86 | train=1, num_hard_neg=10): 87 | self.model = model 88 | self.train = train 89 | self.num_hard_neg = num_hard_neg 90 | self.num_labels = num_labels 91 | self.num_random = num_random 92 | 93 | def __call__(self, batch): 94 | context = {} 95 | context["inputs"] = np.array([b[0] for b in batch], dtype=np.int64) 96 | self.model.query(context) 97 | 98 | if(self.train): 99 | all_labels_pos = [b[1] for b in batch] 100 | # hard_neg = np.array([b[2][:self.num_hard_neg] for b in batch], dtype=np.int64) 101 | label_ids = np.zeros((self.num_labels, ), dtype=np.bool) 102 | 103 | label_ids[[x for subl in all_labels_pos for x in subl]] = 1 104 | label_ids[[x for b in batch for x in b[2]]] = 1 105 | # label_ids[np.ravel(hard_neg)] = 1 106 | 107 | random_neg = np.random.choice( 108 | np.where( 109 | label_ids == 0)[0], 110 | self.num_random, 111 | replace=False) 112 | label_ids[random_neg] = 1 113 | 114 | label_map = { 115 | x: i for i, x in enumerate( 116 | np.where( 117 | label_ids == 1)[0])} 118 | 119 | batch_Y = np.zeros((len(batch), len(label_map)), dtype=np.float32) 120 | for i, labels in enumerate(all_labels_pos): 121 | for l in labels: 122 | batch_Y[i][label_map[l]] = 1.0 123 | 124 | context["Y"] = torch.from_numpy(batch_Y) 125 | context["label_ids"] = torch.tensor(label_ids) 126 | else: 127 | if(not(batch[0][1] is None)): # prediction 128 | if(len(batch[0]) == 2): # shortlist per point 129 | context["label_ids"] = torch.LongTensor( 130 | [b[1] for b in batch]) 131 | elif(len(batch[0]) == 3): # OvA 132 | context["label_ids"] = None 133 | else: # embeddings calc 134 | context["indices"] = np.array( 135 | [b[2] for b in batch], dtype=np.int64) 136 | context['batch_size'] = len(batch) 137 | return context 138 | 139 | 140 | class DatasetGraphPrediction(torch.utils.data.Dataset): 141 | def __init__(self, start, end, prediction_shortlist): 142 | self.start = start 143 | self.end = end 144 | self.prediction_shortlist = prediction_shortlist 145 | 146 | def __getitem__(self, index): 147 | if(self.prediction_shortlist is None): 148 | return (index + self.start, "dummy", "dummy") 149 | return (index + self.start, self.prediction_shortlist[index]) 150 | 151 | def __len__(self): 152 | return self.end - self.start 153 | 154 | 155 | class DatasetGraphPredictionEncode(torch.utils.data.Dataset): 156 | def __init__(self, nodes): 157 | self.nodes = nodes 158 | 159 | def __getitem__(self, index): 160 | return (self.nodes[index], None, index) 161 | 162 | def __len__(self): 163 | return len(self.nodes) 164 | -------------------------------------------------------------------------------- /predict_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parameter import Parameter 7 | import torch.nn.functional as F 8 | import torch.utils.data 9 | 10 | import numpy as np 11 | import math 12 | import time 13 | import os 14 | import pickle 15 | import random 16 | import nmslib 17 | import sys 18 | from scipy.sparse import csr_matrix, lil_matrix, load_npz, hstack, vstack 19 | 20 | from xclib.data import data_utils 21 | from xclib.utils.sparse import normalize 22 | import xclib.evaluation.xc_metrics as xc_metrics 23 | 24 | from data import * 25 | from utils import * 26 | from network import HNSW 27 | 28 | 29 | def predict(net, pred_batch): 30 | """ 31 | head shorty None means predict OvA on head 32 | """ 33 | net.eval() 34 | torch.set_grad_enabled(False) 35 | 36 | out_ans = net.forward(pred_batch, False) 37 | out_ans = out_ans.detach().cpu().numpy() 38 | if(pred_batch["label_ids"] is None): 39 | return out_ans, None 40 | return out_ans, pred_batch["label_ids"].detach().cpu().numpy() 41 | 42 | 43 | def update_predicted(row_indices, predicted_batch_labels, 44 | predicted_labels, remapping, top_k): 45 | batch_size = row_indices.shape[0] 46 | top_values, top_indices = predicted_batch_labels.topk( 47 | k=top_k, dim=1, sorted=False) 48 | ind = np.zeros((top_k * batch_size, 2), dtype=np.int64) 49 | ind[:, 0] = np.repeat(row_indices, [top_k] * batch_size) 50 | if(remapping is not None): 51 | ind[:, 1] = [remapping[x] 52 | for x in top_indices.cpu().numpy().flatten('C')] 53 | else: 54 | ind[:, 1] = [x for x in top_indices.cpu().numpy().flatten('C')] 55 | vals = top_values.cpu().detach().numpy().flatten('C') 56 | predicted_labels[ind[:, 0], ind[:, 1]] = vals 57 | 58 | 59 | def update_predicted_shortlist( 60 | row_indices, predicted_batch_labels, predicted_labels, shortlist, remapping, top_k=10): 61 | if(len(predicted_batch_labels.shape) == 1): 62 | predicted_batch_labels = predicted_batch_labels[None, :] 63 | m = predicted_batch_labels.shape[0] 64 | 65 | top_indices = np.argsort(predicted_batch_labels, axis=1)[ 66 | :, ::-1][:, :top_k] 67 | top_values = predicted_batch_labels[np.arange(m)[:, None], top_indices] 68 | 69 | batch_size, shortlist_size = shortlist.shape 70 | ind = np.zeros((top_k * batch_size, 2), dtype=np.int) 71 | ind[:, 0] = np.repeat(row_indices, [top_k] * batch_size) 72 | 73 | if(remapping is not None): 74 | ind[:, 1] = [remapping[x] 75 | for x in np.ravel(shortlist[np.arange(m)[:, None], top_indices])] 76 | else: 77 | ind[:, 1] = [x for x in np.ravel( 78 | shortlist[np.arange(m)[:, None], top_indices])] 79 | 80 | predicted_labels[ind[:, 0], ind[:, 1]] = np.ravel(top_values) 81 | 82 | 83 | def run_validation(val_predicted_labels, tst_X_Y_val, 84 | tst_exact_remove, tst_X_Y_trn, inv_prop): 85 | data = [] 86 | indptr = [0] 87 | indices = [] 88 | for i in range(val_predicted_labels.shape[0]): 89 | _indices1 = val_predicted_labels.indices[val_predicted_labels.indptr[i]: val_predicted_labels.indptr[i + 1]] 90 | _vals1 = val_predicted_labels.data[val_predicted_labels.indptr[i]: val_predicted_labels.indptr[i + 1]] 91 | 92 | _indices, _vals = [], [] 93 | for _ind, _val in zip(_indices1, _vals1): 94 | if (_ind not in tst_exact_remove[i]) and ( 95 | _ind not in tst_X_Y_trn.indices[tst_X_Y_trn.indptr[i]: tst_X_Y_trn.indptr[i + 1]]): 96 | _indices.append(_ind) 97 | _vals.append(_val) 98 | 99 | indices += list(_indices) 100 | data += list(_vals) 101 | indptr.append(len(indices)) 102 | 103 | _pred = csr_matrix( 104 | (data, indices, indptr), shape=( 105 | val_predicted_labels.shape)) 106 | 107 | print(tst_X_Y_val.shape, _pred.shape) 108 | acc = xc_metrics.Metrics(tst_X_Y_val, inv_psp=inv_prop) 109 | acc = acc.eval(_pred, 5) 110 | _recall = recall(tst_X_Y_val, _pred, 100) 111 | return (acc, _recall), _pred 112 | 113 | 114 | def encode_nodes(net, context): 115 | net.eval() 116 | torch.set_grad_enabled(False) 117 | 118 | embed3 = net.third_layer_enc(context["encoder"]) 119 | embed2 = net.second_layer_enc(context["encoder"]["node_feats"]) 120 | embed1 = net.first_layer_enc( 121 | context["encoder"]["node_feats"]["node_feats"]) 122 | 123 | # embed = torch.stack((net.transform1(embed1.t()), net.transform2(embed2.t()), net.transform3(embed3.t())), dim=1) 124 | embed = torch.stack((embed1.t(), embed2.t(), embed3.t()), dim=1) 125 | embed = torch.mean(embed, dim=1) 126 | 127 | return embed 128 | 129 | 130 | def validate(head_net, params, partition_indices, label_remapping, 131 | label_embs, tst_point_embs, tst_X_Y_val, tst_exact_remove, tst_X_Y_trn, use_graph_embs, topK): 132 | _start = params["num_trn"] 133 | _end = _start + params["num_tst"] 134 | 135 | if(use_graph_embs): 136 | label_nodes = [label_remapping[i] for i in range(len(label_remapping))] 137 | 138 | val_dataset = DatasetGraphPredictionEncode(label_nodes) 139 | hce = GraphCollator(head_net, params["num_labels"], None, train=0) 140 | encode_loader = torch.utils.data.DataLoader( 141 | val_dataset, 142 | batch_size=500, 143 | num_workers=10, 144 | collate_fn=hce, 145 | shuffle=False, 146 | pin_memory=True) 147 | 148 | label_embs_graph = np.zeros( 149 | (len(label_nodes), params["hidden_dims"]), dtype=np.float32) 150 | cnt = 0 151 | for batch in encode_loader: 152 | # print (len(label_nodes), cnt*512) 153 | cnt += 1 154 | encoded = encode_nodes(head_net, batch) 155 | encoded = encoded.detach().cpu().numpy() 156 | label_embs_graph[batch["indices"]] = encoded 157 | 158 | val_dataset = DatasetGraphPredictionEncode( 159 | [i for i in range(_start, _end)]) 160 | hce = GraphCollator(head_net, params["num_labels"], None, train=0) 161 | encode_loader = torch.utils.data.DataLoader( 162 | val_dataset, 163 | batch_size=500, 164 | num_workers=10, 165 | collate_fn=hce, 166 | shuffle=False, 167 | pin_memory=True) 168 | 169 | tst_point_embs_graph = np.zeros( 170 | (params["num_tst"], params["hidden_dims"]), dtype=np.float32) 171 | for batch in encode_loader: 172 | encoded = encode_nodes(head_net, batch) 173 | encoded = encoded.detach().cpu().numpy() 174 | tst_point_embs_graph[batch["indices"]] = encoded 175 | 176 | label_features = label_embs_graph 177 | tst_point_features = tst_point_embs_graph 178 | else: 179 | label_features = label_embs 180 | tst_point_features = tst_point_embs[:params["num_tst"]] 181 | 182 | prediction_shortlists = [] 183 | BATCH_SIZE = 2000000 184 | 185 | t1 = time.time() 186 | for i in range(len(partition_indices)): 187 | print("building ANNS for partition = ", i) 188 | label_NGS = HNSW( 189 | M=100, 190 | efC=300, 191 | efS=params["num_shortlist"], 192 | num_threads=24) 193 | label_NGS.fit( 194 | label_features[partition_indices[i][0]: partition_indices[i][1]]) 195 | print("Done in ", time.time() - t1) 196 | t1 = time.time() 197 | 198 | tst_label_nbrs = np.zeros( 199 | (tst_point_features.shape[0], 200 | params["num_shortlist"]), 201 | dtype=np.int64) 202 | for i in range(0, tst_point_features.shape[0], BATCH_SIZE): 203 | print(i) 204 | _tst_label_nbrs, _ = label_NGS.predict( 205 | tst_point_features[i: i + BATCH_SIZE], params["num_shortlist"]) 206 | tst_label_nbrs[i: i + BATCH_SIZE] = _tst_label_nbrs 207 | 208 | prediction_shortlists.append(tst_label_nbrs) 209 | print("Done in ", time.time() - t1) 210 | t1 = time.time() 211 | 212 | if(len(partition_indices) == 1): 213 | prediction_shortlist = prediction_shortlists[0] 214 | else: 215 | prediction_shortlist = np.hstack(prediction_shortlists) 216 | print(prediction_shortlist.shape) 217 | 218 | del(prediction_shortlists) 219 | 220 | val_dataset = DatasetGraphPrediction(_start, _end, prediction_shortlist) 221 | hcp = GraphCollator(head_net, params["num_labels"], None, train=0) 222 | val_loader = torch.utils.data.DataLoader( 223 | val_dataset, 224 | batch_size=512, 225 | num_workers=10, 226 | collate_fn=hcp, 227 | shuffle=False, 228 | pin_memory=True) 229 | 230 | val_data = dict(val_labels=tst_X_Y_val[:params["num_tst"], :], 231 | val_loader=val_loader) 232 | 233 | val_predicted_labels = lil_matrix(val_data["val_labels"].shape) 234 | 235 | with torch.set_grad_enabled(False): 236 | for batch_idx, batch_data in enumerate(val_data["val_loader"]): 237 | val_preds, val_short = predict(head_net, batch_data) 238 | 239 | partition_length = val_short.shape[1] // len(partition_indices) 240 | for i in range(1, len(partition_indices)): 241 | val_short[:, i * 242 | partition_length: (i + 243 | 1) * 244 | partition_length] += partition_indices[i][0] 245 | 246 | update_predicted_shortlist((batch_data["inputs"]) - _start, val_preds, 247 | val_predicted_labels, val_short, None, topK) 248 | 249 | acc, _ = run_validation(val_predicted_labels.tocsr( 250 | ), val_data["val_labels"], tst_exact_remove, tst_X_Y_trn, params["inv_prop"]) 251 | print("acc = {}".format(acc)) 252 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | import torch.nn.functional as F 5 | import torch.utils.data 6 | 7 | import numpy as np 8 | import numba as nb 9 | import math 10 | import time 11 | import os 12 | import pickle 13 | import random 14 | import nmslib 15 | import sys 16 | from scipy.spatial import distance 17 | from scipy.sparse import csr_matrix, lil_matrix, load_npz, hstack, vstack 18 | 19 | from xclib.data import data_utils 20 | from xclib.utils.sparse import normalize 21 | import xclib.evaluation.xc_metrics as xc_metrics 22 | 23 | from network import * 24 | from data import * 25 | import predict_main 26 | 27 | 28 | def remap_label_indices(trn_point_titles, label_titles): 29 | label_remapping = {} 30 | _new_label_index = len(trn_point_titles) 31 | trn_title_2_index = {x: i for i, x in enumerate(trn_point_titles)} 32 | 33 | for i, x in enumerate(label_titles): 34 | if(x in trn_title_2_index.keys()): 35 | label_remapping[i] = trn_title_2_index[x] 36 | else: 37 | label_remapping[i] = _new_label_index 38 | _new_label_index += 1 39 | 40 | print("_new_label_index =", _new_label_index) 41 | return label_remapping 42 | 43 | 44 | def make_csr_from_ll(ll, num_z): 45 | data = [] 46 | indptr = [0] 47 | indices = [] 48 | for x in ll: 49 | indices += list(x) 50 | data += [1.0] * len(x) 51 | indptr.append(len(indices)) 52 | 53 | return csr_matrix((data, indices, indptr), shape=(len(ll), num_z)) 54 | 55 | 56 | @nb.njit(cache=True) 57 | def _recall(true_labels_indices, true_labels_indptr, 58 | pred_labels_data, pred_labels_indices, pred_labels_indptr, k): 59 | fracs = [] 60 | for i in range(len(true_labels_indptr) - 1): 61 | _true_labels = true_labels_indices[true_labels_indptr[i]: true_labels_indptr[i + 1]] 62 | _data = pred_labels_data[pred_labels_indptr[i]: pred_labels_indptr[i + 1]] 63 | _indices = pred_labels_indices[pred_labels_indptr[i]: pred_labels_indptr[i + 1]] 64 | top_inds = np.argsort(_data)[::-1][:k] 65 | _pred_labels = _indices[top_inds] 66 | if(len(_true_labels) > 0): 67 | fracs.append(len(set(_pred_labels).intersection( 68 | set(_true_labels))) / len(_true_labels)) 69 | return np.mean(np.array(fracs, dtype=np.float32)) 70 | 71 | 72 | def recall(true_labels, pred_labels, k): 73 | return _recall(true_labels.indices.astype(np.int64), true_labels.indptr, 74 | pred_labels.data, pred_labels.indices.astype(np.int64), pred_labels.indptr, k) 75 | 76 | 77 | def create_params_dict(args, node_features, trn_X_Y, 78 | graph, NUM_PARTITIONS, NUM_TRN_POINTS): 79 | DIM = node_features.shape[1] 80 | params = dict(hidden_dims=DIM, 81 | feature_dim=DIM, 82 | embed_dims=DIM, 83 | lr=args.lr, 84 | attention_lr=args.attention_lr 85 | ) 86 | params["batch_size"] = args.batch_size 87 | params["reduction"] = "mean" 88 | params["batch_div"] = False 89 | params["num_epochs"] = args.num_epochs 90 | params["num_HN_epochs"] = args.num_HN_epochs 91 | params["dlr_factor"] = args.dlr_factor 92 | params["adjust_lr_epochs"] = set( 93 | [int(x) for x in args.adjust_lr.strip().split(",")]) 94 | params["num_random_samples"] = args.num_random_samples 95 | params["devices"] = [x.strip() 96 | for x in args.devices.strip().split(",") if len(x.strip()) != 0] 97 | 98 | params["fanouts"] = [int(x.strip()) for x in args.fanouts.strip().split( 99 | ",") if len(x.strip()) != 0] 100 | params["num_partitions"] = NUM_PARTITIONS 101 | params["num_labels"] = trn_X_Y.shape[1] 102 | params["graph"] = graph 103 | params["num_trn"] = NUM_TRN_POINTS 104 | params["inv_prop"] = xc_metrics.compute_inv_propesity( 105 | trn_X_Y, args.A, args.B) 106 | params["num_shortlist"] = args.num_shortlist 107 | params["num_HN_shortlist"] = args.num_HN_shortlist 108 | params["restrict_edges_num"] = args.restrict_edges_num 109 | params["restrict_edges_head_threshold"] = args.restrict_edges_head_threshold 110 | params["random_shuffle_nbrs"] = args.random_shuffle_nbrs 111 | 112 | return params 113 | 114 | 115 | def sample_anns_nbrs(label_features, tst_point_features, num_nbrs=4): 116 | """ 117 | Only works for case when a single graph can be built on all labels 118 | """ 119 | BATCH_SIZE = 2000000 120 | 121 | t1 = time.time() 122 | print("building ANNS for neighbor sampling for NR scenario") 123 | label_NGS = HNSW(M=100, efC=300, efS=500, num_threads=24) 124 | label_NGS.fit(label_features) 125 | print("Done in ", time.time() - t1) 126 | t1 = time.time() 127 | 128 | tst_label_nbrs = np.zeros( 129 | (tst_point_features.shape[0], num_nbrs), dtype=np.int64) 130 | for i in range(0, tst_point_features.shape[0], BATCH_SIZE): 131 | print(i) 132 | _tst_label_nbrs, _ = label_NGS.predict( 133 | tst_point_features[i: i + BATCH_SIZE], num_nbrs) 134 | tst_label_nbrs[i: i + BATCH_SIZE] = _tst_label_nbrs 135 | 136 | print("Done in ", time.time() - t1) 137 | t1 = time.time() 138 | 139 | return tst_label_nbrs 140 | 141 | 142 | def prepare_data(trn_X_Y, tst_X_Y, trn_point_features, tst_point_features, label_features, 143 | trn_point_titles, tst_point_titles, label_titles, args): 144 | if(args.run_type == "PR"): 145 | tst_valid_inds = np.where( 146 | tst_X_Y.indptr[1:] - tst_X_Y.indptr[:-1] > 1)[0] 147 | # in original dataset some points in tst have no labels 148 | print("point with 0 labels:", np.sum( 149 | tst_X_Y.indptr[1:] - tst_X_Y.indptr[:-1] == 0)) 150 | 151 | valid_tst_point_features = tst_point_features[tst_valid_inds] 152 | valid_tst_X_Y = tst_X_Y[tst_valid_inds, :] 153 | 154 | val_adj_list = [valid_tst_X_Y.indices[valid_tst_X_Y.indptr[i] 155 | : valid_tst_X_Y.indptr[i + 1]] for i in range(len(valid_tst_X_Y.indptr) - 1)] 156 | 157 | val_adj_list_trn = [x[:(len(x) // 2)] for x in val_adj_list] 158 | val_adj_list_val = [x[(len(x) // 2):] for x in val_adj_list] 159 | 160 | adj_list = [trn_X_Y.indices[trn_X_Y.indptr[i]: trn_X_Y.indptr[i + 1]] 161 | for i in range(len(trn_X_Y.indptr) - 1)] + val_adj_list_trn 162 | 163 | trn_point_titles = trn_point_titles + \ 164 | [tst_point_titles[i] for i in tst_valid_inds] 165 | 166 | label_remapping = remap_label_indices(trn_point_titles, label_titles) 167 | adj_list = [[label_remapping[x] for x in subl] for subl in adj_list] 168 | 169 | temp = {v: k for k, v in label_remapping.items() if v >= 170 | len(trn_point_titles)} 171 | print("len(label_remapping), len(temp), len(trn_point_titles)", 172 | len(label_remapping), len(temp), len(trn_point_titles)) 173 | 174 | new_label_indices = sorted(list(temp.keys())) 175 | 176 | _x = [temp[x] for x in new_label_indices] 177 | new_label_features = label_features[_x] 178 | lengths = [ 179 | trn_point_features.shape, 180 | valid_tst_point_features.shape, 181 | new_label_features.shape] 182 | print("lengths, sum([x[0] for x in lengths])", 183 | lengths, sum([x[0] for x in lengths])) 184 | 185 | node_features = np.vstack( 186 | [trn_point_features, valid_tst_point_features, new_label_features]) 187 | print("node_features.shape", node_features.shape) 188 | 189 | # add connections only between trn and lbl, tst points are lone nodes 190 | # and thus are not included in convs 191 | adjecency_lists = [[] for i in range(node_features.shape[0])] 192 | for i, l in enumerate(adj_list): 193 | for x in l: 194 | adjecency_lists[i].append(x) 195 | adjecency_lists[x].append(i) 196 | 197 | tst_X_Y_val = make_csr_from_ll(val_adj_list_val, trn_X_Y.shape[1]) 198 | tst_X_Y_trn = make_csr_from_ll(val_adj_list_trn, trn_X_Y.shape[1]) 199 | 200 | trn_X_Y = vstack([trn_X_Y, tst_X_Y_trn]) 201 | 202 | NUM_TRN_POINTS = trn_point_features.shape[0] 203 | 204 | elif(args.run_type == "NR"): 205 | tst_X_Y_val = tst_X_Y 206 | tst_X_Y_trn = lil_matrix(tst_X_Y_val.shape).tocsr() 207 | valid_tst_point_features = tst_point_features 208 | 209 | adj_list = [trn_X_Y.indices[trn_X_Y.indptr[i]: trn_X_Y.indptr[i + 1]] 210 | for i in range(len(trn_X_Y.indptr) - 1)] 211 | 212 | trn_point_titles = trn_point_titles + tst_point_titles 213 | 214 | label_remapping = remap_label_indices(trn_point_titles, label_titles) 215 | adj_list = [[label_remapping[x] for x in subl] for subl in adj_list] 216 | 217 | temp = {v: k for k, v in label_remapping.items() if v >= 218 | len(trn_point_titles)} 219 | print("len(label_remapping), len(temp), len(trn_point_titles)", 220 | len(label_remapping), len(temp), len(trn_point_titles)) 221 | 222 | new_label_indices = sorted(list(temp.keys())) 223 | 224 | _x = [temp[x] for x in new_label_indices] 225 | new_label_features = label_features[_x] 226 | lengths = [ 227 | trn_point_features.shape, 228 | valid_tst_point_features.shape, 229 | new_label_features.shape] 230 | print("lengths, sum([x[0] for x in lengths])", 231 | lengths, sum([x[0] for x in lengths])) 232 | 233 | node_features = np.vstack( 234 | [trn_point_features, valid_tst_point_features, new_label_features]) 235 | print("node_features.shape", node_features.shape) 236 | 237 | print("len(adj_list)", len(adj_list)) 238 | 239 | adjecency_lists = [[] for i in range(node_features.shape[0])] 240 | for i, l in enumerate(adj_list): 241 | for x in l: 242 | adjecency_lists[i].append(x) 243 | adjecency_lists[x].append(i) 244 | 245 | tst_valid_inds = np.arange(tst_X_Y_val.shape[0]) 246 | 247 | NUM_TRN_POINTS = trn_point_features.shape[0] 248 | 249 | if(args.restrict_edges_num >= 3): 250 | head_labels = np.where( 251 | np.sum( 252 | trn_X_Y.astype( 253 | np.bool), 254 | axis=0) > args.restrict_edges_head_threshold)[0] 255 | print( 256 | "Restricting edges: Number of head labels = {}".format( 257 | len(head_labels))) 258 | 259 | for lbl in head_labels: 260 | _nid = label_remapping[lbl] 261 | distances = distance.cdist([node_features[_nid]], [ 262 | node_features[x] for x in adjecency_lists[_nid]], "cosine")[0] 263 | sorted_indices = np.argsort(distances) 264 | 265 | new_nbrs = [] 266 | for k in range(min(args.restrict_edges_num, len(sorted_indices))): 267 | new_nbrs.append(adjecency_lists[_nid][sorted_indices[k]]) 268 | adjecency_lists[_nid] = new_nbrs 269 | 270 | return tst_valid_inds, trn_X_Y, tst_X_Y_trn, tst_X_Y_val, node_features, valid_tst_point_features, label_remapping, adjecency_lists, NUM_TRN_POINTS 271 | 272 | 273 | def create_validation_data(valid_tst_point_features, label_features, tst_X_Y_val, 274 | args, params, TST_TAKE, NUM_PARTITIONS): 275 | """ 276 | Create validation data. For val accuracy pattern observation 277 | This won't provide correct valdation picture as init(not graph) embeddings used and tst connection not added 278 | """ 279 | if(TST_TAKE == -1): 280 | TST_TAKE = valid_tst_point_features.shape[0] 281 | 282 | if(args.validation_freq != -1 and args.predict_ova == 0): 283 | print("Creating shortlists for validation using base embeddings...") 284 | prediction_shortlists = [] 285 | t1 = time.time() 286 | 287 | for i in range(NUM_PARTITIONS): 288 | NGS = HNSW( 289 | M=100, 290 | efC=300, 291 | efS=params["num_shortlist"], 292 | num_threads=24) 293 | NGS.fit(label_features[partition_indices[i] 294 | [0]: partition_indices[i][1]]) 295 | 296 | prediction_shortlist, _ = NGS.predict( 297 | valid_tst_point_features[:TST_TAKE], params["num_shortlist"]) 298 | prediction_shortlists.append(prediction_shortlist) 299 | 300 | if(NUM_PARTITIONS == 1): 301 | prediction_shortlist = prediction_shortlists[0] 302 | else: 303 | prediction_shortlist = np.hstack( 304 | [x for x in prediction_shortlists]) 305 | del(prediction_shortlists) 306 | print("prediction_shortlist.shape", prediction_shortlist.shape) 307 | print("Time taken in creating shortlists per point(ms)", 308 | ((time.time() - t1) / prediction_shortlist.shape[0]) * 1000) 309 | 310 | if(args.validation_freq != -1): 311 | _start = params["num_trn"] 312 | _end = _start + TST_TAKE 313 | print("_start, _end = ", _start, _end) 314 | 315 | if(args.predict_ova == 0): 316 | val_dataset = DatasetGraphPrediction( 317 | _start, _end, prediction_shortlist) 318 | else: 319 | val_dataset = DatasetGraphPrediction(_start, _end, None) 320 | hcp = GraphCollator(head_net, params["num_labels"], None, train=0) 321 | val_loader = torch.utils.data.DataLoader( 322 | val_dataset, 323 | batch_size=512, 324 | num_workers=10, 325 | collate_fn=hcp, 326 | shuffle=False, 327 | pin_memory=False) 328 | 329 | val_data = dict(val_labels=tst_X_Y_val[:TST_TAKE, :], 330 | val_loader=val_loader) 331 | else: 332 | val_data = None 333 | 334 | return val_data 335 | 336 | 337 | def sample_hard_negatives(head_net, label_remapping, partition_indices, num_trn, params): 338 | label_nodes = [label_remapping[i] for i in range(len(label_remapping))] 339 | 340 | val_dataset = DatasetGraphPredictionEncode(label_nodes) 341 | hce = GraphCollator(head_net, params["num_labels"], None, train=0) 342 | encode_loader = torch.utils.data.DataLoader( 343 | val_dataset, 344 | batch_size=512, 345 | num_workers=4, 346 | collate_fn=hce, 347 | shuffle=False, 348 | pin_memory=True) 349 | 350 | label_embs_graph = np.zeros( 351 | (len(label_nodes), 352 | params["hidden_dims"]), 353 | dtype=np.float32) 354 | for batch in encode_loader: 355 | encoded = predict_main.encode_nodes(head_net, batch) 356 | encoded = encoded.detach().cpu().numpy() 357 | label_embs_graph[batch["indices"]] = encoded 358 | 359 | val_dataset = DatasetGraphPredictionEncode( 360 | [i for i in range(num_trn)]) 361 | hce = GraphCollator(head_net, params["num_labels"], None, train=0) 362 | encode_loader = torch.utils.data.DataLoader( 363 | val_dataset, 364 | batch_size=512, 365 | num_workers=4, 366 | collate_fn=hce, 367 | shuffle=False, 368 | pin_memory=True) 369 | 370 | trn_point_embs_graph = np.zeros( 371 | (num_trn, params["hidden_dims"]), dtype=np.float32) 372 | for batch in encode_loader: 373 | encoded = predict_main.encode_nodes(head_net, batch) 374 | encoded = encoded.detach().cpu().numpy() 375 | trn_point_embs_graph[batch["indices"]] = encoded 376 | 377 | label_features = label_embs_graph 378 | trn_point_features = trn_point_embs_graph 379 | 380 | prediction_shortlists_trn = [] 381 | BATCH_SIZE = 2000000 382 | 383 | t1 = time.time() 384 | for i in range(len(partition_indices)): 385 | print("building ANNS for partition = ", i) 386 | label_NGS = HNSW( 387 | M=100, 388 | efC=300, 389 | efS=params["num_HN_shortlist"], 390 | num_threads=24) 391 | label_NGS.fit( 392 | label_features[partition_indices[i][0]: partition_indices[i][1]]) 393 | print("Done in ", time.time() - t1) 394 | t1 = time.time() 395 | 396 | trn_label_nbrs = np.zeros( 397 | (trn_point_features.shape[0], 398 | params["num_HN_shortlist"]), 399 | dtype=np.int64) 400 | for i in range(0, trn_point_features.shape[0], BATCH_SIZE): 401 | print(i) 402 | _trn_label_nbrs, _ = label_NGS.predict( 403 | trn_point_features[i: i + BATCH_SIZE], params["num_HN_shortlist"]) 404 | trn_label_nbrs[i: i + BATCH_SIZE] = _trn_label_nbrs 405 | 406 | prediction_shortlists_trn.append(trn_label_nbrs) 407 | print("Done in ", time.time() - t1) 408 | t1 = time.time() 409 | 410 | if(len(partition_indices) == 1): 411 | prediction_shortlist_trn = prediction_shortlists_trn[0] 412 | else: 413 | prediction_shortlist_trn = np.hstack(prediction_shortlists_trn) 414 | 415 | return prediction_shortlist_trn 416 | -------------------------------------------------------------------------------- /train_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parameter import Parameter 7 | import torch.nn.functional as F 8 | import torch.utils.data 9 | 10 | import numpy as np 11 | import math 12 | import time 13 | import os 14 | import pickle 15 | import random 16 | import nmslib 17 | import sys 18 | import argparse 19 | import warnings 20 | import logging 21 | from scipy.spatial import distance 22 | from scipy.sparse import csr_matrix, lil_matrix, load_npz, hstack, vstack, save_npz 23 | 24 | from xclib.data import data_utils 25 | from xclib.utils.sparse import normalize 26 | import xclib.evaluation.xc_metrics as xc_metrics 27 | 28 | from collections import defaultdict, Counter 29 | from network import * 30 | from data import * 31 | from predict_main import * 32 | from utils import * 33 | 34 | torch.manual_seed(22) 35 | torch.cuda.manual_seed_all(22) 36 | np.random.seed(22) 37 | 38 | 39 | def test(): 40 | if(RUN_TYPE == "NR"): 41 | # introduce the tst points into the graph, assume all tst points known 42 | # at once. For larger graphs, doing ANNS on trn_points, labels work 43 | # equally well. 44 | tst_point_nbrs = sample_anns_nbrs( 45 | node_features, 46 | valid_tst_point_features, 47 | args.prediction_introduce_edges) 48 | val_adj_list_trn = [list(x) for x in tst_point_nbrs] 49 | 50 | for i, l in enumerate(val_adj_list_trn): 51 | for x in l: 52 | adjecency_lists[i + NUM_TRN_POINTS].append(x) 53 | new_graph = Graph( 54 | node_features, 55 | adjecency_lists, 56 | args.random_shuffle_nbrs) 57 | head_net.graph = new_graph 58 | 59 | t1 = time.time() 60 | validate(head_net, params, partition_indices, label_remapping, 61 | label_features, valid_tst_point_features, tst_X_Y_val, tst_exact_remove, tst_X_Y_trn, True, 100) 62 | print("Prediction time Per point(ms): ", 63 | ((time.time() - t1) / valid_tst_point_features.shape[0]) * 1000) 64 | 65 | 66 | def train(): 67 | if(args.mpt == 1): 68 | scaler = torch.cuda.amp.GradScaler() 69 | 70 | for epoch in range(params["num_epochs"]): 71 | 72 | epoch_train_start_time = time.time() 73 | head_net.train() 74 | torch.set_grad_enabled(True) 75 | 76 | num_batches = len(head_train_loader.dataset) // params["batch_size"] 77 | mean_loss = 0 78 | for batch_idx, batch_data in enumerate(head_train_loader): 79 | t1 = time.time() 80 | head_net.zero_grad() 81 | batch_size = batch_data['batch_size'] 82 | 83 | if(args.mpt == 1): 84 | with torch.cuda.amp.autocast(): 85 | out_ans = head_net.forward(batch_data) 86 | loss = head_criterion( 87 | out_ans, batch_data["Y"].to( 88 | out_ans.get_device())) 89 | elif(args.mpt == 0): 90 | out_ans = head_net.forward(batch_data) 91 | loss = head_criterion( 92 | out_ans, batch_data["Y"].to( 93 | out_ans.get_device())) 94 | 95 | if params["batch_div"]: 96 | loss = loss / batch_size 97 | mean_loss += loss.item() * batch_size 98 | 99 | if(args.mpt == 1): 100 | scaler.scale(loss).backward() # loss.backward() 101 | scaler.step(head_optimizer) # head_optimizer3.step() 102 | scaler.update() 103 | elif(args.mpt == 0): 104 | loss.backward() 105 | head_optimizer.step() 106 | del batch_data 107 | 108 | epoch_train_end_time = time.time() 109 | mean_loss /= len(head_train_loader.dataset) 110 | print( 111 | "Epoch: {}, loss: {}, time: {} sec".format( 112 | epoch, 113 | mean_loss, 114 | epoch_train_end_time - 115 | epoch_train_start_time)) 116 | logging.info( 117 | "Epoch: {}, loss: {}, time: {} sec".format( 118 | epoch, 119 | mean_loss, 120 | epoch_train_end_time - 121 | epoch_train_start_time)) 122 | 123 | if(epoch in params["adjust_lr_epochs"]): 124 | for param_group in head_optimizer.param_groups: 125 | param_group['lr'] = param_group['lr'] * params["dlr_factor"] 126 | 127 | if(val_data is not None and ((epoch == 0) or (epoch % args.validation_freq == 0) or (epoch == params["num_epochs"] - 1))): 128 | val_predicted_labels = lil_matrix(val_data["val_labels"].shape) 129 | 130 | t1 = time.time() 131 | with torch.set_grad_enabled(False): 132 | for batch_idx, batch_data in enumerate(val_data["val_loader"]): 133 | val_preds, val_short = predict(head_net, batch_data) 134 | 135 | if(not(val_short is None)): 136 | partition_length = val_short.shape[1] // len( 137 | partition_indices) 138 | for i in range(1, len(partition_indices)): 139 | val_short[:, i * 140 | partition_length: (i + 141 | 1) * 142 | partition_length] += partition_indices[i][0] 143 | 144 | update_predicted_shortlist((batch_data["inputs"]) - _start, val_preds, 145 | val_predicted_labels, val_short, None, 10) 146 | else: 147 | update_predicted(batch_data["inputs"] - _start, torch.from_numpy(val_preds), 148 | val_predicted_labels, None, 10) 149 | 150 | print( 151 | "Per point(ms): ", 152 | ((time.time() - t1) / val_predicted_labels.shape[0]) * 1000) 153 | acc, _ = run_validation(val_predicted_labels.tocsr( 154 | ), val_data["val_labels"], tst_exact_remove, tst_X_Y_trn, inv_prop) 155 | print("acc = {}".format(acc)) 156 | logging.info("acc = {}".format(acc)) 157 | 158 | 159 | if __name__ == "__main__": 160 | parser = argparse.ArgumentParser() 161 | 162 | parser.add_argument('--dataset', required=True, help='dataset name') 163 | parser.add_argument( 164 | '--devices', 165 | required=True, 166 | help=', separated list of devices to use for training') 167 | parser.add_argument( 168 | '--save-model', 169 | required=True, 170 | type=int, 171 | help='whether to save trained model or not') 172 | 173 | parser.add_argument( 174 | '--num-epochs', 175 | required=True, 176 | type=int, 177 | help='number of epochs to train the graph(with random negatives) for') 178 | parser.add_argument( 179 | '--num-HN-epochs', 180 | required=True, 181 | type=int, 182 | help='number of epochs to fine tune the classifiers for') 183 | parser.add_argument( 184 | '--batch-size', 185 | type=int, 186 | default=512, 187 | help='batch size to use') 188 | parser.add_argument( 189 | '--lr', 190 | required=True, 191 | type=float, 192 | help='learning rate for entire model except attention weights') 193 | parser.add_argument( 194 | '--attention-lr', 195 | required=True, 196 | type=float, 197 | help='learning rate for attention weights') 198 | parser.add_argument( 199 | '--adjust-lr', 200 | required=True, 201 | type=str, 202 | help=', separated epoch nums at which to adjust lr') 203 | parser.add_argument( 204 | '--dlr-factor', 205 | required=True, 206 | type=float, 207 | help='lr reduction factor') 208 | parser.add_argument( 209 | '--mpt', 210 | default="0", 211 | type=int, 212 | help='whether to do mixed precision training') 213 | 214 | parser.add_argument( 215 | '--restrict-edges-num', 216 | type=int, 217 | default=-1, 218 | help='take top neighbors when building graph') 219 | parser.add_argument( 220 | '--restrict-edges-head-threshold', 221 | type=int, 222 | default=3, 223 | help='take top neighbors for head labels having documents more than this') 224 | parser.add_argument( 225 | '--num-random-samples', 226 | required=True, 227 | type=int, 228 | help='num of batch random to sample') 229 | parser.add_argument( 230 | '--random-shuffle-nbrs', 231 | required=True, 232 | type=int, 233 | help='shuffle neighbors when sampling for a node') 234 | parser.add_argument( 235 | '--fanouts', 236 | default="3,3,3", 237 | type=str, 238 | help='fanouts for gcn') 239 | parser.add_argument( 240 | '--num-HN-shortlist', 241 | default=500, 242 | type=int, 243 | help='number of labels to shortlist for HN training') 244 | 245 | parser.add_argument( 246 | '--embedding-type', 247 | required=True, 248 | type=str, 249 | help='embedding type to use, a folder {embedding-type}CondensedData with embeddings files should be present') 250 | parser.add_argument( 251 | '--run-type', 252 | required=True, 253 | type=str, 254 | help='should be PR(Partial Reveal)/NR(No Reveal)') 255 | 256 | parser.add_argument( 257 | '--num-validation', 258 | default=25000, 259 | type=int, 260 | help='number of points to take for validation') 261 | parser.add_argument( 262 | '--validation-freq', 263 | default=6, 264 | type=int, 265 | help='validate after how many epochs, -1 means dont validate') 266 | parser.add_argument( 267 | '--num-shortlist', 268 | default=500, 269 | type=int, 270 | help='number of labels to shortlist per point for prediction') 271 | parser.add_argument( 272 | '--prediction-introduce-edges', 273 | default=4, 274 | type=int, 275 | help='number of edges to introduce from the test point') 276 | parser.add_argument( 277 | '--predict-ova', 278 | default=0, 279 | type=int, 280 | help='if to predict ova') 281 | 282 | parser.add_argument( 283 | '--A', 284 | default=0.55, 285 | type=float, 286 | help='param A for inv prop calculation') 287 | parser.add_argument( 288 | '--B', 289 | default=1.5, 290 | type=float, 291 | help='param B for inv prop calculation') 292 | 293 | # args = parser.parse_args() 294 | args, _ = parser.parse_known_args() 295 | print("***args=", args) 296 | 297 | DATASET = args.dataset 298 | NUM_PARTITIONS = len(args.devices.strip().split(",")) 299 | EMB_TYPE = args.embedding_type 300 | RUN_TYPE = args.run_type 301 | TST_TAKE = args.num_validation 302 | NUM_TRN_POINTS = -1 303 | 304 | ######################### Data load ######################### 305 | trn_point_titles = [ 306 | line.strip() for line in open( 307 | "{}/trn_X.txt".format(DATASET), 308 | "r", 309 | encoding="latin").readlines()] 310 | tst_point_titles = [ 311 | line.strip() for line in open( 312 | "{}/tst_X.txt".format(DATASET), 313 | "r", 314 | encoding="latin").readlines()] 315 | label_titles = [ 316 | line.strip() for line in open( 317 | "{}/Y.txt".format(DATASET), 318 | "r", 319 | encoding="latin").readlines()] 320 | print("len(trn_point_titles), len(tst_point_titles), len(label_titles) = ", len( 321 | trn_point_titles), len(tst_point_titles), len(label_titles)) 322 | 323 | trn_point_features = np.load( 324 | "{}/{}CondensedData/trn_point_embs.npy".format(DATASET, EMB_TYPE)) 325 | label_features = np.load( 326 | "{}/{}CondensedData/label_embs.npy".format(DATASET, EMB_TYPE)) 327 | tst_point_features = np.load( 328 | "{}/{}CondensedData/tst_point_embs.npy".format(DATASET, EMB_TYPE)) 329 | print( 330 | "trn_point_features.shape, tst_point_features.shape, label_features.shape", 331 | trn_point_features.shape, 332 | tst_point_features.shape, 333 | label_features.shape) 334 | 335 | trn_X_Y = data_utils.read_sparse_file( 336 | "{}/trn_X_Y.txt".format(DATASET), force_header=True) 337 | tst_X_Y = data_utils.read_sparse_file( 338 | "{}/tst_X_Y.txt".format(DATASET), force_header=True) 339 | 340 | tst_valid_inds, trn_X_Y, tst_X_Y_trn, tst_X_Y_val, node_features, valid_tst_point_features, label_remapping, adjecency_lists, NUM_TRN_POINTS = prepare_data(trn_X_Y, tst_X_Y, trn_point_features, tst_point_features, label_features, 341 | trn_point_titles, tst_point_titles, label_titles, args) 342 | 343 | hard_negs = [[] for i in range(node_features.shape[0])] 344 | 345 | print("trn_X_Y.shape, tst_X_Y_trn.shape, tst_X_Y_val.shape", 346 | trn_X_Y.shape, tst_X_Y_trn.shape, tst_X_Y_val.shape) 347 | 348 | temp = [ 349 | line.strip().split() for line in open( 350 | "{}/filter_labels_test.txt".format(DATASET), 351 | "r").readlines()] 352 | removed = defaultdict(list) 353 | for x in temp: 354 | removed[int(float(x[0]))].append(int(float(x[1]))) 355 | removed = dict(removed) 356 | del(temp) 357 | 358 | # remove from prediciton where label == point exactly text wise because 359 | # that is already removed from gt 360 | tst_exact_remove = { 361 | i: removed.get( 362 | tst_valid_inds[i], 363 | []) for i in range( 364 | len(tst_valid_inds))} 365 | print("len(tst_exact_remove)", len(tst_exact_remove)) 366 | 367 | print("node_features.shape, len(adjecency_lists)", 368 | node_features.shape, len(adjecency_lists)) 369 | graph = Graph(node_features, adjecency_lists, args.random_shuffle_nbrs) 370 | 371 | params = create_params_dict( 372 | args, 373 | node_features, 374 | trn_X_Y, 375 | graph, 376 | NUM_PARTITIONS, 377 | NUM_TRN_POINTS) 378 | print("***params=", params) 379 | 380 | ######################### M1/Phase1 Training(with random negatives) ## 381 | head_net = GalaXCBase(params["num_labels"], params["hidden_dims"], params["devices"], 382 | params["feature_dim"], params["fanouts"], params["graph"], params["embed_dims"]) 383 | 384 | head_optimizer = torch.optim.Adam([{'params': [head_net.classifier.classifiers[0].attention_weights], 'lr': params["attention_lr"]}, 385 | {"params": [param for name, param in head_net.named_parameters() if name != "classifier.classifiers.0.attention_weights"], "lr": params["lr"]}], lr=params["lr"]) 386 | 387 | # required to split classification layer onto multiple GPUs 388 | partition_size = math.ceil(trn_X_Y.shape[1] / NUM_PARTITIONS) 389 | partition_indices = [] 390 | for i in range(NUM_PARTITIONS): 391 | _start = i * partition_size 392 | _end = min(_start + partition_size, trn_X_Y.shape[1]) 393 | partition_indices.append((_start, _end)) 394 | 395 | print(partition_indices) 396 | 397 | val_data = create_validation_data(valid_tst_point_features, label_features, tst_X_Y_val, 398 | args, params, TST_TAKE, NUM_PARTITIONS) 399 | 400 | for handler in logging.root.handlers[:]: 401 | logging.root.removeHandler(handler) 402 | logging.basicConfig( 403 | filename="{}/GraphXML_log_{}.txt".format(DATASET, RUN_TYPE), level=logging.INFO) 404 | 405 | # training loop 406 | warnings.simplefilter('ignore') 407 | 408 | head_criterion = torch.nn.BCEWithLogitsLoss(reduction=params["reduction"]) 409 | print("Model parameters: ", params) 410 | print("Model configuration: ", head_net) 411 | 412 | head_train_dataset = DatasetGraph(trn_X_Y, hard_negs) 413 | print('Dataset Loaded') 414 | 415 | hc = GraphCollator( 416 | head_net, 417 | params["num_labels"], 418 | params["num_random_samples"], 419 | num_hard_neg=0) 420 | print('Collator created') 421 | 422 | head_train_loader = torch.utils.data.DataLoader( 423 | head_train_dataset, 424 | batch_size=params["batch_size"], 425 | num_workers=10, 426 | collate_fn=hc, 427 | shuffle=True, 428 | pin_memory=False 429 | ) 430 | 431 | inv_prop = xc_metrics.compute_inv_propesity(trn_X_Y, args.A, args.B) 432 | 433 | head_net.move_to_devices() 434 | 435 | if(args.mpt == 1): 436 | scaler = torch.cuda.amp.GradScaler() 437 | 438 | train() 439 | 440 | # should be kept as how many we want to test on 441 | params["num_tst"] = tst_X_Y_val.shape[0] 442 | 443 | if(args.save_model == 1): 444 | model_dir = "{}/GraphXMLModel{}".format(DATASET, RUN_TYPE) 445 | if not os.path.exists(model_dir): 446 | print("Making model dir...") 447 | os.makedirs(model_dir) 448 | 449 | torch.save( 450 | head_net.state_dict(), 451 | os.path.join( 452 | model_dir, 453 | "model_state_dict.pt")) 454 | with open(os.path.join(model_dir, "model_params.pkl"), "wb") as fout: 455 | pickle.dump(params, fout, protocol=4) 456 | 457 | if(params["num_HN_epochs"] <= 0): 458 | print("Accuracies with graph embeddings to shortlist:") 459 | test() 460 | sys.exit( 461 | "You have chosen not to fine tune classifiers using hard negatives by providing num_HN_epochs <= 0") 462 | 463 | print("==================================================================") 464 | 465 | ######################### M4/Phase2 Training(with hard negatives) #### 466 | print("***params=", params) 467 | print("****** Starting HN fine tuning of calssifiers ******") 468 | 469 | prediction_shortlist_trn = sample_hard_negatives( 470 | head_net, label_remapping, partition_indices, trn_X_Y.shape[0], params) 471 | 472 | head_criterion = torch.nn.BCEWithLogitsLoss(reduction=params["reduction"]) 473 | print("Model parameters: ", params) 474 | 475 | head_train_dataset = DatasetGraph(trn_X_Y, prediction_shortlist_trn) 476 | print('Dataset Loaded') 477 | 478 | params["num_tst"] = 25000 479 | 480 | head_optimizer = torch.optim.Adam([{'params': [head_net.classifier.classifiers[0].attention_weights], 'lr': params["attention_lr"]}, 481 | {"params": [param for name, param in head_net.classifier.named_parameters() if name != "classifiers.0.attention_weights"], "lr": params["lr"]}], lr=params["lr"]) 482 | 483 | validation_freq = 1 484 | 485 | hc = GraphCollator( 486 | head_net, 487 | params["num_labels"], 488 | 0, 489 | num_hard_neg=params["num_HN_shortlist"]) 490 | print('Collator created') 491 | 492 | head_train_loader = torch.utils.data.DataLoader( 493 | head_train_dataset, 494 | batch_size=params["batch_size"], 495 | num_workers=6, 496 | collate_fn=hc, 497 | shuffle=True, 498 | pin_memory=True 499 | ) 500 | 501 | inv_prop = xc_metrics.compute_inv_propesity(trn_X_Y, args.A, args.B) 502 | 503 | head_net.move_to_devices() 504 | 505 | if(args.mpt == 1): 506 | scaler = torch.cuda.amp.GradScaler() 507 | 508 | params["adjust_lr_epochs"] = np.arange(0, params["num_HN_epochs"], 4) 509 | params["num_epochs"] = params["num_HN_epochs"] 510 | 511 | train() 512 | 513 | print("==================================================================") 514 | print("Accuracies with graph embeddings to shortlist:") 515 | params["num_tst"] = tst_X_Y_val.shape[0] 516 | test() 517 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import nmslib 2 | from typing import Callable 3 | import logging 4 | import torch 5 | import numpy as np 6 | import math 7 | from scipy.sparse import csr_matrix, lil_matrix 8 | 9 | import torch.nn as nn 10 | from torch.nn.parameter import Parameter 11 | import torch.nn.functional as F 12 | import torch.utils.data 13 | 14 | 15 | class MeanAggregator(nn.Module): 16 | """Aggregates a node's embeddings using mean of neighbors' embeddings.""" 17 | 18 | def __init__(self, features: Callable[[torch.Tensor], torch.Tensor]): 19 | super(MeanAggregator, self).__init__() 20 | self.features = features 21 | 22 | def forward(self, neighs: torch.Tensor, node_count: int, device): 23 | neigh_feats = self.features(neighs).to(device) 24 | 25 | nb_count = int(neigh_feats.shape[0] / node_count) 26 | fv_by_node = neigh_feats.view( 27 | node_count, nb_count, neigh_feats.shape[-1]) 28 | return fv_by_node.mean(1) 29 | 30 | 31 | class SumAggregator(nn.Module): 32 | """Aggregates a node's embeddings using mean of neighbors' embeddings.""" 33 | 34 | def __init__(self, features: Callable[[torch.Tensor], torch.Tensor]): 35 | super(SumAggregator, self).__init__() 36 | self.features = features 37 | 38 | def forward(self, neighs: torch.Tensor, node_count: int, device): 39 | neigh_feats = self.features(neighs).to(device) 40 | 41 | nb_count = int(neigh_feats.shape[0] / node_count) 42 | fv_by_node = neigh_feats.view( 43 | node_count, nb_count, neigh_feats.shape[-1]) 44 | return fv_by_node.sum(1) 45 | 46 | 47 | class SaintEncoder(nn.Module): 48 | """Encode a node's using 'convolutional' GraphSaint approach.""" 49 | 50 | def __init__( 51 | self, 52 | features, 53 | query_func, 54 | device_name, 55 | feature_dim: int, 56 | aggregator: nn.Module, 57 | num_sample: int, 58 | intermediate_dim: int, 59 | embed_dim: int = 300, 60 | activation_fn: callable = F.relu, 61 | base_model=None, 62 | ): 63 | super(SaintEncoder, self).__init__() 64 | 65 | self.device_name = device_name 66 | if base_model: 67 | self.base_model = base_model 68 | self.features = features 69 | if query_func is None: 70 | self.query_func = self.query_feature 71 | else: 72 | self.query_func = query_func 73 | self.aggregator = aggregator 74 | self.num_sample = num_sample 75 | self.activation_fn = activation_fn 76 | self.weight_1 = nn.Parameter( 77 | torch.FloatTensor( 78 | embed_dim // 2, 79 | intermediate_dim)) 80 | self.weight_2 = nn.Parameter( 81 | torch.FloatTensor( 82 | embed_dim // 2, 83 | intermediate_dim)) 84 | nn.init.xavier_uniform_(self.weight_1) 85 | nn.init.xavier_uniform_(self.weight_2) 86 | 87 | def query( 88 | self, 89 | nodes: np.array, 90 | graph 91 | ): 92 | context = {} 93 | neigh_nodes = graph.sample_neighbors(nodes, self.num_sample)[ 94 | 0 95 | ].flatten() 96 | 97 | context["node_feats"] = self.query_func( 98 | nodes, graph 99 | ) 100 | 101 | context["neighbor_feats"] = self.query_func( 102 | neigh_nodes, graph 103 | ) 104 | context["node_count"] = len(nodes) 105 | return context 106 | 107 | def query_feature( 108 | self, 109 | nodes: np.array, 110 | graph 111 | ): 112 | features = graph.node_features(nodes) 113 | return features 114 | 115 | def forward(self, context: dict): 116 | """Generate embeddings for a batch of nodes.""" 117 | neigh_feats = self.aggregator.forward( 118 | context["neighbor_feats"], context["node_count"], self.device_name 119 | ) 120 | self_feats = self.features(context["node_feats"]).to(self.device_name) 121 | 122 | # print (neigh_feats.shape, self_feats.shape) 123 | combined = torch.cat( 124 | [self.weight_1.mm(self_feats.t()), self.weight_2.mm(neigh_feats.t())], dim=0) 125 | combined = self.activation_fn(combined) 126 | 127 | return combined 128 | 129 | 130 | class SageEncoder(nn.Module): 131 | """Encode a node's using 'convolutional' GraphSage approach.""" 132 | 133 | def __init__( 134 | self, 135 | features, 136 | query_func, 137 | device_name, 138 | feature_dim: int, 139 | aggregator: nn.Module, 140 | num_sample: int, 141 | intermediate_dim: int, 142 | embed_dim: int = 300, 143 | activation_fn: callable = F.relu, 144 | base_model=None, 145 | ): 146 | super(SageEncoder, self).__init__() 147 | 148 | self.device_name = device_name 149 | if base_model: 150 | self.base_model = base_model 151 | self.features = features 152 | if query_func is None: 153 | self.query_func = self.query_feature 154 | else: 155 | self.query_func = query_func 156 | self.aggregator = aggregator 157 | self.num_sample = num_sample 158 | self.activation_fn = activation_fn 159 | self.weight = nn.Parameter( 160 | torch.FloatTensor( 161 | embed_dim, 162 | 2 * intermediate_dim)) 163 | nn.init.xavier_uniform_(self.weight) 164 | 165 | def query( 166 | self, 167 | nodes: np.array, 168 | graph, 169 | ): 170 | context = {} 171 | neigh_nodes = graph.sample_neighbors(nodes, self.num_sample)[ 172 | 0 173 | ].flatten() 174 | 175 | context["node_feats"] = self.query_func( 176 | nodes, graph 177 | ) 178 | 179 | context["neighbor_feats"] = self.query_func( 180 | neigh_nodes, graph 181 | ) 182 | context["node_count"] = len(nodes) 183 | return context 184 | 185 | def query_feature( 186 | self, 187 | nodes: np.array, 188 | graph, 189 | ): 190 | features = graph.node_features( 191 | nodes 192 | ) 193 | return features 194 | 195 | def forward(self, context: dict): 196 | """Generate embeddings for a batch of nodes.""" 197 | neigh_feats = self.aggregator.forward( 198 | context["neighbor_feats"], context["node_count"], self.device_name 199 | ) 200 | self_feats = self.features(context["node_feats"]).to(self.device_name) 201 | combined = torch.cat([self_feats, neigh_feats], dim=1) 202 | combined = self.activation_fn(self.weight.mm(combined.t())) 203 | 204 | return combined 205 | 206 | 207 | class GINEncoder(nn.Module): 208 | """Encode a node's using 'convolutional' GIN approach.""" 209 | 210 | def __init__( 211 | self, 212 | features, 213 | query_func, 214 | device_name, 215 | feature_dim: int, 216 | aggregator: nn.Module, 217 | num_sample: int, 218 | intermediate_dim: int, 219 | embed_dim: int = 300, 220 | activation_fn: callable = F.relu, 221 | base_model=None, 222 | ): 223 | super(GINEncoder, self).__init__() 224 | 225 | self.device_name = device_name 226 | if base_model: 227 | self.base_model = base_model 228 | self.features = features 229 | if query_func is None: 230 | self.query_func = self.query_feature 231 | else: 232 | self.query_func = query_func 233 | self.aggregator = aggregator 234 | self.num_sample = num_sample 235 | self.activation_fn = activation_fn 236 | self.eps = nn.Parameter(torch.rand(1)) 237 | 238 | def query( 239 | self, 240 | nodes: np.array, 241 | graph 242 | ): 243 | context = {} 244 | neigh_nodes = graph.sample_neighbors(nodes, self.num_sample)[ 245 | 0 246 | ].flatten() 247 | 248 | context["node_feats"] = self.query_func( 249 | nodes, graph 250 | ) 251 | 252 | context["neighbor_feats"] = self.query_func( 253 | neigh_nodes, graph 254 | ) 255 | 256 | context["node_count"] = len(nodes) 257 | return context 258 | 259 | def query_feature( 260 | self, 261 | nodes: np.array, 262 | graph, 263 | ): 264 | features = graph.node_features( 265 | nodes 266 | ) 267 | return features 268 | 269 | def forward(self, context: dict): 270 | """Generate embeddings for a batch of nodes.""" 271 | 272 | neigh_feats = self.aggregator.forward( 273 | context["neighbor_feats"], context["node_count"], self.device_name 274 | ) 275 | self_feats = self.features(context["node_feats"]).to(self.device_name) 276 | 277 | combined = torch.add(neigh_feats, (1.0 + self.eps) * self_feats) 278 | return combined.t() 279 | 280 | 281 | class LinearChunk(nn.Module): 282 | """One part for distributed fully connected layer""" 283 | 284 | def __init__(self, input_size, output_size, device_embeddings, bias=True): 285 | super(LinearChunk, self).__init__() 286 | self.device_embeddings = device_embeddings 287 | self.input_size = input_size 288 | self.output_size = output_size 289 | self.weight = Parameter( 290 | torch.Tensor( 291 | self.output_size, 292 | self.input_size)) 293 | if bias: 294 | self.bias = Parameter(torch.Tensor(self.output_size, )) 295 | else: 296 | self.register_parameter('bias', None) 297 | self.attention_weights = Parameter(torch.Tensor(self.output_size, 3)) 298 | 299 | self.sparse = False 300 | self.reset_parameters() 301 | self.act = torch.nn.Softmax(dim=1) 302 | 303 | def forward(self, input): 304 | if(input[1] is None): 305 | w = self.weight.unsqueeze( 306 | 1) * (self.act(self.attention_weights).unsqueeze(2)) 307 | x = input[0].mm(w.view((-1, input[0].shape[-1])).t() 308 | ) + self.bias.view(-1) 309 | return x 310 | else: 311 | if len(input[1].shape) == 1: 312 | # 350K X 1 X 300 350K X 3 X 1 313 | # .permute(0, 2, 1).reshape(-1, 900) 314 | w = (self.weight[input[1]].unsqueeze( 315 | 1)) * (self.act(self.attention_weights[input[1]])).unsqueeze(2) 316 | x = input[0].mm(w.view((-1, input[0].shape[-1])).t() 317 | ) + self.bias[input[1]].view(-1) 318 | return x 319 | elif len(input[1].shape) == 2: 320 | short_weights = F.embedding(input[1].to(self.device_embeddings), 321 | self.weight, 322 | sparse=self.sparse).view(input[1].shape[0] * input[1].shape[1], -1) 323 | 324 | short_bias = F.embedding(input[1].to(self.device_embeddings), 325 | self.bias.view(-1, 1), 326 | sparse=self.sparse) 327 | 328 | short_att = F.embedding(input[1].to(self.device_embeddings), 329 | self.attention_weights, 330 | sparse=self.sparse).view(input[1].shape[0] * input[1].shape[1], -1) 331 | 332 | w = short_weights.unsqueeze( 333 | 1) * (self.act(short_att).unsqueeze(2)) 334 | x = input[0].unsqueeze(1).repeat(1, 335 | input[1].shape[1], 336 | 1) * w.view((input[1].shape[0], 337 | input[1].shape[1], 338 | input[0].shape[-1])) 339 | 340 | x = x.sum(axis=2) + short_bias.squeeze() 341 | return x 342 | 343 | def move_to_devices(self): 344 | super().to(self.device_embeddings) 345 | 346 | def reset_parameters(self): 347 | nn.init.normal_(self.attention_weights) 348 | stdv = 1. / math.sqrt(self.weight.size(1)) 349 | self.weight.data.uniform_(-stdv, stdv) 350 | if self.bias is not None: 351 | self.bias.data.uniform_(-stdv, stdv) 352 | 353 | 354 | class LinearDistributed(nn.Module): 355 | """Distributed fully connected layer""" 356 | 357 | def __init__(self, input_size, output_size, device_embeddings): 358 | super(LinearDistributed, self).__init__() 359 | self.num_partitions = len(device_embeddings) 360 | self.device_embeddings = device_embeddings 361 | self.input_size = input_size 362 | self.output_size = output_size 363 | 364 | self.partition_size = math.ceil(output_size / self.num_partitions) 365 | self.partition_indices = [] 366 | for i in range(self.num_partitions): 367 | _start = i * self.partition_size 368 | _end = min(_start + self.partition_size, output_size) 369 | self.partition_indices.append((_start, _end)) 370 | 371 | print(self.partition_indices) 372 | 373 | self.classifiers = nn.ModuleList() 374 | for i in range(len(self.device_embeddings)): 375 | output_size = self.partition_indices[i][1] - \ 376 | self.partition_indices[i][0] 377 | self.classifiers.append( 378 | LinearChunk( 379 | input_size, 380 | output_size, 381 | self.device_embeddings[i])) 382 | 383 | self.reset_parameters() 384 | 385 | def forward(self, input): 386 | if(input[1] is None): 387 | total_x = [] 388 | for i in range(len(self.device_embeddings)): 389 | embed = input[0].to(self.device_embeddings[i]) 390 | x = self.classifiers[i]((embed, None)) 391 | total_x.append(x.to(self.device_embeddings[0])) 392 | total_x = torch.cat(total_x, dim=1) 393 | return total_x 394 | else: 395 | if len(input[1].shape) == 1: 396 | total_x = [] 397 | for i in range(len(self.device_embeddings)): 398 | _start = self.partition_indices[i][0] 399 | _end = self.partition_indices[i][1] 400 | embed = input[0].to(self.device_embeddings[i]) 401 | indices = input[1][_start: _end] 402 | 403 | x = self.classifiers[i]((embed, indices)) 404 | total_x.append(x.to(self.device_embeddings[0])) 405 | total_x = torch.cat(total_x, dim=1) 406 | return total_x 407 | elif len(input[1].shape) == 2: 408 | partition_length = input[1].shape[1] // len( 409 | self.partition_indices) 410 | total_x = [] 411 | for i in range(len(self.device_embeddings)): 412 | embed = input[0].to(self.device_embeddings[i]) 413 | short = input[1][:, i * 414 | partition_length: (i + 1) * partition_length] 415 | x = self.classifiers[i]((embed, short)) 416 | total_x.append(x.to(self.device_embeddings[0])) 417 | total_x = torch.cat(total_x, dim=1) 418 | return total_x 419 | 420 | def move_to_devices(self): 421 | print("Moving to different devices...") 422 | for i in range(len(self.device_embeddings)): 423 | self.classifiers[i].move_to_devices() 424 | 425 | def reset_parameters(self): 426 | for i in range(len(self.device_embeddings)): 427 | self.classifiers[i].reset_parameters() 428 | 429 | 430 | class Residual(nn.Module): 431 | """Residual layer implementation""" 432 | 433 | def __init__(self, input_size, output_size, dropout, init='eye'): 434 | super(Residual, self).__init__() 435 | self.input_size = input_size 436 | self.output_size = output_size 437 | self.init = init 438 | self.dropout = dropout 439 | self.padding_size = self.output_size - self.input_size 440 | self.hidden_layer = nn.Sequential(nn.Linear(self.input_size, 441 | self.output_size), 442 | nn.BatchNorm1d(self.output_size), 443 | nn.ReLU(), 444 | nn.Dropout(self.dropout)) 445 | self.initialize(self.init) 446 | 447 | def forward(self, embed): 448 | temp = F.pad(embed, (0, self.padding_size), 'constant', 0) 449 | embed = self.hidden_layer(embed) + temp 450 | return embed 451 | 452 | def initialize(self, init_type): 453 | if init_type == 'random': 454 | nn.init.xavier_uniform_( 455 | self.hidden_layer[0].weight, 456 | gain=nn.init.calculate_gain('relu')) 457 | nn.init.constant_(self.hidden_layer[0].bias, 0.0) 458 | else: 459 | print("Using eye to initialize!") 460 | nn.init.eye_(self.hidden_layer[0].weight) 461 | nn.init.constant_(self.hidden_layer[0].bias, 0.0) 462 | 463 | 464 | class GalaXCBase(nn.Module): 465 | """Base class for GalaXC""" 466 | 467 | def __init__(self, num_labels, hidden_dims, device_names, 468 | feature_dim: int, 469 | fanouts: list, 470 | graph, 471 | embed_dim: int, 472 | dropout=0.5, num_clf_partitions=1, padding_idx=0): 473 | super(GalaXCBase, self).__init__() 474 | 475 | # only 1 or 2 hops are allowed. 476 | assert len(fanouts) in [1, 2, 3] 477 | 478 | self.graph = graph 479 | self.fanouts = fanouts 480 | self.num_labels = num_labels 481 | self.feature_dim = feature_dim 482 | self.hidden_dims = hidden_dims 483 | self.embed_dim = embed_dim 484 | self.device_names = device_names 485 | self.device_name = self.device_names[0] 486 | self.device_embeddings = torch.device(self.device_name) 487 | 488 | self.dropout = dropout 489 | self.padding_idx = padding_idx 490 | self.num_clf_partitions = num_clf_partitions 491 | 492 | self._construct_embeddings() 493 | self.transform1 = self._construct_transform() 494 | self.transform2 = self._construct_transform() 495 | self.transform3 = self._construct_transform() 496 | self.classifier = self._construct_classifier() 497 | 498 | def query(self, context: dict): 499 | context["encoder"] = self.third_layer_enc.query( 500 | context["inputs"], 501 | self.graph 502 | ) 503 | 504 | def _construct_transform(self): 505 | return nn.Sequential(nn.ReLU(), nn.Dropout(self.dropout), Residual( 506 | self.embed_dim, self.hidden_dims, self.dropout)) 507 | 508 | def _construct_classifier(self): 509 | return LinearDistributed( 510 | self.hidden_dims, self.num_labels, self.device_names) 511 | 512 | def _construct_embeddings(self): 513 | """ 514 | Some calculation is repeated. Optimizing doesn't help much, keeping for simplicity. 515 | """ 516 | def feature_func(features): return features.squeeze(0) 517 | 518 | self.first_layer_enc = GINEncoder( 519 | features=feature_func, 520 | query_func=None, 521 | feature_dim=self.feature_dim, 522 | intermediate_dim=self.feature_dim, 523 | aggregator=SumAggregator(feature_func), 524 | embed_dim=self.embed_dim, 525 | num_sample=self.fanouts[0], 526 | device_name=self.device_name 527 | ) 528 | 529 | self.second_layer_enc = GINEncoder( 530 | features=lambda context: self.first_layer_enc(context).t(), 531 | query_func=self.first_layer_enc.query, 532 | feature_dim=self.feature_dim, 533 | intermediate_dim=self.embed_dim, 534 | aggregator=SumAggregator( 535 | lambda context: self.first_layer_enc(context).t() 536 | ), 537 | embed_dim=self.embed_dim, 538 | num_sample=self.fanouts[1], 539 | base_model=self.first_layer_enc, 540 | device_name=self.device_name 541 | ) 542 | 543 | self.third_layer_enc = GINEncoder( 544 | features=lambda context: self.second_layer_enc(context).t(), 545 | query_func=self.second_layer_enc.query, 546 | feature_dim=self.feature_dim, 547 | intermediate_dim=self.embed_dim, 548 | aggregator=SumAggregator( 549 | lambda context: self.second_layer_enc(context).t() 550 | ), 551 | embed_dim=self.embed_dim, 552 | num_sample=self.fanouts[2], 553 | base_model=self.second_layer_enc, 554 | device_name=self.device_name 555 | ) 556 | 557 | def encode(self, context): 558 | embed3 = self.third_layer_enc(context["encoder"]) 559 | embed2 = self.second_layer_enc(context["encoder"]["node_feats"]) 560 | embed1 = self.first_layer_enc( 561 | context["encoder"]["node_feats"]["node_feats"]) 562 | 563 | embed = torch.cat( 564 | (self.transform1( 565 | embed1.t()), self.transform2( 566 | embed2.t()), self.transform3( 567 | embed3.t())), dim=1) 568 | return embed 569 | 570 | def encode_graph_embedding(self, context): 571 | embed = self.embeddings(context["encoder"], self.device_embeddings) 572 | return embed.t() 573 | 574 | def forward(self, batch_data, only_head=True): 575 | encoded = self.encode(batch_data) 576 | 577 | return self.classifier((encoded, batch_data["label_ids"])) 578 | 579 | def initialize_embeddings(self, word_embeddings): 580 | self.embeddings.weight.data.copy_(torch.from_numpy(word_embeddings)) 581 | 582 | def initialize_classifier(self, clf_weights): 583 | self.classifier.weight.data.copy_(torch.from_numpy(clf_weights[:, -1])) 584 | self.classifier.bias.data.copy_( 585 | torch.from_numpy(clf_weights[:, -1]).view(-1, 1)) 586 | 587 | def get_clf_weights(self): 588 | return self.classifier.get_weights() 589 | 590 | def move_to_devices(self): 591 | self.third_layer_enc.to(self.device_embeddings) 592 | self.transform1.to(self.device_embeddings) 593 | self.transform2.to(self.device_embeddings) 594 | self.transform3.to(self.device_embeddings) 595 | self.classifier.move_to_devices() 596 | 597 | @property 598 | def num_trainable_params(self): 599 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 600 | 601 | @property 602 | def model_size(self): 603 | return self.num_trainable_params * 4 / math.pow(2, 20) 604 | 605 | 606 | class HNSW(object): 607 | """HNSW ANNS implementation""" 608 | 609 | def __init__(self, M, efC, efS, num_threads): 610 | self.index = nmslib.init(method='hnsw', space='cosinesimil') 611 | self.M = M 612 | self.num_threads = num_threads 613 | self.efC = efC 614 | self.efS = efS 615 | 616 | def fit(self, data, print_progress=True): 617 | self.index.addDataPointBatch(data) 618 | self.index.createIndex( 619 | {'M': self.M, 620 | 'indexThreadQty': self.num_threads, 621 | 'efConstruction': self.efC}, 622 | print_progress=print_progress 623 | ) 624 | 625 | def _filter(self, output, num_search): 626 | indices = np.zeros((len(output), num_search), dtype=np.int32) 627 | distances = np.zeros((len(output), num_search), dtype=np.float32) 628 | for idx, item in enumerate(output): 629 | indices[idx] = item[0] 630 | distances[idx] = item[1] 631 | return indices, distances 632 | 633 | def predict(self, data, num_search): 634 | self.index.setQueryTimeParams({'efSearch': self.efS}) 635 | output = self.index.knnQueryBatch( 636 | data, k=num_search, num_threads=self.num_threads 637 | ) 638 | indices, distances = self._filter(output, num_search) 639 | return indices, distances 640 | 641 | def save(self, fname): 642 | nmslib.saveIndex(self.index, fname) 643 | --------------------------------------------------------------------------------