├── README.md └── src ├── approx_eval.py ├── config.py ├── evaluate.py ├── preprocess.py ├── train.py ├── train.sh ├── util ├── Gather.cpp ├── Gather.h ├── Makefile ├── setup.py └── util.pyx └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This repository has the official code for the algorithm MACH discussed in the NeurIPS 2019 paper [Extreme Classification in Log-Memory using Count-Min Sketch](https://papers.nips.cc/paper/9482-extreme-classification-in-log-memory-using-count-min-sketch-a-case-study-of-amazon-search-with-50m-products.pdf). 3 | MACH proposes a novel zero-communication distributed training method for Extreme Classification (classification with millions of classes). We project the huge output 4 | vector with millions of dimensions to a small dimensional count-min sketch (CMS) matrix. We then train indpenedent networks to predict each column of this CMS matrix 5 | instead of the hige label vector. 6 | 7 | If you find our approach interesting, please cite our paper with the following bibtex 8 | ``` 9 | @inproceedings{medini2019extreme, 10 | title={Extreme Classification in Log Memory using Count-Min Sketch: A Case Study of Amazon Search with 50M Products}, 11 | author={Medini, Tharun Kumar Reddy and Huang, Qixuan and Wang, Yiqiu and Mohan, Vijai and Shrivastava, Anshumali}, 12 | booktitle={Advances in Neural Information Processing Systems}, 13 | pages={13244--13254}, 14 | year={2019} 15 | } 16 | ``` 17 | 18 | # Download links for datasets 19 | Most of the public datasets that we use are available on [Extreme Classification Repository (XML Repo)](http://manikvarma.org/downloads/XC/XMLRepository.html). Specific links are as follows: 20 | 21 | 1. [Amazon-670K](https://drive.google.com/file/d/1TLaXCNB_IDtLhk4ycOnyud0PswWAW6hR/view) / [Kaggle Link](https://www.kaggle.com/c/extreme-classification-amazon) 22 | 2. [Delicious-200K](https://drive.google.com/file/d/0B3lPMIHmG6vGR3lBWWYyVlhDLWM/view) 23 | 3. [Wiki10-31K](http://manikvarma.org/downloads/XC/XMLRepository.html) 24 | 25 | After downloading any of the XML repo datasets, please unzip them and move the train and test files to any folder(s) of your choice. Update *train_data_loc* and *eval_data_loc* in *config.py*. 26 | 27 | 4. ODP dataset: [Train](http://hunch.net/~vw/odp_train.vw.gz) / [Test](http://hunch.net/~vw/odp_test.vw.gz) . 28 | The data format must be changed to match the datasets on Extreme Classification repo. 29 | 30 | 5. Fine-grained ImageNet-22K dataset: [Train](http://hunch.net/~jl/datasets/imagenet/training.txt.gz) / [Test](http://hunch.net/~jl/datasets/imagenet/testing.txt.gz) . 31 | Yet again, the data format must be changed to match the datasets on Extreme Classification repo. 32 | 33 | # Running MACH 34 | 35 | ## Requirements 36 | You are expected to have TensorFlow 1.x installed (1.8 - 1.14 should work) and have atleast 2 GPUs with 32GB memory (or 4 GPUs with 16 GB memory). We will add support for TensorFlow 2.x in subsequent versions. 37 | *Cython* is also required for importing a C++ function *gather_batch* during evaluation (if you cannot use C++ for any reason, please refer to the **Cython vs Python for evaluation** section below). 38 | *sklearn* is required for importing *murmurhash3_32* (from sklearn.utils). Although the version requirements for *cython* and *sklearn* are non that stringent as Tensorflow, 39 | use Cython-0.29.14 and sklearn-0.22.2 in case you run into any issues. 40 | 41 | ## Configuration 42 | After cloning the repo, move in to *src* folder and change the config.py file. Most of the configurations are self explanatory. Some non-trivial ones are: 43 | 1. *feat_dim_orig* corresponds to the original input dimension of the dataset. Since this might be huge for some datasets, we need to feature hash it to smaller dimension (set by *feat_hash_dim*) 44 | 2. *feat_hash_dim* is the smaller dimension that the input is hashed into. To avoid loss of information, we use different random seeds for hash functions in each independent model. 45 | 3. *lookups_loc* is the location to save murmurhash lookups for each model (each lookup is an *n_classes* dimensional integer array with each value ranging from [0,B-1]). 46 | 4. *B* stands for number of buckets (*B*<<*n_classes*) 47 | 5. *R* in eval_config is the number of models that we want to use for evaluation. 48 | 6. *R_per_gpu* stands for how many models can simultaneously be run on a single GPU. We can generally run upto 8 models (each with around 400M parameters) at once on a single V-100 GPU with 32 GB memory. 49 | 50 | ## Pre-processing 51 | We are going to use TFRecords format for the input data (we will soon add support for loading from txt files with tf.data). TFRecords allows the data to be streamed with provision for prefetching and pseudo-shuffling. 52 | Compared to writing a data loader to load from a plain .txt file, TFRecords reduces GPU idle time and speeds up the trainign by 2-3x. 53 | 54 | The code *preprocess.py* has 3 sub-parts. First one transforms the .txt data to TFRecords format. Second one creates lookups for classes. Third one creates lookups for input feature hashing. 55 | The line of data is assumed to be in the format *2980,3177,9026,9053,12256,12258 63:5.906890 87:3.700440 242:6.339850 499:2.584960 611:4.321930 672:2.807350* where 2980,...,12258 are true labels while subsequent 56 | ones are input feature indexes and their respective values. If you want to parse any other data format, please modify the function *create_tfrecords* in *utils.py*. 57 | 58 | Once you're clear with the data format, please run 59 | ``` 60 | python3 preprocess.py 61 | ``` 62 | 63 | ## Training 64 | The script *train.sh* has all the commands for running 16 models in parallel via tmux sessions. To run just one model, please run the following command. 65 | 66 | ``` 67 | python3 train.py --repetition=0 --gpu=0 --gpu_usage=0.45 68 | ``` 69 | 70 | *gpu_usage* limits the proportion of GPU memory that this model can use. By limiting it to 45%, we can train 2 models at once on each GPU. 71 | 72 | ## Evaluation 73 | Please run the follwoing commands to compile a Cython function. You may require *sudo* access to your machine for this step. In case you don't have it, read the section *Cython vs Python for evaluation* below. 74 | ``` 75 | cd src/util/ 76 | make clean 77 | make 78 | export PYTHONPATH=$(pwd) 79 | ``` 80 | 81 | After an error-free compilation, please change the config.py file to specify the number of models/repetitions *R*, which epoch's models to evaulate with, data paths, logfile paths etc. 82 | Adjust the *R_per_gpu* in case you run out of GPU memory. Then run 83 | 84 | ``` 85 | python3 evaluate.py 86 | ``` 87 | 88 | The time taken to load the lookups, weights and initialize the network is quite substantial. It takes >30 mins for 16 models each with ~400M parameters. The time printed in the log files doesn not account for 89 | the initialization. It only measures the proper evaluation time. 90 | 91 | # Some finer details 92 | 93 | ## Cython vs Python for evaluation 94 | The code evaluate.py imports a cython function *gather_batch* to reconstruct the scores for all classes from the predicted count-min sketch logits. It also gets the top-5 predictions based on these gathered scores 95 | using a priority queue implementation. We use simple *#pragma omp* parallelization across a batch of inputs. 96 | 97 | Alternatively, we can also do the same aggregation and partitioning of scores in python and parallelize it using *multiprocessing* *Pool* object. However, using exclusively python for this step cause the evaluation 98 | to be more than **5x slower**. Nevertheless, if you cannot use Cython, please comment the *import gather_batch* and running *gather_batch(...)* lines from *evaluate.py* and uncomment the following two snippets: 99 | ``` 100 | def process_logits(inp): 101 | R = inp.shape[0] 102 | B = inp.shape[1] 103 | ## 104 | scores = np.zeros(config.n_classes, dtype=float) 105 | ## 106 | for r in range(R): 107 | for b in range(B): 108 | val = inp[r,b] 109 | scores[inv_lookup[r, counts[r,b]:counts[r,b+1]]] += val 110 | ## 111 | top_idxs = np.argpartition(scores, -5)[-5:] 112 | temp = np.argsort(-scores[top_idxs]) 113 | return top_idxs[temp] 114 | 115 | p = Pool(config.n_cores) 116 | ``` 117 | 118 | and 119 | 120 | ``` 121 | top_preds = p.map(process_logits, logits_) 122 | ``` 123 | 124 | You should then be able to run 125 | ``` 126 | python3 evaluate.py 127 | ``` 128 | 129 | **Note:** pragma omp parallel doesn't allow huge matrices to be passed on. Hence adjust the *batch_size* in *eval_config* so that _batch_size*n_classes < 3 billion_. 130 | 131 | ## tf.constant vs tf.Variable 132 | In *evaluate.py*, the network initialization uses *tf.constant()* when loading saved weights instead of *tf.Variable()* as we do not have to train the weights again (it might be slightly faster too). 133 | However, some older versions of TensorFlow might not allow initializing a Tensor with a >2GB matrix. In that case, we need to define a *tf.placeholder()* first, then define a *tf.Variable(tf.placeholder())* 134 | and feed the matrix later during variable initialization. 135 | 136 | ## logits vs probabilities 137 | In the paper, the proposed unbiased estimator for recovering all probabilities is proportional to the sum of predicted bucket probabilities. However, for multilabel classification, it turns out thet summing up 138 | predicted bucket logits instead of probabilities gives better precision because logits have wide range of values compared to probabilities. Nevertheless, if you want to experiment with probabilities, you can 139 | simply changes 140 | 141 | ``` 142 | logits_, y_idxs = sess.run([logits, next_y_idxs]) 143 | ``` 144 | 145 | to 146 | 147 | ``` 148 | probs_, y_idxs = sess.run([probs, next_y_idxs]) 149 | ``` 150 | 151 | ## TF Record support 152 | ## 153 | -------------------------------------------------------------------------------- /src/approx_eval.py: -------------------------------------------------------------------------------- 1 | from config import eval_config as config 2 | import tensorflow as tf 3 | import time 4 | import numpy as np 5 | import logging 6 | import argparse 7 | import os 8 | import json 9 | import glob 10 | from utils import _parse_function 11 | from multiprocessing import Pool 12 | 13 | 14 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 15 | 16 | ############################## load lookups ################################ 17 | N = config.n_classes 18 | 19 | lookup = np.zeros([config.R,config.n_classes]).astype(int) 20 | inv_lookup = np.zeros([config.R,config.n_classes]).astype(int) 21 | counts = np.zeros([config.R,config.B+1]).astype(int) 22 | for r in range(config.R): 23 | lookup[r] = np.load(config.lookups_loc+'bucket_order_'+str(r)+'.npy')[:N] 24 | inv_lookup[r] = np.load(config.lookups_loc+'class_order_'+str(r)+'.npy')[:N] 25 | counts[r] = np.load(config.lookups_loc+'counts_'+str(r)+'.npy')[:config.B+1] 26 | 27 | query_lookup = np.empty([config.R, config.feat_dim_orig], dtype=int) 28 | for r in range(config.R): 29 | query_lookup[r] = np.load(config.query_lookups_loc+'bucket_order_'+str(r)+'.npy') 30 | 31 | ################################## 32 | W1 = [None for r in range(config.R)] 33 | b1 = [None for r in range(config.R)] 34 | hidden_layer = [None for r in range(config.R)] 35 | W2 = [None for r in range(config.R)] 36 | b2 = [None for r in range(config.R)] 37 | logits = [None for r in range(config.R)] 38 | probs = [None for r in range(config.R)] 39 | top_buckets = [None for i in range(config.R)] 40 | 41 | ################################## 42 | params = [np.load(config.model_loc+'r_'+str(r)+'_epoch_'+str(config.eval_epoch)+'.npz') for r in range(config.R)] 43 | W1_tmp = [params[r]['W1'] for r in range(config.R)] 44 | b1_tmp = [params[r]['b1'] for r in range(config.R)] 45 | W2_tmp = [params[r]['W2'] for r in range(config.R)] 46 | b2_tmp = [params[r]['b2'] for r in range(config.R)] 47 | 48 | ################# Data Loader #################### 49 | eval_files = glob.glob(config.tfrecord_loc+'*_test.tfrecords') 50 | 51 | dataset = tf.data.TFRecordDataset(eval_files) 52 | dataset = dataset.apply(tf.contrib.data.map_and_batch( 53 | map_func=_parse_function, batch_size=config.batch_size)) 54 | 55 | iterator = dataset.make_initializable_iterator() 56 | next_y_idxs, next_y_vals, next_x_idxs, next_x_vals = iterator.get_next() 57 | x = [tf.SparseTensor(tf.stack([next_x_idxs.indices[:,0], tf.gather(query_lookup[r], next_x_idxs.values)], axis=-1), 58 | next_x_vals.values, [config.batch_size, config.feat_hash_dim]) for r in range(config.R)] 59 | 60 | ############################## Create Graph ################################ 61 | for r in range(config.R): 62 | with tf.device('/gpu:'+str(r//config.R_per_gpu)): 63 | ###### 64 | # W1[r] = tf.Variable(tf.truncated_normal([config.feat_hash_dim, config.hidden_dim], stddev=0.05, dtype=tf.float32)) 65 | # b1[r] = tf.Variable(tf.truncated_normal([config.hidden_dim], stddev=0.05, dtype=tf.float32)) 66 | # hidden_layer[r] = tf.nn.relu(tf.sparse_tensor_dense_matmul(x[r],W1[r])+b1[r]) 67 | # # 68 | # W2[r] = tf.Variable(tf.truncated_normal([config.hidden_dim, config.B], stddev=0.05, dtype=tf.float32)) 69 | # b2[r] = tf.Variable(tf.truncated_normal([config.B], stddev=0.05, dtype=tf.float32)) 70 | ###### 71 | W1[r] = tf.constant(W1_tmp[r]) 72 | b1[r] = tf.constant(b1_tmp[r]) 73 | hidden_layer[r] = tf.nn.relu(tf.sparse_tensor_dense_matmul(x[r],W1[r])+b1[r]) 74 | # 75 | W2[r] = tf.constant(W2_tmp[r]) 76 | b2[r] = tf.constant(b2_tmp[r]) 77 | ###### 78 | logits[r] = tf.matmul(hidden_layer[r],W2[r])+b2[r] 79 | # probs[r] = tf.sigmoid(logits[r]) 80 | top_buckets[r] = tf.nn.top_k(logits[r], k=config.topk, sorted=True) 81 | 82 | tf_config = tf.ConfigProto() 83 | tf_config.gpu_options.allow_growth = True 84 | sess = tf.Session(config=tf_config) 85 | sess.run(tf.global_variables_initializer()) 86 | 87 | 88 | ################# Load Eval Files ##################### 89 | n_check = 1000 90 | count = 0 91 | score_sum = [0.0,0.0,0.0] 92 | 93 | ##### Run Graph Optimizer on first batch (might take ~50s) #### 94 | sess.run(iterator.initializer) 95 | top_buckets_, y_idxs = sess.run([top_buckets, next_y_idxs]) 96 | 97 | ###### Re-initialize the data loader #### 98 | sess.run(iterator.initializer) 99 | 100 | 101 | def process_scores(inp): 102 | R = inp.shape[0] 103 | topk = inp.shape[2] 104 | ## 105 | scores = {} 106 | freqs = {} 107 | ## 108 | for r in range(config.R): 109 | for k in range(topk): 110 | val = inp[r,0,k] 111 | ## 112 | for key in inv_lookup[r,counts[r,int(inp[r,1,k])]:counts[r,int(inp[r,1,k])+1]]: 113 | if key in scores: 114 | scores[key] += val 115 | freqs[key] += 1 116 | else: 117 | scores[key] = val 118 | freqs[key] = 1 119 | ## 120 | i = 0 121 | while True: 122 | candidates = np.array([key for key in scores if freqs[key]>=config.minfreq-i]) 123 | if len(candidates)>=5: 124 | break 125 | i += 1 126 | scores = np.array([scores[key] for key in candidates]) 127 | # ## 128 | top_idxs = np.argpartition(scores, -5)[-5:] 129 | temp = np.argsort(-scores[top_idxs]) 130 | return candidates[top_idxs[temp]] 131 | 132 | p = Pool(config.n_cores) 133 | 134 | begin_time = time.time() 135 | 136 | with open(config.logfile, 'a', encoding='utf-8') as fw: 137 | while True: 138 | try: 139 | top_buckets_, y_idxs = sess.run([top_buckets, next_y_idxs]) 140 | top_buckets_ = np.array(top_buckets_) 141 | top_buckets_ = np.transpose(top_buckets_, (2,0,1,3)) 142 | preds = p.map(process_scores, top_buckets_) 143 | ## 144 | curr_batch_size = y_idxs[2][0] 145 | labels = [[] for i in range(curr_batch_size)] 146 | for j in range(len(y_idxs[0])): 147 | labels[y_idxs[0][j,0]].append(y_idxs[1][j]) 148 | ## 149 | for i in range(curr_batch_size): 150 | true_labels = labels[i] 151 | #### P@1 152 | if preds[i][0] in true_labels: 153 | score_sum[0] += 1 154 | #### P@3 155 | score_sum[1] += len(np.intersect1d(preds[i][:3],true_labels))/min(len(true_labels),3) 156 | #### P@5 157 | score_sum[2] += len(np.intersect1d(preds[i],true_labels))/min(len(true_labels),5) 158 | count += 1 159 | if count%n_check==0: 160 | print('P@1 for',count,'points:',score_sum[0]/count, file=fw) 161 | print('P@3 for',count,'points:',score_sum[1]/count, file=fw) 162 | print('P@5 for',count,'points:',score_sum[2]/count, file=fw) 163 | print('time_elapsed: ',time.time()-begin_time, file=fw) 164 | except tf.errors.OutOfRangeError: 165 | print('overall P@1 for',count,'points:',score_sum[0]/count, file=fw) 166 | print('overall P@3 for',count,'points:',score_sum[1]/count, file=fw) 167 | print('overall P@5 for',count,'points:',score_sum[2]/count, file=fw) 168 | print('time_elapsed: ',time.time()-begin_time, file=fw) 169 | break 170 | 171 | p.close() 172 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | class train_config: 2 | train_data_loc = '../data/Amazon-3M/' 3 | tfrecord_loc = '../data/Amazon-3M/tfrecords/' 4 | model_save_loc = '../saved_models/Amazon-3M/b_10000/' 5 | query_lookups_loc = '../lookups/Amazon-3M/b_100000/' 6 | lookups_loc = '../lookups/Amazon-3M/b_10000/' 7 | logfile = '../logs/Amazon-3M/b_10000/' 8 | #### 9 | feat_dim_orig = 337067 10 | n_classes = 2812281 11 | #### 12 | n_cores = 4 # core count for TF REcord data loader 13 | B = 10000 14 | batch_size = 2000 15 | n_epochs = 10 16 | load_epoch = 0 17 | feat_hash_dim = 100000 18 | hidden_dim = 4096 19 | # Only used if training multiple repetitions from the same script 20 | R_per_gpu = 2 21 | 22 | class eval_config: 23 | query_lookups_loc = '../lookups/Amazon-3M/b_100000/' 24 | lookups_loc = '../lookups/Amazon-3M/b_10000/' 25 | model_loc = '../saved_models/Amazon-3M/b_10000/' 26 | eval_data_loc = '../data/Amazon-3M/' 27 | tfrecord_loc = '../data/Amazon-3M/tfrecords/' 28 | logfile = '../logs/Amazon-3M/eval_logs.txt' 29 | ### 30 | feat_dim_orig = 337067 31 | n_classes = 2812281 32 | ### 33 | B = 10000 34 | R = 16 35 | eval_epoch = 10 36 | R_per_gpu = 4 37 | num_gpus = 4 # R/R_per_gpu gpus 38 | n_cores = 32 # core count for parallelizable operations 39 | batch_size = 720 40 | feat_hash_dim = 100000 41 | hidden_dim = 4096 42 | ### only used by approx_eval.py (ignore if you are using evaluate.py) 43 | topk = 50 # how many top buckets to take 44 | minfreq = 2 # min number of times a class 45 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | from config import eval_config as config 2 | import tensorflow as tf 3 | import time 4 | import numpy as np 5 | import logging 6 | import argparse 7 | import os 8 | import json 9 | import glob 10 | from utils import _parse_function 11 | from multiprocessing import Pool 12 | 13 | try: 14 | from util import gather_batch 15 | from util import gather_K 16 | except: 17 | print('**********************CANNOT IMPORT GATHER***************************') 18 | exit() 19 | 20 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 21 | 22 | ############################## load lookups ################################ 23 | ## inv_lookup and counts are NOT needed if you are using C++ gather function (see further for details) 24 | 25 | N = config.n_classes 26 | 27 | lookup = np.zeros([config.R,config.n_classes]).astype(int) 28 | inv_lookup = np.zeros([config.R,config.n_classes]).astype(int) 29 | counts = np.zeros([config.R,config.B+1]).astype(int) 30 | for r in range(config.R): 31 | lookup[r] = np.load(config.lookups_loc+'bucket_order_'+str(r)+'.npy')[:N] 32 | inv_lookup[r] = np.load(config.lookups_loc+'class_order_'+str(r)+'.npy')[:N] 33 | counts[r] = np.load(config.lookups_loc+'counts_'+str(r)+'.npy')[:config.B+1] 34 | 35 | query_lookup = np.empty([config.R, config.feat_dim_orig], dtype=int) 36 | for r in range(config.R): 37 | query_lookup[r] = np.load(config.query_lookups_loc+'bucket_order_'+str(r)+'.npy') 38 | 39 | ##################### create empty lists for future tensors ################ 40 | W1 = [None for r in range(config.R)] 41 | b1 = [None for r in range(config.R)] 42 | hidden_layer = [None for r in range(config.R)] 43 | W2 = [None for r in range(config.R)] 44 | b2 = [None for r in range(config.R)] 45 | logits = [None for r in range(config.R)] 46 | probs = [None for r in range(config.R)] 47 | top_buckets = [None for i in range(config.R)] 48 | 49 | ####################### load saved weights ############################ 50 | #### If you just want to test the code for bugs, don't load these. 51 | #### Just use random weight when creating a TF graph (shown later) 52 | 53 | params = [np.load(config.model_loc+'r_'+str(r)+'_epoch_'+str(config.eval_epoch)+'.npz') for r in range(config.R)] 54 | W1_tmp = [params[r]['W1'] for r in range(config.R)] 55 | b1_tmp = [params[r]['b1'] for r in range(config.R)] 56 | W2_tmp = [params[r]['W2'] for r in range(config.R)] 57 | b2_tmp = [params[r]['b2'] for r in range(config.R)] 58 | 59 | ################# Create TF Data Loader #################### 60 | eval_files = glob.glob(config.tfrecord_loc+'*_test.tfrecords') 61 | 62 | dataset = tf.data.TFRecordDataset(eval_files) 63 | dataset = dataset.apply(tf.contrib.data.map_and_batch( 64 | map_func=_parse_function, batch_size=config.batch_size, num_parallel_calls=4)) 65 | # dataset = dataset.prefetch(buffer_size=10) 66 | iterator = dataset.make_initializable_iterator() 67 | next_y_idxs, next_y_vals, next_x_idxs, next_x_vals = iterator.get_next() 68 | x = [tf.SparseTensor(tf.stack([next_x_idxs.indices[:,0], tf.gather(query_lookup[r], next_x_idxs.values)], axis=-1), 69 | next_x_vals.values, [config.batch_size, config.feat_hash_dim]) for r in range(config.R)] 70 | 71 | ############################## Create Graph ################################ 72 | 73 | #### Uncomment these if you are using placeholders and writing your own data loader 74 | # x_idxs = tf.placeholder(tf.int64, [None, 2]) 75 | # # x_vals = tf.ones_like(x_idxs[:,0], dtype=tf.float32) 76 | # x_vals = tf.placeholder(tf.float32, [None,]) 77 | # x = tf.SparseTensor(x_idxs, x_vals, [config.batch_size, config.feat_hash_dim]) 78 | #### 79 | 80 | 81 | for r in range(config.R): 82 | with tf.device('/gpu:'+str(r//config.R_per_gpu)): 83 | ###### Random weight initialization to test for bugs 84 | # W1[r] = tf.Variable(tf.truncated_normal([config.feat_hash_dim, config.hidden_dim], stddev=0.05, dtype=tf.float32)) 85 | # b1[r] = tf.Variable(tf.truncated_normal([config.hidden_dim], stddev=0.05, dtype=tf.float32)) 86 | # hidden_layer[r] = tf.nn.relu(tf.sparse_tensor_dense_matmul(x[r],W1[r])+b1[r]) 87 | # # 88 | # W2[r] = tf.Variable(tf.truncated_normal([config.hidden_dim, config.B], stddev=0.05, dtype=tf.float32)) 89 | # b2[r] = tf.Variable(tf.truncated_normal([config.B], stddev=0.05, dtype=tf.float32)) 90 | ###### Load weights into tensors (tf.constant takes less memory than tf.Variable) 91 | W1[r] = tf.constant(W1_tmp[r]) 92 | b1[r] = tf.constant(b1_tmp[r]) 93 | hidden_layer[r] = tf.nn.relu(tf.sparse_tensor_dense_matmul(x[r],W1[r])+b1[r]) 94 | # 95 | W2[r] = tf.constant(W2_tmp[r]) 96 | b2[r] = tf.constant(b2_tmp[r]) 97 | ###### 98 | logits[r] = tf.matmul(hidden_layer[r],W2[r])+b2[r] 99 | probs[r] = tf.sigmoid(logits[r]) 100 | 101 | 102 | tf_config = tf.ConfigProto() 103 | tf_config.gpu_options.allow_growth = True 104 | sess = tf.Session(config=tf_config) 105 | sess.run(tf.global_variables_initializer()) 106 | 107 | ################ (Uncomment this snippet if you want to use python multiprocessing) 108 | # def process_logits(inp): 109 | # R = inp.shape[0] 110 | # B = inp.shape[1] 111 | # ## 112 | # scores = np.zeros(config.n_classes, dtype=float) 113 | # ## 114 | # for r in range(R): 115 | # for b in range(B): 116 | # val = inp[r,b] 117 | # scores[inv_lookup[r, counts[r,b]:counts[r,b+1]]] += val 118 | # ## 119 | # top_idxs = np.argpartition(scores, -5)[-5:] 120 | # temp = np.argsort(-scores[top_idxs]) 121 | # return top_idxs[temp] 122 | 123 | # p = Pool(config.n_cores) 124 | ################# 125 | 126 | ################# Evaluation begins ##################### 127 | n_check = 100 128 | count = 0 129 | overall_count = 0 130 | score_sum = [0.0,0.0,0.0] 131 | 132 | begin_time = time.time() 133 | 134 | sess.run(iterator.initializer) # initialize TF data loader 135 | 136 | with open(config.logfile, 'a', encoding='utf-8') as fw: 137 | while True: 138 | try: 139 | logits_, y_idxs = sess.run([logits, next_y_idxs]) 140 | logits_ = np.array(logits_) 141 | logits_ = np.transpose(logits_, (1,0,2)) 142 | logits_ = np.ascontiguousarray(logits_) 143 | curr_batch_size = y_idxs[2][0] 144 | ## C++ gather function (faster) 145 | scores = np.zeros([curr_batch_size, N], dtype=np.float32) 146 | top_preds = np.zeros([curr_batch_size, 5], dtype=np.int64) 147 | gather_batch(logits_, lookup, scores, top_preds, config.R, config.B, N, curr_batch_size, config.n_cores) 148 | ## python multiprocessing (~5x slower than C++ gather) 149 | # top_preds = p.map(process_logits, logits_) 150 | ## get true labels 151 | labels = [[] for i in range(curr_batch_size)] 152 | for j in range(len(y_idxs[0])): 153 | labels[y_idxs[0][j,0]].append(y_idxs[1][j]) 154 | ## 155 | for i in range(curr_batch_size): 156 | true_labels = labels[i] 157 | sorted_preds = top_preds[i] 158 | #### P@1 159 | if sorted_preds[0] in true_labels: 160 | score_sum[0] += 1 161 | #### P@3 162 | score_sum[1] += len(np.intersect1d(sorted_preds[:3],true_labels))/min(len(true_labels),3) 163 | #### P@5 164 | score_sum[2] += len(np.intersect1d(sorted_preds,true_labels))/min(len(true_labels),5) 165 | count += 1 166 | if count%n_check==0: 167 | print('P@1 for',count,'points:',score_sum[0]/count, file=fw) 168 | print('P@3 for',count,'points:',score_sum[1]/count, file=fw) 169 | print('P@5 for',count,'points:',score_sum[2]/count, file=fw) 170 | print('time_elapsed: ',time.time()-begin_time, file=fw) 171 | except tf.errors.OutOfRangeError: # this happens after the last batch of data is loaded 172 | print('overall P@1 for',count,'points:',score_sum[0]/count, file=fw) 173 | print('overall P@3 for',count,'points:',score_sum[1]/count, file=fw) 174 | print('overall P@5 for',count,'points:',score_sum[2]/count, file=fw) 175 | print('time_elapsed: ',time.time()-begin_time, file=fw) 176 | break 177 | 178 | -------------------------------------------------------------------------------- /src/preprocess.py: -------------------------------------------------------------------------------- 1 | from utils import input_example, create_tfrecords, create_universal_lookups, create_query_lookups 2 | import glob 3 | import time 4 | from multiprocessing import Pool 5 | from config import train_config as config 6 | 7 | ######## Create TF Records ########## 8 | begin_time = time.time() 9 | train_files = glob.glob(config.train_data_loc+'*.txt') 10 | for file in train_files: 11 | nothing = create_tfrecords(file) 12 | 13 | print('elapsed_time:', time.time()-begin_time) 14 | 15 | ########## Prepare Label lookups (for MACH grouping) 16 | begin_time = time.time() 17 | p = Pool(32) 18 | p.map(create_universal_lookups, list(range(32))) 19 | p.close() 20 | p.join() 21 | print('elapsed_time:', time.time()-begin_time) 22 | 23 | # ########## Prepare input idx lookups (for feature hashing) 24 | begin_time = time.time() 25 | p = Pool(32) 26 | p.map(create_query_lookups, list(range(32))) 27 | p.close() 28 | p.join() 29 | print('elapsed_time:', time.time()-begin_time) 30 | 31 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from config import train_config as config 2 | import tensorflow as tf 3 | import glob 4 | import argparse 5 | import time 6 | import numpy as np 7 | import logging 8 | from utils import _parse_function 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--repetition", help="which repetition?", default=0) 12 | parser.add_argument("--gpu", default=0) 13 | parser.add_argument("--gpu_usage", default=0.45) 14 | args = parser.parse_args() 15 | 16 | if not args.gpu=='all': 17 | import os 18 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 19 | 20 | r = int(args.repetition) # which repetition 21 | 22 | ############################## Test code from here ################################ 23 | lookup = tf.constant(np.load(config.lookups_loc+'bucket_order_'+str(r)+'.npy')) 24 | query_lookup = tf.constant(np.load(config.query_lookups_loc+'bucket_order_'+str(r)+'.npy')) 25 | 26 | train_files = glob.glob(config.tfrecord_loc+'*_train.tfrecords') 27 | 28 | dataset = tf.data.TFRecordDataset(train_files) 29 | # dataset = dataset.map(_parse_function, num_parallel_calls=4) 30 | # dataset = dataset.batch(config.batch_size) 31 | dataset = dataset.apply(tf.contrib.data.map_and_batch( 32 | map_func=_parse_function, batch_size=config.batch_size)) 33 | dataset = dataset.prefetch(buffer_size=1000) 34 | dataset = dataset.shuffle(buffer_size=1000) 35 | # dataset = dataset.repeat(config.n_epochs) 36 | iterator = dataset.make_initializable_iterator() 37 | next_y_idxs, next_y_vals, next_x_idxs, next_x_vals = iterator.get_next() 38 | ############### 39 | x_idxs = tf.stack([next_x_idxs.indices[:,0], tf.gather(query_lookup, next_x_idxs.values)], axis=-1) 40 | x_vals = next_x_vals.values 41 | x = tf.SparseTensor(x_idxs, x_vals, [config.batch_size, config.feat_hash_dim]) 42 | #### 43 | y_idxs = tf.stack([next_y_idxs.indices[:,0], tf.gather(lookup, next_y_idxs.values)], axis=-1) 44 | y_vals = next_y_vals.values 45 | y = tf.SparseTensor(y_idxs, y_vals, [config.batch_size, config.B]) 46 | y_ = tf.sparse_tensor_to_dense(y, validate_indices=False) 47 | ############### 48 | if config.load_epoch>0: 49 | params=np.load(config.model_save_loc+'r_'+str(r)+'_epoch_'+str(config.load_epoch)+'.npz') 50 | # 51 | W1_tmp = tf.placeholder(tf.float32, shape=[config.feat_hash_dim, config.hidden_dim]) 52 | b1_tmp = tf.placeholder(tf.float32, shape=[config.hidden_dim]) 53 | W1 = tf.Variable(W1_tmp) 54 | b1 = tf.Variable(b1_tmp) 55 | hidden_layer = tf.nn.relu(tf.sparse_tensor_dense_matmul(x,W1)+b1) 56 | # 57 | W2_tmp = tf.placeholder(tf.float32, shape=[config.hidden_dim, config.B]) 58 | b2_tmp = tf.placeholder(tf.float32, shape=[config.B]) 59 | W2 = tf.Variable(W2_tmp) 60 | b2 = tf.Variable(b2_tmp) 61 | logits = tf.matmul(hidden_layer,W2)+b2 62 | else: 63 | W1 = tf.Variable(tf.truncated_normal([config.feat_hash_dim, config.hidden_dim], stddev=0.05, dtype=tf.float32)) 64 | b1 = tf.Variable(tf.truncated_normal([config.hidden_dim], stddev=0.05, dtype=tf.float32)) 65 | hidden_layer = tf.nn.relu(tf.sparse_tensor_dense_matmul(x,W1)+b1) 66 | # 67 | W2 = tf.Variable(tf.truncated_normal([config.hidden_dim, config.B], stddev=0.05, dtype=tf.float32)) 68 | b2 = tf.Variable(tf.truncated_normal([config.B], stddev=0.05, dtype=tf.float32)) 69 | logits = tf.matmul(hidden_layer,W2)+b2 70 | 71 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=y_)) 72 | train_op = tf.train.AdamOptimizer().minimize(loss) 73 | 74 | sess = tf.Session(config = tf.ConfigProto( 75 | allow_soft_placement=True, 76 | log_device_placement=False, 77 | gpu_options=tf.GPUOptions(allow_growth=True, per_process_gpu_memory_fraction=float(args.gpu_usage)))) 78 | 79 | if config.load_epoch==0: 80 | sess.run(tf.global_variables_initializer()) 81 | else: 82 | sess.run(tf.global_variables_initializer(), 83 | feed_dict = { 84 | W1_tmp:params['W1'], 85 | b1_tmp:params['b1'], 86 | W2_tmp:params['W2'], 87 | b2_tmp:params['b2']}) 88 | del params 89 | 90 | begin_time = time.time() 91 | total_time = 0 92 | logging.basicConfig(filename = config.logfile+'logs_'+str(r), level=logging.INFO) 93 | n_check=100 94 | 95 | for curr_epoch in range(config.load_epoch+1,config.load_epoch+config.n_epochs+1): 96 | sess.run(iterator.initializer) 97 | count = 0 98 | while True: 99 | try: 100 | sess.run(train_op) 101 | count += 1 102 | if count%n_check==0: 103 | _, train_loss = sess.run([train_op, loss]) 104 | time_diff = time.time()-begin_time 105 | total_time += time_diff 106 | logging.info('finished '+str(count)+' steps. Time elapsed for last '+str(n_check)+' steps: '+str(time_diff)+' s') 107 | logging.info('train_loss: '+str(train_loss)) 108 | begin_time = time.time() 109 | count+=1 110 | except tf.errors.OutOfRangeError: 111 | break 112 | logging.info('###################################') 113 | logging.info('finished epoch '+str(curr_epoch)) 114 | logging.info('total time elapsed so far: '+str(total_time)) 115 | logging.info('###################################') 116 | if curr_epoch%5==0: 117 | params = sess.run([W1,b1,W2,b2]) 118 | np.savez_compressed(config.model_save_loc+'r_'+str(r)+'_epoch_'+str(curr_epoch)+'.npz', 119 | W1=params[0], 120 | b1=params[1], 121 | W2=params[2], 122 | b2=params[3]) 123 | del params 124 | begin_time = time.time() 125 | 126 | -------------------------------------------------------------------------------- /src/train.sh: -------------------------------------------------------------------------------- 1 | tmux new -d -s 0 'python3 train.py --repetition=0 --gpu=0 --gpu_usage=0.45' 2 | tmux new -d -s 1 'python3 train.py --repetition=1 --gpu=0 --gpu_usage=0.45' 3 | tmux new -d -s 2 'python3 train.py --repetition=2 --gpu=1 --gpu_usage=0.45' 4 | tmux new -d -s 3 'python3 train.py --repetition=3 --gpu=1 --gpu_usage=0.45' 5 | tmux new -d -s 4 'python3 train.py --repetition=4 --gpu=2 --gpu_usage=0.45' 6 | tmux new -d -s 5 'python3 train.py --repetition=5 --gpu=2 --gpu_usage=0.45' 7 | tmux new -d -s 6 'python3 train.py --repetition=6 --gpu=3 --gpu_usage=0.45' 8 | tmux new -d -s 7 'python3 train.py --repetition=7 --gpu=3 --gpu_usage=0.45' 9 | tmux new -d -s 8 'python3 train.py --repetition=8 --gpu=4 --gpu_usage=0.45' 10 | tmux new -d -s 9 'python3 train.py --repetition=9 --gpu=4 --gpu_usage=0.45' 11 | tmux new -d -s 10 'python3 train.py --repetition=10 --gpu=5 --gpu_usage=0.45' 12 | tmux new -d -s 11 'python3 train.py --repetition=11 --gpu=5 --gpu_usage=0.45' 13 | tmux new -d -s 12 'python3 train.py --repetition=12 --gpu=6 --gpu_usage=0.45' 14 | tmux new -d -s 13 'python3 train.py --repetition=13 --gpu=6 --gpu_usage=0.45' 15 | tmux new -d -s 14 'python3 train.py --repetition=14 --gpu=7 --gpu_usage=0.45' 16 | tmux new -d -s 15 'python3 train.py --repetition=15 --gpu=7 --gpu_usage=0.45' -------------------------------------------------------------------------------- /src/util/Gather.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "Gather.h" 3 | #include 4 | using namespace std; 5 | // raw - M, R, B 6 | // indices - R, N 7 | // result = M, N 8 | void cgather_batch(float* raw, long* lookup, float* result, long* top_preds, int R, int B, int N, int batch_size, int n_threads) 9 | { 10 | vector>> q(batch_size); 11 | #pragma omp parallel for num_threads(n_threads) 12 | for(int idx = 0; idx < batch_size; ++idx) 13 | { 14 | const int preds_offset = idx * 5; 15 | const int scores_offset = idx * N; 16 | for(int rdx = 0; rdx < R; ++rdx) 17 | { 18 | const int idx_offset = rdx * N; 19 | const int raw_offset = idx * R * B + rdx * B; 20 | 21 | for(int kdx = 0; kdx < N; ++kdx) 22 | { 23 | result[scores_offset + kdx] += raw[raw_offset + lookup[idx_offset + kdx]]; 24 | } 25 | } 26 | // filling the queue 27 | for(int i = 0; i < N; ++i) 28 | { 29 | if(q[idx].size()<5) 30 | q[idx].push(pair(-result[scores_offset + i], i)); 31 | else if(q[idx].top().first > -result[scores_offset + i]){ 32 | q[idx].pop(); 33 | q[idx].push(pair(-result[scores_offset + i], i)); 34 | } 35 | } 36 | // getting the top 5 classes 37 | for(long i = 4; i >=0 ; --i) 38 | { 39 | top_preds[preds_offset + i] = q[idx].top().second; 40 | q[idx].pop(); 41 | } 42 | } 43 | } 44 | 45 | void cgather_K(float* raw, long* lookup, float* result, long* top_preds, int R, int B, int N, int batch_size, int n_threads) 46 | { 47 | vector>> q(batch_size); 48 | for(int idx = 0; idx < batch_size; ++idx) 49 | { 50 | const int preds_offset = idx * 5; 51 | const int scores_offset = idx * N; 52 | for(int rdx = 0; rdx < R; ++rdx) 53 | { 54 | const int idx_offset = rdx * N; 55 | const int raw_offset = idx * R * B + rdx * B; 56 | 57 | #pragma omp parallel for num_threads(n_threads) 58 | for(int kdx = 0; kdx < N; ++kdx) 59 | { 60 | result[scores_offset + kdx] += raw[raw_offset + lookup[idx_offset + kdx]]; 61 | } 62 | } 63 | // filling the queue 64 | for(long i = 0; i < N; ++i) 65 | { 66 | if(q[idx].size()<5) 67 | q[idx].push(pair(-result[scores_offset + i], i)); 68 | else if(q[idx].top().first > -result[scores_offset + i]){ 69 | q[idx].pop(); 70 | q[idx].push(pair(-result[scores_offset + i], i)); 71 | } 72 | } 73 | // getting the top 5 classes 74 | for(long i = 4; i >= 0; --i) 75 | { 76 | top_preds[preds_offset + i] = q[idx].top().second; 77 | q[idx].pop(); 78 | } 79 | } 80 | } 81 | 82 | -------------------------------------------------------------------------------- /src/util/Gather.h: -------------------------------------------------------------------------------- 1 | void cgather_batch(float*, long*, float*, long*, int, int, int, int, int); 2 | void cgather_K(float*, long*, float*, long*, int, int, int, int, int); 3 | -------------------------------------------------------------------------------- /src/util/Makefile: -------------------------------------------------------------------------------- 1 | all: clean 2 | python3 setup.py build_ext --inplace 3 | 4 | clean: 5 | rm -rf build 6 | rm -rf __pycache__ 7 | rm -rf util.cpp 8 | rm -rf util.*.so -------------------------------------------------------------------------------- /src/util/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup(ext_modules = cythonize(Extension( 6 | "util", # the extension name 7 | sources=["util.pyx", "Gather.cpp"], # the Cython source and additional C++ source files 8 | language="c++", # generate and compile C++ code 9 | include_dirs=[numpy.get_include()], 10 | extra_compile_args=["-std=c++11", "-fopenmp", "-fopenmp-simd"], 11 | extra_link_args=["-fopenmp", "-fopenmp-simd"] 12 | ))) -------------------------------------------------------------------------------- /src/util/util.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | import cython 4 | 5 | cdef extern from "Gather.h": 6 | void cgather_batch(float*, long*, float*, long*, int, int, int, int, int) except + 7 | void cgather_K(float*, long*, float*, long*, int, int, int, int,int) except + 8 | 9 | @cython.boundscheck(False) 10 | def gather_batch(np.ndarray[float, ndim=3, mode="c"] raw, 11 | np.ndarray[long, ndim=2, mode="c"] indices, 12 | np.ndarray[float, ndim=2, mode="c"] scores, 13 | np.ndarray[long, ndim=2, mode="c"] top_preds, 14 | int R, int B, int N, int batch_size, int n_threads): 15 | cgather_batch(&raw[0,0,0], &indices[0,0], &scores[0,0], &top_preds[0,0], R, B, N, batch_size, n_threads) 16 | 17 | @cython.boundscheck(False) 18 | def gather_K(np.ndarray[float, ndim=3, mode="c"] raw, 19 | np.ndarray[long, ndim=2, mode="c"] indices, 20 | np.ndarray[float, ndim=2, mode="c"] scores, 21 | np.ndarray[long, ndim=2, mode="c"] top_preds, 22 | int R, int B, int N, int batch_size, int n_threads): 23 | cgather_K(&raw[0,0,0], &indices[0,0], &scores[0,0], &top_preds[0,0], R, B, N, batch_size, n_threads) 24 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from config import train_config as config 2 | from multiprocessing import Pool 3 | from sklearn.utils import murmurhash3_32 as mmh3 4 | import tensorflow as tf 5 | import glob 6 | import time 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | def create_universal_lookups(r): 11 | counts = np.zeros(config.B+1, dtype=int) 12 | bucket_order = np.zeros(config.n_classes, dtype=int) 13 | # 14 | for i in range(config.n_classes): 15 | bucket = mmh3(i,seed=r)%config.B 16 | bucket_order[i] = bucket 17 | counts[bucket+1] += 1 18 | # 19 | counts = np.cumsum(counts) 20 | rolling_counts = np.zeros(config.B, dtype=int) 21 | class_order = np.zeros(config.n_classes,dtype=int) 22 | for i in range(config.n_classes): 23 | temp = bucket_order[i] 24 | class_order[counts[temp]+rolling_counts[temp]] = i 25 | rolling_counts[temp] += 1 26 | np.save(config.lookups_loc+'class_order_'+str(r)+'.npy', class_order) 27 | np.save(config.lookups_loc+'counts_'+str(r)+'.npy',counts) 28 | np.save(config.lookups_loc+'bucket_order_'+str(r)+'.npy', bucket_order) 29 | 30 | def create_query_lookups(r): 31 | bucket_order = np.zeros(config.feat_dim_orig, dtype=int) 32 | # 33 | for i in range(config.feat_dim_orig): 34 | bucket = mmh3(i,seed=r)%config.feat_hash_dim 35 | bucket_order[i] = bucket 36 | np.save(config.query_lookups_loc+'bucket_order_'+str(r)+'.npy', bucket_order) 37 | 38 | def input_example(labels, label_vals, inp_idxs, inp_vals): # for writing TFRecords 39 | labels_list = tf.train.Int64List(value = labels) 40 | label_vals_list = tf.train.FloatList(value = label_vals) 41 | inp_idxs_list = tf.train.Int64List(value = inp_idxs) 42 | inp_vals_list = tf.train.FloatList(value = inp_vals) 43 | # Create a dictionary with above lists individually wrapped in Feature 44 | feature = { 45 | 'labels': tf.train.Feature(int64_list = labels_list), 46 | 'label_vals': tf.train.Feature(float_list = label_vals_list), 47 | 'input_idxs': tf.train.Feature(int64_list = inp_idxs_list), 48 | 'input_vals': tf.train.Feature(float_list = inp_vals_list) 49 | } 50 | # Create Example object with features 51 | example = tf.train.Example(features = tf.train.Features(feature=feature)) 52 | return example 53 | 54 | def create_tfrecords(file): 55 | f = open(file, 'r', encoding = 'utf-8') 56 | header = f.readline() 57 | write_loc = config.tfrecord_loc+file.split('/')[-1].split('.')[0] 58 | with tf.python_io.TFRecordWriter(write_loc+'.tfrecords') as writer: 59 | for line in f: 60 | itms = line.strip().split() 61 | y_idxs = [int(itm) for itm in itms[0].split(',')] 62 | y_vals = [1.0 for itm in range(len(y_idxs))] 63 | x_idxs = [int(itm.split(':')[0]) for itm in itms[1:]] 64 | x_vals = [float(itm.split(':')[1]) for itm in itms[1:]] 65 | ############################ 66 | tf_example = input_example(y_idxs, y_vals, x_idxs, x_vals) 67 | writer.write(tf_example.SerializeToString()) 68 | 69 | def _parse_function(example_proto): # for reading TFRecords 70 | features = {"labels": tf.VarLenFeature(tf.int64), 71 | "label_vals": tf.VarLenFeature(tf.float32), 72 | "input_idxs": tf.VarLenFeature(tf.int64), 73 | "input_vals": tf.VarLenFeature(tf.float32) 74 | } 75 | parsed_features = tf.parse_single_example(example_proto, features) 76 | return parsed_features["labels"], parsed_features["label_vals"], parsed_features["input_idxs"], parsed_features["input_vals"] 77 | 78 | --------------------------------------------------------------------------------