├── HAN.py ├── HMCN.py ├── Linear_Model.py ├── Logger_morning.py ├── OHCNN.py ├── OHCNN_fast.py ├── README.md ├── TextCNN.py ├── clean_runs.py ├── conf.py ├── feature_dataset.py ├── features_test.py ├── fig ├── HiLAP_architecture.jpg └── prediction_animation.gif ├── loadData.py ├── main.py ├── model.py ├── readData_fungo.py ├── readData_nyt.py ├── readData_rcv1.py ├── readData_yelp.py ├── tree.py ├── util.py └── yelp ├── Taxonomy_100 └── yelp_data_100.csv.sample /HAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | """ 7 | Hierarchical Attention Networks for Document Classification 8 | https://www.cs.cmu.edu/~hovy/papers/16HLT-hierarchical-attention-networks.pdf 9 | """ 10 | 11 | 12 | def batch_matmul_bias(seq, weight, bias, nonlinearity=''): 13 | s = None 14 | bias_dim = bias.size() 15 | for i in range(seq.size(0)): 16 | _s = torch.mm(seq[i], weight) 17 | _s_bias = _s + bias.expand(bias_dim[0], _s.size()[0]).transpose(0, 1) 18 | if nonlinearity == 'tanh': 19 | _s_bias = torch.tanh(_s_bias) 20 | _s_bias = _s_bias.unsqueeze(0) 21 | if s is None: 22 | s = _s_bias 23 | else: 24 | s = torch.cat((s, _s_bias), 0) 25 | return s.squeeze() 26 | 27 | 28 | def batch_matmul(seq, weight, nonlinearity=''): 29 | s = None 30 | for i in range(seq.size(0)): 31 | _s = torch.mm(seq[i], weight) 32 | if nonlinearity == 'tanh': 33 | _s = torch.tanh(_s) 34 | _s = _s.unsqueeze(0) 35 | if s is None: 36 | s = _s 37 | else: 38 | s = torch.cat((s, _s), 0) 39 | return s.squeeze() 40 | 41 | 42 | def attention_mul(rnn_outputs, att_weights): 43 | attn_vectors = None 44 | for i in range(rnn_outputs.size(0)): 45 | h_i = rnn_outputs[i] 46 | a_i = att_weights[i].unsqueeze(1).expand_as(h_i) 47 | h_i = a_i * h_i 48 | h_i = h_i.unsqueeze(0) 49 | if attn_vectors is None: 50 | attn_vectors = h_i 51 | else: 52 | attn_vectors = torch.cat((attn_vectors, h_i), 0) 53 | return torch.sum(attn_vectors, 0).unsqueeze(0) 54 | 55 | 56 | class HAN(nn.Module): 57 | def __init__(self, args, n_classes, word_vec=None, num_tokens=None, embed_size=None): 58 | super(HAN, self).__init__() 59 | self.args = args 60 | if word_vec is None: 61 | assert num_tokens is not None and embed_size is not None 62 | self.num_tokens = num_tokens 63 | self.embed_size = embed_size 64 | else: 65 | self.num_tokens = word_vec.shape[0] 66 | self.embed_size = word_vec.shape[1] 67 | self.word_gru_hidden = args.word_gru_hidden_size 68 | self.lookup = nn.Embedding(self.num_tokens, self.embed_size) 69 | if args.pretrained_word_embed: 70 | self.lookup.weight = nn.Parameter(torch.from_numpy(word_vec).float()) 71 | self.lookup.weight.requires_grad = args.update_word_embed 72 | self.word_gru = nn.GRU(self.embed_size, self.word_gru_hidden, bidirectional=True) 73 | self.weight_W_word = nn.Parameter(torch.Tensor(2 * self.word_gru_hidden, 2 * self.word_gru_hidden)) 74 | self.bias_word = nn.Parameter(torch.Tensor(2 * self.word_gru_hidden, 1)) 75 | self.weight_proj_word = nn.Parameter(torch.Tensor(2 * self.word_gru_hidden, 1)) 76 | nn.init.uniform(self.weight_W_word, -0.1, 0.1) 77 | nn.init.uniform(self.bias_word, -0.1, 0.1) 78 | nn.init.uniform(self.weight_proj_word, -0.1, 0.1) 79 | # sentence level 80 | self.sent_gru_hidden = args.sent_gru_hidden_size 81 | self.word_gru_hidden = args.word_gru_hidden_size 82 | self.sent_gru = nn.GRU(2 * self.word_gru_hidden, self.sent_gru_hidden, bidirectional=True) 83 | self.weight_W_sent = nn.Parameter(torch.Tensor(2 * self.sent_gru_hidden, 2 * self.sent_gru_hidden)) 84 | self.bias_sent = nn.Parameter(torch.Tensor(2 * self.sent_gru_hidden, 1)) 85 | self.weight_proj_sent = nn.Parameter(torch.Tensor(2 * self.sent_gru_hidden, 1)) 86 | C = n_classes 87 | self.fc1 = nn.Linear(2 * self.sent_gru_hidden, C) 88 | nn.init.uniform(self.bias_sent, -0.1, 0.1) 89 | nn.init.uniform(self.weight_W_sent, -0.1, 0.1) 90 | nn.init.uniform(self.weight_proj_sent, -0.1, 0.1) 91 | 92 | def forward(self, mini_batch, fc=False): 93 | max_sents, batch_size, max_tokens = mini_batch.size() 94 | word_attn_vectors = None 95 | state_word = self.init_hidden(mini_batch.size()[1]) 96 | for i in range(max_sents): 97 | embed = mini_batch[i, :, :].transpose(0, 1) 98 | embedded = self.lookup(embed) 99 | output_word, state_word = self.word_gru(embedded, state_word) 100 | word_squish = batch_matmul_bias(output_word, self.weight_W_word, self.bias_word, nonlinearity='tanh') 101 | # logger.debug(word_squish.size()) torch.Size([20, 2, 200]) 102 | word_attn = batch_matmul(word_squish, self.weight_proj_word) 103 | # logger.debug(word_attn.size()) torch.Size([20, 2]) 104 | word_attn_norm = F.softmax(word_attn.transpose(1, 0), dim=-1) 105 | word_attn_vector = attention_mul(output_word, word_attn_norm.transpose(1, 0)) 106 | if word_attn_vectors is None: 107 | word_attn_vectors = word_attn_vector 108 | else: 109 | word_attn_vectors = torch.cat((word_attn_vectors, word_attn_vector), 0) 110 | # logger.debug(word_attn_vectors.size()) torch.Size([1, 2, 200]) 111 | state_sent = self.init_hidden(mini_batch.size()[1]) 112 | output_sent, state_sent = self.sent_gru(word_attn_vectors, state_sent) 113 | # logger.debug(output_sent.size()) torch.Size([8, 2, 200]) 114 | sent_squish = batch_matmul_bias(output_sent, self.weight_W_sent, self.bias_sent, nonlinearity='tanh') 115 | # logger.debug(sent_squish.size()) torch.Size([8, 2, 200]) 116 | if len(sent_squish.size()) == 2: 117 | sent_squish = sent_squish.unsqueeze(0) 118 | sent_attn = batch_matmul(sent_squish, self.weight_proj_sent) 119 | if len(sent_attn.size()) == 1: 120 | sent_attn = sent_attn.unsqueeze(0) 121 | # logger.debug(sent_attn.size()) torch.Size([8, 2]) 122 | sent_attn_norm = F.softmax(sent_attn.transpose(1, 0), dim=-1) 123 | # logger.debug(sent_attn_norm.size()) torch.Size([2, 8]) 124 | sent_attn_vectors = attention_mul(output_sent, sent_attn_norm.transpose(1, 0)) 125 | # logger.debug(sent_attn_vectors.size()) torch.Size([1, 2, 200]) 126 | x = sent_attn_vectors.squeeze(0) 127 | if fc: 128 | x = self.fc1(x) 129 | return x 130 | 131 | def init_hidden(self, batch_size, hidden_dim=None): 132 | if hidden_dim is None: 133 | hidden_dim = self.sent_gru_hidden 134 | if self.args.gpu: 135 | return Variable(torch.zeros(2, batch_size, hidden_dim)).cuda() 136 | return Variable(torch.zeros(2, batch_size, hidden_dim)) 137 | -------------------------------------------------------------------------------- /HMCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from IPython import embed 6 | 7 | """ 8 | Hierarchical Multi-Label Classification Networks 9 | http://proceedings.mlr.press/v80/wehrmann18a.html 10 | """ 11 | 12 | 13 | class HMCN(nn.Module): 14 | def __init__(self, base_model, args, neuron_each_local_l2, total_class, in_dim): 15 | super(HMCN, self).__init__() 16 | 17 | neuron_each_layer = [384] * len(neuron_each_local_l2) 18 | neuron_each_local_l1 = [384] * len(neuron_each_local_l2) 19 | self.beta = 0.5 20 | 21 | self.args = args 22 | self.base_model = base_model 23 | 24 | self.layer_num = len(neuron_each_layer) 25 | self.linear_layers = nn.ModuleList([]) 26 | self.local_linear_l1 = nn.ModuleList([]) 27 | self.local_linear_l2 = nn.ModuleList([]) 28 | self.batchnorms = nn.ModuleList([]) 29 | self.batchnorms_local_1 = nn.ModuleList([]) 30 | for idx, neuron_number in enumerate(neuron_each_layer): 31 | if idx == 0: 32 | self.linear_layers.append(nn.Linear(in_dim, neuron_number)) 33 | else: 34 | self.linear_layers.append( 35 | nn.Linear(neuron_each_layer[idx - 1] + in_dim, neuron_number)) 36 | self.batchnorms.append(nn.BatchNorm1d(neuron_number)) 37 | 38 | for idx, neuron_number in enumerate(neuron_each_local_l1): 39 | self.local_linear_l1.append( 40 | nn.Linear(neuron_each_layer[idx], neuron_each_local_l1[idx])) 41 | self.batchnorms_local_1.append( 42 | nn.BatchNorm1d(neuron_each_local_l1[idx])) 43 | for idx, neuron_number in enumerate(neuron_each_local_l2): 44 | self.local_linear_l2.append( 45 | nn.Linear(neuron_each_local_l1[idx], neuron_each_local_l2[idx])) 46 | 47 | self.final_linear_layer = nn.Linear( 48 | neuron_each_layer[-1] + in_dim, total_class) 49 | 50 | def forward(self, x): 51 | x = self.base_model(x, False) 52 | local_outputs = [] 53 | output = x 54 | for layer_idx, layer in enumerate(self.linear_layers): 55 | if layer_idx == 0: 56 | output = layer(output) 57 | output = F.relu(output) 58 | else: 59 | output = layer(torch.cat([output, x], dim=1)) 60 | output = F.relu(output) 61 | 62 | local_output = self.local_linear_l1[layer_idx](output) 63 | local_output = F.relu(local_output) 64 | local_output = self.local_linear_l2[layer_idx](local_output) 65 | local_outputs.append(local_output) 66 | 67 | global_outputs = F.sigmoid( 68 | self.final_linear_layer(torch.cat([output, x], dim=1))) 69 | local_outputs = F.sigmoid(torch.cat(local_outputs, dim=1)) 70 | 71 | output = self.beta * global_outputs + (1 - self.beta) * local_outputs 72 | 73 | return output 74 | -------------------------------------------------------------------------------- /Linear_Model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | """ 5 | Linear model used for feature input (instead of text) 6 | """ 7 | 8 | 9 | class Linear_Model(nn.Module): 10 | def __init__(self, args, input_dim, n_classes): 11 | super(Linear_Model, self).__init__() 12 | self.dropout = nn.Dropout(args.dropout) 13 | self.fc1 = nn.Linear(input_dim, args.n_hidden) 14 | self.fc2 = nn.Linear(args.n_hidden, n_classes) 15 | 16 | def forward(self, x, fc=False): 17 | # x = self.dropout(x) 18 | x = F.relu(self.fc1(x)) 19 | if fc: 20 | x = self.fc2(x) 21 | return x 22 | -------------------------------------------------------------------------------- /Logger_morning.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def myLogger(name=None, log_path='./tmp.log'): 5 | logger = logging.getLogger(name) 6 | if len(logger.handlers) != 0: 7 | print(f'[Logger_morning.py]reuse logger:{name}') 8 | return logger 9 | logging.basicConfig(level=logging.DEBUG, 10 | format='%(asctime)s %(filename)s[line:%(lineno)d][%(funcName)s] %(levelname)s-> %(message)s', 11 | datefmt='%a %d %b %Y %H:%M:%S', filename=log_path, filemode='a') 12 | # define a new Handler to log to console as well 13 | console = logging.StreamHandler() 14 | # set a format which is the same for console use 15 | formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d][%(funcName)s] %(levelname)s-> %(message)s', 16 | datefmt='%a %d %b %Y %H:%M:%S') 17 | # tell the handler to use this format 18 | console.setFormatter(formatter) 19 | # add the handler to the root logger 20 | logger.addHandler(console) 21 | logger.info(f'create new logger:{name}') 22 | logger.info(f'saving log to {log_path}') 23 | return logger 24 | 25 | 26 | if __name__ == '__main__': 27 | logger = myLogger('111', log_path='runs/') 28 | logger.setLevel(10) 29 | logger.info('This is info message') 30 | logger.warning('This is warning message') 31 | 32 | logger = myLogger('111') 33 | logger.setLevel(10) 34 | logger.info('T2') 35 | logger.warning('This2') 36 | -------------------------------------------------------------------------------- /OHCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Effective Use of Word Order for Text Categorization with Convolutional Neural Networks 7 | http://www.anthology.aclweb.org/N/N15/N15-1011.pdf 8 | """ 9 | 10 | 11 | class OHCNN(nn.Module): 12 | 13 | def __init__(self, args, n_classes): 14 | super(OHCNN, self).__init__() 15 | self.args = args 16 | D = 30000 17 | C = n_classes 18 | Ci = 1 19 | Co = 1000 20 | self.Co = Co 21 | self.n_pool = 10 22 | if args.mode == 'ohcnn-seq': 23 | Ks = [3] 24 | self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D), stride=1, padding=(K - 1, 0)) for K in Ks]) 25 | else: 26 | Ks = [1] 27 | self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D), stride=1) for K in Ks]) 28 | self.dropout = nn.Dropout(0.5) 29 | # self.lrn = nn.LocalResponseNorm(2) 30 | self.fc1 = nn.Linear(len(Ks) * Co * self.n_pool, C) 31 | 32 | def forward(self, x): 33 | x = x.unsqueeze(1) # (N, Ci, W, D) 34 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] # [(N, Co, W), ...]*len(Ks) 35 | # x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N, Co), ...]*len(Ks) 36 | # [(N, Co * n_pool), ...]*len(Ks) 37 | x = [F.avg_pool1d(i, int(i.size(2) / self.n_pool)).view(-1, self.n_pool * self.Co) for i in x] 38 | x = torch.cat(x, 1) 39 | x = self.dropout(x) # (N, len(Ks)*Co) 40 | # response norm 41 | x /= (1 + x.pow(2).sum(1)).sqrt().view(-1, 1) 42 | logit = self.fc1(x) # (N, C) 43 | return logit 44 | -------------------------------------------------------------------------------- /OHCNN_fast.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | """ 8 | Effective Use of Word Order for Text Categorization with Convolutional Neural Networks 9 | http://www.anthology.aclweb.org/N/N15/N15-1011.pdf 10 | use equivalent embeddings for faster speed (2~3x) 11 | """ 12 | 13 | 14 | class OHCNN_fast(nn.Module): 15 | 16 | def __init__(self, unk_idx, n_classes, vocab_size): 17 | super(OHCNN_fast, self).__init__() 18 | # D = 30001 19 | print(f'vocab_size:{vocab_size}') 20 | D = vocab_size 21 | C = n_classes 22 | Co = 1000 23 | self.Co = Co 24 | self.n_pool = 10 25 | self.embed = nn.Embedding(D, Co) 26 | self.bias = nn.Parameter(torch.Tensor(1, Co, 1)) 27 | self.dropout = nn.Dropout(0.5) 28 | self.fc1 = nn.Linear(Co * self.n_pool, C) 29 | self.unk_idx = unk_idx 30 | # init as in cnn 31 | stdv = 1. / math.sqrt(D) 32 | self.embed.weight.data.uniform_(-stdv, stdv) 33 | if self.bias is not None: 34 | self.bias.data.uniform_(-stdv, stdv) 35 | 36 | def forward(self, x, fc=False): 37 | # (32, 256, 20) 38 | sent_len = x.size(1) 39 | x = x.view(x.size(0), -1) 40 | x_embed = self.embed(x) # (N, W * D, Co) 41 | # deal with unk in the region 42 | x = (x != self.unk_idx).float().unsqueeze(-1) * x_embed 43 | x = x.view(x.size(0), sent_len, -1, self.Co) # (N, W, D, Co) 44 | x = F.relu(x.sum(2).permute(0, 2, 1) + self.bias) # (N, Co, W) 45 | x = F.avg_pool1d(x, int(x.size(2) / self.n_pool)).view(-1, self.n_pool * self.Co) # (N, n_pool * Co) 46 | x = self.dropout(x) 47 | # response norm 48 | x /= (1 + x.pow(2).sum(1)).sqrt().view(-1, 1) 49 | if fc: 50 | x = self.fc1(x) # (N, C) 51 | return x 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo provides the code with paper ["Hierarchical Text Classification with Reinforced Label Assignment"](https://arxiv.org/abs/1908.10419) EMNLP 2019. 2 | 3 | prediction_animation 4 | 5 | HiLAP_architecture 6 | 7 | ## Abstract 8 | 9 | While existing hierarchical text classification (HTC) methods attempt to capture label hierarchies for model training, they either make local decisions regarding each label or completely ignore the hierarchy information during inference. To solve the mismatch between training and inference as well as modeling label dependencies in a more principled way, we formulate HTC as a Markov decision process and propose to learn a **L**abel **A**ssignment **P**olicy via deep reinforcement learning to determine *where to place* an object and *when to stop* the assignment process. The proposed method, **HiLAP**, explores the hierarchy during both training and inference time in a *consistent* manner and makes *inter-dependent* decisions. As a general framework, HiLAP can incorporate different neural encoders as *base models* for end-to-end training. Experiments on five public datasets and four base models show that HiLAP yields an average improvement of 33.4% in Macro-F1 over flat classifiers and outperforms state-of-the-art HTC methods by a large margin. 10 | 11 | ## Model 12 | 13 | `model.py`: The main model of HiLAP. 14 | 15 | `TextCNN.py`: Our implementation of "Convolutional Neural Networks for Sentence Classification" EMNLP 2014. 16 | 17 | `OHCNN(_fast).py`: Our implementation of "Effective Use of Word Order for Text Categorization with Convolutional Neural Networks" NAACL 2015. 18 | 19 | `HAN.py`: Our implementation of "Hierarchical Attention Networks for Document Classification" NAACL 2016. 20 | 21 | `HMCN.py`: Our implementation of "Hierarchical Multi-Label Classification Networks" ICML 2018. 22 | 23 | ## Requirements 24 | 25 | Python **3** 26 | 27 | PyTorch **0.3** 28 | 29 | ## Data 30 | 31 | Due to copyright issues, we can't directly release the datasets used in our experiments. 32 | Instead, we provide the links to the five data sources (the first two may require license): 33 | 34 | - RCV1 [original release](http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/lyrl2004_rcv1v2_README.htm), [text data](https://trec.nist.gov/data/reuters/reuters.html) (**update:** download the text data and convert to docs.txt with format "docid content") 35 | - [NYT](https://catalog.ldc.upenn.edu/LDC2008T19) 36 | - [Yelp](https://www.yelp.com/dataset/challenge) (**update:** the latest release is different from what we used, pls send an email if you need the version we used) 37 | - [FunGO](https://dtai.cs.kuleuven.be/clus/hmcdatasets/) 38 | 39 | Please check `readData_*.py` to see how to use our scripts to process and generate the datasets from the original data. 40 | 41 | ## Run 42 | All the parameters in `conf.py` have default values. Change parameters `mode`, `base_model`, and `dataset` and then run `main.py` to train or test on different settings. To test a model, set `load_model=model_file` & `is_Train=False` in `conf.py` and run `main.py`. 43 | 44 | ## Cite 45 | 46 | ``` 47 | @inproceedings{mao-etal-2019-hierarchical, 48 | title = "Hierarchical Text Classification with Reinforced Label Assignment", 49 | author = "Mao, Yuning and 50 | Tian, Jingjing and 51 | Han, Jiawei and 52 | Ren, Xiang", 53 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)", 54 | month = nov, 55 | year = "2019", 56 | address = "Hong Kong, China", 57 | publisher = "Association for Computational Linguistics", 58 | url = "https://www.aclweb.org/anthology/D19-1042", 59 | doi = "10.18653/v1/D19-1042", 60 | pages = "445--455", 61 | } 62 | ``` 63 | 64 | -------------------------------------------------------------------------------- /TextCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Convolutional Neural Networks for Sentence Classification 7 | http://www.aclweb.org/anthology/D14-1181 8 | """ 9 | 10 | 11 | class TextCNN(nn.Module): 12 | def __init__(self, args, word_vec, n_classes): 13 | super(TextCNN, self).__init__() 14 | # V = args.embed_num 15 | # D = args.embed_dim 16 | # C = args.class_num 17 | # Ci = 1 18 | # Co = args.kernel_num 19 | # Ks = args.kernel_sizes 20 | V = word_vec.shape[0] 21 | D = word_vec.shape[1] 22 | C = n_classes 23 | Ci = 1 24 | self.Co = 1000 25 | Ks = [3, 4, 5] 26 | 27 | self.embed = nn.Embedding(V, D) 28 | if args.pretrained_word_embed: 29 | self.embed.weight = nn.Parameter(torch.from_numpy(word_vec).float()) 30 | self.embed.weight.requires_grad = args.update_word_embed 31 | self.convs1 = nn.ModuleList([nn.Conv2d(Ci, self.Co, (K, D)) for K in Ks]) 32 | self.dropout = nn.Dropout(0.) 33 | self.fc1 = nn.Linear(len(Ks) * self.Co, C) 34 | 35 | def forward(self, x, fc=False): 36 | x = self.embed(x) # (N, W, D) 37 | x = x.unsqueeze(1) # (N, Ci, W, D) 38 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] # [(N, Co, W), ...]*len(Ks) 39 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N, Co), ...]*len(Ks) 40 | x = torch.cat(x, 1) 41 | x = self.dropout(x) # (N, len(Ks)*Co) 42 | if fc: 43 | x = self.fc1(x) # (N, C) 44 | return x 45 | -------------------------------------------------------------------------------- /clean_runs.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | """ 5 | remove folders under runs/ by folder size and date 6 | """ 7 | 8 | 9 | def get_size(start_path='.'): 10 | total_size = 0 11 | for dirpath, dirnames, filenames in os.walk(start_path): 12 | for f in filenames: 13 | fp = os.path.join(dirpath, f) 14 | total_size += os.path.getsize(fp) 15 | return total_size 16 | 17 | 18 | now = datetime.datetime.now().strftime('%b%d') 19 | for dirpath, dirnames, filenames in os.walk('runs2/'): 20 | for dirname in dirnames: 21 | path = os.path.join(dirpath, dirname) 22 | size = get_size(path) 23 | # print(path, size) 24 | # rm folder that < 100KB and not today 25 | if (size < 100000 or 'del' in path) and now not in path: 26 | print('rm {} with size {}KB'.format(path, size / 1000)) 27 | os.system("rm -rf '{}'".format(path)) 28 | 29 | for filename in filenames: 30 | path = os.path.join(dirpath, filename) 31 | if 'del' in path and now not in path: 32 | print(f'rm {path}') 33 | os.system("rm -rf '{}'".format(path)) 34 | break 35 | -------------------------------------------------------------------------------- /conf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def print_config(config, logger=None): 5 | config = vars(config) 6 | info = "Running with the following configs:\n" 7 | for k, v in config.items(): 8 | info += "\t{} : {}\n".format(k, str(v)) 9 | if not logger: 10 | print("\n" + info + "\n") 11 | else: 12 | logger.info("\n" + info + "\n") 13 | 14 | 15 | def conf(): 16 | ap = argparse.ArgumentParser() 17 | # change the parameters to test different method/base_model/dataset combinations 18 | ap.add_argument('--mode', default='sl', choices=['hilap', 'sl', 'hilap-sl', 'hmcn']) 19 | ap.add_argument('--base_model', default='textcnn', choices=['han', 'textcnn', 'ohcnn-bow-fast', 'raw']) 20 | ap.add_argument('--dataset', default='rcv1', choices=['yelp', 'rcv1', 'nyt', 'cellcycle_FUN']) 21 | ap.add_argument('--isTrain', default=False, help='True for continuing training') 22 | ap.add_argument('--load_model', default=None) 23 | ap.add_argument('--remark', default='del', help='reminder of this run') 24 | 25 | # most of following parameters do not need any changes 26 | ap.add_argument('--lr', default=1e-3, help='learning rate 1e-3 for ohcnn, textcnn, 1e-1 for han') 27 | ap.add_argument('--l2_weight', default=1e-6, help='weight decay of optimizer') 28 | ap.add_argument('--save_every', default=10, help='evaluate and save model every k epochs') 29 | ap.add_argument('--num_epoch', default=50) 30 | ap.add_argument('--word_gru_hidden_size', default=50) 31 | ap.add_argument('--sent_gru_hidden_size', default=50) 32 | ap.add_argument('--hist_embed_size', default=50) 33 | ap.add_argument('--update_beta_every', default=500) 34 | ap.add_argument('--pretrained_word_embed', default=True) 35 | ap.add_argument('--update_word_embed', default=True) 36 | ap.add_argument('--allow_stay', default=True, 37 | help='if sample_mode=random, has to be False in case select prob=0->nan') 38 | ap.add_argument('--sample_mode', default='normal', choices=['choose_max', 'random', 'normal']) 39 | ap.add_argument('--batch_size', default=32) 40 | ap.add_argument('--batch_size_test', default=32) 41 | ap.add_argument('--log_level', default=20) 42 | ap.add_argument('--stat_check', default=False, help='calculate and print some stats of the data') 43 | ap.add_argument('--max_tokens', default=256, help='max size of tokens') 44 | ap.add_argument('--debug', default=False, help='if True, run some epochs on the FIRST batch') 45 | ap.add_argument('--gamma', default=.9, help='discounting factor') 46 | ap.add_argument('--beta', default=2, help='weight of entropy') 47 | ap.add_argument('--use_cur_class_embed', default=True, help='add embedding of current class to state embedding') 48 | ap.add_argument('--use_l1', default=True) 49 | ap.add_argument('--use_l2', default=True, help='only valid when use_l1=True') 50 | ap.add_argument('--l1_size', default=500, help='output size of l1. only valid when use_l2=True') 51 | ap.add_argument('--class_embed_size', default=50) 52 | ap.add_argument('--softmax', default=True, choices='softmax or sigmoid') 53 | ap.add_argument('--sl_ratio', default=1, help='[0 = off] for rl-taxo: sl_loss = rl_loss + sl_ratio * sl_loss') 54 | ap.add_argument('--global_ratio', default=0.5, 55 | help='[0 = off]for step-sl: sl_loss = (1-global_ratio) * local_loss + global_ratio * global_loss') 56 | ap.add_argument('--gpu', default=True) 57 | ap.add_argument('--n_rollouts', default=1) 58 | ap.add_argument('--reward', default='f1', choices=['01', '01-1', 'f1', 'direct', 'taxo']) 59 | ap.add_argument('--early_stop', default=False, help='for rl-taxo only') 60 | ap.add_argument('--baseline', default='greedy', choices=[None, 'avg', 'greedy']) 61 | ap.add_argument('--avg_reward_mode', default='batch', choices=['off', 'each', 'batch'], 62 | help='if n_step=1, cannot be each ->nan') 63 | ap.add_argument('--min_mem', default=3000, help='minimum gpu memory requirement (MB)') 64 | # outdated parameters 65 | ap.add_argument('--allow_up', default=False, help='not used anymore') 66 | ap.add_argument('--use_history', default=False, help='not used anymore') 67 | ap.add_argument('--multi_label', default=True, help='whether predict multi labels. valid for sl and step-sl') 68 | ap.add_argument('--split_multi', default=False, 69 | help='split one sample with m labels to m samples, only valid when filter_ancestors=True') 70 | ap.add_argument('--mix_flat_probs', default=False, help='add flat prob to help rl') 71 | args = ap.parse_args() 72 | args.use_history &= (args.mode == 'rl') 73 | if args.dataset == 'rcv1': 74 | args.ohcnn_data = '_rcv1_len256_padded' # where to load ohcnn cache 75 | args.n_steps_sl = 4 # steps of step-sl 76 | args.n_steps = 17 # steps of rl 77 | args.output_every = 723 # output metrics every k batches 78 | elif args.dataset == 'yelp': 79 | args.ohcnn_data = '_yelp_root_100_5_10_len256_padded' 80 | args.n_steps_sl = 4 81 | args.n_steps = 10 82 | args.output_every = 2730 83 | elif args.dataset == 'nyt': 84 | args.ohcnn_data = '_nyt' 85 | args.n_steps_sl = 3 86 | args.n_steps = 20 87 | args.output_every = 789 88 | elif 'FUN' in args.dataset: 89 | args.ohcnn_data = '_cellcycle_FUN_root_padded' 90 | args.n_steps_sl = 6 91 | args.n_steps = 45 92 | args.output_every = 100 93 | args.n_hidden = 150 94 | args.l1_size = 1000 95 | args.class_embed_size = 1000 96 | args.dropout = 0 97 | elif 'GO' in args.dataset: 98 | args.ohcnn_data = '_cellcycle_GO_root_padded' 99 | args.n_steps_sl = 14 100 | args.n_steps = 45 101 | args.output_every = 100 102 | args.n_hidden = 150 103 | args.l1_size = 1000 104 | args.class_embed_size = 1000 105 | args.dropout = 0 106 | if args.n_steps == 1: 107 | args.avg_reward_mode = 'off' 108 | if args.mode in ['sl', 'hmcn']: 109 | args.filter_ancestors = False # 'only use lowest labels as gold for training' 110 | else: 111 | args.filter_ancestors = True 112 | return args 113 | -------------------------------------------------------------------------------- /feature_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.datasets import fetch_rcv1 4 | from sklearn.feature_extraction.text import TfidfVectorizer 5 | from sklearn.preprocessing import MultiLabelBinarizer 6 | from torch.utils.data import Dataset 7 | 8 | from readData_fungo import read_fungo 9 | from readData_nyt import read_nyt 10 | from readData_yelp import read_yelp 11 | 12 | 13 | def rcv1_test(rcv1): 14 | X_train = rcv1.data[:23149] 15 | Y_train = rcv1.target[:23149] 16 | X_test = rcv1.data[23149:] 17 | Y_test = rcv1.target[23149:] 18 | return X_train, Y_train, X_test, Y_test 19 | 20 | 21 | def yelp_test(): 22 | subtree_name = 'root' 23 | X_train, X_test, train_ids, test_ids, business_dict, nodes = read_yelp(subtree_name, 5, 10) 24 | print(f'#training={len(train_ids)} #test={len(test_ids)}') 25 | n_tokens = 256 26 | print(f'use only first {n_tokens} tokens') 27 | X_train = [' '.join(i.split()[:n_tokens]) for i in X_train] 28 | X_test = [' '.join(i.split()[:n_tokens]) for i in X_test] 29 | print('fit_transform...') 30 | tf = TfidfVectorizer() 31 | X_train = tf.fit_transform(X_train) 32 | X_test = tf.transform(X_test) 33 | Y_train = [business_dict[bid]['categories'] for bid in train_ids] 34 | Y_test = [business_dict[bid]['categories'] for bid in test_ids] 35 | mlb = MultiLabelBinarizer() 36 | Y_train = mlb.fit_transform(Y_train) 37 | Y_test = mlb.transform(Y_test) 38 | return X_train, Y_train, X_test, Y_test, train_ids, test_ids 39 | 40 | 41 | def nyt_test(): 42 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_nyt() 43 | print(f'#training={len(train_ids)} #test={len(test_ids)}') 44 | n_tokens = 256 45 | print(f'use only first {n_tokens} tokens') 46 | X_train = [' '.join(i.split()[:n_tokens]) for i in X_train] 47 | X_test = [' '.join(i.split()[:n_tokens]) for i in X_test] 48 | print('fit_transform...') 49 | tf = TfidfVectorizer() 50 | X_train = tf.fit_transform(X_train) 51 | X_test = tf.transform(X_test) 52 | Y_train = [id2doc[bid]['categories'] for bid in train_ids] 53 | Y_test = [id2doc[bid]['categories'] for bid in test_ids] 54 | mlb = MultiLabelBinarizer() 55 | Y_train = mlb.fit_transform(Y_train) 56 | Y_test = mlb.transform(Y_test) 57 | return X_train, Y_train, X_test, Y_test, train_ids, test_ids 58 | 59 | 60 | def fungo_test(data_name): 61 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_fungo(data_name) 62 | Y_train = [id2doc[bid]['categories'] for bid in train_ids] 63 | Y_test = [id2doc[bid]['categories'] for bid in test_ids] 64 | # Actually here Y is not used. We use id2doc for labels. 65 | mlb = MultiLabelBinarizer() 66 | Y = mlb.fit_transform(np.concatenate([Y_train, Y_test])) 67 | Y_train = Y[:len(Y_train)] 68 | Y_test = Y[-len(Y_test):] 69 | 70 | return X_train, Y_train, X_test, Y_test, train_ids, test_ids 71 | 72 | 73 | def my_collate(batch): 74 | features = torch.FloatTensor([item[0] for item in batch]) 75 | labels = [item[1] for item in batch] 76 | return [features, labels] 77 | 78 | 79 | class featureDataset(Dataset): 80 | def __init__(self, data_name, train=True): 81 | self.train = train 82 | self.data = data_name 83 | if data_name == 'rcv1': 84 | self.rcv1 = fetch_rcv1() 85 | X_train, Y_train, X_test, Y_test = rcv1_test(self.rcv1) 86 | if train: 87 | self.samples = X_train 88 | else: 89 | self.samples = X_test 90 | else: 91 | if data_name == 'yelp': 92 | X_train, Y_train, X_test, Y_test, train_ids, test_ids = yelp_test() 93 | elif data_name == 'nyt': 94 | X_train, Y_train, X_test, Y_test, train_ids, test_ids = nyt_test() 95 | else: 96 | X_train, Y_train, X_test, Y_test, train_ids, test_ids = fungo_test(data_name) 97 | if train: 98 | self.samples = X_train 99 | self.ids = train_ids 100 | else: 101 | self.samples = X_test 102 | self.ids = test_ids 103 | 104 | def __len__(self): 105 | if 'FUN' in self.data or 'GO' in self.data: 106 | return len(self.samples) 107 | if self.data == 'fungo': 108 | return len(self.samples) 109 | return self.samples.shape[0] 110 | 111 | def __getitem__(self, item): 112 | if self.data == 'rcv1': 113 | if self.train: 114 | vector, label = self.samples[item].todense().tolist()[0], str(int(self.rcv1.sample_id[item])) 115 | else: 116 | vector, label = self.samples[item].todense().tolist()[0], str(int(self.rcv1.sample_id[item + 23149])) 117 | elif self.data in ['yelp', 'nyt']: 118 | vector, label = self.samples[item].todense().tolist()[0], self.ids[item].strip().strip('\n') 119 | elif 'FUN' in self.data or 'GO' in self.data or self.data == 'fungo': 120 | vector, label = self.samples[item], self.ids[item] 121 | return vector, label 122 | -------------------------------------------------------------------------------- /features_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.feature_extraction.text import TfidfVectorizer 3 | from sklearn.metrics import f1_score 4 | from sklearn.multiclass import OneVsRestClassifier 5 | from sklearn.preprocessing import MultiLabelBinarizer 6 | from sklearn.svm import LinearSVC 7 | 8 | from conf import conf 9 | from feature_dataset import fungo_test 10 | from loadData import load_data_rcv1 11 | from readData_fungo import read_fungo 12 | from readData_yelp import read_yelp 13 | from tree import Tree 14 | 15 | 16 | def evaluate(pred, Y): 17 | print(f1_score(Y, pred, average='micro')) 18 | print(f1_score(Y, pred, average='macro')) 19 | print(f1_score(Y, pred, average='samples')) 20 | # print(classification_report(Y, pred)) 21 | 22 | 23 | def print_some(pred, Y, k=500): 24 | pred_tmp = pred[:k].todense().tolist() 25 | from sklearn.datasets import fetch_rcv1 26 | rcv1 = fetch_rcv1() 27 | for tmp, y in zip(pred_tmp, Y.todense().tolist()): 28 | for i in range(len(tmp)): 29 | if tmp[i] == 1: 30 | print(rcv1.target_names[i], end=' ') 31 | print() 32 | for i in range(len(tmp)): 33 | if y[i] == 1: 34 | print(rcv1.target_names[i], end=' ') 35 | print('---') 36 | 37 | 38 | def run(X_train, Y_train, X_test, Y_test): 39 | print('start training') 40 | model = OneVsRestClassifier(LinearSVC(loss='hinge'), n_jobs=5) 41 | # model = OneVsRestClassifier(LogisticRegression(solver='lbfgs', multi_class='multinomial')) 42 | model.fit(X_train, Y_train) 43 | pred = model.predict(X_train) 44 | print('eval training') 45 | # print_some(pred, Y_train, 50) 46 | evaluate(pred, Y_train) 47 | print('eval testing') 48 | pred = model.predict(X_test) 49 | evaluate(pred, Y_test) 50 | # print_some(pred, Y_test, 50) 51 | 52 | 53 | def rcv1_test(): 54 | from sklearn.datasets import fetch_rcv1 55 | rcv1 = fetch_rcv1() 56 | X_train = rcv1.data[:23149] 57 | Y_train = rcv1.target[:23149] 58 | X_test = rcv1.data[23149:] 59 | Y_test = rcv1.target[23149:] 60 | print(Y_train[:2]) 61 | print(rcv1.target_names[34], rcv1.target_names[59]) 62 | return X_train, Y_train, X_test, Y_test 63 | 64 | 65 | def yelp_test(): 66 | subtree_name = 'root' 67 | X_train, X_test, train_ids, test_ids, business_dict, nodes = read_yelp(subtree_name, 5, 10) 68 | print(f'#training={len(train_ids)} #test={len(test_ids)}') 69 | n_tokens = 256 70 | print(f'use only first {n_tokens} tokens') 71 | X_train = [' '.join(i.split()[:n_tokens]) for i in X_train] 72 | X_test = [' '.join(i.split()[:n_tokens]) for i in X_test] 73 | print('fit_transform...') 74 | tf = TfidfVectorizer() 75 | X_train = tf.fit_transform(X_train) 76 | X_test = tf.transform(X_test) 77 | Y_train = [business_dict[bid]['categories'] for bid in train_ids] 78 | Y_test = [business_dict[bid]['categories'] for bid in test_ids] 79 | mlb = MultiLabelBinarizer() 80 | Y_train = mlb.fit_transform(Y_train) 81 | Y_test = mlb.transform(Y_test) 82 | return X_train, Y_train, X_test, Y_test 83 | 84 | 85 | def fungo_test_wrapper(name='cellcycle_FUN'): 86 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_fungo(name) 87 | X_train, X_test = np.array(X_train), np.array(X_test) 88 | id2doc_train = id2doc 89 | args = conf() 90 | # id2doc_train = filter_ancestors(id2doc, nodes) 91 | tree = Tree(args, train_ids, test_ids, id2doc=id2doc_train, id2doc_a=id2doc, nodes=nodes, rootname='Top') 92 | mlb = MultiLabelBinarizer(classes=tree.class_idx) 93 | Y_train = mlb.fit_transform([tree.id2doc_ancestors[docid]['class_idx'] for docid in train_ids]) 94 | Y_test = mlb.transform([tree.id2doc_ancestors[docid]['class_idx'] for docid in test_ids]) 95 | return X_train, Y_train, X_test, Y_test 96 | 97 | 98 | if __name__ == '__main__': 99 | X_train, Y_train, X_test, Y_test = rcv1_test() 100 | # X_train, Y_train, X_test, Y_test = fungo_test_wrapper('cellcycle_FUN') 101 | run(X_train, Y_train, X_test, Y_test) 102 | -------------------------------------------------------------------------------- /fig/HiLAP_architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/morningmoni/HiLAP/19b63121c37c6a1e3367062925f2b1d7ad9cc87f/fig/HiLAP_architecture.jpg -------------------------------------------------------------------------------- /fig/prediction_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/morningmoni/HiLAP/19b63121c37c6a1e3367062925f2b1d7ad9cc87f/fig/prediction_animation.gif -------------------------------------------------------------------------------- /loadData.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import spacy 8 | from nltk.tokenize import sent_tokenize 9 | from tqdm import tqdm 10 | 11 | from Logger_morning import myLogger 12 | from readData_rcv1 import read_rcv1 13 | from readData_yelp import read_yelp 14 | from readData_nyt import read_nyt 15 | 16 | spacy_en = spacy.load('en') 17 | logger = myLogger('exp') 18 | 19 | 20 | def tokenizer(text): # create a tokenizer function 21 | # return text.lower().split() 22 | return [tok.text.lower() for tok in spacy_en.tokenizer(text)] 23 | 24 | 25 | def read_word_embed(pre_trained_path, biovec=False): 26 | logger.info('loading pre-trained embedding from {}'.format(pre_trained_path)) 27 | if not biovec: 28 | with open(pre_trained_path) as f: 29 | words, vectors = zip(*[line.strip().split(' ', 1) for line in f]) 30 | wv = np.loadtxt(vectors) 31 | else: 32 | with open(pre_trained_path + 'types.txt') as f: 33 | words = [line.strip() for line in f] 34 | wv = np.loadtxt(pre_trained_path + 'vectors.txt') 35 | return words, wv 36 | 37 | 38 | def prepare_word(pre_trained_path, vocab, biovec): 39 | words, wv = read_word_embed(pre_trained_path, biovec) 40 | unknown_vector = np.random.random_sample((wv.shape[1],)) 41 | word_set = set(words) 42 | unknown_words = list(set(vocab).difference(set(words))) 43 | logger.info('there are {} OOV words'.format(len(unknown_words))) 44 | word_index = {w: i for i, w in enumerate(words)} 45 | unknown_word_vectors = [np.add.reduce([wv[word_index[w]] if w in word_set else unknown_vector 46 | for w in word.split(' ')]) 47 | for word in unknown_words] 48 | wv = np.vstack((wv, unknown_word_vectors)) 49 | words = list(words) + unknown_words 50 | # Normalize each row (word vector) in the matrix to sum-up to 1 51 | row_norm = np.sum(np.abs(wv) ** 2, axis=-1) ** (1. / 2) 52 | wv /= row_norm[:, np.newaxis] 53 | 54 | word_index = {w: i for i, w in enumerate(words)} 55 | return wv, word_index 56 | 57 | 58 | def prepare_word_filter(pre_trained_path, vocab, biovec, shared_unk=False, norm=True): 59 | words, wv = read_word_embed(pre_trained_path, biovec) 60 | known_words = set(vocab) & set(words) 61 | unknown_words = set(vocab).difference(set(words)) 62 | logger.info('there are {} OOV words'.format(len(unknown_words))) 63 | 64 | word_index = {w: i for i, w in enumerate(words)} 65 | new_word_index = {} 66 | 67 | filter_idx = [] 68 | if shared_unk: 69 | ct = 1 70 | else: 71 | ct = 0 72 | for word in known_words: 73 | new_word_index[word] = ct 74 | filter_idx.append(word_index[word]) 75 | ct += 1 76 | for word in unknown_words: 77 | if shared_unk: 78 | new_word_index[word] = 0 79 | else: 80 | new_word_index[word] = ct 81 | ct += 1 82 | 83 | wv = wv[filter_idx] 84 | if shared_unk: 85 | unknown_vector = np.random.random_sample((wv.shape[1],)) - .5 86 | wv = np.vstack((unknown_vector, wv)) 87 | else: 88 | unknown_vectors = np.random.random_sample((len(unknown_words), wv.shape[1])) - .5 89 | wv = np.vstack((wv, unknown_vectors)) 90 | 91 | # Normalize each row (word vector) in the matrix to sum-up to 1 92 | if norm: 93 | row_norm = np.sum(np.abs(wv) ** 2, axis=-1) ** (1. / 2) 94 | wv /= row_norm[:, np.newaxis] 95 | 96 | return wv, new_word_index 97 | 98 | 99 | def tokenize_text_normal(X_l): 100 | X = [] 101 | vocab = set() 102 | for text in tqdm(X_l): 103 | X.append([]) 104 | for sent in sent_tokenize(text): 105 | t = tokenizer(sent) 106 | vocab.update(t) 107 | X[-1].append(t) 108 | return X, vocab 109 | 110 | 111 | def oov2unk(X_train, X_test, vocab_train, vocab_test, vocab_size=None, word_min_sup=5, count_test=False, filter=False): 112 | if not vocab_size: 113 | logger.info('word_min_sup={}'.format(word_min_sup)) 114 | vocab = vocab_train | vocab_test 115 | logger.info( 116 | "[before]vocab_train:{} vocab_test:{} vocab:{}".format(len(vocab_train), len(vocab_test), len(vocab))) 117 | word_count = defaultdict(int) 118 | if filter: 119 | word_w_num = set() 120 | with open('rcv1/rcv1_stopword.txt') as f: 121 | stop_word = set([line.strip() for line in f]) 122 | for doc in X_train: 123 | for sent in doc: 124 | for word in sent: 125 | if filter: 126 | if word in stop_word: 127 | continue 128 | if any(c.isdigit() for c in word): 129 | word_w_num.add(word) 130 | continue 131 | word_count[word] += 1 132 | if filter: 133 | logger.info('removeNumbers {} {}'.format(len(word_w_num), word_w_num)) 134 | if count_test: 135 | for doc in X_test: 136 | for sent in doc: 137 | for word in sent: 138 | word_count[word] += 1 139 | if vocab_size: 140 | logger.info(f'limit vocab_size={vocab_size}') 141 | vocab_by_freq = set([k for k in sorted(word_count, key=word_count.get, reverse=True)][:vocab_size]) 142 | X_train = [[[word if word in vocab_by_freq else 'UNK' for word in sent] for sent in doc] for doc in 143 | X_train] 144 | X_test = [[[word if word in vocab_by_freq else 'UNK' for word in sent] for sent in doc] for doc in 145 | X_test] 146 | else: 147 | X_train = [[[word if word_count[word] >= word_min_sup else 'UNK' for word in sent] for sent in doc] for doc in 148 | X_train] 149 | X_test = [[[word if word_count[word] >= word_min_sup else 'UNK' for word in sent] for sent in doc] for doc in 150 | X_test] 151 | logger.info('Tokenization may be poor. Next are some examples:') 152 | logger.info(X_train[:5]) 153 | vocab_train = set([word for doc in X_train for sent in doc for word in sent]) 154 | vocab_test = set([word for doc in X_test for sent in doc for word in sent]) 155 | vocab = vocab_train | vocab_test 156 | logger.info( 157 | "[after]vocab_train:{} vocab_test:{} vocab:{}".format(len(vocab_train), len(vocab_test), len(vocab))) 158 | return X_train, X_test, vocab 159 | 160 | 161 | def load_data_rcv1_onehot(suffix): 162 | if os.path.exists('preload_data{}_onehot.pkl'.format(suffix)): 163 | logger.warn('loading from preload_data{}_onehot.pkl'.format(suffix)) 164 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load( 165 | open('preload_data{}_onehot.pkl'.format(suffix), 'rb')) 166 | return X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes 167 | 168 | _, _, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load( 169 | open('preload_data{}.pkl'.format(suffix), 'rb')) 170 | X_train, vocab_train, X_test, vocab_test = pickle.load(open('text_tokenized{}.pkl'.format(suffix), 'rb')) 171 | X_train, X_test, vocab = oov2unk(X_train, X_test, vocab_train, vocab_test, vocab_size=30000, filter=True) 172 | word_index = {w: ct for ct, w in enumerate(vocab)} 173 | X_train = [[[word_index[word] for word in sent] for sent in doc] for doc in X_train] 174 | X_test = [[[word_index[word] for word in sent] for sent in doc] for doc in X_test] 175 | logger.info('saving preload_data{}_onehot.pkl'.format(suffix)) 176 | res = np.array(X_train), np.array(X_test), np.array(train_ids), np.array(test_ids), id2doc, wv, word_index, nodes 177 | pickle.dump(res, open('preload_data{}_onehot.pkl'.format(suffix), 'wb')) 178 | return res 179 | 180 | 181 | def load_data_rcv1(embedding_path, suffix): 182 | if os.path.exists('preload_data{}.pkl'.format(suffix)): 183 | logger.warn('loading from preload_data{}.pkl'.format(suffix)) 184 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load( 185 | open('preload_data{}.pkl'.format(suffix), 'rb')) 186 | return X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes 187 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_rcv1() 188 | if os.path.exists('text_tokenized{}.pkl'.format(suffix)): 189 | X_train, vocab_train, X_test, vocab_test = pickle.load(open('text_tokenized{}.pkl'.format(suffix), 'rb')) 190 | else: 191 | X_train, vocab_train = tokenize_text_normal(X_train) 192 | X_test, vocab_test = tokenize_text_normal(X_test) 193 | res = X_train, vocab_train, X_test, vocab_test 194 | pickle.dump(res, open('text_tokenized{}.pkl'.format(suffix), 'wb')) 195 | X_train, X_test, vocab = oov2unk(X_train, X_test, vocab_train, vocab_test, count_test=True) 196 | wv, word_index = prepare_word_filter(embedding_path, vocab, biovec=False) 197 | X_train = [[[word_index[word] for word in sent] for sent in doc] for doc in X_train] 198 | X_test = [[[word_index[word] for word in sent] for sent in doc] for doc in X_test] 199 | logger.info('saving preload_data{}.pkl'.format(suffix)) 200 | res = np.array(X_train), np.array(X_test), np.array(train_ids), np.array(test_ids), id2doc, wv, word_index, nodes 201 | pickle.dump(res, open('preload_data{}.pkl'.format(suffix), 'wb')) 202 | return res 203 | 204 | 205 | def load_data_nyt_onehot(suffix): 206 | if os.path.exists('preload_data{}_onehot.pkl'.format(suffix)): 207 | logger.warn('loading from preload_data{}_onehot.pkl'.format(suffix)) 208 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load( 209 | open('preload_data{}_onehot.pkl'.format(suffix), 'rb')) 210 | return X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes 211 | 212 | _, _, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load( 213 | open('preload_data{}.pkl'.format(suffix), 'rb')) 214 | X_train, vocab_train, X_test, vocab_test = pickle.load(open('text_tokenized{}.pkl'.format(suffix), 'rb')) 215 | X_train, X_test, vocab = oov2unk(X_train, X_test, vocab_train, vocab_test, vocab_size=30000, filter=True) 216 | word_index = {w: ct for ct, w in enumerate(vocab)} 217 | X_train = [[[word_index[word] for word in sent] for sent in doc] for doc in X_train] 218 | X_test = [[[word_index[word] for word in sent] for sent in doc] for doc in X_test] 219 | logger.info('saving preload_data{}_onehot.pkl'.format(suffix)) 220 | res = np.array(X_train), np.array(X_test), np.array(train_ids), np.array(test_ids), id2doc, wv, word_index, nodes 221 | pickle.dump(res, open('preload_data{}_onehot.pkl'.format(suffix), 'wb')) 222 | return res 223 | 224 | 225 | def load_data_nyt(embedding_path, suffix): 226 | if os.path.exists('preload_data{}.pkl'.format(suffix)): 227 | logger.warn('loading from preload_data{}.pkl'.format(suffix)) 228 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load( 229 | open('preload_data{}.pkl'.format(suffix), 'rb')) 230 | return X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes 231 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_nyt() 232 | if os.path.exists('text_tokenized{}.pkl'.format(suffix)): 233 | X_train, vocab_train, X_test, vocab_test = pickle.load(open('text_tokenized{}.pkl'.format(suffix), 'rb')) 234 | else: 235 | X_train, vocab_train = tokenize_text_normal(X_train) 236 | X_test, vocab_test = tokenize_text_normal(X_test) 237 | res = X_train, vocab_train, X_test, vocab_test 238 | pickle.dump(res, open('text_tokenized{}.pkl'.format(suffix), 'wb')) 239 | X_train, X_test, vocab = oov2unk(X_train, X_test, vocab_train, vocab_test, count_test=True) 240 | wv, word_index = prepare_word_filter(embedding_path, vocab, biovec=False) 241 | X_train = [[[word_index[word] for word in sent] for sent in doc] for doc in X_train] 242 | X_test = [[[word_index[word] for word in sent] for sent in doc] for doc in X_test] 243 | logger.info('saving preload_data{}.pkl'.format(suffix)) 244 | res = np.array(X_train), np.array(X_test), np.array(train_ids), np.array(test_ids), id2doc, wv, word_index, nodes 245 | pickle.dump(res, open('preload_data{}.pkl'.format(suffix), 'wb')) 246 | return res 247 | 248 | 249 | def load_data_yelp(embedding_path, suffix, root, min_reviews=1, max_reviews=10): 250 | """ 251 | suffix: used to distinguish different pkl files 252 | root: which subtree to use e.g., root (use all nodes as classes, 1004 - 1 in total), Hotels & Travel 253 | min_reviews: remove businesses that have < min_reviews 254 | max_reviews: use at most max_reviews for a business 255 | """ 256 | if os.path.exists('preload_data{}.pkl'.format(suffix)): 257 | logger.warn('loading from preload_data{}.pkl'.format(suffix)) 258 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load( 259 | open('preload_data{}.pkl'.format(suffix), 'rb')) 260 | return X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes 261 | logger.warn('reviews min_sup={}, max_sup={}'.format(min_reviews, max_reviews)) 262 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_yelp(root, min_reviews, max_reviews) 263 | # TODO need to make sure every time train_ids, test_ids are the same 264 | if os.path.exists('text_tokenized{}.pkl'.format(suffix)): 265 | logger.warn('loading from text_tokenized{}.pkl'.format(suffix)) 266 | X_train, vocab_train, X_test, vocab_test = pickle.load(open('text_tokenized{}.pkl'.format(suffix), 'rb')) 267 | else: 268 | X_train, vocab_train = tokenize_text_normal(X_train) 269 | X_test, vocab_test = tokenize_text_normal(X_test) 270 | res = X_train, vocab_train, X_test, vocab_test 271 | pickle.dump(res, open('text_tokenized{}.pkl'.format(suffix), 'wb')) 272 | X_train, X_test, vocab = oov2unk(X_train, X_test, vocab_train, vocab_test, vocab_size=30000) 273 | wv, word_index = prepare_word_filter(embedding_path, vocab, biovec=False) 274 | X_train = [[[word_index[word] for word in sent] for sent in doc] for doc in X_train] 275 | X_test = [[[word_index[word] for word in sent] for sent in doc] for doc in X_test] 276 | logger.info('saving preload_data{}.pkl'.format(suffix)) 277 | res = np.array(X_train), np.array(X_test), np.array(train_ids), np.array(test_ids), id2doc, wv, word_index, nodes 278 | pickle.dump(res, open('preload_data{}.pkl'.format(suffix), 'wb')) 279 | return res 280 | 281 | 282 | def filter_ancestors(id2doc, nodes): 283 | logger.info('keep only lowest label in a path...') 284 | id2doc_na = defaultdict(dict) 285 | labels_ct = [] 286 | lowest_labels_ct = [] 287 | for bid in id2doc: 288 | lowest_labels = [] 289 | cat_set = set(id2doc[bid]['categories']) 290 | for label in id2doc[bid]['categories']: 291 | if len(set(nodes[label]['children']) & cat_set) == 0: 292 | lowest_labels.append(label) 293 | labels_ct.append(len(cat_set)) 294 | lowest_labels_ct.append(len(lowest_labels)) 295 | id2doc_na[bid]['categories'] = lowest_labels 296 | logger.info('#labels') 297 | logger.info(pd.Series(labels_ct).describe(percentiles=[.25, .5, .75, .8, .85, .9, .95, .96, .98])) 298 | logger.info('#lowest labels') 299 | logger.info(pd.Series(lowest_labels_ct).describe(percentiles=[.25, .5, .75, .8, .85, .9, .95, .96, .98])) 300 | return id2doc_na 301 | 302 | 303 | def split_multi(X_train, train_ids, id2doc_train, id2doc): 304 | logger.info('split one sample with m labels to m samples...') 305 | X_train_new = [] 306 | train_ids_new = [] 307 | id2doc_train_new = defaultdict(dict) 308 | for X, did in zip(X_train, train_ids): 309 | ct = 0 310 | for label in id2doc_train[did]['categories']: 311 | newID = '{}-{}'.format(did, ct) 312 | id2doc[newID] = id2doc[did] 313 | X_train_new.append(X) 314 | train_ids_new.append(newID) 315 | id2doc_train_new[newID]['categories'] = [label] 316 | ct += 1 317 | return np.array(X_train_new), np.array(train_ids_new), id2doc_train_new, id2doc 318 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | import torch 7 | from sklearn.preprocessing import MultiLabelBinarizer 8 | from tensorboardX import SummaryWriter 9 | from torch.autograd import Variable 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from Linear_Model import Linear_Model 14 | from Logger_morning import myLogger 15 | from conf import conf, print_config 16 | from feature_dataset import featureDataset, my_collate 17 | from readData_fungo import read_fungo 18 | 19 | args = conf() 20 | if args.load_model is None or args.isTrain: 21 | comment = f'_{args.dataset}_{args.base_model}_{args.mode}_{args.remark}' 22 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 23 | log_dir = os.path.join('runs', current_time + comment) 24 | else: 25 | log_dir = ''.join(args.load_model[:args.load_model.rfind('/')]) 26 | print(f'reuse dir: {log_dir}') 27 | logger = myLogger(name='exp', log_path=log_dir + '.log') 28 | # incompatible with logger... 29 | writer = SummaryWriter(log_dir=log_dir) 30 | writer.add_text('Parameters', str(vars(args))) 31 | print_config(args, logger) 32 | logger.setLevel(args.log_level) 33 | 34 | from HMCN import HMCN 35 | from HAN import HAN 36 | from OHCNN_fast import OHCNN_fast 37 | from TextCNN import TextCNN 38 | from loadData import load_data_yelp, filter_ancestors, load_data_rcv1, split_multi, \ 39 | load_data_rcv1_onehot, load_data_nyt_onehot, load_data_nyt 40 | from model import Policy 41 | from tree import Tree 42 | from util import get_gpu_memory_map, save_checkpoint, check_doc_size, gen_minibatch_from_cache, gen_minibatch, \ 43 | save_minibatch, contains_nan 44 | 45 | 46 | def finish_episode(policy, update=True): 47 | policy_loss = [] 48 | all_cum_rewards = [] 49 | for i in range(args.n_rollouts): 50 | rewards = [] 51 | R = np.zeros(len(policy.rewards[i][0])) 52 | for r in policy.rewards[i][::-1]: 53 | R = r + args.gamma * R 54 | rewards.insert(0, R) 55 | all_cum_rewards.extend(rewards) 56 | rewards = torch.Tensor(rewards) # (length, batch_size) 57 | # logger.warning(f'original {rewards}') 58 | if args.baseline == 'avg': 59 | rewards -= policy.baseline_reward 60 | elif args.baseline == 'greedy': 61 | rewards_greedy = [] 62 | R = np.zeros(len(policy.rewards_greedy[0])) 63 | for r in policy.rewards_greedy[::-1]: 64 | R = r + args.gamma * R 65 | rewards_greedy.insert(0, R) 66 | rewards_greedy = torch.Tensor(rewards_greedy) 67 | rewards -= rewards_greedy 68 | # logger.warning(f'after baseline {rewards}') 69 | if args.avg_reward_mode == 'batch': 70 | rewards = Variable((rewards - rewards.mean()) / (rewards.std() + float(np.finfo(np.float32).eps))) 71 | elif args.avg_reward_mode == 'each': 72 | # mean/std is separate for each in the batch 73 | rewards = Variable((rewards - rewards.mean(dim=0)) / (rewards.std(dim=0) + float(np.finfo(np.float32).eps))) 74 | else: 75 | rewards = Variable(rewards) 76 | if args.gpu: 77 | rewards = rewards.cuda() 78 | for log_prob, reward in zip(policy.saved_log_probs[i], rewards): 79 | policy_loss.append(-log_prob * reward) 80 | # logger.warning(f'after mean_std {rewards}') 81 | if update: 82 | tree.n_update += 1 83 | try: 84 | policy_loss = torch.cat(policy_loss).mean() 85 | except Exception as e: 86 | logger.error(e) 87 | entropy = torch.cat(policy.entropy_l).mean() 88 | writer.add_scalar('data/policy_loss', policy_loss, tree.n_update) 89 | writer.add_scalar('data/entropy_loss', policy.beta * entropy.data[0], tree.n_update) 90 | policy_loss += policy.beta * entropy 91 | if args.sl_ratio > 0: 92 | policy_loss += args.sl_ratio * policy.sl_loss 93 | writer.add_scalar('data/sl_loss', args.sl_ratio * policy.sl_loss, tree.n_update) 94 | writer.add_scalar('data/total_loss', policy_loss.data[0], tree.n_update) 95 | optimizer.zero_grad() 96 | policy_loss.backward() 97 | if contains_nan(policy.class_embed.weight.grad): 98 | logger.error('nan in class_embed.weight.grad!') 99 | else: 100 | optimizer.step() 101 | policy.update_baseline(np.mean(np.concatenate(all_cum_rewards))) 102 | policy.finish_episode() 103 | 104 | 105 | def calc_sl_loss(probs, doc_ids, update=True, return_var=False): 106 | mlb = MultiLabelBinarizer(classes=tree.class_idx) 107 | y_l = [tree.id2doc_ancestors[docid]['class_idx'] for docid in doc_ids] 108 | y_true = mlb.fit_transform(y_l) 109 | if args.gpu: 110 | y_true = Variable(torch.from_numpy(y_true)).cuda().float() 111 | else: 112 | y_true = Variable(torch.from_numpy(y_true)).float() 113 | loss = criterion(probs, y_true) 114 | if update: 115 | tree.n_update += 1 116 | optimizer.zero_grad() 117 | loss.backward() 118 | optimizer.step() 119 | if return_var: 120 | return loss 121 | return loss.data[0] 122 | 123 | 124 | def get_cur_size(tokens): 125 | if args.base_model != 'han': 126 | return tokens.size()[0] 127 | else: 128 | return tokens.size()[1] 129 | 130 | 131 | def forward_step_sl(tokens, doc_ids, flat_probs_only=False): 132 | # TODO can reuse logits 133 | if args.global_ratio > 0: 134 | probs = policy.base_model(tokens, True) 135 | global_loss = calc_sl_loss(probs, doc_ids, update=False, return_var=True) 136 | else: 137 | probs = None 138 | global_loss = 0 139 | if flat_probs_only: 140 | policy.sl_loss = global_loss 141 | return global_loss, probs 142 | policy.doc_vec = None 143 | cur_batch_size = get_cur_size(tokens) 144 | cur_class_batch = np.zeros(cur_batch_size, dtype=int) 145 | for t in range(args.n_steps_sl): 146 | next_classes_batch = tree.p2c_batch(cur_class_batch) 147 | next_classes_batch_true, indices, next_class_batch_true, doc_ids = tree.get_next(cur_class_batch, 148 | next_classes_batch, 149 | doc_ids) 150 | policy.step_sl(tokens, cur_class_batch, next_classes_batch, next_classes_batch_true) 151 | cur_class_batch = next_class_batch_true 152 | policy.duplicate_doc_vec(indices) 153 | policy.sl_loss /= args.n_steps 154 | writer.add_scalar('data/step-sl_sl_loss', (1 - args.global_ratio) * policy.sl_loss, tree.n_update) 155 | policy.sl_loss = (1 - args.global_ratio) * policy.sl_loss + args.global_ratio * global_loss 156 | writer.add_scalar('data/flat_sl_loss', args.global_ratio * global_loss, tree.n_update) 157 | return global_loss, probs 158 | 159 | 160 | def train_step_sl(): 161 | policy.train() 162 | for i in range(1, args.num_epoch + 1): 163 | g = select_data('train' + args.ohcnn_data, shuffle=True) 164 | loss_total = 0 165 | tree.cur_epoch = i 166 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)): 167 | if 'FUN' in args.dataset or 'GO' in args.dataset: 168 | tokens = Variable(tokens).cuda() 169 | global_loss, flat_probs = forward_step_sl(tokens, doc_ids, flat_probs_only=(args.global_ratio == 1)) 170 | optimizer.zero_grad() 171 | policy.sl_loss.backward() 172 | optimizer.step() 173 | tree.n_update += 1 174 | loss_total += policy.sl_loss.data[0] 175 | if ct % args.output_every == 0 and ct != 0: 176 | if args.global_ratio > 0: 177 | logger.info( 178 | f'loss_cur:{policy.sl_loss.data[0]} global_loss:{args.global_ratio * global_loss.data[0]}') 179 | else: 180 | logger.info(f'loss_cur:{policy.sl_loss.data[0]} global_loss:off') 181 | logger.info(f'[{i}:{ct}] loss_avg:{loss_total / ct}') 182 | writer.add_scalar('data/sl_loss', global_loss, tree.n_update) 183 | writer.add_scalar('data/loss_avg', loss_total / ct, tree.n_update) 184 | policy.sl_loss = 0 185 | if i % args.save_every == 0: 186 | eval_save_model(i, datapath='train' + args.ohcnn_data, save=True, output=False) 187 | test_step_sl('test' + args.ohcnn_data, save_prob=False) 188 | 189 | 190 | def test_step_sl(data_path, save_prob=False, output=True): 191 | logger.info('test starts') 192 | policy.eval() 193 | g = select_data(data_path) 194 | pred_l = [] 195 | target_l = [] 196 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)): 197 | if 'FUN' in args.dataset or 'GO' in args.dataset: 198 | tokens = Variable(tokens).cuda() 199 | real_doc_ids = [i for i in doc_ids] 200 | policy.doc_vec = None 201 | cur_batch_size = get_cur_size(tokens) 202 | cur_class_batch = np.zeros(cur_batch_size, dtype=int) 203 | for _ in range(args.n_steps_sl): 204 | next_classes_batch = tree.p2c_batch(cur_class_batch) 205 | probs = policy.step_sl(tokens, cur_class_batch, next_classes_batch, None, sigmoid=True) 206 | indices, next_class_batch_pred, doc_ids = tree.get_next_by_probs(cur_class_batch, next_classes_batch, 207 | doc_ids, probs, save_prob) 208 | cur_class_batch = next_class_batch_pred 209 | policy.duplicate_doc_vec(indices) 210 | last_did = None 211 | for c, did in zip(cur_class_batch, doc_ids): 212 | if last_did != did: 213 | pred_l.append([]) 214 | if c != 0: 215 | pred_l[-1].append(c) 216 | last_did = did 217 | target_l.extend(real_doc_ids) 218 | if save_prob: 219 | logger.info(f'saving {writer.file_writer.get_logdir()}/{data_path}.tree.id2prob2.pkl') 220 | pickle.dump(tree.id2prob, open(f'{writer.file_writer.get_logdir()}/{data_path}.tree.id2prob2.pkl', 'wb')) 221 | tree.id2prob.clear() 222 | return 223 | return evaluate(pred_l, target_l, output=output) 224 | 225 | 226 | def output_log(cur_class_batch, doc_ids, acc, i, ct): 227 | if args.dataset in ['yelp', 'rcv1']: 228 | label_key = 'categories' 229 | try: 230 | logger.info('pred{}{} real{}{} pred_h{} real_h{}'.format(cur_class_batch[:3], 231 | [tree.id2name[tree.idx2id[cur]] for cur in 232 | cur_class_batch[:3]], 233 | [tree.id2doc_ancestors[docid]['class_idx'] for 234 | docid in doc_ids[:3]], 235 | [tree.id2doc_ancestors[docid][label_key] for docid 236 | in doc_ids[:3]], 237 | tree.h_batch(cur_class_batch[:3]), 238 | tree.h_doc_batch(doc_ids[:3]))) 239 | except Exception as e: 240 | logger.warning(e) 241 | writer.add_scalar('data/acc', np.mean(acc), tree.n_update) 242 | writer.add_scalar('data/beta', policy.beta, tree.n_update) 243 | logger.info('single-label acc for epoch {} batch {}: {}'.format(i, ct, np.mean(acc))) 244 | if (cur_class_batch == cur_class_batch[0]).all(): 245 | logger.error('predictions in a batch are all the same! [{}]'.format(cur_class_batch[0])) 246 | writer.add_text('error', 'predictions in a batch are all the same! [{}]'.format(cur_class_batch[0]), 247 | tree.n_update) 248 | if not args.debug: 249 | exit(1) 250 | 251 | 252 | def train_taxo(): 253 | policy.train() 254 | for i in range(1, args.num_epoch + 1): 255 | g = select_data('train' + args.ohcnn_data, shuffle=True) 256 | pred_l = [] 257 | target_l = [] 258 | tree.cur_epoch = i 259 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)): 260 | if 'FUN' in args.dataset or 'GO' in args.dataset: 261 | tokens = Variable(tokens).cuda() 262 | flat_probs = None 263 | if args.sl_ratio > 0: 264 | _, flat_probs = forward_step_sl(tokens, doc_ids, flat_probs_only=(args.global_ratio == 1)) 265 | if not args.mix_flat_probs: 266 | flat_probs = None 267 | policy.doc_vec = None 268 | cur_batch_size = get_cur_size(tokens) 269 | 270 | # greedy baseline 271 | if args.baseline == 'greedy': 272 | tree.taken_actions = [set() for _ in range(cur_batch_size)] 273 | cur_class_batch = np.zeros(cur_batch_size, dtype=int) 274 | next_classes_batch = [set() for _ in range(cur_batch_size)] 275 | for t in range(args.n_steps): 276 | next_classes_batch, next_classes_batch_np, _ = tree.get_next_candidates(cur_class_batch, 277 | next_classes_batch) 278 | choices, m = policy.step(tokens, cur_class_batch, next_classes_batch_np, test=True, 279 | flat_probs=flat_probs) 280 | cur_class_batch = next_classes_batch_np[ 281 | np.arange(len(next_classes_batch_np)), choices.data.cpu().numpy()] 282 | tree.update_actions(cur_class_batch) 283 | policy.rewards_greedy.append(tree.calc_reward(t < args.n_steps - 1, cur_class_batch, doc_ids)) 284 | tree.last_R = None 285 | for _ in range(args.n_rollouts): 286 | policy.saved_log_probs.append([]) 287 | policy.rewards.append([]) 288 | tree.taken_actions = [set() for _ in range(cur_batch_size)] 289 | cur_class_batch = np.zeros(cur_batch_size, dtype=int) 290 | next_classes_batch = [set() for _ in range(cur_batch_size)] 291 | for t in range(args.n_steps): 292 | next_classes_batch, next_classes_batch_np, all_stop = tree.get_next_candidates(cur_class_batch, 293 | next_classes_batch) 294 | if args.early_stop and all_stop: 295 | break 296 | choices, m = policy.step(tokens, cur_class_batch, next_classes_batch_np, flat_probs=flat_probs) 297 | cur_class_batch = next_classes_batch_np[ 298 | np.arange(len(next_classes_batch_np)), choices.data.cpu().numpy()] 299 | tree.update_actions(cur_class_batch) 300 | policy.saved_log_probs[-1].append(m.log_prob(choices)) 301 | policy.rewards[-1].append(tree.calc_reward(t < args.n_steps - 1, cur_class_batch, doc_ids)) 302 | tree.last_R = None 303 | tree.remove_stop() 304 | pred_l.extend(tree.taken_actions) 305 | target_l.extend(doc_ids) 306 | finish_episode(policy, update=True) 307 | if ct % args.output_every == 0 and ct != 0: 308 | logger.info(f'epoch {i} batch {ct}') 309 | eval_save_model(i, pred_l, target_l, output=False) 310 | if i % args.save_every == 0: 311 | eval_save_model(i, pred_l, target_l, save=True, output=False) 312 | test_taxo('test' + args.ohcnn_data, save_prob=False) 313 | 314 | 315 | def select_data(data_path, shuffle=False): 316 | if 'FUN' in args.dataset or 'GO' in args.dataset: 317 | isTrain = False 318 | if 'train' in data_path: 319 | isTrain = True 320 | test_dataset = featureDataset(args.dataset, isTrain) 321 | g = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=my_collate, shuffle=shuffle) 322 | elif args.base_model == 'ohcnn-bow-fast': 323 | g = gen_minibatch_from_cache(logger, args, tree, args.batch_size, name=data_path, shuffle=shuffle) 324 | else: 325 | if 'test' in data_path: 326 | g = gen_minibatch(logger, args, word_index, X_test, test_ids, args.batch_size, shuffle=shuffle) 327 | else: 328 | g = gen_minibatch(logger, args, word_index, X_train, train_ids, args.batch_size, shuffle=shuffle) 329 | return g 330 | 331 | 332 | def test_taxo(data_path, save_prob=False): 333 | logger.info('test starts') 334 | policy.eval() 335 | g = select_data(data_path) 336 | pred_l = [] 337 | target_l = [] 338 | if save_prob: 339 | args.n_steps = tree.n_class - 1 340 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)): 341 | if 'FUN' in args.dataset or 'GO' in args.dataset: 342 | tokens = Variable(tokens).cuda() 343 | flat_probs = None 344 | if args.sl_ratio > 0 and args.mix_flat_probs: 345 | _, flat_probs = forward_step_sl(tokens, doc_ids, flat_probs_only=True) 346 | policy.doc_vec = None 347 | cur_batch_size = get_cur_size(tokens) 348 | tree.taken_actions = [set() for _ in range(cur_batch_size)] 349 | cur_class_batch = np.zeros(cur_batch_size, dtype=int) 350 | next_classes_batch = [set() for _ in range(cur_batch_size)] 351 | for _ in range(args.n_steps): 352 | next_classes_batch, next_classes_batch_np, all_stop = tree.get_next_candidates(cur_class_batch, 353 | next_classes_batch, 354 | save_prob) 355 | if all_stop: 356 | if save_prob: 357 | logger.error('should not enter') 358 | break 359 | choices, m = policy.step(tokens, cur_class_batch, next_classes_batch_np, test=True, flat_probs=flat_probs) 360 | cur_class_batch = next_classes_batch_np[ 361 | np.arange(len(next_classes_batch_np)), choices.data.cpu().numpy()] 362 | if save_prob: 363 | for did, idx, p_ in zip(doc_ids, cur_class_batch, 364 | m.probs.gather(-1, choices.unsqueeze(-1)).squeeze(-1).data.cpu().numpy()): 365 | assert 0 < idx < 104, idx 366 | if idx in tree.id2prob[did]: 367 | logger.warning(f'[{did}][{idx}] already existed!') 368 | tree.id2prob[did][idx] = p_ 369 | tree.update_actions(cur_class_batch) 370 | tree.remove_stop() 371 | pred_l.extend(tree.taken_actions) 372 | target_l.extend(doc_ids) 373 | if save_prob: 374 | logger.info(f'saving {writer.file_writer.get_logdir()}/{data_path}.tree.id2prob-rl.pkl') 375 | pickle.dump(tree.id2prob, open(f'{writer.file_writer.get_logdir()}/{data_path}.tree.id2prob-rl.pkl', 'wb')) 376 | tree.id2prob.clear() 377 | return evaluate(pred_l, target_l) 378 | 379 | 380 | def eval_save_model(i, pred_l=None, target_l=None, datapath=None, save=False, output=True): 381 | if args.mode == 'hilap-sl': 382 | test_f = test_step_sl 383 | elif args.mode == 'hilap': 384 | test_f = test_taxo 385 | else: 386 | test_f = test_sl 387 | if pred_l: 388 | f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s = evaluate(pred_l, target_l, output=output) 389 | elif datapath: 390 | f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s = test_f(datapath, output=output) 391 | else: 392 | f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s = test_f(X_train, train_ids) 393 | writer.add_scalar('data/micro_train', f1_aa, tree.n_update) 394 | writer.add_scalar('data/macro_train', f1_aa_macro, tree.n_update) 395 | writer.add_scalar('data/samples_train', f1_aa_s, tree.n_update) 396 | if not save: 397 | return 398 | if args.mode in ['hilap', 'hilap-sl']: 399 | save_checkpoint({ 400 | 'state_dict': policy.state_dict(), 401 | 'optimizer': optimizer.state_dict(), 402 | }, writer.file_writer.get_logdir(), f'epoch{i}_{f1_aa}_{f1_aa_macro}_{f1_aa_s}.pth.tar', logger, True) 403 | else: 404 | save_checkpoint({ 405 | 'state_dict': model.state_dict(), 406 | 'optimizer': optimizer.state_dict(), 407 | }, writer.file_writer.get_logdir(), f'epoch{i}_{f1_aa}_{f1_aa_macro}_{f1_aa_s}.pth.tar', logger, True) 408 | 409 | 410 | def decode(pred_l, target_l, doc_ids, probs): 411 | cur_class_batch = None 412 | target_l.extend(doc_ids) 413 | if args.multi_label: 414 | if args.mode != 'hmcn': 415 | probs = torch.sigmoid(probs) 416 | preds = (probs >= .5).int().data.cpu().numpy() 417 | for pred in preds: 418 | idx = np.nonzero(pred)[0] 419 | if len(idx) == 0: 420 | pred_l.append([]) 421 | else: 422 | pred_l.append(idx + 1) 423 | else: 424 | cur_class_batch = torch.max(probs, 1)[1].data.cpu().numpy() + 1 425 | pred_l.extend(cur_class_batch) 426 | return pred_l, target_l, cur_class_batch 427 | 428 | 429 | def train_hmcn(): 430 | model.train() 431 | for i in range(1, args.num_epoch + 1): 432 | g = select_data('train' + args.ohcnn_data, shuffle=True) 433 | loss_total = 0 434 | pred_l = [] 435 | target_l = [] 436 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)): 437 | if 'FUN' in args.dataset or 'GO' in args.dataset: 438 | tokens = Variable(tokens).cuda() 439 | probs = model(tokens) 440 | pred_l, target_l, cur_class_batch = decode(pred_l, target_l, doc_ids, probs) 441 | loss = calc_sl_loss(probs, doc_ids, update=True) 442 | if ct % 50 == 0 and ct != 0: 443 | logger.info('loss: {}'.format(loss)) 444 | if args.multi_label: 445 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l)) 446 | else: 447 | acc = tree.acc(np.array(pred_l), np.array(target_l)) 448 | logger.info('acc for epoch {} batch {}: {}'.format(i, ct, acc)) 449 | writer.add_scalar('data/loss', loss, tree.n_update) 450 | writer.add_scalar('data/acc', acc, tree.n_update) 451 | if not args.multi_label and (cur_class_batch == cur_class_batch[0]).all(): 452 | logger.error('predictions in a batch are all the same! [{}]'.format(cur_class_batch[0])) 453 | writer.add_text('error', 'predictions in a batch are all the same! [{}]'.format(cur_class_batch[0]), 454 | tree.n_update) 455 | # exit(1) 456 | loss_total += loss 457 | loss_avg = loss_total / (ct + 1) 458 | if args.multi_label: 459 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l)) 460 | else: 461 | acc = tree.acc(np.array(pred_l), np.array(target_l)) 462 | logger.info('loss_avg:{} acc:{}'.format(loss_avg, acc)) 463 | if not args.multi_label: 464 | pred_l = [[label] for label in pred_l] 465 | if i % args.save_every == 0: 466 | eval_save_model(i, pred_l, target_l, save=True) 467 | 468 | 469 | def test_hmcn(): 470 | logger.info('testing starts') 471 | model.eval() 472 | g = select_data('test' + args.ohcnn_data) 473 | loss_total = 0 474 | pred_l = [] 475 | target_l = [] 476 | probs_l = [] 477 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)): 478 | if 'FUN' in args.dataset or 'GO' in args.dataset: 479 | tokens = Variable(tokens).cuda() 480 | probs = model(tokens) 481 | probs_l.append(probs.data.cpu().numpy()) 482 | pred_l, target_l, cur_class_batch = decode(pred_l, target_l, doc_ids, probs) 483 | loss = calc_sl_loss(probs, doc_ids, update=False) 484 | loss_total += loss 485 | probs_np = np.concatenate(probs_l, axis=0) 486 | logger.info('saving probs to {}/probs_{}.pkl'.format(writer.file_writer.get_logdir(), tree.n_update)) 487 | pickle.dump((probs_np, target_l), 488 | open('{}/probs_{}.pkl'.format(writer.file_writer.get_logdir(), tree.n_update), 'wb')) 489 | loss_avg = loss_total / (ct + 1) 490 | if args.multi_label: 491 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l)) 492 | else: 493 | acc = tree.acc(np.array(pred_l), np.array(target_l)) 494 | logger.info('loss_avg:{} acc:{}'.format(loss_avg, acc)) 495 | if not args.multi_label: 496 | pred_l = [[label] for label in pred_l] 497 | inc = 0 498 | for p in pred_l: 499 | p = set(p) 500 | exist = False 501 | for l in p: 502 | cur = tree.c2p_idx[l][0] 503 | while cur != 0: 504 | if cur not in p: 505 | inc += 1 506 | exist = True 507 | break 508 | cur = tree.c2p_idx[cur][0] 509 | if exist: 510 | break 511 | print(inc) 512 | # exit() 513 | return evaluate(pred_l, target_l) 514 | 515 | 516 | def train_sl(): 517 | model.train() 518 | for i in range(1, args.num_epoch + 1): 519 | g = select_data('train' + args.ohcnn_data, shuffle=True) 520 | loss_total = 0 521 | pred_l = [] 522 | target_l = [] 523 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)): 524 | probs = model(tokens, True) 525 | pred_l, target_l, cur_class_batch = decode(pred_l, target_l, doc_ids, probs) 526 | loss = calc_sl_loss(probs, doc_ids, update=True) 527 | if ct % args.output_every == 0 and ct != 0: 528 | logger.info('sl_loss: {}'.format(loss)) 529 | if args.multi_label: 530 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l)) 531 | else: 532 | acc = tree.acc(np.array(pred_l), np.array(target_l)) 533 | logger.info('acc for epoch {} batch {}: {}'.format(i, ct, acc)) 534 | writer.add_scalar('data/sl_loss', loss, tree.n_update) 535 | writer.add_scalar('data/acc', acc, tree.n_update) 536 | if not args.multi_label and (cur_class_batch == cur_class_batch[0]).all(): 537 | logger.error('predictions in a batch are all the same! [{}]'.format(cur_class_batch[0])) 538 | writer.add_text('error', 'predictions in a batch are all the same! [{}]'.format(cur_class_batch[0]), 539 | tree.n_update) 540 | # exit(1) 541 | loss_total += loss 542 | loss_avg = loss_total / (ct + 1) 543 | if args.multi_label: 544 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l)) 545 | else: 546 | acc = tree.acc(np.array(pred_l), np.array(target_l)) 547 | logger.info('loss_avg:{} acc:{}'.format(loss_avg, acc)) 548 | if not args.multi_label: 549 | pred_l = [[label] for label in pred_l] 550 | if i % args.save_every == 0: 551 | # eval_save_model(i, pred_l, target_l, save=True) 552 | test_sl() 553 | 554 | 555 | def test_sl(): 556 | logger.info('testing starts') 557 | model.eval() 558 | g = select_data('test' + args.ohcnn_data) 559 | loss_total = 0 560 | pred_l = [] 561 | target_l = [] 562 | probs_l = [] 563 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)): 564 | probs = model(tokens, True) 565 | probs_l.append(probs.data.cpu().numpy()) 566 | pred_l, target_l, cur_class_batch = decode(pred_l, target_l, doc_ids, probs) 567 | loss = calc_sl_loss(probs, doc_ids, update=False) 568 | loss_total += loss 569 | # probs_np = np.concatenate(probs_l, axis=0) 570 | # logger.info('saving probs to {}/probs_{}.pkl'.format(writer.file_writer.get_logdir(), tree.n_update)) 571 | # pickle.dump((probs_np, target_l), 572 | # open('{}/probs_{}.pkl'.format(writer.file_writer.get_logdir(), tree.n_update), 'wb')) 573 | loss_avg = loss_total / (ct + 1) 574 | if args.multi_label: 575 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l)) 576 | else: 577 | acc = tree.acc(np.array(pred_l), np.array(target_l)) 578 | logger.info('loss_avg:{} acc:{}'.format(loss_avg, acc)) 579 | if not args.multi_label: 580 | pred_l = [[label] for label in pred_l] 581 | return evaluate(pred_l, target_l) 582 | 583 | 584 | def evaluate(pred_l, test_ids, save_path=None, output=True): 585 | acc = round(tree.acc_multi(pred_l, test_ids), 4) 586 | res = tree.calc_f1(pred_l, test_ids, save_path, output) 587 | if output: 588 | f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s = [round(i, 4) for i in res] 589 | logger.info( 590 | f'acc:{acc} f1_s:{f1_aa_s} micro-f1:{f1} {f1_a} {f1_aa} macro-f1:{f1_macro} {f1_a_macro} {f1_aa_macro}') 591 | if f1_aa > tree.miF[0]: 592 | tree.miF = (f1_aa, f1_aa_macro, tree.cur_epoch) 593 | if f1_aa_macro > tree.maF[1]: 594 | tree.maF = (f1_aa, f1_aa_macro, tree.cur_epoch) 595 | logger.warning(f'best: {tree.miF}, {tree.maF}') 596 | return f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s 597 | else: 598 | f1_a, f1_a_macro, f1_a_s = [round(i, 4) for i in res] 599 | return 0, 0, f1_a, 0, 0, f1_a_macro, f1_a_s 600 | 601 | 602 | if args.dataset == 'rcv1': 603 | if 'oh' in args.base_model: 604 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = load_data_rcv1_onehot('_rcv1_ptAll') 605 | else: 606 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = load_data_rcv1( 607 | '../datasets/glove.6B.50d.txt', '_rcv1_ptAll') 608 | if args.filter_ancestors: 609 | id2doc_train = filter_ancestors(id2doc, nodes) 610 | if args.split_multi: 611 | X_train, train_ids, id2doc_train, id2doc = split_multi(X_train, train_ids, id2doc_train, id2doc) 612 | else: 613 | id2doc_train = id2doc 614 | tree = Tree(args, train_ids, test_ids, id2doc=id2doc_train, id2doc_a=id2doc, nodes=nodes, rootname='Root') 615 | 616 | elif args.dataset == 'yelp': 617 | subtree_name = 'root' 618 | min_reviews = 5 619 | max_reviews = 10 620 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = load_data_yelp('../datasets/glove.6B.50d.txt', 621 | '_yelp_root_100_{}_{}'.format( 622 | min_reviews, max_reviews), 623 | root=subtree_name, 624 | min_reviews=min_reviews, 625 | max_reviews=max_reviews) 626 | logger.warning(f'{len(X_train)} {len(train_ids)} {len(X_test)} {len(test_ids)}') 627 | # save_minibatch(logger, args, word_index, X_train, train_ids, 32, name='train_yelp_root_100_5_10_len256_padded') 628 | # save_minibatch(logger, args, word_index, X_test, test_ids, 32, name='test_yelp_root_100_5_10_len256_padded') 629 | if args.filter_ancestors: 630 | id2doc_train = filter_ancestors(id2doc, nodes) 631 | else: 632 | id2doc_train = id2doc 633 | tree = Tree(args, train_ids, test_ids, id2doc=id2doc_train, id2doc_a=id2doc, nodes=nodes, rootname=subtree_name) 634 | elif args.dataset == 'nyt': 635 | if 'oh' in args.base_model: 636 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = load_data_nyt_onehot('_nyt_ptAll') 637 | else: 638 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = load_data_nyt( 639 | '../datasets/glove.6B.50d.txt', '_nyt_ptAll') 640 | if args.filter_ancestors: 641 | id2doc_train = filter_ancestors(id2doc, nodes) 642 | if args.split_multi: 643 | X_train, train_ids, id2doc_train, id2doc = split_multi(X_train, train_ids, id2doc_train, id2doc) 644 | else: 645 | id2doc_train = id2doc 646 | save_minibatch(logger, args, word_index, X_train, train_ids, 32, name='train_nyt') 647 | save_minibatch(logger, args, word_index, X_test, test_ids, 32, name='test_nyt') 648 | tree = Tree(args, train_ids, test_ids, id2doc=id2doc_train, id2doc_a=id2doc, nodes=nodes, rootname='Top') 649 | elif 'FUN' in args.dataset or 'GO' in args.dataset: 650 | X_train, _, train_ids, test_ids, id2doc, nodes = read_fungo(args.dataset) 651 | if args.filter_ancestors: 652 | id2doc_train = filter_ancestors(id2doc, nodes) 653 | else: 654 | id2doc_train = id2doc 655 | tree = Tree(args, train_ids, test_ids, id2doc=id2doc_train, id2doc_a=id2doc, nodes=nodes, rootname='Top') 656 | else: 657 | logger.error('No such dataset: {}'.format(args.dataset)) 658 | exit(1) 659 | if args.stat_check: 660 | check_doc_size(X_train, logger) 661 | check_doc_size(X_test, logger) 662 | if args.gpu: 663 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 664 | sorted_gpu_info = get_gpu_memory_map() 665 | for gpu_id, (mem_left, util) in sorted_gpu_info: 666 | if mem_left >= args.min_mem: 667 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 668 | logger.info('use gpu:{} with {} MB left, util {}%'.format(gpu_id, mem_left, util)) 669 | break 670 | else: 671 | logger.warn(f'no gpu has memory left >= {args.min_mem} MB, exiting...') 672 | exit() 673 | else: 674 | torch.set_num_threads(10) 675 | if 'cnn' in args.base_model: 676 | if args.base_model == 'textcnn': 677 | model = TextCNN(args, word_vec=wv, n_classes=tree.n_class - 1) 678 | in_dim = 3000 679 | elif args.base_model == 'ohcnn-bow-fast': 680 | model = OHCNN_fast(word_index['UNK'], n_classes=tree.n_class - 1, vocab_size=len(word_index)) 681 | in_dim = 10000 682 | if args.mode == 'hmcn': 683 | local_output_size = tree.get_layer_node_number() 684 | model = HMCN(model, args, local_output_size, tree.n_class - 1, in_dim) 685 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_weight) 686 | if args.gpu: 687 | logger.info(model.cuda()) 688 | elif args.base_model == 'han': 689 | model = HAN(args, word_vec=wv, n_classes=tree.n_class - 1) 690 | in_dim = args.sent_gru_hidden_size * 2 691 | if args.mode == 'sl': 692 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_weight) 693 | if args.gpu: 694 | logger.info(model.cuda()) 695 | elif args.base_model == 'raw': 696 | local_output_size = tree.get_layer_node_number() 697 | if args.dataset == 'rcv1': 698 | model = HMCN(None, args, local_output_size, tree.n_class - 1, 47236) 699 | elif args.dataset == 'yelp': 700 | model = HMCN(None, args, local_output_size, tree.n_class - 1, 146587) 701 | elif args.dataset == 'nyt': 702 | model = HMCN(None, args, local_output_size, tree.n_class - 1, 102755) 703 | elif 'FUN' in args.dataset or 'GO' in args.dataset: 704 | n_features = len(X_train[0]) 705 | logger.info(f'n_features={n_features}') 706 | if args.mode == 'hmcn': 707 | model = HMCN(None, args, local_output_size, tree.n_class - 1, n_features) 708 | else: 709 | model = Linear_Model(args, n_features, tree.n_class - 1) 710 | X_train, X_test, train_ids, test_ids = None, None, None, None 711 | in_dim = args.n_hidden 712 | if args.gpu: 713 | model.cuda() 714 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_weight) 715 | 716 | base_model = model 717 | if args.mode in ['hilap', 'hilap-sl']: 718 | if args.mode == 'hilap': 719 | policy = Policy(args, tree.n_class + 1, base_model, in_dim) 720 | else: 721 | policy = Policy(args, tree.n_class, base_model, in_dim) 722 | if args.gpu: 723 | logger.info(policy.cuda()) 724 | optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr, weight_decay=args.l2_weight) 725 | for name, param in policy.named_parameters(): 726 | logger.info('{} {} {}'.format(name, type(param.data), param.size())) 727 | 728 | if args.mode == 'hmcn': 729 | criterion = torch.nn.BCELoss() 730 | else: 731 | criterion = torch.nn.BCEWithLogitsLoss() 732 | if args.load_model: 733 | if os.path.isfile(args.load_model): 734 | checkpoint = torch.load(args.load_model) 735 | load_optimizer = True 736 | if args.mode in ['sl', 'hmcn']: 737 | model.load_state_dict(checkpoint['state_dict']) 738 | else: 739 | policy_dict = policy.state_dict() 740 | load_from_sl = False 741 | if load_from_sl: 742 | for i in list(checkpoint['state_dict'].keys()): 743 | checkpoint['state_dict']['base_model.' + i] = checkpoint['state_dict'].pop(i) 744 | checkpoint['state_dict']['class_embed.weight'] = torch.cat( 745 | [policy_dict['class_embed.weight'][-1:], checkpoint['state_dict']['base_model.fc2.weight'], 746 | policy_dict['class_embed.weight'][-1:]]) 747 | checkpoint['state_dict']['class_embed_bias.weight'] = torch.cat( 748 | [policy_dict['class_embed_bias.weight'][-1:], 749 | checkpoint['state_dict']['base_model.fc2.bias'].view(-1, 1), 750 | policy_dict['class_embed_bias.weight'][-1:]]) 751 | load_optimizer = False 752 | elif checkpoint['state_dict']['class_embed.weight'].size()[0] == \ 753 | policy_dict['class_embed.weight'].size()[0] - 1: 754 | logger.warning('try loading pretrained x for rl-taxo. class_embed also loaded.') 755 | load_optimizer = False 756 | # checkpoint['state_dict']['class_embed.weight'] = policy_dict['class_embed.weight'] 757 | checkpoint['state_dict']['class_embed.weight'] = torch.cat( 758 | [checkpoint['state_dict']['class_embed.weight'], policy_dict['class_embed.weight'][-1:]]) 759 | checkpoint['state_dict']['class_embed_bias.weight'] = torch.cat( 760 | [checkpoint['state_dict']['class_embed_bias.weight'], policy_dict['class_embed_bias.weight'][-1:]]) 761 | logger.warning(checkpoint['state_dict']['class_embed.weight'].size()) 762 | policy.load_state_dict(checkpoint['state_dict']) 763 | if load_optimizer: 764 | optimizer.load_state_dict(checkpoint['optimizer']) 765 | else: 766 | logger.warning('optimizer not loaded') 767 | logger.info("loaded checkpoint '{}' ".format(args.load_model)) 768 | else: 769 | logger.error("no checkpoint found at '{}'".format(args.load_model)) 770 | exit(1) 771 | if args.stat_check: 772 | evaluate([[tree.id2idx[id2doc_train[did]['categories'][0]]] for did in train_ids], train_ids) 773 | evaluate([[tree.id2idx[id2doc_train[did]['categories'][0]]] for did in test_ids], test_ids) 774 | if not args.load_model or args.isTrain: 775 | if args.mode == 'hilap-sl': 776 | train_step_sl() 777 | # test_tmp('test' + args.ohcnn_data) 778 | test_step_sl('test' + args.ohcnn_data, save_prob=False) 779 | elif args.mode == 'hilap': 780 | train_taxo() 781 | # test_taxo('train' + args.ohcnn_data, save_prob=False) 782 | # test_taxo('test' + args.ohcnn_data, save_prob=False) 783 | elif args.mode == 'hmcn': 784 | train_hmcn() 785 | else: 786 | train_sl() 787 | test_sl() 788 | else: 789 | if args.mode == 'hilap': 790 | # test_step_sl('train' + args.ohcnn_data, save_prob=False) 791 | # test_step_sl('test' + args.ohcnn_data, save_prob=True) 792 | test_taxo('train' + args.ohcnn_data, save_prob=False) 793 | test_taxo('test' + args.ohcnn_data, save_prob=False) 794 | elif args.mode == 'hilap-sl': 795 | # test_sl(X_test, test_ids) # for testing global_flat with additional local_loss 796 | # test_step_sl('train' + args.ohcnn_data, save_prob=True) 797 | # test_tmp('test' + args.ohcnn_data) 798 | test_step_sl('train' + args.ohcnn_data, save_prob=False) 799 | test_step_sl('test' + args.ohcnn_data, save_prob=False) 800 | elif args.mode == 'hmcn': 801 | test_hmcn() 802 | else: 803 | test_sl() 804 | writer.close() 805 | logger.info(f'log_dir: {log_dir}') 806 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from torch.distributions import Categorical 9 | 10 | 11 | class Policy(nn.Module): 12 | def __init__(self, args, n_class, base_model, in_dim): 13 | super(Policy, self).__init__() 14 | self.args = args 15 | self.L = 0.02 16 | self.baseline_reward = 0 17 | self.entropy_l = [] 18 | self.beta = args.beta 19 | self.beta_decay_rate = .9 20 | self.n_update = 0 21 | self.class_embed = nn.Embedding(n_class, args.class_embed_size) 22 | self.class_embed_bias = nn.Embedding(n_class, 1) 23 | 24 | stdv = 1. / np.sqrt(self.class_embed.weight.size(1)) 25 | self.class_embed.weight.data.uniform_(-stdv, stdv) 26 | self.class_embed_bias.weight.data.uniform_(-stdv, stdv) 27 | 28 | self.saved_log_probs = [] 29 | self.rewards = [] 30 | self.rewards_greedy = [] 31 | self.doc_vec = None 32 | self.base_model = base_model 33 | 34 | self.criterion = torch.nn.BCEWithLogitsLoss() 35 | self.sl_loss = 0 36 | 37 | if self.args.use_history: 38 | self.state_hist = None 39 | self.output_hist = None 40 | self.hist_gru = nn.GRU(args.class_embed_size, args.class_embed_size, bidirectional=True) 41 | if self.args.use_cur_class_embed: 42 | in_dim += self.args.class_embed_size 43 | if self.args.use_history: 44 | in_dim += args.hist_embed_size * 2 45 | if self.args.use_l2: 46 | self.l1 = nn.Linear(in_dim, args.l1_size) 47 | self.l2 = nn.Linear(args.l1_size, args.class_embed_size) 48 | elif self.args.use_l1: 49 | self.l1 = nn.Linear(in_dim, args.class_embed_size) 50 | 51 | def update_baseline(self, target): 52 | # a moving average baseline, not used anymore 53 | self.baseline_reward = self.L * target + (1 - self.L) * self.baseline_reward 54 | 55 | def finish_episode(self): 56 | self.sl_loss = 0 57 | self.n_update += 1 58 | if self.n_update % self.args.update_beta_every == 0: 59 | self.beta *= self.beta_decay_rate 60 | self.entropy_l = [] 61 | del self.rewards[:] 62 | del self.saved_log_probs[:] 63 | del self.rewards_greedy[:] 64 | 65 | def forward(self, cur_class_batch, next_classes_batch): 66 | cur_class_embed = self.class_embed(cur_class_batch) # (batch, 50) 67 | next_classes_embed = self.class_embed(next_classes_batch) # (batch, max_choices, 50) 68 | nb = self.class_embed_bias(next_classes_batch).squeeze(-1) 69 | states_embed = self.doc_vec 70 | if self.args.use_cur_class_embed: 71 | states_embed = torch.cat((states_embed, cur_class_embed), 1) 72 | if self.args.use_history: 73 | states_embed = torch.cat((states_embed, self.output_hist.squeeze()), 1) 74 | if not self.args.use_l1: 75 | return torch.bmm(next_classes_embed, states_embed.unsqueeze(-1)).squeeze(-1) + nb 76 | if self.args.use_l2: 77 | h1 = F.relu(self.l1(states_embed)) 78 | h2 = F.relu(self.l2(h1)) 79 | else: 80 | h2 = F.relu(self.l1(states_embed)) 81 | h2 = h2.unsqueeze(-1) # (batch, 50, 1) 82 | probs = torch.bmm(next_classes_embed, h2).squeeze(-1) + nb 83 | if self.args.use_history: 84 | self.output_hist, self.state_hist = self.hist_gru(cur_class_embed.unsqueeze(0), self.state_hist) 85 | return probs 86 | 87 | def duplicate_doc_vec(self, indices): 88 | assert self.doc_vec is not None 89 | assert len(indices) > 0 90 | self.doc_vec = self.doc_vec[indices] 91 | 92 | def duplicate_reward(self, indices): 93 | assert len(indices) > 0 94 | self.saved_log_probs[-1] = [[probs[i] for i in indices] for probs in self.saved_log_probs[-1]] 95 | self.rewards[-1] = [[R[i] for i in indices] for R in self.rewards[-1]] 96 | 97 | def generate_doc_vec(self, mini_batch): 98 | self.doc_vec = self.base_model(mini_batch) 99 | 100 | def generate_logits(self, mini_batch, cur_class_batch, next_classes_batch): 101 | if self.doc_vec is None: 102 | self.generate_doc_vec(mini_batch) 103 | if self.args.gpu: 104 | cur_class_batch = Variable(torch.from_numpy(cur_class_batch)).cuda() 105 | next_classes_batch = Variable(torch.from_numpy(next_classes_batch)).cuda() 106 | else: 107 | cur_class_batch = Variable(torch.from_numpy(cur_class_batch)) 108 | next_classes_batch = Variable(torch.from_numpy(next_classes_batch)) 109 | logits = self(cur_class_batch, next_classes_batch) 110 | # mask padding relations 111 | logits = (next_classes_batch == 0).float() * -99999 + (next_classes_batch != 0).float() * logits 112 | return logits 113 | 114 | def step_sl(self, mini_batch, cur_class_batch, next_classes_batch, next_classes_batch_true, sigmoid=True): 115 | logits = self.generate_logits(mini_batch, cur_class_batch, next_classes_batch) 116 | if not sigmoid: 117 | return logits 118 | if next_classes_batch_true is not None: 119 | if self.args.gpu: 120 | y_true = Variable(torch.from_numpy(next_classes_batch_true)).cuda().float() 121 | else: 122 | y_true = Variable(torch.from_numpy(next_classes_batch_true)).float() 123 | self.sl_loss += self.criterion(logits, y_true) 124 | return F.sigmoid(logits) 125 | 126 | def step(self, mini_batch, cur_class_batch, next_classes_batch, test=False, flat_probs=None): 127 | logits = self.generate_logits(mini_batch, cur_class_batch, next_classes_batch) 128 | if self.args.softmax: 129 | probs = F.softmax(logits, dim=-1) 130 | else: 131 | probs = F.sigmoid(logits) 132 | if not test: 133 | # + epsilon to avoid log(0) 134 | self.entropy_l.append(torch.mean(torch.log(probs + 1e-32) * probs)) 135 | next_classes_batch = Variable(torch.from_numpy(next_classes_batch)).cuda() 136 | probs = probs + (next_classes_batch != 0).float() * 1e-16 137 | m = Categorical(probs) 138 | if test or self.args.sample_mode == 'choose_max': 139 | action = torch.max(probs, 1)[1] 140 | elif self.args.sample_mode == 'random': 141 | if random.random() < 1.2: 142 | if self.args.gpu: 143 | action = Variable(torch.zeros(probs.size()[0]).long().random_(0, probs.size()[1])).cuda() 144 | else: 145 | action = Variable(torch.zeros(probs.size()[0]).long().random_(0, probs.size()[1])) 146 | else: 147 | action = m.sample() 148 | else: 149 | action = m.sample() 150 | return action, m 151 | 152 | # not used anymore 153 | def init_hist(self, tokens_size): 154 | if self.args.gpu: 155 | self.state_hist = self.init_hidden(tokens_size, self.args.hist_embed_size).cuda() 156 | first_input = Variable( 157 | torch.from_numpy(np.zeros((1, tokens_size, self.args.class_embed_size)))).cuda().float() 158 | else: 159 | self.state_hist = self.init_hidden(tokens_size, self.args.hist_embed_size) 160 | first_input = Variable( 161 | torch.from_numpy((np.zeros(1, tokens_size, self.args.class_embed_size)))).float() 162 | self.output_hist, self.state_hist = self.hist_gru(first_input, self.state_hist) 163 | -------------------------------------------------------------------------------- /readData_fungo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | from IPython import embed 8 | 9 | 10 | def replace_nan_by_mean(X): 11 | # Obtain mean of columns as you need, nanmean is just convenient. 12 | col_mean = np.nanmean(X, axis=0) 13 | # Find indicies that you need to replace 14 | inds = np.where(np.isnan(X)) 15 | # Place column means in the indices. Align the arrays using take 16 | X[inds] = np.take(col_mean, inds[1]) 17 | return X 18 | 19 | 20 | def read_fungo_(f_in): 21 | """ 22 | X: features 23 | Y: labels 24 | C: for *_FUN: in format ancestor1/ancestor2/.../label 25 | for *_GO: in format parent/child [CAUTION: since it is a DAG, its size is larger than C_slash] 26 | C_slash: unique labels for *_GO [somehow have 3 more labels than HMCN's paper.] 27 | 28 | """ 29 | ct = 0 30 | A = 0 31 | C = 0 32 | C_set = set() 33 | flag = False 34 | X = [] 35 | Y = [] 36 | with open(f_in) as f: 37 | for line in f: 38 | if line.startswith('@ATTRIBUTE'): 39 | if '/' in line: 40 | # print(line.split(',')) 41 | C = line.strip().split(',') 42 | C[0] = C[0].split()[-1] 43 | # print([i.split('/')[-1] for i in C][-10:]) 44 | C_slash = set([i.split('/')[-1] for i in C]) 45 | else: 46 | A += 1 47 | if flag: 48 | ct += 1 49 | data = line.strip().split(',') 50 | classes = data[-1].split('@') 51 | # convert ? to nan 52 | X.append([float(i) if i != '?' else np.nan for i in data[:-1]]) 53 | Y.append(classes) 54 | C_set.update(classes) 55 | if line.startswith('@DATA'): 56 | flag = True 57 | X = np.array(X) 58 | X = replace_nan_by_mean(X) 59 | # print(f'[{f_in}] #features={A}, #Classes={len(C)}, #C_slash={len(C_slash)}, #C_appear={len(C_set)}, #samples={ct}') 60 | return X, Y, C, C_slash 61 | 62 | 63 | def read_fungo_all(name): 64 | valid = read_fungo_( 65 | f"./protein_datasets/{name}.valid.arff") 66 | test = read_fungo_( 67 | f"./protein_datasets/{name}.test.arff") 68 | train = read_fungo_( 69 | f"./protein_datasets/{name}.train.arff") 70 | return train, valid, test 71 | 72 | 73 | def construct_hier_dag(name): 74 | hierarchy = defaultdict(set) 75 | train_classes = read_fungo_( 76 | f"./protein_datasets/{name}.train.arff")[2] 77 | valid_classes = read_fungo_( 78 | f"./protein_datasets/{name}.valid.arff")[2] 79 | test_classes = read_fungo_( 80 | f"./protein_datasets/{name}.test.arff")[2] 81 | for t in [train_classes, valid_classes, test_classes]: 82 | for y in t: 83 | parent, child = y.split('/') 84 | if parent == 'root': 85 | parent = 'Top' 86 | hierarchy[parent].add(child) 87 | if not child in hierarchy: 88 | hierarchy[child] = set() 89 | return hierarchy 90 | 91 | 92 | def construct_hierarchy(name): 93 | # Construct hierarchy from train/valid/test. 94 | hierarchy = defaultdict(set) 95 | train_classes = read_fungo_( 96 | f"./protein_datasets/{name}.train.arff")[2] 97 | valid_classes = read_fungo_( 98 | f"./protein_datasets/{name}.valid.arff")[2] 99 | test_classes = read_fungo_( 100 | f"./protein_datasets/{name}.test.arff")[2] 101 | for t in [train_classes, valid_classes, test_classes]: 102 | for y in t: 103 | hier_list = y.split('/') 104 | # Add a pseudo node: Top 105 | if len(hier_list) == 1: 106 | hierarchy['Top'].add(hier_list[0]) 107 | hierarchy[hier_list[0]] = set() 108 | continue 109 | for l in range(1, len(hier_list)): 110 | parent = '/'.join(hier_list[:l]) 111 | child = '/'.join(hier_list[:l + 1]) 112 | if l == 1: 113 | hierarchy['Top'].add(parent) 114 | if not child in hierarchy[parent]: 115 | hierarchy[parent].add(child) 116 | if not child in hierarchy: 117 | hierarchy[child] = set() 118 | return hierarchy 119 | 120 | 121 | def get_all_ancestor_nodes(hierarchy, node): 122 | node_list = set() 123 | 124 | def dfs(node): 125 | if node != 'Top': 126 | node_list.add(node) 127 | parents = hierarchy[node]['parent'] 128 | for parent in parents: 129 | dfs(parent) 130 | 131 | dfs(node) 132 | return node_list 133 | 134 | 135 | def read_go(name): 136 | if os.path.exists(f'./protein_datasets/{name}.pkl'): 137 | return pickle.load(open(f'./protein_datasets/{name}.pkl', 'rb')) 138 | p2c = defaultdict(list) 139 | id2doc = defaultdict(lambda: defaultdict(list)) 140 | nodes = defaultdict(lambda: defaultdict(list)) 141 | random.seed(42) 142 | 143 | train, valid, test = read_fungo_all(name) 144 | hierarchy = construct_hier_dag(name) 145 | for parent in hierarchy: 146 | for child in hierarchy[parent]: 147 | p2c[parent].append(child) 148 | for label in p2c: 149 | for children in p2c[label]: 150 | nodes[label]['children'].append(children) 151 | nodes[children]['parent'].append(label) 152 | 153 | train_data = np.concatenate([train[0], valid[0]]) 154 | train[1].extend(valid[1]) 155 | 156 | X_train = [] 157 | X_test = [] 158 | train_ids = [] 159 | test_ids = [] 160 | 161 | for idx, (feature, classes) in enumerate(zip(train_data, train[1])): 162 | X_train.append(feature) 163 | train_ids.append(idx) 164 | for class_ in classes: 165 | ancestor_nodes = get_all_ancestor_nodes(nodes, class_) 166 | for label in ancestor_nodes: 167 | if not label in id2doc[idx]['categories']: 168 | id2doc[idx]['categories'].append(label) 169 | 170 | for idx, (feature, classes) in enumerate(zip(test[0], test[1])): 171 | X_test.append(feature) 172 | test_ids.append(idx + train_data.shape[0]) 173 | for class_ in classes: 174 | ancestor_nodes = get_all_ancestor_nodes(nodes, class_) 175 | for label in ancestor_nodes: 176 | if not label in id2doc[idx + train_data.shape[0]]['categories']: 177 | id2doc[idx + train_data.shape[0]]['categories'].append(label) 178 | res = X_train, X_test, train_ids, test_ids, dict(id2doc), dict(nodes) 179 | pickle.dump(res, open(f'./protein_datasets/{name}.pkl', 'wb')) 180 | return res 181 | 182 | 183 | def read_fun(name): 184 | if os.path.exists(f'./protein_datasets/{name}.pkl'): 185 | return pickle.load(open(f'./protein_datasets/{name}.pkl', 'rb')) 186 | p2c = defaultdict(list) 187 | id2doc = defaultdict(lambda: defaultdict(list)) 188 | nodes = defaultdict(lambda: defaultdict(list)) 189 | random.seed(42) 190 | 191 | train, valid, test = read_fungo_all(name) 192 | hierarchy = construct_hierarchy(name) 193 | for parent in hierarchy: 194 | for child in hierarchy[parent]: 195 | p2c[parent].append(child) 196 | for label in p2c: 197 | for children in p2c[label]: 198 | nodes[label]['children'].append(children) 199 | nodes[children]['parent'].append(label) 200 | 201 | train_data = np.concatenate([train[0], valid[0]]) 202 | train[1].extend(valid[1]) 203 | 204 | X_train = [] 205 | X_test = [] 206 | train_ids = [] 207 | test_ids = [] 208 | 209 | for idx, (feature, classes) in enumerate(zip(train_data, train[1])): 210 | X_train.append(feature) 211 | train_ids.append(idx) 212 | for class_ in classes: 213 | hier_list = class_.split('/') 214 | for l in range(1, len(hier_list) + 1): 215 | label = '/'.join(hier_list[:l]) 216 | if not label in id2doc[idx]['categories']: 217 | id2doc[idx]['categories'].append(label) 218 | 219 | for idx, (feature, classes) in enumerate(zip(test[0], test[1])): 220 | X_test.append(feature) 221 | test_ids.append(idx + train_data.shape[0]) 222 | for class_ in classes: 223 | # For each sample, we treat all nodes in the path as labels. 224 | hier_list = class_.split('/') 225 | for l in range(1, len(hier_list) + 1): 226 | label = '/'.join(hier_list[:l]) 227 | if not label in id2doc[idx + train_data.shape[0]]['categories']: 228 | id2doc[idx + train_data.shape[0]]['categories'].append(label) 229 | res = X_train, X_test, train_ids, test_ids, dict(id2doc), dict(nodes) 230 | pickle.dump(res, open(f'./protein_datasets/{name}.pkl', 'wb')) 231 | return res 232 | 233 | 234 | def read_fungo(name): 235 | if 'FUN' in name: 236 | return read_fun(name) 237 | return read_go(name) 238 | 239 | 240 | def process_for_cssag(name, train_data, train_labels, test_data, test_labels): 241 | if 'FUN' in name: 242 | hierarchy = construct_hierarchy(name) 243 | else: 244 | hierarchy = construct_hier_dag(name) 245 | nodes = defaultdict(lambda: defaultdict(list)) 246 | p2c = defaultdict(list) 247 | for parent in hierarchy: 248 | for child in hierarchy[parent]: 249 | p2c[parent].append(child) 250 | for label in p2c: 251 | for children in p2c[label]: 252 | nodes[label]['children'].append(children) 253 | nodes[children]['parent'].append(label) 254 | label2id = {} 255 | for k in hierarchy: 256 | if k not in label2id: 257 | label2id[k] = len(label2id) 258 | with open('./cssag/' + name + '.hier', 'w') as OUT: 259 | for k in hierarchy: 260 | OUT.write(str(label2id[k])) 261 | for c in hierarchy[k]: 262 | OUT.write(' ' + str(label2id[c])) 263 | OUT.write('\n') 264 | with open('./cssag/' + name + '.train.x', 'w') as OUT: 265 | for l in range(train_data.shape[0]): 266 | x = train_data[l] 267 | OUT.write('0') 268 | for idx, i in enumerate(x): 269 | OUT.write(' ' + str(idx + 1) + ':' + str(i)) 270 | OUT.write('\n') 271 | with open('./cssag/' + name + '.test.x', 'w') as OUT: 272 | for l in range(test_data.shape[0]): 273 | x = test_data[l] 274 | OUT.write('0') 275 | for idx, i in enumerate(x): 276 | OUT.write(' ' + str(idx + 1) + ':' + str(i)) 277 | OUT.write('\n') 278 | if 'FUN' in name: 279 | with open('./cssag/' + name + '.train.y', 'w') as OUT: 280 | for classes in train_labels: 281 | # OUT.write(str(label2id['Top'])) 282 | labels = set() 283 | for class_ in classes: 284 | hier_list = class_.split('/') 285 | for l in range(1, len(hier_list) + 1): 286 | label = '/'.join(hier_list[:l]) 287 | labels.add(label) 288 | for i, label in enumerate(labels): 289 | if i > 0: 290 | OUT.write(',') 291 | OUT.write(str(label2id[label])) 292 | OUT.write('\n') 293 | with open('./cssag/' + name + '.test.y', 'w') as OUT: 294 | for classes in test_labels: 295 | # OUT.write(str(label2id['Top'])) 296 | labels = set() 297 | for class_ in classes: 298 | hier_list = class_.split('/') 299 | for l in range(1, len(hier_list) + 1): 300 | label = '/'.join(hier_list[:l]) 301 | labels.add(label) 302 | for i, label in enumerate(labels): 303 | if i > 0: 304 | OUT.write(',') 305 | OUT.write(str(label2id[label])) 306 | OUT.write('\n') 307 | 308 | elif 'GO' in name: 309 | with open('./cssag/' + name + '.train.y', 'w') as OUT: 310 | for classes in train_labels: 311 | labels = set() 312 | for class_ in classes: 313 | ancestor_nodes = get_all_ancestor_nodes(nodes, class_) 314 | for label in ancestor_nodes: 315 | labels.add(label) 316 | for i, label in enumerate(labels): 317 | if i > 0: 318 | OUT.write(',') 319 | OUT.write(str(label2id[label])) 320 | OUT.write('\n') 321 | with open('./cssag/' + name + '.test.y', 'w') as OUT: 322 | for classes in test_labels: 323 | labels = set() 324 | for class_ in classes: 325 | ancestor_nodes = get_all_ancestor_nodes(nodes, class_) 326 | for label in ancestor_nodes: 327 | labels.add(label) 328 | for i, label in enumerate(labels): 329 | if i > 0: 330 | OUT.write(',') 331 | OUT.write(str(label2id[label])) 332 | OUT.write('\n') 333 | 334 | 335 | if __name__ == '__main__': 336 | # res = read_go() 337 | # embed() 338 | # exit() 339 | name = 'eisen_FUN' 340 | train, valid, test = read_fungo_all(name) 341 | train_data = np.concatenate([train[0], valid[0]]) 342 | train[1].extend(valid[1]) 343 | train_labels = train[1] 344 | process_for_cssag(name, train_data, train_labels, test[0], test[1]) 345 | embed() 346 | exit() 347 | -------------------------------------------------------------------------------- /readData_nyt.py: -------------------------------------------------------------------------------- 1 | import random 2 | import xml.dom.minidom 3 | from collections import defaultdict 4 | 5 | from tqdm import tqdm 6 | 7 | sample_ratio = 0.02 8 | train_ratio = 0.7 9 | min_per_node = 50 10 | 11 | 12 | def read_nyt_ids(file_path1, file_path2): 13 | ids = [[], []] 14 | random.seed(42) 15 | for cnt, file_path in enumerate([file_path1, file_path2]): 16 | filelist_path = './datasets/NYT_annotated_corpus/data/filelist_' + file_path + '.txt' 17 | with open(filelist_path, 'r') as fin: 18 | for file_cnt, file_name in tqdm(enumerate(fin)): 19 | if random.random() < sample_ratio: 20 | ids[cnt].append(file_name[:-5]) 21 | return ids 22 | 23 | 24 | def construct_hierarchy(file_path1, file_path2): 25 | hierarchy = defaultdict(set) 26 | for cnt, file_path in enumerate([file_path1, file_path2]): 27 | print('Processing file %d...' % cnt) 28 | filelist_path = './datasets/NYT_annotated_corpus/data/filelist_' + file_path + '.txt' 29 | with open(filelist_path, 'r') as fin: 30 | for file_cnt, file_name in tqdm(enumerate(fin)): 31 | file_name = file_name.strip('\n') 32 | xml_path = './datasets/NYT_annotated_corpus/data/accum' + file_path + '/' + file_name 33 | try: 34 | dom = xml.dom.minidom.parse(xml_path) 35 | root = dom.documentElement 36 | tags = root.getElementsByTagName('classifier') 37 | for tag in tags: 38 | type = tag.getAttribute('type') 39 | if type != 'taxonomic_classifier': 40 | continue 41 | hier_path = tag.firstChild.data 42 | hier_list = hier_path.split('/') 43 | for l in range(1, len(hier_list)): 44 | parent = '/'.join(hier_list[:l]) 45 | child = '/'.join(hier_list[:l + 1]) 46 | if not child in hierarchy[parent]: 47 | hierarchy[parent].add(child) 48 | if not child in hierarchy: 49 | hierarchy[child] = set() 50 | except: 51 | print('Something went wrong...') 52 | continue 53 | write_hierarchy_to_file(hierarchy, './datasets/nyt/nyt_hier') 54 | 55 | 56 | def write_hierarchy_to_file(hierarchy, filepath): 57 | with open(filepath, 'w') as fout: 58 | nodes = hierarchy.keys() 59 | for parent in nodes: 60 | for child in hierarchy[parent]: 61 | fout.write(parent + '\t' + child + '\n') 62 | 63 | 64 | def trim_nyt_tree(): 65 | ids_1, ids_2 = read_nyt_ids('1987-02', '2003-07') 66 | label_cnt = defaultdict(int) 67 | for id_idx in tqdm(range(len(ids_1) + len(ids_2))): 68 | if id_idx < len(ids_1): 69 | doc_id = ids_1[id_idx] 70 | xml_path = './datasets/NYT_annotated_corpus/data/accum1987-02/' + str(doc_id) + '.xml' 71 | else: 72 | doc_id = ids_2[id_idx - len(ids_1)] 73 | xml_path = './datasets/NYT_annotated_corpus/data/accum2003-07/' + str(doc_id) + '.xml' 74 | try: 75 | dom = xml.dom.minidom.parse(xml_path) 76 | root = dom.documentElement 77 | tags = root.getElementsByTagName('classifier') 78 | for tag in tags: 79 | type = tag.getAttribute('type') 80 | if type != 'taxonomic_classifier': 81 | continue 82 | hier_path = tag.firstChild.data 83 | hier_list = hier_path.split('/') 84 | for l in range(1, len(hier_list) + 1): 85 | label = '/'.join(hier_list[:l]) 86 | label_cnt[label] += 1 87 | except: 88 | print('Something went wrong...') 89 | continue 90 | label_cnt = {label: label_cnt[label] for label in label_cnt if label_cnt[label] > min_per_node} 91 | with open('nyt/nyt_trimed_hier', 'w') as fout: 92 | for label in label_cnt: 93 | fout.write(label + '\n') 94 | return label_cnt 95 | 96 | 97 | def read_nyt(): 98 | p2c = defaultdict(list) 99 | id2doc = defaultdict(lambda: defaultdict(list)) 100 | nodes = defaultdict(lambda: defaultdict(list)) 101 | random.seed(42) 102 | 103 | trimmed_nodes = set() 104 | with open('nyt/nyt_trimed_hier', 'r') as fin: 105 | for line in fin: 106 | trimmed_nodes.add(line.strip('\n')) 107 | with open('nyt/nyt_hier', 'r') as f: 108 | for line in f: 109 | line = line.strip('\n') 110 | parent, child = line.split('\t') 111 | if parent in trimmed_nodes and child in trimmed_nodes: 112 | p2c[parent].append(child) 113 | for label in p2c: 114 | for children in p2c[label]: 115 | nodes[label]['children'].append(children) 116 | nodes[children]['parent'].append(label) 117 | 118 | ids_1, ids_2 = read_nyt_ids('1987-02', '2003-07') 119 | X_train = [] 120 | X_test = [] 121 | train_ids = [] 122 | test_ids = [] 123 | for id_idx in tqdm(range(len(ids_1) + len(ids_2))): 124 | if id_idx < len(ids_1): 125 | doc_id = ids_1[id_idx] 126 | xml_path = './datasets/NYT_annotated_corpus/data/accum1987-02/' + str(doc_id) + '.xml' 127 | else: 128 | doc_id = ids_2[id_idx - len(ids_1)] 129 | xml_path = './datasets/NYT_annotated_corpus/data/accum2003-07/' + str(doc_id) + '.xml' 130 | try: 131 | dom = xml.dom.minidom.parse(xml_path) 132 | root = dom.documentElement 133 | tags = root.getElementsByTagName('p') 134 | text = '' 135 | for tag in tags[1:]: 136 | text += tag.firstChild.data 137 | if text == '': 138 | print('Something went wrong with text...') 139 | continue 140 | tags = root.getElementsByTagName('classifier') 141 | for tag in tags: 142 | type = tag.getAttribute('type') 143 | if type != 'taxonomic_classifier': 144 | continue 145 | hier_path = tag.firstChild.data 146 | hier_list = hier_path.split('/') 147 | for l in range(1, len(hier_list) + 1): 148 | label = '/'.join(hier_list[:l]) 149 | if label in trimmed_nodes and not label in id2doc[doc_id]['categories']: 150 | id2doc[doc_id]['categories'].append(label) 151 | if not doc_id in id2doc: 152 | continue 153 | if random.random() < train_ratio: 154 | train_ids.append(doc_id) 155 | X_train.append(text) 156 | else: 157 | test_ids.append(doc_id) 158 | X_test.append(text) 159 | except: 160 | print('Something went wrong...') 161 | continue 162 | 163 | return X_train, X_test, train_ids, test_ids, dict(id2doc), dict(nodes) 164 | -------------------------------------------------------------------------------- /readData_rcv1.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from tqdm import tqdm 3 | 4 | 5 | def read_rcv1_ids(filepath): 6 | ids = set() 7 | with open(filepath) as f: 8 | new_doc = True 9 | for line in f: 10 | line_split = line.strip().split() 11 | if new_doc and len(line_split) == 2: 12 | tmp, did = line_split 13 | if tmp == '.I': 14 | ids.add(did) 15 | new_doc = False 16 | else: 17 | print(line_split) 18 | print('maybe error') 19 | elif len(line_split) == 0: 20 | new_doc = True 21 | print('{} samples in {}'.format(len(ids), filepath)) 22 | return ids 23 | 24 | 25 | def read_rcv1(): 26 | p2c = defaultdict(list) 27 | id2doc = defaultdict(lambda: defaultdict(list)) 28 | nodes = defaultdict(lambda: defaultdict(list)) 29 | with open('rcv1/rcv1.topics.hier.orig.txt') as f: 30 | for line in f: 31 | start = line.find('parent: ') + len('parent: ') 32 | end = line.find(' ', start) 33 | parent = line[start:end] 34 | start = line.find('child: ') + len('child: ') 35 | end = line.find(' ', start) 36 | child = line[start:end] 37 | start = line.find('child-description: ') + len('child-description: ') 38 | end = line.find('\n', start) 39 | child_desc = line[start:end] 40 | p2c[parent].append(child) 41 | for label in p2c: 42 | if label == 'None': 43 | continue 44 | for children in p2c[label]: 45 | nodes[label]['children'].append(children) 46 | nodes[children]['parent'].append(label) 47 | 48 | with open('rcv1/rcv1-v2.topics.qrels') as f: 49 | for line in f: 50 | cat, doc_id, _ = line.strip().split() 51 | id2doc[doc_id]['categories'].append(cat) 52 | X_train = [] 53 | X_test = [] 54 | train_ids = [] 55 | test_ids = [] 56 | train_id_set = read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_train.dat') 57 | test_id_set = read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt0.dat') 58 | test_id_set |= read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt1.dat') 59 | test_id_set |= read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt2.dat') 60 | test_id_set |= read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt3.dat') 61 | print('len(test) total={}'.format(len(test_id_set))) 62 | n_not_found = 0 63 | with open('rcv1/docs.txt') as f: 64 | for line in tqdm(f): 65 | doc_id, text = line.strip().split(maxsplit=1) 66 | if doc_id in train_id_set: 67 | train_ids.append(doc_id) 68 | X_train.append(text) 69 | elif doc_id in test_id_set: 70 | test_ids.append(doc_id) 71 | X_test.append(text) 72 | else: 73 | n_not_found += 1 74 | print('there are {} that cannot be found in official tokenized rcv1'.format(n_not_found)) 75 | print('len(train_ids)={} len(test_ids)={}'.format(len(train_ids), len(test_ids))) 76 | return X_train, X_test, train_ids, test_ids, dict(id2doc), dict(nodes) 77 | 78 | 79 | if __name__ == '__main__': 80 | read_rcv1() 81 | -------------------------------------------------------------------------------- /readData_yelp.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import random 3 | 4 | from tqdm import tqdm 5 | 6 | 7 | def read_yelp_files(root, review_min_sup): 8 | nodes = {} 9 | with open('yelp/Taxonomy_100') as f: 10 | for line in f: 11 | node = ast.literal_eval(line) 12 | nodes[node['title']] = node 13 | subtree_name = root 14 | node_set = set() 15 | q = [subtree_name] 16 | nextq = [] 17 | print(q) 18 | while len(q) > 0: 19 | cur_name = q.pop(0) 20 | node_set.add(cur_name) 21 | print(nodes[cur_name]['children'], end=',') 22 | nextq.extend(nodes[cur_name]['children']) 23 | if len(q) == 0: 24 | q = nextq 25 | nextq = [] 26 | print() 27 | print('keep {} nodes in total'.format(len(node_set))) 28 | print(node_set) 29 | nodes = {node: nodes[node] for node in node_set} 30 | 31 | business_dict = {} 32 | ct = 0 33 | below_min_sup = 0 34 | with open('yelp/yelp_data_100.csv') as f: 35 | for line in tqdm(f): 36 | business = ast.literal_eval(line) 37 | if len(business['reviews']) < review_min_sup: 38 | below_min_sup += 1 39 | continue 40 | labels = set(business['categories']) & node_set 41 | if len(labels) == 0: 42 | continue 43 | if len(labels) != len(business['categories']): 44 | ct += 1 45 | # print('there are labels on other subtree: {} {}'.format(labels, business['categories'])) 46 | business['categories'] = labels 47 | business_dict[business['business_id']] = business 48 | print('keep {} business for tree {}'.format(len(business_dict), subtree_name)) 49 | print('there are {} that have labels on other subtrees'.format(ct)) 50 | print('there are {} that are filtered because of min_sup'.format(below_min_sup)) 51 | return nodes, business_dict 52 | 53 | 54 | def read_yelp(root, min_reviews=1, max_reviews=10): 55 | nodes, business_dict = read_yelp_files(root, min_reviews) 56 | random.seed(42) 57 | train_ids = [] 58 | test_ids = [] 59 | X_train = [] 60 | X_test = [] 61 | 62 | for bid in business_dict: 63 | reviews_concat = ' '.join(business_dict[bid]['reviews'][:max_reviews]) 64 | if random.random() > 0.3: 65 | train_ids.append(bid) 66 | X_train.append(reviews_concat) 67 | else: 68 | test_ids.append(bid) 69 | X_test.append(reviews_concat) 70 | 71 | return X_train, X_test, train_ids, test_ids, business_dict, nodes 72 | 73 | 74 | if __name__ == '__main__': 75 | read_yelp() 76 | -------------------------------------------------------------------------------- /tree.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | from collections import defaultdict, Counter 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from scipy import sparse 8 | from sklearn.metrics import f1_score 9 | from sklearn.preprocessing import MultiLabelBinarizer 10 | 11 | 12 | class Tree: 13 | def __init__(self, args, train_ids, test_ids, id2doc=None, id2doc_a=None, nodes=None, X_train=None, 14 | X_test=None, rootname=None): 15 | # child, parent, ancestor 16 | self.c2p = defaultdict(list) 17 | self.p2c = defaultdict(list) 18 | self.p2c_idx = defaultdict(list) 19 | self.c2p_idx = defaultdict(list) 20 | self.c2a = defaultdict(set) 21 | 22 | # real id to auto-increment id 23 | self.id2idx = {} 24 | self.idx2id = {} 25 | # real id to label name 26 | self.id2name = {} 27 | self.name2id = {} 28 | # real id to height 29 | self.id2h = {} 30 | # doc id to doc obj 31 | self.id2doc = id2doc 32 | self.id2doc_ancestors = id2doc_a 33 | self.nodes = nodes # for yelp 34 | self.rootname = rootname 35 | self.id2doc_h = {} 36 | self.train_ids = train_ids 37 | self.test_ids = test_ids 38 | self.X_train = X_train 39 | self.X_test = X_test 40 | self.taken_actions = None 41 | self.n_update = 0 # global update times 42 | self.data_cache = None # loaded from pkl 43 | self.last_R = None # R at last time step 44 | self.id2prob = defaultdict(dict) # [id][class] = p 45 | 46 | self.args = args 47 | self.logger = logging.getLogger('exp') 48 | self.miF = (0, 0) 49 | self.maF = (0, 0) 50 | self.cur_epoch = 0 51 | self.read_mapping() 52 | self.read_edges() 53 | for parent in self.p2c: 54 | for child in self.p2c[parent]: 55 | self.p2c_idx[self.id2idx[parent]].append(self.id2idx[child]) 56 | self.c2p_idx[self.id2idx[child]].append(self.id2idx[parent]) 57 | self.logger.info('# class = {}'.format(len(self.name2id))) 58 | self.n_class = len(self.name2id) 59 | self.class_idx = list(range(1, self.n_class)) 60 | 61 | self.p2c_idx_np = self.pad_p2c_idx() 62 | if args.mode != 'sl' and args.mode != 'hmcn': 63 | self.next_true_bin, self.next_true = self.generate_next_true(keep_cur=True) 64 | self.logger.info('{} terms have more than 1 parent'.format(sum([len(v) > 1 for v in self.c2p.values()]))) 65 | leaves = set(self.c2p) - set(self.p2c) 66 | self.logger.info('{} terms are leaves'.format(len(leaves))) 67 | for term in self.c2p: 68 | ancestor_q = [i for i in self.c2p[term]] 69 | while len(ancestor_q) != 0: 70 | cur = ancestor_q.pop(0) 71 | if cur in self.c2p: 72 | ancestor_q.extend(self.c2p[cur]) 73 | # exclude root in one's ancestors 74 | self.c2a[term].add(cur) 75 | 76 | # stat info 77 | # note that only the first path is considered 78 | for k in self.id2name: 79 | cur = k 80 | h = 0 81 | while cur in self.c2p: 82 | h += 1 83 | cur = self.c2p[cur][0] 84 | self.id2h[k] = h 85 | if args.stat_check: 86 | self.logger.info('node height description:') 87 | self.logger.info(pd.Series(list(self.id2h.values())).describe()) 88 | self.doc_check(self.id2doc, 'all') 89 | 90 | def pad_p2c_idx(self): 91 | col = max([len(c) for c in self.p2c_idx.values()]) + 2 92 | res = np.zeros((len(self.name2id), col), dtype=int) 93 | for row_i in range(len(self.name2id)): 94 | # stay at cur node 95 | if self.args.allow_stay: 96 | res[row_i, len(self.p2c_idx[row_i])] = row_i 97 | if self.args.allow_up: 98 | if row_i in self.c2p_idx: 99 | res[row_i, len(self.p2c_idx[row_i]) + 1] = self.c2p_idx[row_i][0] 100 | # next level node 101 | res[row_i, :len(self.p2c_idx[row_i])] = self.p2c_idx[row_i] 102 | return res 103 | 104 | def get_next_candidates(self, last_selections, cur_candidates, nonstop=False): 105 | all_stop = True 106 | for candi, sel in zip(cur_candidates, last_selections): 107 | if sel != 0: # 0 is not in candi 108 | candi.remove(sel) 109 | if sel == self.n_class: # stop action 110 | candi.clear() 111 | else: 112 | candi.update(self.p2c_idx[sel]) 113 | all_stop = False 114 | if not nonstop: 115 | candi.add(self.n_class) 116 | return cur_candidates, self.pad_candidates(cur_candidates), all_stop 117 | 118 | def update_actions(self, cur_class_batch): 119 | for taken, a in zip(self.taken_actions, cur_class_batch): 120 | taken.add(a) 121 | 122 | def pad_candidates(self, cur_candidates): 123 | col = max([len(c) for c in cur_candidates]) 124 | res = np.zeros((len(cur_candidates), col), dtype=int) 125 | for row_i, c in enumerate(cur_candidates): 126 | res[row_i, :len(c)] = list(c) 127 | return res 128 | 129 | def doc_check(self, docids, name): 130 | avg_h = [] 131 | n_classes = [] 132 | hs = [] 133 | label_key = 'categories' 134 | classes = [] 135 | for docid in docids: 136 | if len(self.id2doc[docid][label_key]) != len(set(self.id2doc[docid][label_key])): 137 | self.logger.error(label_key) 138 | self.logger.error(self.id2doc[docid][label_key]) 139 | exit(1) 140 | classes.extend(self.id2doc[docid][label_key]) 141 | h = [self.id2h[self.name2id[term]] for term in self.id2doc[docid][label_key]] 142 | hs.append(h) 143 | self.id2doc_h[docid] = h 144 | avg_h.append(np.mean(h)) 145 | n_classes.append(len(h)) 146 | if self.args.stat_check: 147 | self.logger.info('check info of classes for each doc in {}'.format(name)) 148 | self.logger.info('heights of classes') 149 | self.logger.info(hs[:5]) 150 | self.logger.info('avg_height of labels') 151 | self.logger.info(pd.Series(avg_h).describe()) 152 | self.logger.info('label count of documents') 153 | self.logger.info(pd.Series(n_classes).describe()) 154 | self.logger.info('node support') 155 | self.logger.info(Counter(classes)) 156 | 157 | def h_doc_batch(self, docids): 158 | return [self.id2doc_h[docid] for docid in docids] 159 | 160 | def h_batch(self, ids): 161 | return [self.id2h[self.idx2id[vid]] for vid in ids] 162 | 163 | def p2c_batch(self, ids): 164 | # ids is virtual 165 | res = self.p2c_idx_np[ids] 166 | # remove columns full of zeros 167 | return res[:, ~np.all(res == 0, axis=0)] 168 | 169 | def generate_next_true(self, keep_cur=False): 170 | self.logger.info('keep_cur={}'.format(keep_cur)) 171 | next_true_bin = defaultdict(lambda: defaultdict(list)) 172 | next_true = defaultdict(lambda: defaultdict(list)) 173 | for did in self.id2doc_ancestors: 174 | class_idx_set = set(self.id2doc_ancestors[did]['class_idx']) 175 | class_idx_set.add(0) 176 | for c in class_idx_set: 177 | for idx, next_c in enumerate(self.p2c_idx[c]): 178 | if next_c in class_idx_set: 179 | next_true_bin[did][c].append(1) 180 | next_true[did][c].append(next_c) 181 | else: 182 | next_true_bin[did][c].append(0) 183 | # if lowest label 184 | if len(next_true[did][c]) == 0: 185 | next_true[did][c].append(c) 186 | if self.args.allow_stay: 187 | next_true_bin[did][c].append(1) 188 | elif keep_cur and c != 0: 189 | # append 1 only for loss calculation 190 | next_true_bin[did][c].append(1) 191 | return next_true_bin, next_true 192 | 193 | def get_next(self, cur_class_batch, next_classes_batch, doc_ids): 194 | assert len(cur_class_batch) == len(doc_ids) 195 | next_classes_batch_true = np.zeros(next_classes_batch.shape) 196 | indices = [] 197 | next_class_batch_true = [] 198 | for ct, (c, did) in enumerate(zip(cur_class_batch, doc_ids)): 199 | nt = self.next_true_bin[did][c] 200 | if len(self.next_true[did][c]) == 0: 201 | print(ct, did, c) 202 | print(nt, self.next_true[did][c]) 203 | print(self.id2doc_ancestors[did]) 204 | exit(-1) 205 | next_classes_batch_true[ct][:len(nt)] = nt 206 | for idx in self.next_true[did][c]: 207 | indices.append(ct) 208 | next_class_batch_true.append(idx) 209 | doc_ids = [doc_ids[idx] for idx in indices] 210 | return next_classes_batch_true, indices, np.array(next_class_batch_true), doc_ids 211 | 212 | def get_next_by_probs(self, cur_class_batch, next_classes_batch, doc_ids, probs, save_prob): 213 | assert len(cur_class_batch) == len(doc_ids) == len(doc_ids) == len(probs) 214 | indices = [] 215 | next_class_batch_pred = [] 216 | if save_prob: 217 | thres = 0 218 | else: 219 | thres = 0.5 220 | preds = (probs > thres).int().data.cpu().numpy() 221 | for ct, (c, next_classes, did, pred, p) in enumerate( 222 | zip(cur_class_batch, next_classes_batch, doc_ids, preds, probs)): 223 | # need allow_stay=True, filter last one (cur) to avoid duplication 224 | next_pred = np.nonzero(pred)[0] 225 | if not self.args.multi_label: 226 | idx_above_thres = np.argsort(p.data.cpu().numpy()[next_pred]) 227 | for idx in idx_above_thres[::-1]: 228 | if next_classes[next_pred[idx]] != c: 229 | next_pred = [next_pred[idx]] 230 | break 231 | else: 232 | if len(next_pred) != 0 and next_classes[next_pred[-1]] == c: 233 | next_pred = next_pred[:-1] 234 | # if no next > threshold, stay at current class 235 | if len(next_pred) == 0: 236 | p_selected = [] 237 | next_pred = [c] 238 | else: 239 | p_selected = p.data.cpu().numpy()[next_pred] 240 | next_pred = next_classes[next_pred] 241 | # indices remember where one is from; idx is virtual class idx 242 | for idx in next_pred: 243 | indices.append(ct) 244 | next_class_batch_pred.append(idx) 245 | if save_prob: 246 | for idx, p_ in zip(next_pred, p_selected): 247 | if idx in self.id2prob[did]: 248 | self.logger.warning(f'[{did}][{idx}] already existed!') 249 | self.id2prob[did][idx] = p_ 250 | doc_ids = [doc_ids[idx] for idx in indices] 251 | return indices, np.array(next_class_batch_pred), doc_ids 252 | 253 | def get_flat_idx_each_layer(self): 254 | flat_idx_each_layer = [[0]] 255 | idx2layer_idx = {} 256 | idx2layer_idx[0] = (0, 0) 257 | for i in range(self.args.n_steps_sl): 258 | flat_idx_each_layer.append([]) 259 | current_nodes = flat_idx_each_layer[i] 260 | for current_node in current_nodes: 261 | child_nodes = self.p2c_idx[current_node] 262 | for (in_layer_idx, child_node) in enumerate(child_nodes): 263 | flat_idx_each_layer[i + 1].append(child_node) 264 | idx2layer_idx[child_node] = (i, in_layer_idx) 265 | 266 | self.flat_idx_each_layer = flat_idx_each_layer 267 | self.idx2layer_idx = idx2layer_idx 268 | return flat_idx_each_layer, idx2layer_idx 269 | 270 | def get_layer_node_number(self): 271 | layer_numer = max(self.id2h.values()) 272 | local_output_number = [0] * layer_numer 273 | for id in self.id2h: 274 | if self.id2h[id] == 0: 275 | continue 276 | local_output_number[self.id2h[id] - 1] += 1 277 | assert sum(local_output_number) == self.n_class - 1 278 | return local_output_number 279 | 280 | def calc_reward(self, notLast, actions, ids): 281 | if self.args.reward == '01': 282 | return self.calc_reward_01(actions, ids) 283 | elif self.args.reward == 'f1': 284 | return self.calc_reward_f1(actions, ids) 285 | elif self.args.reward == '01-1': 286 | return self.calc_reward_neg(actions, ids) 287 | elif self.args.reward == 'direct': 288 | return self.calc_reward_direct(notLast, actions, ids) 289 | elif self.args.reward == 'taxo': 290 | return self.calc_reward_taxo(actions, ids) 291 | else: 292 | raise NotImplementedError 293 | 294 | def calc_reward_f1(self, actions, ids): 295 | R = [] 296 | for a, i, taken in zip(actions, ids, self.taken_actions): 297 | taken.add(a) 298 | y_l_a = self.id2doc_ancestors[i]['class_idx'] 299 | correct = set(y_l_a) & taken 300 | p = len(correct) / len(taken) 301 | r = len(correct) / len(y_l_a) 302 | f1 = 2 * p * r / (p + r + 1e-32) 303 | R.append(f1) 304 | R = np.array(R) 305 | res = np.copy(R) 306 | if self.last_R is not None: 307 | res -= self.last_R 308 | self.last_R = R 309 | return res 310 | 311 | def calc_reward_direct(self, notLast, actions, ids): 312 | if notLast: 313 | return [0] * len(actions) 314 | R = [] 315 | for a, i, taken in zip(actions, ids, self.taken_actions): 316 | if a in self.id2doc[i]['class_idx']: 317 | R.append(1) 318 | else: 319 | R.append(0) 320 | return R 321 | 322 | def calc_reward_taxo(self, actions, ids): 323 | R = [] 324 | for taken, a, i in zip(self.taken_actions, actions, ids): 325 | if a in self.id2doc_ancestors[i]['class_idx']: 326 | R.append(1) 327 | elif a == self.n_class: 328 | R.append(0) 329 | else: 330 | R.append(-1) 331 | return R 332 | 333 | def calc_reward_01(self, actions, ids): 334 | R = [] 335 | label_key = 'categories' 336 | for a, i, taken in zip(actions, ids, self.taken_actions): 337 | if a in self.id2doc[i]['class_idx']: 338 | R.append(1) 339 | continue 340 | if a in taken: 341 | R.append(0) 342 | continue 343 | for label in self.id2doc[i][label_key]: 344 | if self.idx2id[a] in self.c2a[self.name2id[label]]: 345 | R.append(1) 346 | break 347 | else: 348 | R.append(0) 349 | taken.add(a) 350 | return R 351 | 352 | def calc_reward_neg(self, actions, ids): 353 | # actions : virtual id 354 | # ids: doc id for mesh, yelp | real labels for review 355 | R = [] 356 | label_key = 'categories' 357 | assert len(actions) == len(ids) == len(self.taken_actions) 358 | if self.args.dataset in ['mesh', 'yelp', 'rcv1']: 359 | for a, i, taken in zip(actions, ids, self.taken_actions): 360 | if a in self.id2doc[i]['class_idx']: 361 | R.append(1) 362 | continue 363 | for label in self.id2doc[i][label_key]: 364 | if self.idx2id[a] in self.c2a[self.name2id[label]]: 365 | if a in taken: 366 | R.append(0) 367 | else: 368 | R.append(1) 369 | break 370 | else: 371 | R.append(-1) 372 | taken.add(a) 373 | return R 374 | 375 | def calc_f1(self, pred_l, id_l, save_path=None, output=True): 376 | assert len(pred_l) == len(id_l) 377 | if not output: 378 | mlb = MultiLabelBinarizer(classes=self.class_idx) 379 | y_l_a = [self.id2doc_ancestors[docid]['class_idx'] for docid in id_l] 380 | y_true_a = mlb.fit_transform(y_l_a) 381 | y_pred = mlb.transform(pred_l) 382 | f1_a = f1_score(y_true_a, y_pred, average='micro') 383 | f1_a_macro = f1_score(y_true_a, y_pred, average='macro') 384 | f1_a_s = f1_score(y_true_a, y_pred, average='samples') 385 | self.logger.info(f'micro:{f1_a:.4f} macro:{f1_a_macro:.4f} samples:{f1_a_s:.4f}') 386 | return f1_a, f1_a_macro, f1_a_s 387 | # pred_l_a contains all ancestors (except root) 388 | pred_l_a = [] 389 | for pred in pred_l: 390 | pred_l_a.append([]) 391 | for i in pred: 392 | pred_l_a[-1].append(i) 393 | # TODO can be refactored 394 | pred_l_a[-1].extend( 395 | [self.id2idx[self.name2id[term]] for term in self.c2a[self.id2name[self.idx2id[i]]]]) 396 | y_l = [self.id2doc[docid]['class_idx'] for docid in id_l] 397 | if self.id2doc_ancestors: 398 | y_l_a = [self.id2doc_ancestors[docid]['class_idx'] for docid in id_l] 399 | else: 400 | y_l_a = y_l 401 | 402 | self.logger.info('measuring f1 of {} samples...'.format(len(pred_l))) 403 | avg_len = np.mean([len(p) for p in pred_l]) 404 | self.logger.info(f'len(pred): {avg_len}') 405 | # self.logger.info( 406 | # 'docid:{} pred:{} real:{} pred_a:{} real_a:{}'.format(id_l[:5], pred_l[:5], y_l[:5], pred_l_a[:5], 407 | # y_l_a[:5])) 408 | # self.logger.info('Counter(pred_l):{}'.format(Counter(flatten(pred_l)))) 409 | # self.logger.info('Counter(y_l):{}'.format(Counter(flatten(y_l)))) 410 | # self.logger.info('Counter(y_l_a):{}'.format(Counter(flatten(y_l_a)))) 411 | # self.logger.info(self.idx2id) 412 | 413 | calc_leaf_only = False 414 | if calc_leaf_only: 415 | y_true_leaf = [[idx for idx in i if len(self.p2c_idx[idx]) == 0] for i in y_l] 416 | y_pred_leaf = [[idx for idx in i if len(self.p2c_idx[idx]) == 0] for i in pred_l] 417 | mlb = MultiLabelBinarizer() 418 | y_true = sparse.csr_matrix(mlb.fit_transform(y_true_leaf)) 419 | print(len(mlb.classes_), 36504) 420 | y_pred = sparse.csr_matrix(mlb.transform(y_pred_leaf)) 421 | print(f1_score(y_true, y_pred, average='micro')) 422 | print(f1_score(y_true, y_pred, average='macro')) 423 | print(f1_score(y_true, y_pred, average='samples')) 424 | 425 | mlb = MultiLabelBinarizer(classes=self.class_idx) 426 | y_true = sparse.csr_matrix(mlb.fit_transform(y_l)) 427 | y_pred = sparse.csr_matrix(mlb.transform(pred_l)) 428 | f1 = f1_score(y_true, y_pred, average='micro') 429 | f1_macro = f1_score(y_true, y_pred, average='macro') 430 | 431 | calc_other_f1 = True 432 | if calc_other_f1: 433 | y_true_a = sparse.csr_matrix(mlb.transform(y_l_a)) 434 | y_pred_a = sparse.csr_matrix(mlb.transform(pred_l_a)) 435 | if save_path is not None: 436 | self.logger.info('saving to {}/preds.pkl'.format(save_path)) 437 | pickle.dump((y_pred, id_l, y_true_a), open('{}/preds.pkl'.format(save_path), 'wb')) 438 | f1_a = f1_score(y_true_a, y_pred, average='micro') 439 | f1_aa = f1_score(y_true_a, y_pred_a, average='micro') 440 | f1_a_macro = f1_score(y_true_a, y_pred, average='macro') 441 | f1_aa_macro = f1_score(y_true_a, y_pred_a, average='macro') 442 | f1_aa_s = f1_score(y_true_a, y_pred_a, average='samples') 443 | # from sklearn.metrics import classification_report 444 | # print(classification_report(y_true, y_pred)) 445 | else: 446 | f1_a = 0 447 | f1_aa = 0 448 | f1_a_macro = 0 449 | f1_aa_macro = 0 450 | f1_aa_s = 0 451 | return f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s 452 | 453 | def acc(self, pred_l, id_l): 454 | R = [] 455 | for pred, i in zip(pred_l, id_l): 456 | R.append(int(pred in self.id2doc[i]['class_idx'])) 457 | return np.mean(R) 458 | 459 | def acc_multi(self, pred_l, id_l): 460 | R = [] 461 | for preds, i in zip(pred_l, id_l): 462 | for pred in preds: 463 | R.append(int(pred in self.id2doc[i]['class_idx'])) 464 | return np.mean(R) 465 | 466 | def read_mapping(self): 467 | for idx, c in enumerate(self.nodes): 468 | if c == self.rootname: 469 | root_idx = idx 470 | self.id2name[c] = c 471 | self.name2id[c] = c 472 | self.id2idx[c] = idx 473 | self.idx2id[idx] = c 474 | # put root to idx 0 475 | self.idx2id[root_idx] = self.idx2id[0] 476 | self.id2idx[self.idx2id[root_idx]] = root_idx 477 | self.id2idx[self.rootname] = 0 478 | self.idx2id[0] = self.rootname 479 | # remove root from labels and add class_idx 480 | for bid in self.id2doc: 481 | self.id2doc[bid]['categories'] = [c for c in self.id2doc[bid]['categories'] if c != self.rootname] 482 | self.id2doc[bid]['class_idx'] = [self.id2idx[c] for c in self.id2doc[bid]['categories']] 483 | if self.id2doc_ancestors is None: 484 | return 485 | for bid in self.id2doc_ancestors: 486 | self.id2doc_ancestors[bid]['categories'] = [c for c in self.id2doc_ancestors[bid]['categories'] if 487 | c != self.rootname] 488 | self.id2doc_ancestors[bid]['class_idx'] = [self.id2idx[c] for c in self.id2doc_ancestors[bid]['categories']] 489 | 490 | def read_edges(self): 491 | for parent in self.nodes: 492 | for child in self.nodes[parent]['children']: 493 | self.c2p[child].append(parent) 494 | self.p2c[parent].append(child) 495 | 496 | def remove_stop(self): 497 | for taken in self.taken_actions: 498 | taken.discard(self.n_class) 499 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import pickle 4 | import subprocess 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from torch.autograd import Variable 10 | from tqdm import tqdm 11 | 12 | 13 | def isnan(x): 14 | return x != x 15 | 16 | 17 | def contains_nan(x): 18 | return isnan(x).any() 19 | 20 | 21 | def explode(x): 22 | return (x > 10).any() 23 | 24 | 25 | def eu_dist(x): 26 | return sum((x[0] - x[1]) ** 2) / len(x[0]) 27 | 28 | 29 | def get_gpu_memory_map(): 30 | result = subprocess.check_output( 31 | [ 32 | 'nvidia-smi', '--query-gpu=memory.free,utilization.gpu', 33 | '--format=csv,nounits,noheader' 34 | ], encoding='utf-8') 35 | gpu_info = [eval(x) for x in result.strip().split('\n')] 36 | gpu_info = dict(zip(range(len(gpu_info)), gpu_info)) 37 | sorted_gpu_info = sorted(gpu_info.items(), key=lambda kv: kv[1][0], reverse=True) 38 | sorted_gpu_info = sorted(sorted_gpu_info, key=lambda kv: kv[1][1]) 39 | print(f'gpu_id, (mem_left, util): {sorted_gpu_info}') 40 | return sorted_gpu_info 41 | 42 | 43 | def save_checkpoint(state, modelpath, modelname, logger=None, del_others=True): 44 | if del_others: 45 | for dirpath, dirnames, filenames in os.walk(modelpath): 46 | for filename in filenames: 47 | path = os.path.join(dirpath, filename) 48 | if path.endswith('pth.tar'): 49 | if logger is None: 50 | print(f'rm {path}') 51 | else: 52 | logger.warning(f'rm {path}') 53 | os.system("rm -rf '{}'".format(path)) 54 | break 55 | path = os.path.join(modelpath, modelname) 56 | if logger is None: 57 | print('saving model to {}...'.format(path)) 58 | else: 59 | logger.warning('saving model to {}...'.format(path)) 60 | try: 61 | torch.save(state, path) 62 | except Exception as e: 63 | logger.error(e) 64 | 65 | 66 | def flatten(x): 67 | if isinstance(x, collections.Iterable): 68 | return [a for i in x for a in flatten(i)] 69 | else: 70 | return [x] 71 | 72 | 73 | def check_doc_size(X_train, logger): 74 | n_sent = [] 75 | n_words = [] 76 | n_words_per_doc = [] 77 | for doc in X_train: 78 | n_sent.append(len(doc)) 79 | words_per_doc = 0 80 | for sent in doc: 81 | n_words.append(len(sent)) 82 | words_per_doc += len(sent) 83 | n_words_per_doc.append(words_per_doc) 84 | logger.info('#sent in a document') 85 | logger.info(pd.Series(n_sent).describe(percentiles=[.25, .5, .75, .8, .85, .9, .95, .96, .98])) 86 | logger.info('#words in a sent') 87 | logger.info(pd.Series(n_words).describe(percentiles=[.25, .5, .75, .8, .85, .9, .95, .96, .98])) 88 | logger.info('#words in a document') 89 | logger.info(pd.Series(n_words_per_doc).describe(percentiles=[.25, .5, .75, .8, .85, .9, .95, .96, .98])) 90 | 91 | 92 | def pad_batch(mini_batch): 93 | mini_batch_size = len(mini_batch) 94 | max_sent_len = min(np.max([len(x) for x in mini_batch]), 10) 95 | max_token_len = min(np.max([len(val) for sublist in mini_batch for val in sublist]), 50) 96 | main_matrix = np.zeros((mini_batch_size, max_sent_len, max_token_len), dtype=np.int) 97 | for i in range(main_matrix.shape[0]): 98 | for j in range(main_matrix.shape[1]): 99 | for k in range(main_matrix.shape[2]): 100 | try: 101 | main_matrix[i, j, k] = mini_batch[i][j][k] 102 | except IndexError: 103 | pass 104 | return Variable(torch.from_numpy(main_matrix).transpose(0, 1)) 105 | 106 | 107 | def pad_batch_nosent_fast(args, word_index, mini_batch, region, stride): 108 | mini_batch_size = len(mini_batch) 109 | n_tokens = min(args.max_tokens, max([sum([len(sent) for sent in doc]) for doc in mini_batch])) 110 | main_matrix = np.zeros((mini_batch_size, n_tokens, region), dtype=np.int) 111 | unk_idx = word_index['UNK'] 112 | main_matrix.fill(unk_idx) 113 | for i in range(mini_batch_size): 114 | sent_cat = [unk_idx] * (region - 1) + [word for sent in mini_batch[i] for word in sent] # padded 115 | # sent_cat = [word for sent in mini_batch[i] for word in sent] 116 | idx = 0 117 | ct = 0 118 | last_set = set() 119 | while ct < n_tokens and idx < len(sent_cat): 120 | word_set = set() # words in current region 121 | for region_idx, word in enumerate(sent_cat[idx: idx + region]): 122 | if word in word_set: 123 | main_matrix[i][ct][region_idx] = unk_idx 124 | continue 125 | if word != unk_idx: 126 | word_set.add(word) 127 | main_matrix[i][ct][region_idx] = word 128 | if last_set == word_set: 129 | ct -= 1 130 | last_set = word_set 131 | idx += stride 132 | ct += 1 133 | 134 | return main_matrix 135 | 136 | 137 | # region is for bow-cnn. need to covert vectors to multi-hot 138 | def pad_batch_nosent(mini_batch, word_index, onehot=False, region=None, stride=None): 139 | mini_batch_size = len(mini_batch) 140 | n_tokens = min(256, max([sum([len(sent) for sent in doc]) for doc in mini_batch])) 141 | if onehot: 142 | main_matrix = np.zeros((mini_batch_size, n_tokens, 30000), dtype=np.float32) 143 | unk_idx = word_index['UNK'] 144 | for i in range(mini_batch_size): 145 | if not region: 146 | ct = 0 147 | for sent in mini_batch[i]: 148 | for word in sent: 149 | if word != unk_idx: 150 | if word > unk_idx: 151 | word -= 1 152 | main_matrix[i][ct][word] = 1 153 | ct += 1 154 | if ct == n_tokens: 155 | break 156 | if ct == n_tokens: 157 | break 158 | else: 159 | sent_cat = [unk_idx] * (region - 1) + [word for sent in mini_batch[i] for word in sent] 160 | idx = 0 161 | ct = 0 162 | last_set = set() 163 | while ct < n_tokens and idx < len(sent_cat): 164 | word_set = set() 165 | for word in sent_cat[idx: idx + region]: 166 | if word != unk_idx: 167 | if word > unk_idx: 168 | word -= 1 169 | word_set.add(word) 170 | main_matrix[i][ct][word] = 1 171 | # variable-stride 172 | if last_set == word_set: 173 | ct -= 1 174 | last_set = word_set 175 | idx += stride 176 | ct += 1 177 | else: 178 | main_matrix = np.zeros((mini_batch_size, n_tokens), dtype=np.int) 179 | for i in range(mini_batch_size): 180 | ct = 0 181 | for sent in mini_batch[i]: 182 | for word in sent: 183 | main_matrix[i][ct] = word 184 | ct += 1 185 | if ct == n_tokens: 186 | break 187 | if ct == n_tokens: 188 | break 189 | return Variable(torch.from_numpy(main_matrix)) 190 | 191 | 192 | def iterate_minibatches(args, inputs, targets, batchsize, shuffle): 193 | assert inputs.shape[0] == targets.shape[0] 194 | if args.debug: 195 | for _ in range(300): 196 | yield inputs[:batchsize], targets[:batchsize] 197 | return 198 | if shuffle: 199 | indices = np.arange(inputs.shape[0]) 200 | np.random.shuffle(indices) 201 | for start_idx in range(0, inputs.shape[0] - batchsize + 1, batchsize): 202 | if shuffle: 203 | excerpt = indices[start_idx:start_idx + batchsize] 204 | else: 205 | excerpt = slice(start_idx, start_idx + batchsize) 206 | yield inputs[excerpt], targets[excerpt] 207 | if start_idx + batchsize < inputs.shape[0]: 208 | if shuffle: 209 | excerpt = indices[start_idx + batchsize:] 210 | else: 211 | excerpt = slice(start_idx + batchsize, start_idx + batchsize * 2) 212 | yield inputs[excerpt], targets[excerpt] 213 | 214 | 215 | def iterate_minibatches_order(args, inputs, targets, batchsize): 216 | assert inputs.shape[0] == targets.shape[0] 217 | if args.debug: 218 | for _ in range(300): 219 | yield inputs[:batchsize], targets[:batchsize] 220 | return 221 | indices = np.argsort([-len(doc) for doc in inputs]) 222 | for start_idx in range(0, inputs.shape[0] - batchsize + 1, batchsize): 223 | excerpt = indices[start_idx:start_idx + batchsize] 224 | yield inputs[excerpt], targets[excerpt] 225 | if start_idx + batchsize < inputs.shape[0]: 226 | excerpt = indices[start_idx + batchsize:] 227 | yield inputs[excerpt], targets[excerpt] 228 | 229 | 230 | def gen_minibatch(logger, args, word_index, tokens, labels, mini_batch_size, shuffle=False): 231 | logger.info('# batches = {}'.format(len(tokens) / mini_batch_size)) 232 | # for token, label in iterate_minibatches(tokens, labels, mini_batch_size, shuffle=shuffle): 233 | for token, label in iterate_minibatches_order(args, tokens, labels, mini_batch_size): 234 | if args.base_model == 'textcnn': 235 | token = pad_batch_nosent(token, word_index) 236 | elif args.base_model == 'ohcnn-seq': 237 | token = pad_batch_nosent(token, word_index, onehot=True) 238 | elif args.base_model == 'ohcnn-bow': 239 | token = pad_batch_nosent(token, word_index, onehot=True, region=20, stride=2) 240 | elif args.base_model == 'ohcnn-bow-fast': 241 | main_matrix = pad_batch_nosent_fast(args, word_index, token, region=20, stride=2) 242 | token = Variable(torch.from_numpy(main_matrix)) 243 | else: 244 | token = pad_batch(token) 245 | if args.gpu: 246 | yield token.cuda(), label 247 | else: 248 | yield token, label 249 | 250 | 251 | def gen_minibatch_from_cache(logger, args, tree, mini_batch_size, name, shuffle): 252 | pkl_path = '{}_{}.pkl'.format(name, mini_batch_size) 253 | if not os.path.exists(pkl_path): 254 | logger.error('{} NOT FOUND'.format(pkl_path)) 255 | exit(-1) 256 | if 'train' in name: 257 | if tree.data_cache is not None: 258 | (token_l, label_l) = tree.data_cache 259 | logger.info('loaded from tree.data_cache') 260 | else: 261 | (token_l, label_l) = pickle.load(open(pkl_path, 'rb')) 262 | tree.data_cache = (token_l, label_l) 263 | else: 264 | (token_l, label_l) = pickle.load(open(pkl_path, 'rb')) 265 | logger.info('loaded {} batches from {}'.format(len(label_l), pkl_path)) 266 | if args.debug: 267 | for _ in range(1000): 268 | token = Variable(torch.from_numpy(token_l[0])) 269 | label = label_l[0] 270 | if args.gpu: 271 | yield token.cuda(), label 272 | else: 273 | yield token, label 274 | return 275 | if shuffle: 276 | indices = np.arange(len(token_l)) 277 | np.random.shuffle(indices) 278 | for i in indices: 279 | token = token_l[i] 280 | label = label_l[i] 281 | token = Variable(torch.from_numpy(token)) 282 | if args.gpu: 283 | yield token.cuda(), label 284 | else: 285 | yield token, label 286 | else: 287 | for token, label in zip(token_l, label_l): 288 | # out of memory 289 | if mini_batch_size > 32: 290 | new_batch_size = mini_batch_size // 2 291 | for i in range(0, mini_batch_size, new_batch_size): 292 | token_v = Variable(torch.from_numpy(token[i:i + new_batch_size])) 293 | label_v = label[i:i + new_batch_size] 294 | if args.gpu: 295 | yield token_v.cuda(), label_v 296 | else: 297 | yield token_v, label_v 298 | else: 299 | token = Variable(torch.from_numpy(token)) 300 | if args.gpu: 301 | yield token.cuda(), label 302 | else: 303 | yield token, label 304 | 305 | 306 | def save_minibatch(logger, args, word_index, tokens, labels, mini_batch_size, name=''): 307 | filename = '{}_{}.pkl'.format(name, mini_batch_size) 308 | if os.path.exists(filename): 309 | logger.warning(f'skipped since {filename} existed') 310 | return 311 | token_l = [] 312 | label_l = [] 313 | for token, label in tqdm(iterate_minibatches_order(args, tokens, labels, mini_batch_size)): 314 | token = pad_batch_nosent_fast(args, word_index, token, region=20, stride=2) 315 | token_l.append(token) 316 | label_l.append(label) 317 | pickle.dump((token_l, label_l), open(filename, 'wb')) 318 | -------------------------------------------------------------------------------- /yelp/yelp_data_100.csv.sample: -------------------------------------------------------------------------------- 1 | {'reviews': ['They make a plan, for your mouth to get healthy & stay healthy! Wonderful service.', "I have extreme anxiety when it comes to the dentist. I had a really bad experience with my childhood dentist. This made it so when I was an adult I didn't go as often and I should have. I moved to Ahwatukee 9 years ago and a friend recommended this office to me. I've been going regulalry for routine cleaning since. I went last month for a cleaning and found out I needed a crown. Sensing my anxiety he assured me I would be fine. I scheduled my appt. for today. I went in so nervous I wanted to leave. He came in, explaining everything he would be doing. He put me at ease and I've never felt better. I would highly recommend Dr. Kode to anyone looking for a dentist. He's restored my faith in dentists. I will go back in a few weeks for the permanent crown and I'm not a bit nervous about it. Thank you so much for your patience with me and the absolute best experience I've had from a dentist.", "I've only been here once but my husband has been going for a while and I plan to obviously continue. I'm a complete wimp at the dentist and they were pretty great. The staff is friendly and relaxed, which helps. They have been very honest and the fancy office and TV's are just the icing on top. I would recommend any of my family or friends here!!", 'Great experience with the hygenists and pediatric dentist! Clean office, good with kids and timely. The dentist was great. We didnt get the feeling of being "upsold" like other practices. The staff was friendly enough. \n\nWe will definitely be back!', "My husband and I went to this dentist as a recommendation from my sister-in-law. My husband and I went as we were not happy with our current dentsit. What a blessing! We both went and they worked on both of us and did our work that we needed that day! My husband needed a veneer\\/crown with no scheduling down the road because of finances. I had priced the crown and the out of pocket cost and this office was 1\\/2 of what I was quoted from many different offices in the area. We have great dental insurance which has never been an issue but they were concerned about us and not $$$$$. I appreciate that. Everyone in the office was super nice. I went into the office the next day to get 3 fillings and everyone came by to ask how my husband was and how he liked his work. I don't know about anyone else but I will pay extra for great customer service. Here you get the most caring staff and you can understand why when you meet the dentist, Dentist Kobe. He is great at his craft and he sincerely cares about you. God has blessed him and us by introducing us. I highly recommend this dental office to anyone looking for a dentist in the Phoenix area.", 'Recently called Dental By Design because I was recommended by two different people to go and check them out. I called twice and was told they would call me back to schedule my appointment, no call back. Horrible service.', "Dr. Curry has been my dentist for almost 12 years now. She's the definition of awesomeness. AND everyone there always has a smile! Thanks so much", 'Very nice staff Gentle Hygienist. I love the fact that they will call you to schedule an appointment when they have cancellations that might work for you. Off the top of my head, I can\'t recall that lady\'s name that calls from the office all the time, but I just heart her! She is such a gem. \n\nI\'ve been to dental offices that are far more technologically advanced with headphones, satellite TV, computer applications that the staff uses to show and explain what\'s going on with your teeth. Since I\'m an engineer, I love gadgets and tech stuff. \n \nIn addition, the main reason I did not give this place a four is the fact that I have to ask "what is that instrument used to detect?", "what does that tell us?", "what does that mean?". I absolutely detest going to an office where I\'m paying you to do work for me and I leave there not understanding what you did for me. So, every time I"ve been to this office, I ask those questions. I don\'t like it. the information should flow freely. It helps when your patient is informed and can make better decisions. \n\nAll-in-All, I enjoy my visits here.', "This place is dishonest and I would never recommend anyone I know to go there. They tried to make me feel bad about my teeth in order to pressure me into their services. They were rude and tried to take advantage of me. They lied to me and said I had multiple cavities which I don't according to the dentist I go to now. They had a really nice office and equipment but I wasn't going to pay that for them. I'm going to pay someone to be honest with me and to take good care of my teeth, not just look at what's in my wallet.", "I am not giving them one star based on their work but I up front told them that I had 2000 left on my insurance and 3 days to use it but I was limited on money in the future and wanted to be sure that after the 3 days that I was all paid up. I would pay daily for what the insurance didn't cover. They called the insurance everyday and told me what it would cost for the work that day which I paid in full everyday. The final day I paid them and made sure I owed them nothing and they said I'm all paid up. Then months later they claim I owe them over $300 and after talking to the insurance company they claimed the dentist got a quote on one thing but charged them for something different. Well how is that my fault they knew i was not going to be able to pay a bill later yet they screwed up and charged me for it and now they are giving me bad credit over something that I had no control over. Now had the lady at the desk who called my insurance know what to ask the insurance company I would have gladly paid what I needed to that day. You do not dock the customers for your screw ups and that is not how a business keeps customers. I would recommend looking else where unless you too want to be billed for their incompetence. There are plenty of dentists that will not try to pull the wool over your eyes. Tell you one thing but mean another.", "Dr Kahn and Misty were very clear and explained everything for me. It was a super friendly environment. I'll definitely be making this my dentist.", 'WOW, all they want is $$$$$$$. I sure don\'t mind to pay for what I need but trying to squeeze more money from our insurance AND extra from patient? Not cool! Their rates are much more than most places and their dental chair office is like an assembly line. Horrible experiences and when issues were brought up, they talk down on you or act like it was a "mistype" error.\nI would not recommend this place if it was the last dental office on earth.\nExpensive, overpriced, low quality of work and sanitary.', "Going to the dentist office has always been a source of anguish for me, ever since I was young. Perhaps due to some of the abnormal experiences I had gone through up until now. Not having insurance or being covered by my parents plan. As a result of growing up close to the boarder I often made trips to Mexico for my dentistry work.\n\nMy friend who works there suggested I'd come in for a cleaning and after a few months of reluctancy I finally decided to make an appointment. In all honestly I was visibly uncomfortable which was largely due to the lapse of time between visits.\n\nHowever, Dental By Design made every effort to make me feel as calm and comfortable as possible. The office itself is very clean, the front line were all amazingly nice, pleasant and helpful. They worked with my budget because at the time I was uninsured and I was able to fit into a payment plan that worked best for me.\n\nI'm not sure if bed side manner is a term associated with dentistry but Dr. Curry handled me with care. She spoke me all the through my assessment and made sure I understood her, outlining steps to make moving forward.\n\nI can't say I'll look forward to the dentist again, but I will say I look forward to being in a welcoming, knowledgeable, and safe atmosphere again with individuals who do great work.", 'I live 24 miles away but still come back for Dr. Kode and his lovely assistants. This place does it all. Friendly, latest technology, all levels of service and has great office systems\\/processes in place. They even will provide you a rating of your teeth from 1 to 100, but you have to ask for it. Keep up the good work Dental by Design.', 'I have been seeing Dr Kode since he took over the practice. Have never met a more knowledgeable professional with such great bedside or chair-side manner. He is also a great community member.', "Just like most people, I dislike going to the dentist, but I've been going to Dental by Design for over a year and recently got a crown done. It was fair in price and things went very well through the entire experience. The staff is very friendly and professional and they get you in and out at your set appointment time.", "This dental office is one of the best that I have ever been to. When I first moved to the Phoenix area, I probably tried out about 8 different dentists before finally listening to my co-worker's referral and coming here. Dr. Kode is an awesome dentist. He is highly skilled, and uses up-to-date high tech equipment. Marianne is also the best dental hygienists that I have had. She is thorough, gentle, and very knowledgeable about everything!!! She is also a perfectionist, which I can appreciate-being one myself, and never misses anything. Sheryl, Sandy, and Heather are also fantastic, and always friendly! I don't know the rest of the staff as well, but everyone has always been overly nice. \n\nIn addition to being extremely competent, this dental office is also very patient-oriented. They have events (like movie nights) for patients, and always have really cool giveaways. They are also involved in community events and charity work. \n\nMoreover, they are also very compassionate and considerate of patients who do not have dental insurance. They are very reasonably priced, which is rare for this quality of service. \n\nI moved to California for school almost 2 years ago, but I am still driving to Arizona twice a year just to see these guys! They are totally worth it! I tried out a dentist in California who charged me twice as much for a cleaning and did a horrible job. Being in graduate school, I don't have time or money for trial-and-error dentists, and it's good to know that this office is there for me, even though I have to drive about 350 miles to get there. :)\n\nThank you!!!", 'Our family has been going to Dr.Kode even before he formed Dental by Design, over 10 years now. He\'s been there for us through all our family dental "fun"; from the headgear and braces to the wisdom teeth and deep tissue cleaning. He\'s done it all. His staff from the front desk (shout out to Sandy) to the everyone behind the scenes like his best hygienist, MaryAnn take care of our family like we\'re one of their own. Face it, no one likes going to the dentist. But Dental by Design has made the experience a pleasant one. It\'s also very apparent that they appreciate their patients, as they are always holding events and contests to thank them.\nThank you, Dr,Kode and Dental by Design for making my teeth and my family\'s teeth shine :)', 'Been going to Dr. Kode for 7 years now. Very competent and capable dentist, great problem solver and great communicator. Love his staff - Laci is Da Bomb!\n\nAnd... this is coming from the son of a Dentist!', 'Thereza Wright is the best hygienist ive had the pleasure of working with. She is along with the rest of the staff is awesome, understanding, patient and friendly. They have superior service and down to earth vibes in the office putting you at ease when you are sitting in the chair. Going to the dentist has never been easier when you go to these guys and get handled by the team that works there.', 'I have been going here for 5 years now. The staff is always so friendly and there is always a cold water for you to help yourself to. The confirmation process is so convenient because it is texted and emailed to me. Dr. Kode is a licensed Invisalign dentist so I did Invisalign. He did a wonderful job. I have so much more self confidence now. He even threw in an ipad and free teeth whitening. They always do a great job!', "2 bad experiences I thought I would give them the benefit of the doubt but I actually really hated it. I felt like they wanted my money but the work I had done was shoddy. If you want the experience of a used car dealership with the feeling of just being a number like at the DMV. Try it you may like it. Not for me especially since there is a dentist practically on every corner I don't have to settle.", 'Dr. Kode has been my dentist for about 6 years. As well as receiving my dental check ups and care from him, I also did the Invisalign program with him with great success. His staff is very personable, his office is clean and appointments are always handled professionally. Dr Kode is very knowledgable, personable, kind and gentle with all of his treatments. I have recommended him and his staff to several friends and would recommend them to anyone looking for dentist in the Ahwatukee area.'], 'business_id': 'FYWN1wneV18bWNgQjJ2GNg', 'categories': ['Dentists', 'General Dentistry', 'Health & Medical', 'Oral Surgeons', 'Cosmetic Dentists', 'Orthodontists']} 2 | {'reviews': ["I would be grateful to those of you willing to repeat after me slowly: S T E P H E N NEEDS M E D I C A T E D.\nFor every person who loves him let's him go on with himself as he is is just as bad as him. He is an absolute Monster.", 'My family and I have been going to Stephens for many years . I wish I could give them more than 5 stars. We are always treated with the utmost respect , and our stylist is awesome. She is the best we as a family have ever had. Her color, cut, and advice always makes it a pleasant experience. And LOVE the products...', 'Amanda was an absolute peach! So sweet and talented too! I\'d never gotten my hair professionally colored before and she did an amazing job! She gave me her expert opinion on "sombr\xc3\xa9" (subtle ombr\xc3\xa9) and I just love it. Never once did I feel rushed and she made me feel right at home. We chatted the whole time (I probably talked her ear off haha I just felt so comfortable and was having a great time!). My cut and color are fantastic and I\'m super happy with it all :) I\'m from Pittsburgh but I currently live in NYC. Regardless, I plan to come back to Amanda for my next cut and my second coloring experience! Highly recommend!!!', 'I specifically requested Stephen as I thought since he owned the business and had been in business for as long as he has, he would be excellent at what he does. Before washing my hair we discussed what I wanted, which was a simple trim of my current style. I told him it was important to keep as much length as possible. I then had my hair washed and sat down for him to cut my hair. He prepared to make his first cut and had quite a bit of length in his hand. I asked him if that is what he planned on cutting off and he said yes. I reminded him that I wanted a trim only and wanted to keep my length. He then proceeded to tell me he was not going to cut my hair because I would not let him have space to be creative. He said "I\'ll dry your hair for you though." Really? How nice of him. As I left I asked the two women at the front desk if he does that often and they said "sometimes".', "Paid $115.00 for about 8 highlights that dont even look nice. Then she didn't even blow dry my hair nicely. So I spent over a hundred dollars and had to go home and fix my own hair. I will never go back there again.", 'My daughter recently went to have her hair done at Stephen\'s at South Hills Village. During her visit, one of the owners struck up a conversation with her and proceeded to express his unsolicited opinions on her college choices, making comments that were rude, insulting and totally unprofessional. I called the salon to explain what had happened and let them know how upset we were. A manager called me back to pass along that the owner "said he was sorry". He did not even have the courtesy to respond to us personally. I\'ve been a Stephen\'s customer for almost 15 years and I know some great stylists there. BUT if this is the kind of person that is benefitting from my patronage there - I won\'t be back.', "I have been going to Stephen Szabo Salon for 15+ years and have nothing but great things to say about them. I have gone to several of the stylists and have enjoyed the Brazilian blowout, different color variations, straightening, etc. There are many salon options to choose from in the area, but Stephen's is by far the BEST in the South Hills! Thank you for always taking care of your clients :-)", 'I went here to get my hair done for a wedding with the bridal party. There were 7 girls including the bride...so there were ALOT of us. You would think they would have anticipated that we would need attention and time, but it was utter chaos. It\'s like they hadn\'t known for months we were coming on that particular Saturday morning to get our hair done. It was so rushed and chaotic that none of us, including the bride, enjoyed the experience. \n\nI had even helped them a great deal by straightening my hair beforehand. My hair is really curly and I figured it would be easier to manage and style if I straightened it. I did warn them that I straightened it and not to put any water on it b\\/c the curls would spring up like perennials in Spring. Now to their credit, the woman who was styling my hair didn\'t put any water on it, but she was using hairspray and a curling iron (which I told her may not be necessary) and she made a comment, "wow you weren\'t kidding when you said your hair was curly". b\\/c as she was spraying the hairspray on the "death by fumes" setting and curling the hair (fire hazard anyone?) my hair was holding on to the curls made by the curling iron with such intensity that the hair spray wasn\'t even needed. Umm, did I stutter when I said that my hair is very curly. \n\nPerhaps I\'m being too hard on them, but I had so wanted and expected it to be a good experience (not even great) and I was so disappointed. They didn\'t do anything that I couldn\'t have managed on my own with a curling iron and a large can of Aquanet and a few bobby pins.', "I'm a regular here for many years. Totally reliable and affordable. Many good stylists.", "I have been a client here for almost 12 years now. My stylist is Leigh Ann. She is the best you will find. She accommodates my hectic schedule as best as she can. She also takes care of my daughter and my son who are in their 20's. Leigh Ann has been in the business for more than 30 years. She keeps up to speed with the latest trends, cuts and coloring. Her personality is awesome. They are currently remodeling now and it looks fabulous. I would recommend them very highly.", 'Ok: I have been going to Stephen on\\/off for 10 years. I have lived through the insults and even staff dragging a comb through my shampoo drenched hair. The latest visit revealed that he is now requesting staff not use conditioner after shampooing. what does he expect when detergent is added to color treated hair? "If only I would have use his personalized shampoo!" Mary T, I will gladly and slowly repeat after you: "STEPHEN NEEDS MEDICATED." whew! I feel much better. On a good note: I do like the environment: staff and hairdressers are very friendly.'], 'business_id': 'He-G7vWjzVUysIKrfNbPUQ', 'categories': ['Hair Stylists', 'Hair Salons', "Men's Hair Salons", 'Blow Dry/Out Services', 'Hair Extensions', 'Beauty & Spas']} 3 | --------------------------------------------------------------------------------