├── .gitignore ├── README.md ├── cnn.jpg ├── code ├── cnn_test.py ├── cnn_train.py ├── header.py ├── main.py ├── models │ ├── classifier.py │ ├── cnn_encoder.py │ ├── embedding_layer.py │ ├── header.py │ └── xmlCNN.py ├── precision_k.py └── test_manik.m └── utils ├── data_dive.py ├── data_helpers.py ├── fiddle_clusters.py ├── futils.py ├── loss.py ├── process_eurlex.py └── w2v.py /.gitignore: -------------------------------------------------------------------------------- 1 | *saved_models 2 | datasets 3 | *.npy 4 | *.npz 5 | *.pyc 6 | embedding_weights/* 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo is from my past life. I am now co-founder and CTO @ [Merlin](getmerlin.in). Check out our [blog](getmerlin.in/blog) and [extension](https://chromewebstore.google.com/detail/merlin-1-click-access-to/camppjleccjaphfdbohjdohecfnoikec) 2 | 3 | # XML-CNN 4 | Pytorch implementation of the paper [Deep Learning for Extreme Multi-label Text Classification](http://nyc.lti.cs.cmu.edu/yiming/Publications/jliu-sigir17.pdf) with dynamic pooling 5 | 6 | ## Dependencies 7 | 8 | * NLTK (stopwords) 9 | * Pytorch >= 0.3.1 10 | * Gensim 11 | * Matplotlib 12 | 13 | ![](cnn.jpg) 14 | 15 | Directory Structure: 16 | 17 | ``` 18 | +-- code 19 | | +-- cnn_test.py 20 | | +-- cnn_train.py 21 | | +-- header.py 22 | | +-- main.py 23 | | +-- models 24 | | | +-- classifier.py 25 | | | +-- cnn_encoder.py 26 | | | +-- embedding_layer.py 27 | | | +-- header.py 28 | | | +-- xmlCNN.py 29 | | +-- precision_k.py 30 | | +-- score_matrix.mat 31 | | +-- test_manik.m 32 | +-- datasets 33 | | +-- Folder where datasets need to be kept 34 | +-- embedding_weights 35 | +-- saved_models 36 | | +-- Directory created by default where models are saved 37 | +-- utils 38 | | +-- data_dive.py 39 | | +-- data_helpers.py 40 | | +-- fiddle_clusters.py 41 | | +-- futils.py 42 | | +-- loss.py 43 | | +-- process_eurlex.py 44 | | +-- w2v.py 45 | +-- README.md 46 | ``` 47 | Glove embeddings are needed by default as pre-training for the model. They can be download from [here](https://nlp.stanford.edu/data/glove.6B.zip) and placed in ```embedding_weights``` directory. The Default embedding dimension is 300 with 6 Billion (6B) tokens. Otherwise you can set --model_variation = 0 for starting from scratch. 48 | 49 | Sample dataset RCV can be downloaded from [here](http://cse.iitk.ac.in/users/siddsax/rcv.p). Trained model on RCV1 can be downloaded from below. 50 | 51 | [Trained model](http://cse.iitk.ac.in/users/siddsax/rcvSaved.pt) 52 | 53 | Precision Scores 54 | 55 | | 1 | 2 | 3 | 4 | 5 | 56 | |---|---|---|---|---| 57 | |0.96 | 0.8848 | 0.7809 | 0.6436 | 0.5457 | 58 | 59 | Note: The scores are slighltly higher as the tested dataset is part of the full test dataset. 60 | 61 | Procedure to train and test the model is as follows. The code by default doesn't plot graphs. Though that can be enabled wth a visdom server running [[Visdom](https://github.com/facebookresearch/visdom)] and turing on flag --d 62 | ```bash 63 | python main.py # train a model 64 | python main.py --mn=rcv # train a model and save in directory rcv [inside saved_models] 65 | ``` 66 | This will create multiple files inside the folder ```saved_models/rcv``` in the above case. Checkpoints are saved after every 67 | ```save_step``` epochs, this can be changed with ``--ss`` option in command line. Also a checkpoint is made according to best test precision@1 score and best training batch precision@1. 68 | 69 | ```bash 70 | python main.py --lm=$DESTINATION OF SAVED MODEL # This resumes trainign from the given checkpoint 71 | ``` 72 | 73 | In order to test the model run 74 | ```bash 75 | python main.py --lm=$DESTINATION OF SAVED MODEL --tr=0 76 | ``` 77 | 78 | This will first print the training error and then the test error while also saving a score_matrix.mat in the folder the model is loaded that can be used to run test scripts from [here](https://drive.google.com/open?id=0B3lPMIHmG6vGN0hSQjFJUHZ0YTg) 79 | 80 | -------------------------------------------------------------------------------- /cnn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siddsax/XML-CNN/78473661ae1d588c30740e1893a611bea568f197/cnn.jpg -------------------------------------------------------------------------------- /code/cnn_test.py: -------------------------------------------------------------------------------- 1 | from header import * 2 | from collections import OrderedDict 3 | from sklearn.metrics import log_loss 4 | 5 | # def pass(a, b, model, x_tr, Y, params): 6 | # # e_emb = model.embedding_layer.forward(x_tr[i:i+params.mb_size].view(params.mb_size, x_te.shape[1])) 7 | # # Y[i:i+params.mb_size,:] = model.classifier(e_emb).data 8 | # e_emb = model.embedding_layer.forward(x_tr[a:b].view(params.mb_size, x_tr.shape[1])) 9 | # Y[a:b,:] = model.classifier(e_emb).data 10 | 11 | # return Y 12 | 13 | def test_class(x_te, y_te, params, model=None, x_tr=None, y_tr=None, embedding_weights=None, verbose=True, save=True ): 14 | 15 | 16 | if(model==None): 17 | if(embedding_weights is None): 18 | print("Error: Embedding weights needed!") 19 | exit() 20 | else: 21 | model = xmlCNN(params, embedding_weights) 22 | # state_dict = torch.load(params.load_model + "/model_best", map_location=lambda storage, loc: storage) 23 | # new_state_dict = OrderedDict() 24 | # for k, v in state_dict.items(): 25 | # name = k[7:] 26 | # new_state_dict[name] = v 27 | # model.load_state_dict(new_state_dict) 28 | # del new_state_dict 29 | model = load_model(model, params.load_model) 30 | 31 | if(torch.cuda.is_available()): 32 | params.dtype_f = torch.cuda.FloatTensor 33 | params.dtype_i = torch.cuda.LongTensor 34 | model = model.cuda() 35 | else: 36 | params.dtype_f = torch.FloatTensor 37 | params.dtype_i = torch.LongTensor 38 | 39 | if(x_tr is not None and y_tr is not None): 40 | x_tr, _ = load_batch_cnn(x_tr, y_tr, params, batch=False) 41 | Y = np.zeros(y_tr.shape) 42 | rem = x_tr.shape[0]%params.mb_size 43 | for i in range(0, x_tr.shape[0] - rem, params.mb_size ): 44 | e_emb = model.embedding_layer.forward(x_tr[i:i+params.mb_size].view(params.mb_size, x_te.shape[1])) 45 | Y[i:i+params.mb_size,:] = model.classifier(e_emb).data 46 | if(rem): 47 | e_emb = model.embedding_layer.forward(x_tr[-rem:].view(rem, x_te.shape[1])) 48 | Y[-rem:, :] = model.classifier(e_emb).data 49 | loss = log_loss(y_tr, Y) 50 | prec = precision_k(y_tr.todense(), Y, 5) 51 | print('Test Loss; Precision Scores [1->5] {} {} {} {} {} Cross Entropy {};'.format(prec[0], prec[1], prec[2], prec[3], prec[4],loss)) 52 | 53 | 54 | x_te, _ = load_batch_cnn(x_te, y_te, params, batch=False) 55 | Y2 = np.zeros(y_te.shape) 56 | rem = x_te.shape[0]%params.mb_size 57 | for i in range(0,x_te.shape[0] - rem,params.mb_size): 58 | e_emb = model.embedding_layer.forward(x_te[i:i+params.mb_size].view(params.mb_size, x_te.shape[1])) 59 | Y2[i:i+params.mb_size,:] = model.classifier(e_emb).data 60 | 61 | if(rem): 62 | e_emb = model.embedding_layer.forward(x_te[-rem:].view(rem, x_te.shape[1])) 63 | Y2[-rem:,:] = model.classifier(e_emb).data 64 | 65 | loss = log_loss(y_te, Y2) # Reverse of pytorch 66 | #print("A") 67 | prec = precision_k(y_te.todense(), Y2, 5) # Reverse of pytorch 68 | print('Test Loss; Precision Scores [1->5] {} {} {} {} {} Cross Entropy {};'.format(prec[0], prec[1], prec[2], prec[3], prec[4],loss)) 69 | 70 | if(save): 71 | Y_probabs2 = sparse.csr_matrix(Y2) 72 | sio.savemat('/'.join(params.load_model.split('/')[-1]) + '/score_matrix.mat' , {'score_matrix': Y_probabs2}) 73 | 74 | return prec[0], loss 75 | -------------------------------------------------------------------------------- /code/cnn_train.py: -------------------------------------------------------------------------------- 1 | from header import * 2 | from cnn_test import * 3 | 4 | # --------------------------------------------------------------------------------- 5 | 6 | def train(x_tr, y_tr, x_te, y_te, embedding_weights, params): 7 | 8 | viz = Visdom() 9 | loss_best = float('Inf') 10 | bestTotalLoss = float('Inf') 11 | best_test_acc = 0 12 | max_grad = 0 13 | 14 | num_mb = np.ceil(params.N/params.mb_size) 15 | 16 | model = xmlCNN(params, embedding_weights) 17 | if(torch.cuda.is_available()): 18 | print("--------------- Using GPU! ---------") 19 | model.params.dtype_f = torch.cuda.FloatTensor 20 | model.params.dtype_i = torch.cuda.LongTensor 21 | 22 | model = model.cuda() 23 | else: 24 | model.params.dtype_f = torch.FloatTensor 25 | model.params.dtype_i = torch.LongTensor 26 | print("=============== Using CPU =========") 27 | 28 | optimizer = optim.Adam(filter(lambda p: p.requires_grad,model.parameters()), lr=params.lr) 29 | print(model);print("%"*100) 30 | 31 | if params.dataparallel: 32 | model = nn.DataParallel(model) 33 | 34 | if(len(params.load_model)): 35 | params.model_name = params.load_model 36 | print(params.load_model) 37 | model, optimizer, init = load_model(model, params.load_model, optimizer=optimizer) 38 | else: 39 | init = 0 40 | iteration = 0 41 | # =============================== TRAINING ==================================== 42 | for epoch in range(init, params.num_epochs): 43 | totalLoss = 0.0 44 | 45 | for i in range(int(num_mb)): 46 | # ------------------ Load Batch Data --------------------------------------------------------- 47 | batch_x, batch_y = load_batch_cnn(x_tr, y_tr, params) 48 | # ----------------------------------------------------------------------------------- 49 | loss, output = model.forward(batch_x, batch_y) 50 | loss = loss.mean().squeeze() 51 | # -------------------------------------------------------------------- 52 | 53 | totalLoss += loss.data 54 | 55 | if i % int(num_mb/12) == 0: 56 | print('Iter-{}; Loss: {:.4}; best_loss: {:.4}; max_grad: {}:'.format(i, loss.data, loss_best, max_grad)) 57 | if not os.path.exists('../saved_models/' + params.model_name ): 58 | os.makedirs('../saved_models/' + params.model_name) 59 | save_model(model, optimizer, epoch, params.model_name + "/model_best_batch") 60 | if(loss best_test_acc): 116 | best_test_loss = test_ce_loss 117 | best_test_acc = test_prec_acc 118 | print("This acc is better than the previous recored test acc:- {} ; while CELoss:- {}".format(best_test_acc, best_test_loss)) 119 | if not os.path.exists('../saved_models/' + params.model_name ): 120 | os.makedirs('../saved_models/' + params.model_name) 121 | save_model(model, optimizer, epoch, params.model_name + "/model_best_test") 122 | 123 | if epoch % params.save_step == 0: 124 | save_model(model, optimizer, epoch, params.model_name + "/model_" + str(epoch)) 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /code/header.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.autograd as autograd 4 | import torch.optim as optim 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import matplotlib.gridspec as gridspec 8 | import os 9 | from torch.autograd import Variable 10 | import sys 11 | import numpy as np 12 | sys.path.append('../utils/') 13 | sys.path.append('models') 14 | import data_helpers 15 | 16 | from w2v import * 17 | from embedding_layer import embedding_layer 18 | from cnn_encoder import cnn_encoder 19 | from sklearn import preprocessing 20 | from sklearn.decomposition import PCA 21 | import scipy.io as sio 22 | from scipy import sparse 23 | import argparse 24 | from visdom import Visdom 25 | from sklearn.externals import joblib 26 | from futils import * 27 | from loss import loss 28 | from xmlCNN import xmlCNN 29 | import timeit 30 | from precision_k import precision_k -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | from header import * 2 | from cnn_train import * 3 | from cnn_test import * 4 | import pdb 5 | 6 | # ------------------------ Params ------------------------------------------------------------------------------- 7 | parser = argparse.ArgumentParser(description='Process some integers.') 8 | 9 | parser.add_argument('--zd', dest='Z_dim', type=int, default=100, help='Latent layer dimension') 10 | parser.add_argument('--mb', dest='mb_size', type=int, default=20, help='Size of minibatch, changing might result in latent layer variance overflow') 11 | # parser.add_argument('--hd', dest='h_dim', type=int, default=600, help='hidden layer dimension') 12 | parser.add_argument('--lr', dest='lr', type=int, default=1e-3, help='Learning Rate') 13 | parser.add_argument('--p', dest='plot_flg', type=int, default=0, help='1 to plot, 0 to not plot') 14 | parser.add_argument('--e', dest='num_epochs', type=int, default=100, help='step for displaying loss') 15 | 16 | parser.add_argument('--d', dest='disp_flg', type=int, default=0, help='display graphs') 17 | parser.add_argument('--sve', dest='save', type=int, default=1, help='save models or not') 18 | parser.add_argument('--ss', dest='save_step', type=int, default=10, help='gap between model saves') 19 | parser.add_argument('--mn', dest='model_name', type=str, default='', help='model name') 20 | parser.add_argument('--tr', dest='training', type=int, default=1, help='model name') 21 | parser.add_argument('--lm', dest='load_model', type=str, default="", help='model name') 22 | parser.add_argument('--ds', dest='data_set', type=str, default="rcv", help='dataset name') 23 | 24 | parser.add_argument('--pp', dest='pp_flg', type=int, default=0, help='1 is for min-max pp, 2 is for gaussian pp, 0 for none') 25 | parser.add_argument('--loss', dest='loss_type', type=str, default="BCELoss", help='Loss') 26 | 27 | parser.add_argument('--hidden_dims', type=int, default=512, help='hidden layer dimension') 28 | parser.add_argument('--sequence_length',help='max sequence length of a document', type=int,default=500) 29 | parser.add_argument('--embedding_dim', help='dimension of word embedding representation', type=int, default=300) 30 | parser.add_argument('--model_variation', help='model variation: CNN-rand or CNN-pretrain', type=str, default='pretrain') 31 | parser.add_argument('--pretrain_type', help='pretrain model: GoogleNews or glove', type=str, default='glove') 32 | parser.add_argument('--vocab_size', help='size of vocabulary keeping the most frequent words', type=int, default=30000) 33 | parser.add_argument('--drop_prob', help='Dropout probability', type=int, default=.3) 34 | parser.add_argument('--load_data', help='Load Data or not', type=int, default=0) 35 | parser.add_argument('--mg', dest='multi_gpu', type=int, default=0, help='1 for 2 gpus and 0 for normal') 36 | parser.add_argument('--filter_sizes', help='number of filter sizes (could be a list of integer)', type=int, default=[2, 4, 8], nargs='+') 37 | parser.add_argument('--num_filters', help='number of filters (i.e. kernels) in CNN model', type=int, default=32) 38 | parser.add_argument('--pooling_units', help='number of pooling units in 1D pooling layer', type=int, default=32) 39 | parser.add_argument('--pooling_type', help='max or average', type=str, default='max') 40 | parser.add_argument('--model_type', help='glove or GoogleNews', type=str, default='glove') 41 | parser.add_argument('--num_features', help='50, 100, 200, 300', type=int, default=300) 42 | parser.add_argument('--dropouts', help='0 for not using, 1 for using', type=int, default=0) 43 | parser.add_argument('--clip', help='gradient clipping', type=float, default=1000) 44 | parser.add_argument('--dataset_gpu', help='load dataset in full to gpu', type=int, default=1) 45 | parser.add_argument('--dp', dest='dataparallel', help='to train on multiple GPUs or not', type=int, default=0) 46 | 47 | 48 | params = parser.parse_args() 49 | 50 | if(len(params.model_name)==0): 51 | params.model_name = "Gen_data_CNN_Z_dim-{}_mb_size-{}_hidden_dims-{}_preproc-{}_loss-{}_sequence_length-{}_embedding_dim-{}_params.vocab_size={}".format(params.Z_dim, params.mb_size, params.hidden_dims, params.pp_flg, params.loss_type, params.sequence_length, params.embedding_dim, params.vocab_size) 52 | 53 | print('Saving Model to: ' + params.model_name) 54 | 55 | # ------------------ data ---------------------------------------------- 56 | params.data_path = '../datasets/' + params.data_set 57 | x_tr, x_te, y_tr, y_te, params.vocabulary, params.vocabulary_inv, params = save_load_data(params, save=params.load_data) 58 | 59 | params = update_params(params) 60 | # ----------------------- Loss ------------------------------------ 61 | params.loss_fn = torch.nn.BCELoss(size_average=False) 62 | # -------------------------- Params --------------------------------------------- 63 | if params.model_variation=='pretrain': 64 | embedding_weights = load_word2vec(params) 65 | else: 66 | embedding_weights = None 67 | 68 | if torch.cuda.is_available(): 69 | params.dtype = torch.cuda.FloatTensor 70 | else: 71 | params.dtype = torch.FloatTensor 72 | 73 | 74 | if(params.training): 75 | train(x_tr, y_tr, x_te, y_te, embedding_weights, params) 76 | 77 | else: 78 | test_class(x_te, y_te, params, x_tr=x_tr, y_tr=y_tr, embedding_weights=embedding_weights) 79 | -------------------------------------------------------------------------------- /code/models/classifier.py: -------------------------------------------------------------------------------- 1 | from header import * 2 | class classifier(nn.Module): 3 | def __init__(self, params): 4 | super(classifier, self).__init__() 5 | self.params = params 6 | if(self.params.dropouts): 7 | self.drp = nn.Dropout(.5) 8 | self.l1 = nn.Linear(params.h_dim, params.H_dim) 9 | self.l2 = nn.Linear(params.H_dim, params.y_dim) 10 | self.relu = nn.ReLU() 11 | self.sigmoid = nn.Sigmoid() 12 | torch.nn.init.xavier_uniform_(self.l1.weight) 13 | 14 | def forward(self, H): 15 | H = self.l1(H) 16 | H = self.relu(H) 17 | H = self.l2(H) 18 | H = self.sigmoid(H) 19 | return H -------------------------------------------------------------------------------- /code/models/cnn_encoder.py: -------------------------------------------------------------------------------- 1 | from header import * 2 | 3 | def out_size(l_in, kernel_size, padding=0, dilation=1, stride=1): 4 | a = l_in + 2*padding - dilation*(kernel_size - 1) - 1 5 | b = int(a/stride) 6 | return b + 1 7 | 8 | class cnn_encoder(torch.nn.Module): 9 | 10 | def __init__(self, params): 11 | super(cnn_encoder, self).__init__() 12 | self.params = params 13 | self.conv_layers = nn.ModuleList() 14 | self.pool_layers = nn.ModuleList() 15 | fin_l_out_size = 0 16 | 17 | if(params.dropouts): 18 | self.drp = nn.Dropout(p=.25) 19 | self.drp5 = nn.Dropout(p=.5) 20 | 21 | for fsz in params.filter_sizes: 22 | l_out_size = out_size(params.sequence_length, fsz, stride=2) 23 | pool_size = l_out_size // params.pooling_units 24 | l_conv = nn.Conv1d(params.embedding_dim, params.num_filters, fsz, stride=2) 25 | torch.nn.init.xavier_uniform_(l_conv.weight) 26 | if params.pooling_type == 'average': 27 | l_pool = nn.AvgPool1d(pool_size, stride=None, count_include_pad=True) 28 | pool_out_size = (int((l_out_size - pool_size)/pool_size) + 1)*params.num_filters 29 | elif params.pooling_type == 'max': 30 | l_pool = nn.MaxPool1d(2, stride=1) 31 | pool_out_size = (int(l_out_size*params.num_filters - 2) + 1) 32 | fin_l_out_size += pool_out_size 33 | 34 | self.conv_layers.append(l_conv) 35 | self.pool_layers.append(l_pool) 36 | 37 | self.fin_layer = nn.Linear(fin_l_out_size, params.hidden_dims) 38 | self.out_layer = nn.Linear(params.hidden_dims, params.y_dim) 39 | torch.nn.init.xavier_uniform_(self.fin_layer.weight) 40 | torch.nn.init.xavier_uniform_(self.out_layer.weight) 41 | 42 | def forward(self, inputs): 43 | #o0 = self.drp(self.bn_1(inputs)).permute(0,2,1) 44 | o0 = inputs.permute(0,2,1)# self.bn_1(inputs.permute(0,2,1)) 45 | if(self.params.dropouts): 46 | o0 = self.drp(o0) 47 | conv_out = [] 48 | 49 | for i in range(len(self.params.filter_sizes)): 50 | o = self.conv_layers[i](o0) 51 | o = o.view(o.shape[0], 1, o.shape[1]*o.shape[2]) 52 | o = self.pool_layers[i](o) 53 | o = nn.functional.relu(o) 54 | o = o.view(o.shape[0],-1) 55 | conv_out.append(o) 56 | del o 57 | if len(self.params.filter_sizes)>1: 58 | o = torch.cat(conv_out,1) 59 | else: 60 | o = conv_out[0] 61 | 62 | o = self.fin_layer(o) 63 | o = nn.functional.relu(o) 64 | if(self.params.dropouts): 65 | o = self.drp5(o) 66 | o = self.out_layer(o) 67 | o = torch.nn.functional.sigmoid(o) 68 | return o 69 | -------------------------------------------------------------------------------- /code/models/embedding_layer.py: -------------------------------------------------------------------------------- 1 | from header import * 2 | 3 | class embedding_layer(torch.nn.Module): 4 | 5 | def __init__(self, params, embedding_weights): 6 | super(embedding_layer, self).__init__() 7 | self.l = nn.Embedding(params.vocab_size, params.embedding_dim) 8 | if params.model_variation == 'pretrain': 9 | self.l.weight.data.copy_(torch.from_numpy(embedding_weights)) 10 | self.l.weight.requires_grad=False 11 | 12 | def forward(self, inputs): 13 | o = self.l(inputs) 14 | return o 15 | -------------------------------------------------------------------------------- /code/models/header.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.autograd as autograd 4 | import torch.optim as optim 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import matplotlib.gridspec as gridspec 8 | import os 9 | from torch.autograd import Variable 10 | import sys 11 | import numpy as np 12 | sys.path.append('../../utils/') 13 | sys.path.append('models/') 14 | import data_helpers 15 | 16 | from w2v import * 17 | from embedding_layer import embedding_layer 18 | from sklearn import preprocessing 19 | from sklearn.decomposition import PCA 20 | import scipy.io as sio 21 | from scipy import sparse 22 | import argparse 23 | from visdom import Visdom 24 | from sklearn.externals import joblib 25 | from futils import * 26 | from loss import loss -------------------------------------------------------------------------------- /code/models/xmlCNN.py: -------------------------------------------------------------------------------- 1 | from header import * 2 | from cnn_encoder import cnn_encoder 3 | 4 | class xmlCNN(nn.Module): 5 | def __init__(self, params, embedding_weights): 6 | super(xmlCNN, self).__init__() 7 | self.params = params 8 | self.embedding_layer = embedding_layer(params, embedding_weights) 9 | self.classifier = cnn_encoder(params) 10 | 11 | def forward(self, batch_x, batch_y): 12 | # ----------- Encode (X, Y) -------------------------------------------- 13 | e_emb = self.embedding_layer.forward(batch_x) 14 | Y = self.classifier.forward(e_emb) 15 | loss = self.params.loss_fn(Y, batch_y) 16 | 17 | if(loss<0): 18 | print(cross_entropy) 19 | print(Y[0:100]) 20 | print(batch_y[0:100]) 21 | sys.exit() 22 | 23 | return loss.view(-1,1), Y -------------------------------------------------------------------------------- /code/precision_k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | def precision_k(true_mat, score_mat,k): 4 | p = np.zeros((k,1)) 5 | rank_mat = np.argsort(score_mat) 6 | backup = np.copy(score_mat) 7 | for k in range(k): 8 | score_mat = np.copy(backup) 9 | for i in range(rank_mat.shape[0]): 10 | score_mat[i][rank_mat[i, :-(k+1)]] = 0 11 | score_mat = np.ceil(score_mat) 12 | kk = np.argwhere(score_mat>0) 13 | mat = np.multiply(score_mat, true_mat) 14 | num = np.sum(mat,axis=1) 15 | p[k] = np.mean(num/(k+1)) 16 | 17 | return np.around(p, decimals=4) 18 | -------------------------------------------------------------------------------- /code/test_manik.m: -------------------------------------------------------------------------------- 1 | addpath('/scratch/work/saxenas2/fastxml/manik/Tools/matlab/') 2 | addpath('/scratch/work/saxenas2/fastxml/manik/tools/') 3 | addpath('/scratch/work/saxenas2/fastxml/manik/Tools/metrics/') 4 | addpath('/scratch/work/saxenas2/fastxml/manik/FastXML/') 5 | 6 | A = .55; 7 | B = 1.5; 8 | 9 | load score_matrix.mat 10 | [I, J, S] = find(score_matrix); 11 | [sorted_I, idx] = sort(I); 12 | J = J(idx); 13 | S = S(idx); 14 | score_matrix = sparse(J, sorted_I, S); 15 | 16 | load ty.mat 17 | [I, J, S] = find(ty); 18 | [sorted_I, idx] = sort(I); 19 | J = J(idx); 20 | S = S(idx); 21 | ty = sparse(J, sorted_I, S); 22 | ip = inv_propensity(ty,A,B); 23 | 24 | [metrics] = get_all_metrics(score_matrix , ty, ip) 25 | disp(metrics) 26 | 27 | % -------- For RCV1 His neural net-------- 28 | 29 | % prec 96.58 89.82 79.66 65.28 55.15 30 | % nDCG 96.58 92.51 90.96 91.01 91.46 31 | % prec_wt 86.22 86.25 87.38 87.70 88.48 32 | % nDCG_wt 86.22 86.24 87.00 87.21 87.65 33 | 34 | % ----------------------------------------- 35 | 36 | % prec 93.26 86.08 75.64 62.28 52.79 37 | % nDCG 93.26 88.84 86.81 87.18 87.84 38 | % prec_wt 73.04 76.45 78.40 80.02 81.59 39 | % nDCG_wt 73.04 75.62 77.04 78.06 78.96 40 | 41 | % prec 95.50 87.29 76.72 63.20 53.59 42 | % nDCG 95.50 90.29 88.17 88.53 89.18 43 | % prec_wt 72.24 76.67 79.44 81.27 82.96 44 | % nDCG_wt 72.24 75.59 77.59 78.76 79.73 45 | 46 | 47 | % ---------- Initialized weights with Dropouts ------------- 48 | % Best for test ------------------- 49 | % prec 94.06 84.04 73.35 60.90 51.89 50 | % nDCG 94.06 87.45 84.92 85.63 86.51 51 | % prec_wt 70.89 73.01 74.81 77.17 79.28 52 | % nDCG_wt 70.89 72.50 73.76 75.21 76.40 53 | 54 | % Best for train ------------------- 55 | % prec 93.62 84.88 74.66 61.41 52.02 56 | % nDCG 93.62 88.00 86.00 86.34 86.98 57 | % prec_wt 71.90 75.07 77.10 78.54 80.01 58 | % nDCG_wt 71.90 74.30 75.76 76.67 77.52 59 | 60 | 61 | % ---------------- base_model_with_test_saving_after_each_run ------ 62 | % model_best_batch 63 | % prec 94.49 86.20 75.71 62.40 52.84 64 | % nDCG 94.49 89.23 87.11 87.53 88.16 65 | % prec_wt 72.40 76.11 78.32 80.02 81.60 66 | % nDCG_wt 72.40 75.21 76.81 77.88 78.79 67 | 68 | % model_best_for_test 69 | % prec 94.98 86.05 75.65 62.45 53.06 70 | % nDCG 94.98 89.21 87.08 87.54 88.29 71 | % prec_wt 71.91 75.42 77.85 79.69 81.58 72 | % nDCG_wt 71.91 74.57 76.30 77.47 78.54 73 | 74 | 75 | % --------------- L1 loss ---------------------- 76 | model_best_for_test 77 | bad!!! 78 | 79 | model_best_batch 80 | bad!!! 81 | 82 | % ------------------ Ablation -------------- 83 | % prec 94.59 87.66 77.32 63.61 53.84 84 | % nDCG 94.59 90.38 88.51 88.81 89.37 85 | % prec_wt 74.26 77.99 80.16 81.68 83.20 86 | % nDCG_wt 74.26 77.08 78.66 79.63 80.50 -------------------------------------------------------------------------------- /utils/data_dive.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | # import torch 4 | import timeit 5 | import argparse 6 | import numpy as np 7 | import time 8 | # import torch.nn as nn 9 | # import torch.optim as optim 10 | import matplotlib.pyplot as plt 11 | # import torch.autograd as autograd 12 | from sklearn import preprocessing 13 | # from torch.autograd import Variable 14 | from sklearn.decomposition import PCA 15 | import matplotlib.gridspec as gridspec 16 | 17 | # this file is to explore the generated data and the data that already exist to see how much similarity do they share. 18 | # It prits some stats and qualitative results 19 | 20 | new_data_x_file = "../datasets/Gen_data_Z_dim-200_mb_size-100_h_dim-600_preproc-1_beta-1.01_final_ly-Sigmoid_loss-BCELoss/new_x.npy" 21 | new_data_y_file = "../datasets/Gen_data_Z_dim-200_mb_size-100_h_dim-600_preproc-1_beta-1.01_final_ly-Sigmoid_loss-BCELoss/new_y.npy" 22 | actual_data_x_file = "../datasets/Eurlex/eurlex_docs/x_tr.npy" 23 | actual_data_y_file = "../datasets/Eurlex/eurlex_docs/y_tr.npy" 24 | indx2word_file = "../datasets/Eurlex/eurlex_docs/feature_names.txt" 25 | indx2label = "../datasets/Eurlex/eurlex_docs/label_set.txt" 26 | K = 10 27 | # ---------------------------------------------------------------------------- 28 | 29 | new_data_x = np.load(new_data_x_file) 30 | new_data_y = np.load(new_data_y_file) 31 | actual_data_x = np.load(actual_data_x_file) 32 | actual_data_y = np.load(actual_data_y_file) 33 | f = open(indx2label, 'r') 34 | temp = f.read().splitlines() 35 | labels = [] 36 | for i in temp: 37 | labels.append(i.split(":")[1]) 38 | f = open(indx2word_file, 'r') 39 | temp = f.read().splitlines() 40 | words = [] 41 | for i in temp: 42 | words.append(i.split(":")[1]) 43 | 44 | print("Shapes: new_x: {}; new_y: {}; original_x: {}; original_y: {};".format(new_data_x.shape, \ 45 | new_data_y.shape, actual_data_x.shape, actual_data_y.shape)) 46 | print("Num Words: {}; Num Labels: {};".format(len(labels), len(words))) 47 | 48 | for data_pt_num in range(K): 49 | data_pt_labels = np.argwhere(new_data_y[data_pt_num]==1) 50 | label_names = [] 51 | for label in data_pt_labels.tolist(): 52 | # print(label) 53 | label_names.append(labels[label[0]]) 54 | print("Labels in the data point : {}".format(label_names)) 55 | 56 | data_pt_words = np.argsort(new_data_x[data_pt_num])[-10:] 57 | word_names = [] 58 | for word in data_pt_words.tolist(): 59 | word_names.append(words[word]) 60 | print("Top 10 words in the data point : {}".format(word_names)) 61 | 62 | # Nearest Data point in actual data 63 | indx = -1 64 | closest = 1e10 65 | # print(actual_data_y) 66 | for i in range(len(actual_data_y)): 67 | dist = -len(np.intersect1d(np.argwhere(actual_data_y[i]==1), np.argwhere(new_data_y[data_pt_num]==1))) 68 | # print(np.argwhere(actual_data_y[i]==1)) 69 | # print(np.argwhere(new_data_y[data_pt_num]==1)) 70 | if(dist n): 67 | #y_te = y_te.resize((np.shape(y_te)[0], np.shape(y_tr)[1])) 68 | Y = sp.csr_matrix((val_idx, (row_idx, col_idx)), shape=(m, N)) 69 | elif(N < n): 70 | Y = sp.csr_matrix((val_idx, (row_idx, col_idx)), shape=(m, n)) 71 | Y = Y[:, :N] 72 | else: 73 | Y = sp.csr_matrix((val_idx, (row_idx, col_idx)), shape=(m, n)) 74 | return [x_text, Y, m, n] 75 | 76 | 77 | def build_vocab(sentences, params, vocab_size=50000): 78 | word_counts = Counter(itertools.chain(*sentences)) 79 | vocabulary_inv = [x[0] for x in word_counts.most_common(vocab_size)] 80 | vocabulary = {x: i for i, x in enumerate(vocabulary_inv)} 81 | # append symbol to the vocabulary 82 | vocabulary[''] = len(vocabulary) 83 | vocabulary_inv.append('') 84 | vocabulary[params.go_token] = len(vocabulary) 85 | vocabulary_inv.append(params.go_token) 86 | vocabulary[params.end_token] = len(vocabulary) 87 | vocabulary_inv.append(params.end_token) 88 | 89 | return [vocabulary, vocabulary_inv] 90 | 91 | 92 | def build_input_data(sentences, vocabulary): 93 | x = np.array([[vocabulary[word] if word in vocabulary else vocabulary[''] for word in sentence] for sentence in sentences]) 94 | #x = np.array([[vocabulary[word] if word in vocabulary else len(vocabulary) for word in sentence] for sentence in sentences]) 95 | return x 96 | 97 | 98 | def load_data(params, max_length=500, vocab_size=50000): 99 | # Load and preprocess data 100 | with open(os.path.join(params.data_path), 'rb') as fin: 101 | [train, test, vocab, catgy] = pickle.load(fin) 102 | 103 | # dirty trick to prevent errors happen when test is empty 104 | if len(test) == 0: 105 | test[:5] = train[:5] 106 | 107 | trn_sents, Y_trn, m, n = load_data_and_labels(train) 108 | tst_sents, Y_tst, m, n = load_data_and_labels(test, M=m, N=n) 109 | sents_padded_sets, params.sequence_length = pad_sentences([trn_sents, tst_sents] , padding_word=params.pad_token, max_length=max_length) 110 | # tst_sents_padded = pad_sentences(tst_sents, padding_word=params.pad_token, max_length=max_length) 111 | vocabulary, vocabulary_inv = build_vocab(sents_padded_sets[0] + sents_padded_sets[1], params, vocab_size=vocab_size) 112 | X_trn = build_input_data(sents_padded_sets[0], vocabulary) 113 | X_tst = build_input_data(sents_padded_sets[1], vocabulary) 114 | return X_trn, Y_trn, X_tst, Y_tst, vocabulary, vocabulary_inv, params 115 | # return X_trn, Y_trn, vocabulary, vocabulary_inv 116 | 117 | 118 | def batch_iter(data, batch_size, num_epochs): 119 | """ 120 | Generates a batch iterator for a dataset. 121 | """ 122 | data = np.array(data) 123 | data_size = len(data) 124 | num_batches_per_epoch = int(len(data)/batch_size) + 1 125 | for epoch in range(num_epochs): 126 | # Shuffle the data at each epoch 127 | shuffle_indices = np.random.permutation(np.arange(data_size)) 128 | shuffled_data = data[shuffle_indices] 129 | for batch_num in range(num_batches_per_epoch): 130 | start_index = batch_num * batch_size 131 | end_index = min((batch_num + 1) * batch_size, data_size) 132 | yield shuffled_data[start_index:end_index] 133 | -------------------------------------------------------------------------------- /utils/fiddle_clusters.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('utils/') 3 | sys.path.append('models/') 4 | import numpy as np 5 | import os 6 | from sklearn import preprocessing 7 | from sklearn.decomposition import PCA 8 | import argparse 9 | from sklearn.cluster import KMeans 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | import matplotlib.gridspec as gridspec 13 | import cPickle 14 | from sklearn.metrics import silhouette_score 15 | from dpmeans import * 16 | from sklearn.decomposition import PCA 17 | from sklearn.decomposition import TruncatedSVD 18 | import scipy.io as sio 19 | x_tr = np.load('datasets/Eurlex/eurlex_docs/x_tr.npy') 20 | y_tr = np.load('datasets/Eurlex/eurlex_docs/y_tr.npy') 21 | x_te = np.load('datasets/Eurlex/eurlex_docs/x_te.npy') 22 | y_te = np.load('datasets/Eurlex/eurlex_docs/y_te.npy') 23 | 24 | n = np.shape(x_tr)[0] 25 | m = np.shape(y_tr)[1] 26 | 27 | 28 | 29 | # ------ Making Adjacency ------------------ 30 | dct = {} 31 | for i in range(m): 32 | dct[i] = np.argwhere(y_tr[:,i]==1) 33 | 34 | adjacency_mat = np.zeros((m,m)) 35 | check_mat = np.zeros((m,m)) 36 | for i in range(m): 37 | for j in range(m): 38 | adjacency_mat[i,j] = len(np.intersect1d(dct[i],dct[j])) 39 | adjacency_mat[j, i] = adjacency_mat[i,j] 40 | check_mat[i,j] = check_mat[j,i] = 1 41 | # adjacency_mat[i, i] = len(dct[i]) 42 | # check_mat[i,i] = 1 43 | 44 | print(i) 45 | np.save('adjacency_mat', adjacency_mat) 46 | adjacency_mat = sparse.csr_matrix(adjacency_mat) 47 | sio.savemat('adjacency_mat', adjacency_mat) 48 | print((check_mat==0).any()) 49 | print(adjacency_mat[:100,:100]) 50 | # ----------------------------------------- 51 | 52 | # ------------- PP --------------------------------------- 53 | adjacency_mat = np.load('/scratch/work/saxenas2/CVAE_XML/adjacency_mat.npy') 54 | pp = preprocessing.MinMaxScaler() 55 | scaler = pp.fit(adjacency_mat) 56 | adjacency_mat = scaler.transform(adjacency_mat) 57 | # ------------------------------------------------------- 58 | 59 | # ----------------------- cluster + score --------------- 60 | clusters = [2, 4, 6, 8, 10, 12, 15, 18, 21, 24, 27, 30] 61 | scores = [] 62 | scores_silhoette = [] 63 | for cluster_no in clusters: 64 | print(cluster_no) 65 | kmeans = KMeans(n_clusters=cluster_no, random_state=0).fit(adjacency_mat) 66 | scores.append(kmeans.score(adjacency_mat)) 67 | label = kmeans.labels_ 68 | scores_silhoette.append(silhouette_score(adjacency_mat, label, metric='euclidean')) 69 | with open('classifier_' + str(cluster_no) + '.pkl', 'wb') as fid: 70 | cPickle.dump(kmeans, fid) 71 | # --------------------------------------------------------- 72 | 73 | # scores = [] 74 | # for cluster_no in clusters: 75 | # with open('classifier_'+ str(cluster_no) + '.pkl', 'rb') as fid: 76 | # kmeans = cPickle.load(fid) 77 | # label = kmeans.labels_ 78 | # scores.append(silhouette_score(adjacency_mat, label, metric='euclidean')) 79 | 80 | matplotlib.pyplot.plot(clusters, scores) 81 | plt.show() 82 | 83 | # ---------------------- Explore Clusters ------------------------- 84 | cluster_no = 30 85 | # with open('clusterings/classifier_'+ str(cluster_no) + '.pkl', 'rb') as fid: 86 | with open('classifier_'+ str(cluster_no) + '.pkl', 'rb') as fid: 87 | kmeans = cPickle.load(fid) 88 | 89 | y_pred = kmeans.predict(adjacency_mat) 90 | clusters = {} 91 | y_of_cluster = {} 92 | for i in range(cluster_no): 93 | clusters[i] = np.argwhere(y_pred==i) 94 | y_of_cluster[i] = y_tr[:, clusters[i]] 95 | # y_of_cluster[i] = np.array(y_of_cluster[i][:,0]) 96 | x = np.sum(y_tr, 0) 97 | y = np.sum(y_of_cluster[i], 0) 98 | mean_labels = np.mean(np.sum(y_of_cluster[i], 0)) 99 | top5_labels = np.argsort(y)[-10:] 100 | top5_label_counts = np.sort(y)[-10:] 101 | num_tail_labels_1 = len(np.argwhere(x[clusters[i]]<=1)) 102 | num_tail_labels_2 = len(np.argwhere(x[clusters[i]]<=2)) 103 | num_tail_labels_5 = len(np.argwhere(x[clusters[i]]<=5)) 104 | 105 | print("No. of Labels {6}; Mean No. of Labels {0}; top 5 labels {1}, top 5 label counts {2}; num tail labels(1) \ 106 | {3}; num tail labels(2) {4}; num tail labels(5) {5}".format(mean_labels, top5_labels, top5_label_counts, 107 | num_tail_labels_1, num_tail_labels_2, num_tail_labels_5, len(clusters[i]))) 108 | # ---------------------- Explore Clusters ------------------------- 109 | -------------------------------------------------------------------------------- /utils/futils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import timeit 5 | import argparse 6 | import numpy as np 7 | import time 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import matplotlib.pyplot as plt 11 | import torch.autograd as autograd 12 | from sklearn import preprocessing 13 | from torch.autograd import Variable 14 | from sklearn.decomposition import PCA 15 | import matplotlib.gridspec as gridspec 16 | import data_helpers 17 | import scipy 18 | import subprocess 19 | from scipy import sparse 20 | def weights_init(m): 21 | if(torch.__version__=='0.4.0'): 22 | torch.nn.init.xavier_uniform_(m) 23 | else: 24 | torch.nn.init.xavier_uniform(m) 25 | def get_gpu_memory_map(boom, name=False): 26 | result = subprocess.check_output( 27 | [ 28 | 'nvidia-smi', '--query-gpu=memory.used', 29 | '--format=csv,nounits,noheader' 30 | ]) 31 | gpu_memory = [int(x) for x in result.strip().split('\n')] 32 | gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) 33 | if(name): 34 | print("In " + str(name) + " Print: {0}; Mem(1): {1}; Mem(2): {2}; Mem(3): {3}; Mem(4): {4}".format( boom, gpu_memory_map[0], \ 35 | gpu_memory_map[1], gpu_memory_map[2], gpu_memory_map[3])) 36 | else: 37 | print("Print: {0}; Mem(1): {1}; Mem(2): {2}; Mem(3): {3}; Mem(4): {4}".format( boom, gpu_memory_map[0], \ 38 | gpu_memory_map[1], gpu_memory_map[2], gpu_memory_map[3])) 39 | return boom+1 40 | 41 | 42 | def count_parameters(model): 43 | a = 0 44 | for p in model.parameters(): 45 | if p.requires_grad: 46 | a += p.numel() 47 | return a 48 | 49 | def effective_k(k, d): 50 | return (k - 1) * d + 1 51 | 52 | def load_model(model, name, optimizer=None): 53 | if(torch.cuda.is_available()): 54 | checkpoint = torch.load(name) 55 | else: 56 | checkpoint = torch.load(name, map_location=lambda storage, loc: storage) 57 | 58 | model.load_state_dict(checkpoint['state_dict']) 59 | 60 | if optimizer is not None: 61 | optimizer.load_state_dict(checkpoint['optimizer']) 62 | init = checkpoint['epoch'] 63 | return model, optimizer, init 64 | else: 65 | return model 66 | 67 | def save_model(model, optimizer, epoch, name): 68 | 69 | 70 | checkpoint = { 71 | 'state_dict': model.state_dict(), 72 | 'optimizer': optimizer.state_dict(), 73 | 'epoch': epoch 74 | } 75 | torch.save(checkpoint, "../saved_models/" + name) 76 | def sample_z(mu, log_var, params, dtype_f): 77 | eps = Variable(torch.randn(params.batch_size, params.Z_dim).type(dtype_f)) 78 | k = torch.exp(log_var / 2) * eps 79 | return mu + k 80 | def gen_model_file(params): 81 | data_name = params.data_path.split('/')[-2] 82 | fs_string = '-'.join([str(fs) for fs in params.filter_sizes]) 83 | file_name = 'data-%s_sl-%d_ed-%d_fs-%s_nf-%d_pu-%d_pt-%s_hd-%d_bs-%d_model-%s_pretrain-%s_beta-%s' % \ 84 | (data_name, params.sequence_length, params.embedding_dim, 85 | fs_string, params.num_filters, params.pooling_units, 86 | params.pooling_type, params.hidden_dims, params.batch_size, 87 | params.model_variation, params.pretrain_type, params.beta) 88 | return file_name 89 | 90 | def sample_z(mu, log_var, params): 91 | eps = Variable(torch.randn(log_var.shape[0], params.Z_dim).type(params.dtype)) 92 | k = torch.exp(log_var / 2) * eps 93 | return mu + k 94 | 95 | def load_data(X, Y, params, batch=True): 96 | if(batch): 97 | a = np.random.randint(0,params.N, size=params.mb_size) 98 | if isinstance(X, scipy.sparse.csr.csr_matrix) or isinstance(X, scipy.sparse.csc.csc_matrix): 99 | X, c = X[a].todense(), Y[a].todense() 100 | else: 101 | X, c = X[a], Y[a] 102 | 103 | else: 104 | if isinstance(X, scipy.sparse.csr.csr_matrix) or isinstance(X, scipy.sparse.csc.csc_matrix): 105 | X, c = X.todense(), Y.todense() 106 | else: 107 | X, c = X, Y 108 | 109 | X = Variable(torch.from_numpy(X.astype('float32')).type(params.dtype)) 110 | Y = Variable(torch.from_numpy(c.astype('float32')).type(params.dtype)) 111 | return X,Y 112 | 113 | def write_grads(model, thefile): 114 | grads = [] 115 | for key, value in model.named_parameters(): 116 | if(value.grad is not None): 117 | grads.append(value.grad.mean().squeeze().cpu().numpy()) 118 | 119 | thefile = open('gradient_classifier.txt', 'a+') 120 | for item in grads: 121 | thefile.write("%s " % item) 122 | thefile.write("\n" % item) 123 | thefile.close() 124 | 125 | def save_load_data(params, save=0): 126 | params.pad_token = "" 127 | params.go_token = '' 128 | params.end_token = '' 129 | 130 | if(save): 131 | print("Loading Data") 132 | ##################################################### 133 | params.data_path += '.p' 134 | x_tr, y_tr, x_te, y_te, vocabulary, vocabulary_inv, params = data_helpers.load_data(params, max_length=params.sequence_length, vocab_size=params.vocab_size) 135 | x_tr = x_tr.astype(np.int32) 136 | x_te = x_te.astype(np.int32) 137 | y_tr = y_tr.astype(np.int32) 138 | y_te = y_te.astype(np.int32) 139 | ##################################################### 140 | params.data_path = params.data_path[:-2] 141 | if not os.path.exists(params.data_path): 142 | os.makedirs(params.data_path) 143 | 144 | x_tr = sparse.csr_matrix(x_tr) 145 | x_te = sparse.csr_matrix(x_te) 146 | sparse.save_npz(params.data_path + '/x_train', x_tr) 147 | sparse.save_npz(params.data_path + '/y_train', y_tr) 148 | sparse.save_npz(params.data_path + '/y_test', y_te) 149 | sparse.save_npz(params.data_path + '/x_test', x_te) 150 | np.save(params.data_path + '/vocab', vocabulary) 151 | np.save(params.data_path + '/vocab_inv', vocabulary_inv) 152 | 153 | x_tr = sparse.load_npz(params.data_path + '/x_train.npz') 154 | y_tr = sparse.load_npz(params.data_path + '/y_train.npz') 155 | x_te = sparse.load_npz(params.data_path + '/x_test.npz') 156 | y_te = sparse.load_npz(params.data_path + '/y_test.npz') 157 | 158 | vocabulary = np.load(params.data_path + '/vocab.npy').item() 159 | vocabulary_inv = np.load(params.data_path + '/vocab_inv.npy') 160 | params.X_dim = x_tr.shape[1] 161 | params.y_dim = y_tr.shape[1] 162 | params.N = x_tr.shape[0] 163 | params.vocab_size = len(vocabulary) 164 | params.classes = y_tr.shape[1] 165 | 166 | return x_tr, x_te, y_tr, y_te, vocabulary, vocabulary_inv, params 167 | 168 | def load_batch_cnn(x_tr, y_tr, params, batch=True, batch_size=0, decoder_word_input=None, decoder_target=None, testing=0): 169 | 170 | indexes = 0 # for scope 171 | if(batch): 172 | if(batch_size): 173 | params.go_row = np.ones((batch_size,1))*params.vocabulary[params.go_token] 174 | params.end_row = np.ones((batch_size,1))*params.vocabulary[params.end_token] 175 | indexes = np.array(np.random.randint(x_tr.shape[0], size=batch_size)) 176 | x_tr, y_tr = x_tr[indexes,:], y_tr[indexes,:] 177 | else: 178 | params.go_row = np.ones((params.mb_size,1))*params.vocabulary[params.go_token] 179 | params.end_row = np.ones((params.mb_size,1))*params.vocabulary[params.end_token] 180 | indexes = np.array(np.random.randint(x_tr.shape[0], size=params.mb_size)) 181 | x_tr, y_tr = x_tr[indexes,:], y_tr[indexes,:] 182 | else: 183 | params.go_row = np.ones((x_tr.shape[0],1))*params.vocabulary[params.go_token] 184 | params.end_row = np.ones((x_tr.shape[0],1))*params.vocabulary[params.end_token] 185 | 186 | x_tr = x_tr.todense() 187 | y_tr = y_tr.todense() 188 | 189 | x_tr = Variable(torch.from_numpy(x_tr.astype('int')).type(params.dtype_i)) 190 | if(testing==0): 191 | y_tr = Variable(torch.from_numpy(y_tr.astype('float')).type(params.dtype_f)) 192 | 193 | return x_tr, y_tr 194 | 195 | def update_params(params): 196 | if(len(params.model_name)==0): 197 | params.model_name = gen_model_file(params) 198 | params.decoder_kernels = [(400, params.Z_dim + params.hidden_dims + params.embedding_dim, 3), 199 | (450, 400, 3), 200 | (500, 450, 3)] 201 | params.decoder_dilations = [1, 2, 4] 202 | params.decoder_paddings = [effective_k(w, params.decoder_dilations[i]) - 1 203 | for i, (_, _, w) in enumerate(params.decoder_kernels)] 204 | 205 | return params 206 | 207 | 208 | def sample_word_from_distribution(params, distribution): 209 | ix = np.random.choice(range(params.vocab_size), p=distribution.view(-1)) 210 | x = np.zeros((params.vocab_size, 1)) 211 | x[ix] = 1 212 | return params.vocabulary_inv[np.argmax(x)] 213 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import timeit 5 | import argparse 6 | import numpy as np 7 | import time 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import matplotlib.pyplot as plt 11 | import torch.autograd as autograd 12 | from sklearn import preprocessing 13 | from torch.autograd import Variable 14 | from sklearn.decomposition import PCA 15 | import matplotlib.gridspec as gridspec 16 | import pdb 17 | 18 | def isnan(x): 19 | return x != x 20 | 21 | class loss: 22 | 23 | def MSLoss(self, X_sample, X): 24 | t = torch.mean(torch.norm((X_sample - X),1),dim=0) 25 | return t 26 | 27 | def BCELoss(self, y_pred, y, eps = 1e-25): 28 | t = torch.nn.functional.binary_cross_entropy(y_pred, y)*y.shape[-1] 29 | return t 30 | 31 | def L1Loss(self, X_sample, X): 32 | t = torch.sum(torch.mean(torch.abs(X_sample - X),dim=0)) 33 | return t 34 | -------------------------------------------------------------------------------- /utils/process_eurlex.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import numpy as np 3 | from numpy import genfromtxt 4 | 5 | #bashCommand = "java -cp ~/Downloads/weka-3-8-2/weka.jar weka.core.converters.CSVSaver -i eurlex_nA-5k_CV1-10_train.arff > eurlex_nA-5k_CV1-10_train.csv" 6 | #process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) 7 | #output, error = process.communicate() 8 | 9 | with open('eurlex_nA-5k_CV1-10_train.csv') as f: 10 | lines = f.read().splitlines()[0] 11 | 12 | a = genfromtxt('eurlex_nA-5k_CV1-10_train.csv', delimiter=',') 13 | words = lines.split(',')[1:] 14 | doc_id = {} 15 | doc_id_inv = {} 16 | 17 | words_dict = {} 18 | for i, w in enumerate(words): 19 | words_dict[i] = w 20 | 21 | with open('feature_names.txt', 'w') as f: 22 | for key, value in words_dict.items(): 23 | f.write('%s:%s\n' % (key, value)) 24 | 25 | for i in range(1, len(a[:,0])): 26 | doc_id_inv[a[i,0]] = i-1 27 | doc_id[i-1] = a[i,0] 28 | # doc_id_list = doc_id. 29 | x_tr = a[1:,1:] 30 | np.save('words',words) 31 | np.save('doc_id',doc_id) # dictionary 32 | np.save('doc_id_inv',doc_id_inv) # dictionary 33 | np.save('x_tr',x_tr) 34 | 35 | 36 | labels_data_pt = genfromtxt('/u/79/wa.saxenas2/unix/Downloads/eurlex_id2class/id2class_eurlex_eurovoc.qrels', delimiter=' ')[:,1] 37 | with open('/u/79/wa.saxenas2/unix/Downloads/eurlex_id2class/id2class_eurlex_eurovoc.qrels') as f: 38 | lines = f.read().splitlines() 39 | 40 | label_names = [] 41 | for line in lines: 42 | label_names.append(line.split(' ')[0]) 43 | 44 | 45 | label_set = {} 46 | label_set_inv = {} 47 | count = 0 48 | # data_map = {} 49 | # data_count = 0 50 | for i in range(np.shape(labels_data_pt)[0]): 51 | if label_names[i] not in label_set.keys(): 52 | label_set[label_names[i]] = count 53 | label_set_inv[count] = label_names[i] 54 | count+=1 55 | print(count) 56 | # if labels[i] not in data_map.keys() and labels[i] in doc_id_list: 57 | # data_map[labels[i]] = data_count 58 | # data_count+=1 59 | 60 | np.save('label_set', label_set) # dictionary 61 | np.save('label_set_inv', label_set_inv) # dictionary 62 | 63 | with open('label_set.txt', 'w') as f: 64 | for key, value in label_set_inv.items(): 65 | f.write('%s:%s\n' % (key, value)) 66 | 67 | y_tr = np.zeros((np.shape(x_tr)[0], count)) 68 | y_tr_named = {} 69 | for i in range(np.shape(labels_data_pt)[0]): 70 | if labels_data_pt[i] in doc_id_inv.keys(): 71 | y_tr[doc_id_inv[labels_data_pt[i]], label_set[label_names[i]]] = 1 72 | if doc_id_inv[labels_data_pt[i]] not in y_tr_named.keys(): 73 | y_tr_named[doc_id_inv[labels_data_pt[i]]] = [] 74 | y_tr_named[doc_id_inv[labels_data_pt[i]]].append(label_names[i]) 75 | np.save('y_tr', y_tr) 76 | 77 | with open('y_tr_named.txt', 'w') as f: 78 | for key, value in y_tr_named.items(): 79 | f.write('%s:%s\n' % (key, value)) 80 | -------------------------------------------------------------------------------- /utils/w2v.py: -------------------------------------------------------------------------------- 1 | from gensim.models import word2vec 2 | from os.path import join, exists, split 3 | import os 4 | import numpy as np 5 | 6 | def train_word2vec(sentence_matrix, vocabulary_inv, 7 | num_features=300, min_word_count=1, context=10): 8 | """ 9 | Trains, saves, loads Word2Vec model 10 | Returns initial weights for embedding layer. 11 | 12 | inputs: 13 | sentence_matrix # int matrix: num_sentences x max_sentence_len 14 | vocabulary_inv # dict {str:int} 15 | num_features # Word vector dimensionality 16 | min_word_count # Minimum word count 17 | context # Context window size 18 | """ 19 | model_dir = '../embedding_weights' 20 | model_name = "{:d}features_{:d}minwords_{:d}context".format(num_features, min_word_count, context) 21 | model_name = join(model_dir, model_name) 22 | if exists(model_name): 23 | embedding_model = word2vec.Word2Vec.load(model_name) 24 | #print 'Loading existing Word2Vec model \'%s\'' % split(model_name)[-1] 25 | else: 26 | # Set values for various parameters 27 | num_workers = 2 # Number of threads to run in parallel 28 | downsampling = 1e-3 # Downsample setting for frequent words 29 | 30 | # Initialize and train the model 31 | print( "Training Word2Vec model...") 32 | sentences = [[vocabulary_inv[w] for w in s] for s in sentence_matrix] 33 | embedding_model = word2vec.Word2Vec(sentences, workers=num_workers, \ 34 | size=num_features, min_count = min_word_count, \ 35 | window = context, sample = downsampling) 36 | 37 | # If we don't plan to train the model any further, calling 38 | # init_sims will make the model much more memory-efficient. 39 | embedding_model.init_sims(replace=True) 40 | 41 | # Saving the model for later use. You can load it later using Word2Vec.load() 42 | if not exists(model_dir): 43 | os.mkdir(model_dir) 44 | print ('Saving Word2Vec model' + str(split(model_name)[-1])) 45 | embedding_model.save(model_name) 46 | 47 | # add unknown words 48 | embedding_weights = [np.array([embedding_model[w] if w in embedding_model\ 49 | else np.random.uniform(-0.25,0.25,embedding_model.vector_size)\ 50 | for w in vocabulary_inv])] 51 | return embedding_weights 52 | 53 | 54 | def load_word2vec(params): 55 | """ 56 | loads Word2Vec model 57 | Returns initial weights for embedding layer. 58 | 59 | inputs: 60 | model_type # GoogleNews / glove 61 | vocabulary_inv # dict {str:int} 62 | num_features # Word vector dimensionality 63 | """ 64 | 65 | model_dir = '../embedding_weights' 66 | 67 | if params.model_type == 'GoogleNews': 68 | model_name = join(model_dir, 'GoogleNews-vectors-negative300.bin.gz') 69 | assert(params.num_features == 300) 70 | assert(exists(model_name)) 71 | print('Loading existing Word2Vec model (GoogleNews-300)') 72 | embedding_model = word2vec.Word2Vec.load_word2vec_format(model_name, binary=True) 73 | 74 | elif params.model_type == 'glove': 75 | model_name = join(model_dir, 'glove.6B.%dd.txt' % (params.num_features)) 76 | print(model_name) 77 | assert(exists(model_name)) 78 | print('Loading existing Word2Vec model (Glove.6B.%dd)' % (params.num_features)) 79 | 80 | # dictionary, where key is word, value is word vectors 81 | embedding_model = {} 82 | for line in open(model_name, 'r'): 83 | tmp = line.strip().split() 84 | word, vec = tmp[0], map(float, tmp[1:]) 85 | assert(len(vec) == params.num_features) 86 | if word not in embedding_model: 87 | embedding_model[word] = vec 88 | assert(len(embedding_model) == 400000) 89 | 90 | else: 91 | raise ValueError('Unknown pretrain model type: %s!' % (params.model_type)) 92 | 93 | embedding_weights = [embedding_model[w] if w in embedding_model 94 | else np.random.uniform(-0.25, 0.25, params.num_features) 95 | for w in params.vocabulary_inv] 96 | embedding_weights = np.array(embedding_weights).astype('float32') 97 | 98 | return embedding_weights 99 | 100 | 101 | if __name__=='__main__': 102 | import data_helpers 103 | print("Loading data...") 104 | x, _, _, params.vocabulary_inv = data_helpers.load_data() 105 | w = train_word2vec(x, params.vocabulary_inv) 106 | 107 | --------------------------------------------------------------------------------