├── HAN.py
├── HMCN.py
├── Linear_Model.py
├── Logger_morning.py
├── OHCNN.py
├── OHCNN_fast.py
├── README.md
├── TextCNN.py
├── clean_runs.py
├── conf.py
├── feature_dataset.py
├── features_test.py
├── fig
├── HiLAP_architecture.jpg
└── prediction_animation.gif
├── loadData.py
├── main.py
├── model.py
├── readData_fungo.py
├── readData_nyt.py
├── readData_rcv1.py
├── readData_yelp.py
├── tree.py
├── util.py
└── yelp
├── Taxonomy_100
└── yelp_data_100.csv.sample
/HAN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | """
7 | Hierarchical Attention Networks for Document Classification
8 | https://www.cs.cmu.edu/~hovy/papers/16HLT-hierarchical-attention-networks.pdf
9 | """
10 |
11 |
12 | def batch_matmul_bias(seq, weight, bias, nonlinearity=''):
13 | s = None
14 | bias_dim = bias.size()
15 | for i in range(seq.size(0)):
16 | _s = torch.mm(seq[i], weight)
17 | _s_bias = _s + bias.expand(bias_dim[0], _s.size()[0]).transpose(0, 1)
18 | if nonlinearity == 'tanh':
19 | _s_bias = torch.tanh(_s_bias)
20 | _s_bias = _s_bias.unsqueeze(0)
21 | if s is None:
22 | s = _s_bias
23 | else:
24 | s = torch.cat((s, _s_bias), 0)
25 | return s.squeeze()
26 |
27 |
28 | def batch_matmul(seq, weight, nonlinearity=''):
29 | s = None
30 | for i in range(seq.size(0)):
31 | _s = torch.mm(seq[i], weight)
32 | if nonlinearity == 'tanh':
33 | _s = torch.tanh(_s)
34 | _s = _s.unsqueeze(0)
35 | if s is None:
36 | s = _s
37 | else:
38 | s = torch.cat((s, _s), 0)
39 | return s.squeeze()
40 |
41 |
42 | def attention_mul(rnn_outputs, att_weights):
43 | attn_vectors = None
44 | for i in range(rnn_outputs.size(0)):
45 | h_i = rnn_outputs[i]
46 | a_i = att_weights[i].unsqueeze(1).expand_as(h_i)
47 | h_i = a_i * h_i
48 | h_i = h_i.unsqueeze(0)
49 | if attn_vectors is None:
50 | attn_vectors = h_i
51 | else:
52 | attn_vectors = torch.cat((attn_vectors, h_i), 0)
53 | return torch.sum(attn_vectors, 0).unsqueeze(0)
54 |
55 |
56 | class HAN(nn.Module):
57 | def __init__(self, args, n_classes, word_vec=None, num_tokens=None, embed_size=None):
58 | super(HAN, self).__init__()
59 | self.args = args
60 | if word_vec is None:
61 | assert num_tokens is not None and embed_size is not None
62 | self.num_tokens = num_tokens
63 | self.embed_size = embed_size
64 | else:
65 | self.num_tokens = word_vec.shape[0]
66 | self.embed_size = word_vec.shape[1]
67 | self.word_gru_hidden = args.word_gru_hidden_size
68 | self.lookup = nn.Embedding(self.num_tokens, self.embed_size)
69 | if args.pretrained_word_embed:
70 | self.lookup.weight = nn.Parameter(torch.from_numpy(word_vec).float())
71 | self.lookup.weight.requires_grad = args.update_word_embed
72 | self.word_gru = nn.GRU(self.embed_size, self.word_gru_hidden, bidirectional=True)
73 | self.weight_W_word = nn.Parameter(torch.Tensor(2 * self.word_gru_hidden, 2 * self.word_gru_hidden))
74 | self.bias_word = nn.Parameter(torch.Tensor(2 * self.word_gru_hidden, 1))
75 | self.weight_proj_word = nn.Parameter(torch.Tensor(2 * self.word_gru_hidden, 1))
76 | nn.init.uniform(self.weight_W_word, -0.1, 0.1)
77 | nn.init.uniform(self.bias_word, -0.1, 0.1)
78 | nn.init.uniform(self.weight_proj_word, -0.1, 0.1)
79 | # sentence level
80 | self.sent_gru_hidden = args.sent_gru_hidden_size
81 | self.word_gru_hidden = args.word_gru_hidden_size
82 | self.sent_gru = nn.GRU(2 * self.word_gru_hidden, self.sent_gru_hidden, bidirectional=True)
83 | self.weight_W_sent = nn.Parameter(torch.Tensor(2 * self.sent_gru_hidden, 2 * self.sent_gru_hidden))
84 | self.bias_sent = nn.Parameter(torch.Tensor(2 * self.sent_gru_hidden, 1))
85 | self.weight_proj_sent = nn.Parameter(torch.Tensor(2 * self.sent_gru_hidden, 1))
86 | C = n_classes
87 | self.fc1 = nn.Linear(2 * self.sent_gru_hidden, C)
88 | nn.init.uniform(self.bias_sent, -0.1, 0.1)
89 | nn.init.uniform(self.weight_W_sent, -0.1, 0.1)
90 | nn.init.uniform(self.weight_proj_sent, -0.1, 0.1)
91 |
92 | def forward(self, mini_batch, fc=False):
93 | max_sents, batch_size, max_tokens = mini_batch.size()
94 | word_attn_vectors = None
95 | state_word = self.init_hidden(mini_batch.size()[1])
96 | for i in range(max_sents):
97 | embed = mini_batch[i, :, :].transpose(0, 1)
98 | embedded = self.lookup(embed)
99 | output_word, state_word = self.word_gru(embedded, state_word)
100 | word_squish = batch_matmul_bias(output_word, self.weight_W_word, self.bias_word, nonlinearity='tanh')
101 | # logger.debug(word_squish.size()) torch.Size([20, 2, 200])
102 | word_attn = batch_matmul(word_squish, self.weight_proj_word)
103 | # logger.debug(word_attn.size()) torch.Size([20, 2])
104 | word_attn_norm = F.softmax(word_attn.transpose(1, 0), dim=-1)
105 | word_attn_vector = attention_mul(output_word, word_attn_norm.transpose(1, 0))
106 | if word_attn_vectors is None:
107 | word_attn_vectors = word_attn_vector
108 | else:
109 | word_attn_vectors = torch.cat((word_attn_vectors, word_attn_vector), 0)
110 | # logger.debug(word_attn_vectors.size()) torch.Size([1, 2, 200])
111 | state_sent = self.init_hidden(mini_batch.size()[1])
112 | output_sent, state_sent = self.sent_gru(word_attn_vectors, state_sent)
113 | # logger.debug(output_sent.size()) torch.Size([8, 2, 200])
114 | sent_squish = batch_matmul_bias(output_sent, self.weight_W_sent, self.bias_sent, nonlinearity='tanh')
115 | # logger.debug(sent_squish.size()) torch.Size([8, 2, 200])
116 | if len(sent_squish.size()) == 2:
117 | sent_squish = sent_squish.unsqueeze(0)
118 | sent_attn = batch_matmul(sent_squish, self.weight_proj_sent)
119 | if len(sent_attn.size()) == 1:
120 | sent_attn = sent_attn.unsqueeze(0)
121 | # logger.debug(sent_attn.size()) torch.Size([8, 2])
122 | sent_attn_norm = F.softmax(sent_attn.transpose(1, 0), dim=-1)
123 | # logger.debug(sent_attn_norm.size()) torch.Size([2, 8])
124 | sent_attn_vectors = attention_mul(output_sent, sent_attn_norm.transpose(1, 0))
125 | # logger.debug(sent_attn_vectors.size()) torch.Size([1, 2, 200])
126 | x = sent_attn_vectors.squeeze(0)
127 | if fc:
128 | x = self.fc1(x)
129 | return x
130 |
131 | def init_hidden(self, batch_size, hidden_dim=None):
132 | if hidden_dim is None:
133 | hidden_dim = self.sent_gru_hidden
134 | if self.args.gpu:
135 | return Variable(torch.zeros(2, batch_size, hidden_dim)).cuda()
136 | return Variable(torch.zeros(2, batch_size, hidden_dim))
137 |
--------------------------------------------------------------------------------
/HMCN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from IPython import embed
6 |
7 | """
8 | Hierarchical Multi-Label Classification Networks
9 | http://proceedings.mlr.press/v80/wehrmann18a.html
10 | """
11 |
12 |
13 | class HMCN(nn.Module):
14 | def __init__(self, base_model, args, neuron_each_local_l2, total_class, in_dim):
15 | super(HMCN, self).__init__()
16 |
17 | neuron_each_layer = [384] * len(neuron_each_local_l2)
18 | neuron_each_local_l1 = [384] * len(neuron_each_local_l2)
19 | self.beta = 0.5
20 |
21 | self.args = args
22 | self.base_model = base_model
23 |
24 | self.layer_num = len(neuron_each_layer)
25 | self.linear_layers = nn.ModuleList([])
26 | self.local_linear_l1 = nn.ModuleList([])
27 | self.local_linear_l2 = nn.ModuleList([])
28 | self.batchnorms = nn.ModuleList([])
29 | self.batchnorms_local_1 = nn.ModuleList([])
30 | for idx, neuron_number in enumerate(neuron_each_layer):
31 | if idx == 0:
32 | self.linear_layers.append(nn.Linear(in_dim, neuron_number))
33 | else:
34 | self.linear_layers.append(
35 | nn.Linear(neuron_each_layer[idx - 1] + in_dim, neuron_number))
36 | self.batchnorms.append(nn.BatchNorm1d(neuron_number))
37 |
38 | for idx, neuron_number in enumerate(neuron_each_local_l1):
39 | self.local_linear_l1.append(
40 | nn.Linear(neuron_each_layer[idx], neuron_each_local_l1[idx]))
41 | self.batchnorms_local_1.append(
42 | nn.BatchNorm1d(neuron_each_local_l1[idx]))
43 | for idx, neuron_number in enumerate(neuron_each_local_l2):
44 | self.local_linear_l2.append(
45 | nn.Linear(neuron_each_local_l1[idx], neuron_each_local_l2[idx]))
46 |
47 | self.final_linear_layer = nn.Linear(
48 | neuron_each_layer[-1] + in_dim, total_class)
49 |
50 | def forward(self, x):
51 | x = self.base_model(x, False)
52 | local_outputs = []
53 | output = x
54 | for layer_idx, layer in enumerate(self.linear_layers):
55 | if layer_idx == 0:
56 | output = layer(output)
57 | output = F.relu(output)
58 | else:
59 | output = layer(torch.cat([output, x], dim=1))
60 | output = F.relu(output)
61 |
62 | local_output = self.local_linear_l1[layer_idx](output)
63 | local_output = F.relu(local_output)
64 | local_output = self.local_linear_l2[layer_idx](local_output)
65 | local_outputs.append(local_output)
66 |
67 | global_outputs = F.sigmoid(
68 | self.final_linear_layer(torch.cat([output, x], dim=1)))
69 | local_outputs = F.sigmoid(torch.cat(local_outputs, dim=1))
70 |
71 | output = self.beta * global_outputs + (1 - self.beta) * local_outputs
72 |
73 | return output
74 |
--------------------------------------------------------------------------------
/Linear_Model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | """
5 | Linear model used for feature input (instead of text)
6 | """
7 |
8 |
9 | class Linear_Model(nn.Module):
10 | def __init__(self, args, input_dim, n_classes):
11 | super(Linear_Model, self).__init__()
12 | self.dropout = nn.Dropout(args.dropout)
13 | self.fc1 = nn.Linear(input_dim, args.n_hidden)
14 | self.fc2 = nn.Linear(args.n_hidden, n_classes)
15 |
16 | def forward(self, x, fc=False):
17 | # x = self.dropout(x)
18 | x = F.relu(self.fc1(x))
19 | if fc:
20 | x = self.fc2(x)
21 | return x
22 |
--------------------------------------------------------------------------------
/Logger_morning.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def myLogger(name=None, log_path='./tmp.log'):
5 | logger = logging.getLogger(name)
6 | if len(logger.handlers) != 0:
7 | print(f'[Logger_morning.py]reuse logger:{name}')
8 | return logger
9 | logging.basicConfig(level=logging.DEBUG,
10 | format='%(asctime)s %(filename)s[line:%(lineno)d][%(funcName)s] %(levelname)s-> %(message)s',
11 | datefmt='%a %d %b %Y %H:%M:%S', filename=log_path, filemode='a')
12 | # define a new Handler to log to console as well
13 | console = logging.StreamHandler()
14 | # set a format which is the same for console use
15 | formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d][%(funcName)s] %(levelname)s-> %(message)s',
16 | datefmt='%a %d %b %Y %H:%M:%S')
17 | # tell the handler to use this format
18 | console.setFormatter(formatter)
19 | # add the handler to the root logger
20 | logger.addHandler(console)
21 | logger.info(f'create new logger:{name}')
22 | logger.info(f'saving log to {log_path}')
23 | return logger
24 |
25 |
26 | if __name__ == '__main__':
27 | logger = myLogger('111', log_path='runs/')
28 | logger.setLevel(10)
29 | logger.info('This is info message')
30 | logger.warning('This is warning message')
31 |
32 | logger = myLogger('111')
33 | logger.setLevel(10)
34 | logger.info('T2')
35 | logger.warning('This2')
36 |
--------------------------------------------------------------------------------
/OHCNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """
6 | Effective Use of Word Order for Text Categorization with Convolutional Neural Networks
7 | http://www.anthology.aclweb.org/N/N15/N15-1011.pdf
8 | """
9 |
10 |
11 | class OHCNN(nn.Module):
12 |
13 | def __init__(self, args, n_classes):
14 | super(OHCNN, self).__init__()
15 | self.args = args
16 | D = 30000
17 | C = n_classes
18 | Ci = 1
19 | Co = 1000
20 | self.Co = Co
21 | self.n_pool = 10
22 | if args.mode == 'ohcnn-seq':
23 | Ks = [3]
24 | self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D), stride=1, padding=(K - 1, 0)) for K in Ks])
25 | else:
26 | Ks = [1]
27 | self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D), stride=1) for K in Ks])
28 | self.dropout = nn.Dropout(0.5)
29 | # self.lrn = nn.LocalResponseNorm(2)
30 | self.fc1 = nn.Linear(len(Ks) * Co * self.n_pool, C)
31 |
32 | def forward(self, x):
33 | x = x.unsqueeze(1) # (N, Ci, W, D)
34 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] # [(N, Co, W), ...]*len(Ks)
35 | # x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N, Co), ...]*len(Ks)
36 | # [(N, Co * n_pool), ...]*len(Ks)
37 | x = [F.avg_pool1d(i, int(i.size(2) / self.n_pool)).view(-1, self.n_pool * self.Co) for i in x]
38 | x = torch.cat(x, 1)
39 | x = self.dropout(x) # (N, len(Ks)*Co)
40 | # response norm
41 | x /= (1 + x.pow(2).sum(1)).sqrt().view(-1, 1)
42 | logit = self.fc1(x) # (N, C)
43 | return logit
44 |
--------------------------------------------------------------------------------
/OHCNN_fast.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | """
8 | Effective Use of Word Order for Text Categorization with Convolutional Neural Networks
9 | http://www.anthology.aclweb.org/N/N15/N15-1011.pdf
10 | use equivalent embeddings for faster speed (2~3x)
11 | """
12 |
13 |
14 | class OHCNN_fast(nn.Module):
15 |
16 | def __init__(self, unk_idx, n_classes, vocab_size):
17 | super(OHCNN_fast, self).__init__()
18 | # D = 30001
19 | print(f'vocab_size:{vocab_size}')
20 | D = vocab_size
21 | C = n_classes
22 | Co = 1000
23 | self.Co = Co
24 | self.n_pool = 10
25 | self.embed = nn.Embedding(D, Co)
26 | self.bias = nn.Parameter(torch.Tensor(1, Co, 1))
27 | self.dropout = nn.Dropout(0.5)
28 | self.fc1 = nn.Linear(Co * self.n_pool, C)
29 | self.unk_idx = unk_idx
30 | # init as in cnn
31 | stdv = 1. / math.sqrt(D)
32 | self.embed.weight.data.uniform_(-stdv, stdv)
33 | if self.bias is not None:
34 | self.bias.data.uniform_(-stdv, stdv)
35 |
36 | def forward(self, x, fc=False):
37 | # (32, 256, 20)
38 | sent_len = x.size(1)
39 | x = x.view(x.size(0), -1)
40 | x_embed = self.embed(x) # (N, W * D, Co)
41 | # deal with unk in the region
42 | x = (x != self.unk_idx).float().unsqueeze(-1) * x_embed
43 | x = x.view(x.size(0), sent_len, -1, self.Co) # (N, W, D, Co)
44 | x = F.relu(x.sum(2).permute(0, 2, 1) + self.bias) # (N, Co, W)
45 | x = F.avg_pool1d(x, int(x.size(2) / self.n_pool)).view(-1, self.n_pool * self.Co) # (N, n_pool * Co)
46 | x = self.dropout(x)
47 | # response norm
48 | x /= (1 + x.pow(2).sum(1)).sqrt().view(-1, 1)
49 | if fc:
50 | x = self.fc1(x) # (N, C)
51 | return x
52 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repo provides the code with paper ["Hierarchical Text Classification with Reinforced Label Assignment"](https://arxiv.org/abs/1908.10419) EMNLP 2019.
2 |
3 |
4 |
5 |
6 |
7 | ## Abstract
8 |
9 | While existing hierarchical text classification (HTC) methods attempt to capture label hierarchies for model training, they either make local decisions regarding each label or completely ignore the hierarchy information during inference. To solve the mismatch between training and inference as well as modeling label dependencies in a more principled way, we formulate HTC as a Markov decision process and propose to learn a **L**abel **A**ssignment **P**olicy via deep reinforcement learning to determine *where to place* an object and *when to stop* the assignment process. The proposed method, **HiLAP**, explores the hierarchy during both training and inference time in a *consistent* manner and makes *inter-dependent* decisions. As a general framework, HiLAP can incorporate different neural encoders as *base models* for end-to-end training. Experiments on five public datasets and four base models show that HiLAP yields an average improvement of 33.4% in Macro-F1 over flat classifiers and outperforms state-of-the-art HTC methods by a large margin.
10 |
11 | ## Model
12 |
13 | `model.py`: The main model of HiLAP.
14 |
15 | `TextCNN.py`: Our implementation of "Convolutional Neural Networks for Sentence Classification" EMNLP 2014.
16 |
17 | `OHCNN(_fast).py`: Our implementation of "Effective Use of Word Order for Text Categorization with Convolutional Neural Networks" NAACL 2015.
18 |
19 | `HAN.py`: Our implementation of "Hierarchical Attention Networks for Document Classification" NAACL 2016.
20 |
21 | `HMCN.py`: Our implementation of "Hierarchical Multi-Label Classification Networks" ICML 2018.
22 |
23 | ## Requirements
24 |
25 | Python **3**
26 |
27 | PyTorch **0.3**
28 |
29 | ## Data
30 |
31 | Due to copyright issues, we can't directly release the datasets used in our experiments.
32 | Instead, we provide the links to the five data sources (the first two may require license):
33 |
34 | - RCV1 [original release](http://www.ai.mit.edu/projects/jmlr/papers/volume5/lewis04a/lyrl2004_rcv1v2_README.htm), [text data](https://trec.nist.gov/data/reuters/reuters.html) (**update:** download the text data and convert to docs.txt with format "docid content")
35 | - [NYT](https://catalog.ldc.upenn.edu/LDC2008T19)
36 | - [Yelp](https://www.yelp.com/dataset/challenge) (**update:** the latest release is different from what we used, pls send an email if you need the version we used)
37 | - [FunGO](https://dtai.cs.kuleuven.be/clus/hmcdatasets/)
38 |
39 | Please check `readData_*.py` to see how to use our scripts to process and generate the datasets from the original data.
40 |
41 | ## Run
42 | All the parameters in `conf.py` have default values. Change parameters `mode`, `base_model`, and `dataset` and then run `main.py` to train or test on different settings. To test a model, set `load_model=model_file` & `is_Train=False` in `conf.py` and run `main.py`.
43 |
44 | ## Cite
45 |
46 | ```
47 | @inproceedings{mao-etal-2019-hierarchical,
48 | title = "Hierarchical Text Classification with Reinforced Label Assignment",
49 | author = "Mao, Yuning and
50 | Tian, Jingjing and
51 | Han, Jiawei and
52 | Ren, Xiang",
53 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
54 | month = nov,
55 | year = "2019",
56 | address = "Hong Kong, China",
57 | publisher = "Association for Computational Linguistics",
58 | url = "https://www.aclweb.org/anthology/D19-1042",
59 | doi = "10.18653/v1/D19-1042",
60 | pages = "445--455",
61 | }
62 | ```
63 |
64 |
--------------------------------------------------------------------------------
/TextCNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """
6 | Convolutional Neural Networks for Sentence Classification
7 | http://www.aclweb.org/anthology/D14-1181
8 | """
9 |
10 |
11 | class TextCNN(nn.Module):
12 | def __init__(self, args, word_vec, n_classes):
13 | super(TextCNN, self).__init__()
14 | # V = args.embed_num
15 | # D = args.embed_dim
16 | # C = args.class_num
17 | # Ci = 1
18 | # Co = args.kernel_num
19 | # Ks = args.kernel_sizes
20 | V = word_vec.shape[0]
21 | D = word_vec.shape[1]
22 | C = n_classes
23 | Ci = 1
24 | self.Co = 1000
25 | Ks = [3, 4, 5]
26 |
27 | self.embed = nn.Embedding(V, D)
28 | if args.pretrained_word_embed:
29 | self.embed.weight = nn.Parameter(torch.from_numpy(word_vec).float())
30 | self.embed.weight.requires_grad = args.update_word_embed
31 | self.convs1 = nn.ModuleList([nn.Conv2d(Ci, self.Co, (K, D)) for K in Ks])
32 | self.dropout = nn.Dropout(0.)
33 | self.fc1 = nn.Linear(len(Ks) * self.Co, C)
34 |
35 | def forward(self, x, fc=False):
36 | x = self.embed(x) # (N, W, D)
37 | x = x.unsqueeze(1) # (N, Ci, W, D)
38 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] # [(N, Co, W), ...]*len(Ks)
39 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N, Co), ...]*len(Ks)
40 | x = torch.cat(x, 1)
41 | x = self.dropout(x) # (N, len(Ks)*Co)
42 | if fc:
43 | x = self.fc1(x) # (N, C)
44 | return x
45 |
--------------------------------------------------------------------------------
/clean_runs.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 |
4 | """
5 | remove folders under runs/ by folder size and date
6 | """
7 |
8 |
9 | def get_size(start_path='.'):
10 | total_size = 0
11 | for dirpath, dirnames, filenames in os.walk(start_path):
12 | for f in filenames:
13 | fp = os.path.join(dirpath, f)
14 | total_size += os.path.getsize(fp)
15 | return total_size
16 |
17 |
18 | now = datetime.datetime.now().strftime('%b%d')
19 | for dirpath, dirnames, filenames in os.walk('runs2/'):
20 | for dirname in dirnames:
21 | path = os.path.join(dirpath, dirname)
22 | size = get_size(path)
23 | # print(path, size)
24 | # rm folder that < 100KB and not today
25 | if (size < 100000 or 'del' in path) and now not in path:
26 | print('rm {} with size {}KB'.format(path, size / 1000))
27 | os.system("rm -rf '{}'".format(path))
28 |
29 | for filename in filenames:
30 | path = os.path.join(dirpath, filename)
31 | if 'del' in path and now not in path:
32 | print(f'rm {path}')
33 | os.system("rm -rf '{}'".format(path))
34 | break
35 |
--------------------------------------------------------------------------------
/conf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def print_config(config, logger=None):
5 | config = vars(config)
6 | info = "Running with the following configs:\n"
7 | for k, v in config.items():
8 | info += "\t{} : {}\n".format(k, str(v))
9 | if not logger:
10 | print("\n" + info + "\n")
11 | else:
12 | logger.info("\n" + info + "\n")
13 |
14 |
15 | def conf():
16 | ap = argparse.ArgumentParser()
17 | # change the parameters to test different method/base_model/dataset combinations
18 | ap.add_argument('--mode', default='sl', choices=['hilap', 'sl', 'hilap-sl', 'hmcn'])
19 | ap.add_argument('--base_model', default='textcnn', choices=['han', 'textcnn', 'ohcnn-bow-fast', 'raw'])
20 | ap.add_argument('--dataset', default='rcv1', choices=['yelp', 'rcv1', 'nyt', 'cellcycle_FUN'])
21 | ap.add_argument('--isTrain', default=False, help='True for continuing training')
22 | ap.add_argument('--load_model', default=None)
23 | ap.add_argument('--remark', default='del', help='reminder of this run')
24 |
25 | # most of following parameters do not need any changes
26 | ap.add_argument('--lr', default=1e-3, help='learning rate 1e-3 for ohcnn, textcnn, 1e-1 for han')
27 | ap.add_argument('--l2_weight', default=1e-6, help='weight decay of optimizer')
28 | ap.add_argument('--save_every', default=10, help='evaluate and save model every k epochs')
29 | ap.add_argument('--num_epoch', default=50)
30 | ap.add_argument('--word_gru_hidden_size', default=50)
31 | ap.add_argument('--sent_gru_hidden_size', default=50)
32 | ap.add_argument('--hist_embed_size', default=50)
33 | ap.add_argument('--update_beta_every', default=500)
34 | ap.add_argument('--pretrained_word_embed', default=True)
35 | ap.add_argument('--update_word_embed', default=True)
36 | ap.add_argument('--allow_stay', default=True,
37 | help='if sample_mode=random, has to be False in case select prob=0->nan')
38 | ap.add_argument('--sample_mode', default='normal', choices=['choose_max', 'random', 'normal'])
39 | ap.add_argument('--batch_size', default=32)
40 | ap.add_argument('--batch_size_test', default=32)
41 | ap.add_argument('--log_level', default=20)
42 | ap.add_argument('--stat_check', default=False, help='calculate and print some stats of the data')
43 | ap.add_argument('--max_tokens', default=256, help='max size of tokens')
44 | ap.add_argument('--debug', default=False, help='if True, run some epochs on the FIRST batch')
45 | ap.add_argument('--gamma', default=.9, help='discounting factor')
46 | ap.add_argument('--beta', default=2, help='weight of entropy')
47 | ap.add_argument('--use_cur_class_embed', default=True, help='add embedding of current class to state embedding')
48 | ap.add_argument('--use_l1', default=True)
49 | ap.add_argument('--use_l2', default=True, help='only valid when use_l1=True')
50 | ap.add_argument('--l1_size', default=500, help='output size of l1. only valid when use_l2=True')
51 | ap.add_argument('--class_embed_size', default=50)
52 | ap.add_argument('--softmax', default=True, choices='softmax or sigmoid')
53 | ap.add_argument('--sl_ratio', default=1, help='[0 = off] for rl-taxo: sl_loss = rl_loss + sl_ratio * sl_loss')
54 | ap.add_argument('--global_ratio', default=0.5,
55 | help='[0 = off]for step-sl: sl_loss = (1-global_ratio) * local_loss + global_ratio * global_loss')
56 | ap.add_argument('--gpu', default=True)
57 | ap.add_argument('--n_rollouts', default=1)
58 | ap.add_argument('--reward', default='f1', choices=['01', '01-1', 'f1', 'direct', 'taxo'])
59 | ap.add_argument('--early_stop', default=False, help='for rl-taxo only')
60 | ap.add_argument('--baseline', default='greedy', choices=[None, 'avg', 'greedy'])
61 | ap.add_argument('--avg_reward_mode', default='batch', choices=['off', 'each', 'batch'],
62 | help='if n_step=1, cannot be each ->nan')
63 | ap.add_argument('--min_mem', default=3000, help='minimum gpu memory requirement (MB)')
64 | # outdated parameters
65 | ap.add_argument('--allow_up', default=False, help='not used anymore')
66 | ap.add_argument('--use_history', default=False, help='not used anymore')
67 | ap.add_argument('--multi_label', default=True, help='whether predict multi labels. valid for sl and step-sl')
68 | ap.add_argument('--split_multi', default=False,
69 | help='split one sample with m labels to m samples, only valid when filter_ancestors=True')
70 | ap.add_argument('--mix_flat_probs', default=False, help='add flat prob to help rl')
71 | args = ap.parse_args()
72 | args.use_history &= (args.mode == 'rl')
73 | if args.dataset == 'rcv1':
74 | args.ohcnn_data = '_rcv1_len256_padded' # where to load ohcnn cache
75 | args.n_steps_sl = 4 # steps of step-sl
76 | args.n_steps = 17 # steps of rl
77 | args.output_every = 723 # output metrics every k batches
78 | elif args.dataset == 'yelp':
79 | args.ohcnn_data = '_yelp_root_100_5_10_len256_padded'
80 | args.n_steps_sl = 4
81 | args.n_steps = 10
82 | args.output_every = 2730
83 | elif args.dataset == 'nyt':
84 | args.ohcnn_data = '_nyt'
85 | args.n_steps_sl = 3
86 | args.n_steps = 20
87 | args.output_every = 789
88 | elif 'FUN' in args.dataset:
89 | args.ohcnn_data = '_cellcycle_FUN_root_padded'
90 | args.n_steps_sl = 6
91 | args.n_steps = 45
92 | args.output_every = 100
93 | args.n_hidden = 150
94 | args.l1_size = 1000
95 | args.class_embed_size = 1000
96 | args.dropout = 0
97 | elif 'GO' in args.dataset:
98 | args.ohcnn_data = '_cellcycle_GO_root_padded'
99 | args.n_steps_sl = 14
100 | args.n_steps = 45
101 | args.output_every = 100
102 | args.n_hidden = 150
103 | args.l1_size = 1000
104 | args.class_embed_size = 1000
105 | args.dropout = 0
106 | if args.n_steps == 1:
107 | args.avg_reward_mode = 'off'
108 | if args.mode in ['sl', 'hmcn']:
109 | args.filter_ancestors = False # 'only use lowest labels as gold for training'
110 | else:
111 | args.filter_ancestors = True
112 | return args
113 |
--------------------------------------------------------------------------------
/feature_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from sklearn.datasets import fetch_rcv1
4 | from sklearn.feature_extraction.text import TfidfVectorizer
5 | from sklearn.preprocessing import MultiLabelBinarizer
6 | from torch.utils.data import Dataset
7 |
8 | from readData_fungo import read_fungo
9 | from readData_nyt import read_nyt
10 | from readData_yelp import read_yelp
11 |
12 |
13 | def rcv1_test(rcv1):
14 | X_train = rcv1.data[:23149]
15 | Y_train = rcv1.target[:23149]
16 | X_test = rcv1.data[23149:]
17 | Y_test = rcv1.target[23149:]
18 | return X_train, Y_train, X_test, Y_test
19 |
20 |
21 | def yelp_test():
22 | subtree_name = 'root'
23 | X_train, X_test, train_ids, test_ids, business_dict, nodes = read_yelp(subtree_name, 5, 10)
24 | print(f'#training={len(train_ids)} #test={len(test_ids)}')
25 | n_tokens = 256
26 | print(f'use only first {n_tokens} tokens')
27 | X_train = [' '.join(i.split()[:n_tokens]) for i in X_train]
28 | X_test = [' '.join(i.split()[:n_tokens]) for i in X_test]
29 | print('fit_transform...')
30 | tf = TfidfVectorizer()
31 | X_train = tf.fit_transform(X_train)
32 | X_test = tf.transform(X_test)
33 | Y_train = [business_dict[bid]['categories'] for bid in train_ids]
34 | Y_test = [business_dict[bid]['categories'] for bid in test_ids]
35 | mlb = MultiLabelBinarizer()
36 | Y_train = mlb.fit_transform(Y_train)
37 | Y_test = mlb.transform(Y_test)
38 | return X_train, Y_train, X_test, Y_test, train_ids, test_ids
39 |
40 |
41 | def nyt_test():
42 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_nyt()
43 | print(f'#training={len(train_ids)} #test={len(test_ids)}')
44 | n_tokens = 256
45 | print(f'use only first {n_tokens} tokens')
46 | X_train = [' '.join(i.split()[:n_tokens]) for i in X_train]
47 | X_test = [' '.join(i.split()[:n_tokens]) for i in X_test]
48 | print('fit_transform...')
49 | tf = TfidfVectorizer()
50 | X_train = tf.fit_transform(X_train)
51 | X_test = tf.transform(X_test)
52 | Y_train = [id2doc[bid]['categories'] for bid in train_ids]
53 | Y_test = [id2doc[bid]['categories'] for bid in test_ids]
54 | mlb = MultiLabelBinarizer()
55 | Y_train = mlb.fit_transform(Y_train)
56 | Y_test = mlb.transform(Y_test)
57 | return X_train, Y_train, X_test, Y_test, train_ids, test_ids
58 |
59 |
60 | def fungo_test(data_name):
61 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_fungo(data_name)
62 | Y_train = [id2doc[bid]['categories'] for bid in train_ids]
63 | Y_test = [id2doc[bid]['categories'] for bid in test_ids]
64 | # Actually here Y is not used. We use id2doc for labels.
65 | mlb = MultiLabelBinarizer()
66 | Y = mlb.fit_transform(np.concatenate([Y_train, Y_test]))
67 | Y_train = Y[:len(Y_train)]
68 | Y_test = Y[-len(Y_test):]
69 |
70 | return X_train, Y_train, X_test, Y_test, train_ids, test_ids
71 |
72 |
73 | def my_collate(batch):
74 | features = torch.FloatTensor([item[0] for item in batch])
75 | labels = [item[1] for item in batch]
76 | return [features, labels]
77 |
78 |
79 | class featureDataset(Dataset):
80 | def __init__(self, data_name, train=True):
81 | self.train = train
82 | self.data = data_name
83 | if data_name == 'rcv1':
84 | self.rcv1 = fetch_rcv1()
85 | X_train, Y_train, X_test, Y_test = rcv1_test(self.rcv1)
86 | if train:
87 | self.samples = X_train
88 | else:
89 | self.samples = X_test
90 | else:
91 | if data_name == 'yelp':
92 | X_train, Y_train, X_test, Y_test, train_ids, test_ids = yelp_test()
93 | elif data_name == 'nyt':
94 | X_train, Y_train, X_test, Y_test, train_ids, test_ids = nyt_test()
95 | else:
96 | X_train, Y_train, X_test, Y_test, train_ids, test_ids = fungo_test(data_name)
97 | if train:
98 | self.samples = X_train
99 | self.ids = train_ids
100 | else:
101 | self.samples = X_test
102 | self.ids = test_ids
103 |
104 | def __len__(self):
105 | if 'FUN' in self.data or 'GO' in self.data:
106 | return len(self.samples)
107 | if self.data == 'fungo':
108 | return len(self.samples)
109 | return self.samples.shape[0]
110 |
111 | def __getitem__(self, item):
112 | if self.data == 'rcv1':
113 | if self.train:
114 | vector, label = self.samples[item].todense().tolist()[0], str(int(self.rcv1.sample_id[item]))
115 | else:
116 | vector, label = self.samples[item].todense().tolist()[0], str(int(self.rcv1.sample_id[item + 23149]))
117 | elif self.data in ['yelp', 'nyt']:
118 | vector, label = self.samples[item].todense().tolist()[0], self.ids[item].strip().strip('\n')
119 | elif 'FUN' in self.data or 'GO' in self.data or self.data == 'fungo':
120 | vector, label = self.samples[item], self.ids[item]
121 | return vector, label
122 |
--------------------------------------------------------------------------------
/features_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.feature_extraction.text import TfidfVectorizer
3 | from sklearn.metrics import f1_score
4 | from sklearn.multiclass import OneVsRestClassifier
5 | from sklearn.preprocessing import MultiLabelBinarizer
6 | from sklearn.svm import LinearSVC
7 |
8 | from conf import conf
9 | from feature_dataset import fungo_test
10 | from loadData import load_data_rcv1
11 | from readData_fungo import read_fungo
12 | from readData_yelp import read_yelp
13 | from tree import Tree
14 |
15 |
16 | def evaluate(pred, Y):
17 | print(f1_score(Y, pred, average='micro'))
18 | print(f1_score(Y, pred, average='macro'))
19 | print(f1_score(Y, pred, average='samples'))
20 | # print(classification_report(Y, pred))
21 |
22 |
23 | def print_some(pred, Y, k=500):
24 | pred_tmp = pred[:k].todense().tolist()
25 | from sklearn.datasets import fetch_rcv1
26 | rcv1 = fetch_rcv1()
27 | for tmp, y in zip(pred_tmp, Y.todense().tolist()):
28 | for i in range(len(tmp)):
29 | if tmp[i] == 1:
30 | print(rcv1.target_names[i], end=' ')
31 | print()
32 | for i in range(len(tmp)):
33 | if y[i] == 1:
34 | print(rcv1.target_names[i], end=' ')
35 | print('---')
36 |
37 |
38 | def run(X_train, Y_train, X_test, Y_test):
39 | print('start training')
40 | model = OneVsRestClassifier(LinearSVC(loss='hinge'), n_jobs=5)
41 | # model = OneVsRestClassifier(LogisticRegression(solver='lbfgs', multi_class='multinomial'))
42 | model.fit(X_train, Y_train)
43 | pred = model.predict(X_train)
44 | print('eval training')
45 | # print_some(pred, Y_train, 50)
46 | evaluate(pred, Y_train)
47 | print('eval testing')
48 | pred = model.predict(X_test)
49 | evaluate(pred, Y_test)
50 | # print_some(pred, Y_test, 50)
51 |
52 |
53 | def rcv1_test():
54 | from sklearn.datasets import fetch_rcv1
55 | rcv1 = fetch_rcv1()
56 | X_train = rcv1.data[:23149]
57 | Y_train = rcv1.target[:23149]
58 | X_test = rcv1.data[23149:]
59 | Y_test = rcv1.target[23149:]
60 | print(Y_train[:2])
61 | print(rcv1.target_names[34], rcv1.target_names[59])
62 | return X_train, Y_train, X_test, Y_test
63 |
64 |
65 | def yelp_test():
66 | subtree_name = 'root'
67 | X_train, X_test, train_ids, test_ids, business_dict, nodes = read_yelp(subtree_name, 5, 10)
68 | print(f'#training={len(train_ids)} #test={len(test_ids)}')
69 | n_tokens = 256
70 | print(f'use only first {n_tokens} tokens')
71 | X_train = [' '.join(i.split()[:n_tokens]) for i in X_train]
72 | X_test = [' '.join(i.split()[:n_tokens]) for i in X_test]
73 | print('fit_transform...')
74 | tf = TfidfVectorizer()
75 | X_train = tf.fit_transform(X_train)
76 | X_test = tf.transform(X_test)
77 | Y_train = [business_dict[bid]['categories'] for bid in train_ids]
78 | Y_test = [business_dict[bid]['categories'] for bid in test_ids]
79 | mlb = MultiLabelBinarizer()
80 | Y_train = mlb.fit_transform(Y_train)
81 | Y_test = mlb.transform(Y_test)
82 | return X_train, Y_train, X_test, Y_test
83 |
84 |
85 | def fungo_test_wrapper(name='cellcycle_FUN'):
86 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_fungo(name)
87 | X_train, X_test = np.array(X_train), np.array(X_test)
88 | id2doc_train = id2doc
89 | args = conf()
90 | # id2doc_train = filter_ancestors(id2doc, nodes)
91 | tree = Tree(args, train_ids, test_ids, id2doc=id2doc_train, id2doc_a=id2doc, nodes=nodes, rootname='Top')
92 | mlb = MultiLabelBinarizer(classes=tree.class_idx)
93 | Y_train = mlb.fit_transform([tree.id2doc_ancestors[docid]['class_idx'] for docid in train_ids])
94 | Y_test = mlb.transform([tree.id2doc_ancestors[docid]['class_idx'] for docid in test_ids])
95 | return X_train, Y_train, X_test, Y_test
96 |
97 |
98 | if __name__ == '__main__':
99 | X_train, Y_train, X_test, Y_test = rcv1_test()
100 | # X_train, Y_train, X_test, Y_test = fungo_test_wrapper('cellcycle_FUN')
101 | run(X_train, Y_train, X_test, Y_test)
102 |
--------------------------------------------------------------------------------
/fig/HiLAP_architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/morningmoni/HiLAP/19b63121c37c6a1e3367062925f2b1d7ad9cc87f/fig/HiLAP_architecture.jpg
--------------------------------------------------------------------------------
/fig/prediction_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/morningmoni/HiLAP/19b63121c37c6a1e3367062925f2b1d7ad9cc87f/fig/prediction_animation.gif
--------------------------------------------------------------------------------
/loadData.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import spacy
8 | from nltk.tokenize import sent_tokenize
9 | from tqdm import tqdm
10 |
11 | from Logger_morning import myLogger
12 | from readData_rcv1 import read_rcv1
13 | from readData_yelp import read_yelp
14 | from readData_nyt import read_nyt
15 |
16 | spacy_en = spacy.load('en')
17 | logger = myLogger('exp')
18 |
19 |
20 | def tokenizer(text): # create a tokenizer function
21 | # return text.lower().split()
22 | return [tok.text.lower() for tok in spacy_en.tokenizer(text)]
23 |
24 |
25 | def read_word_embed(pre_trained_path, biovec=False):
26 | logger.info('loading pre-trained embedding from {}'.format(pre_trained_path))
27 | if not biovec:
28 | with open(pre_trained_path) as f:
29 | words, vectors = zip(*[line.strip().split(' ', 1) for line in f])
30 | wv = np.loadtxt(vectors)
31 | else:
32 | with open(pre_trained_path + 'types.txt') as f:
33 | words = [line.strip() for line in f]
34 | wv = np.loadtxt(pre_trained_path + 'vectors.txt')
35 | return words, wv
36 |
37 |
38 | def prepare_word(pre_trained_path, vocab, biovec):
39 | words, wv = read_word_embed(pre_trained_path, biovec)
40 | unknown_vector = np.random.random_sample((wv.shape[1],))
41 | word_set = set(words)
42 | unknown_words = list(set(vocab).difference(set(words)))
43 | logger.info('there are {} OOV words'.format(len(unknown_words)))
44 | word_index = {w: i for i, w in enumerate(words)}
45 | unknown_word_vectors = [np.add.reduce([wv[word_index[w]] if w in word_set else unknown_vector
46 | for w in word.split(' ')])
47 | for word in unknown_words]
48 | wv = np.vstack((wv, unknown_word_vectors))
49 | words = list(words) + unknown_words
50 | # Normalize each row (word vector) in the matrix to sum-up to 1
51 | row_norm = np.sum(np.abs(wv) ** 2, axis=-1) ** (1. / 2)
52 | wv /= row_norm[:, np.newaxis]
53 |
54 | word_index = {w: i for i, w in enumerate(words)}
55 | return wv, word_index
56 |
57 |
58 | def prepare_word_filter(pre_trained_path, vocab, biovec, shared_unk=False, norm=True):
59 | words, wv = read_word_embed(pre_trained_path, biovec)
60 | known_words = set(vocab) & set(words)
61 | unknown_words = set(vocab).difference(set(words))
62 | logger.info('there are {} OOV words'.format(len(unknown_words)))
63 |
64 | word_index = {w: i for i, w in enumerate(words)}
65 | new_word_index = {}
66 |
67 | filter_idx = []
68 | if shared_unk:
69 | ct = 1
70 | else:
71 | ct = 0
72 | for word in known_words:
73 | new_word_index[word] = ct
74 | filter_idx.append(word_index[word])
75 | ct += 1
76 | for word in unknown_words:
77 | if shared_unk:
78 | new_word_index[word] = 0
79 | else:
80 | new_word_index[word] = ct
81 | ct += 1
82 |
83 | wv = wv[filter_idx]
84 | if shared_unk:
85 | unknown_vector = np.random.random_sample((wv.shape[1],)) - .5
86 | wv = np.vstack((unknown_vector, wv))
87 | else:
88 | unknown_vectors = np.random.random_sample((len(unknown_words), wv.shape[1])) - .5
89 | wv = np.vstack((wv, unknown_vectors))
90 |
91 | # Normalize each row (word vector) in the matrix to sum-up to 1
92 | if norm:
93 | row_norm = np.sum(np.abs(wv) ** 2, axis=-1) ** (1. / 2)
94 | wv /= row_norm[:, np.newaxis]
95 |
96 | return wv, new_word_index
97 |
98 |
99 | def tokenize_text_normal(X_l):
100 | X = []
101 | vocab = set()
102 | for text in tqdm(X_l):
103 | X.append([])
104 | for sent in sent_tokenize(text):
105 | t = tokenizer(sent)
106 | vocab.update(t)
107 | X[-1].append(t)
108 | return X, vocab
109 |
110 |
111 | def oov2unk(X_train, X_test, vocab_train, vocab_test, vocab_size=None, word_min_sup=5, count_test=False, filter=False):
112 | if not vocab_size:
113 | logger.info('word_min_sup={}'.format(word_min_sup))
114 | vocab = vocab_train | vocab_test
115 | logger.info(
116 | "[before]vocab_train:{} vocab_test:{} vocab:{}".format(len(vocab_train), len(vocab_test), len(vocab)))
117 | word_count = defaultdict(int)
118 | if filter:
119 | word_w_num = set()
120 | with open('rcv1/rcv1_stopword.txt') as f:
121 | stop_word = set([line.strip() for line in f])
122 | for doc in X_train:
123 | for sent in doc:
124 | for word in sent:
125 | if filter:
126 | if word in stop_word:
127 | continue
128 | if any(c.isdigit() for c in word):
129 | word_w_num.add(word)
130 | continue
131 | word_count[word] += 1
132 | if filter:
133 | logger.info('removeNumbers {} {}'.format(len(word_w_num), word_w_num))
134 | if count_test:
135 | for doc in X_test:
136 | for sent in doc:
137 | for word in sent:
138 | word_count[word] += 1
139 | if vocab_size:
140 | logger.info(f'limit vocab_size={vocab_size}')
141 | vocab_by_freq = set([k for k in sorted(word_count, key=word_count.get, reverse=True)][:vocab_size])
142 | X_train = [[[word if word in vocab_by_freq else 'UNK' for word in sent] for sent in doc] for doc in
143 | X_train]
144 | X_test = [[[word if word in vocab_by_freq else 'UNK' for word in sent] for sent in doc] for doc in
145 | X_test]
146 | else:
147 | X_train = [[[word if word_count[word] >= word_min_sup else 'UNK' for word in sent] for sent in doc] for doc in
148 | X_train]
149 | X_test = [[[word if word_count[word] >= word_min_sup else 'UNK' for word in sent] for sent in doc] for doc in
150 | X_test]
151 | logger.info('Tokenization may be poor. Next are some examples:')
152 | logger.info(X_train[:5])
153 | vocab_train = set([word for doc in X_train for sent in doc for word in sent])
154 | vocab_test = set([word for doc in X_test for sent in doc for word in sent])
155 | vocab = vocab_train | vocab_test
156 | logger.info(
157 | "[after]vocab_train:{} vocab_test:{} vocab:{}".format(len(vocab_train), len(vocab_test), len(vocab)))
158 | return X_train, X_test, vocab
159 |
160 |
161 | def load_data_rcv1_onehot(suffix):
162 | if os.path.exists('preload_data{}_onehot.pkl'.format(suffix)):
163 | logger.warn('loading from preload_data{}_onehot.pkl'.format(suffix))
164 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load(
165 | open('preload_data{}_onehot.pkl'.format(suffix), 'rb'))
166 | return X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes
167 |
168 | _, _, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load(
169 | open('preload_data{}.pkl'.format(suffix), 'rb'))
170 | X_train, vocab_train, X_test, vocab_test = pickle.load(open('text_tokenized{}.pkl'.format(suffix), 'rb'))
171 | X_train, X_test, vocab = oov2unk(X_train, X_test, vocab_train, vocab_test, vocab_size=30000, filter=True)
172 | word_index = {w: ct for ct, w in enumerate(vocab)}
173 | X_train = [[[word_index[word] for word in sent] for sent in doc] for doc in X_train]
174 | X_test = [[[word_index[word] for word in sent] for sent in doc] for doc in X_test]
175 | logger.info('saving preload_data{}_onehot.pkl'.format(suffix))
176 | res = np.array(X_train), np.array(X_test), np.array(train_ids), np.array(test_ids), id2doc, wv, word_index, nodes
177 | pickle.dump(res, open('preload_data{}_onehot.pkl'.format(suffix), 'wb'))
178 | return res
179 |
180 |
181 | def load_data_rcv1(embedding_path, suffix):
182 | if os.path.exists('preload_data{}.pkl'.format(suffix)):
183 | logger.warn('loading from preload_data{}.pkl'.format(suffix))
184 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load(
185 | open('preload_data{}.pkl'.format(suffix), 'rb'))
186 | return X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes
187 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_rcv1()
188 | if os.path.exists('text_tokenized{}.pkl'.format(suffix)):
189 | X_train, vocab_train, X_test, vocab_test = pickle.load(open('text_tokenized{}.pkl'.format(suffix), 'rb'))
190 | else:
191 | X_train, vocab_train = tokenize_text_normal(X_train)
192 | X_test, vocab_test = tokenize_text_normal(X_test)
193 | res = X_train, vocab_train, X_test, vocab_test
194 | pickle.dump(res, open('text_tokenized{}.pkl'.format(suffix), 'wb'))
195 | X_train, X_test, vocab = oov2unk(X_train, X_test, vocab_train, vocab_test, count_test=True)
196 | wv, word_index = prepare_word_filter(embedding_path, vocab, biovec=False)
197 | X_train = [[[word_index[word] for word in sent] for sent in doc] for doc in X_train]
198 | X_test = [[[word_index[word] for word in sent] for sent in doc] for doc in X_test]
199 | logger.info('saving preload_data{}.pkl'.format(suffix))
200 | res = np.array(X_train), np.array(X_test), np.array(train_ids), np.array(test_ids), id2doc, wv, word_index, nodes
201 | pickle.dump(res, open('preload_data{}.pkl'.format(suffix), 'wb'))
202 | return res
203 |
204 |
205 | def load_data_nyt_onehot(suffix):
206 | if os.path.exists('preload_data{}_onehot.pkl'.format(suffix)):
207 | logger.warn('loading from preload_data{}_onehot.pkl'.format(suffix))
208 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load(
209 | open('preload_data{}_onehot.pkl'.format(suffix), 'rb'))
210 | return X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes
211 |
212 | _, _, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load(
213 | open('preload_data{}.pkl'.format(suffix), 'rb'))
214 | X_train, vocab_train, X_test, vocab_test = pickle.load(open('text_tokenized{}.pkl'.format(suffix), 'rb'))
215 | X_train, X_test, vocab = oov2unk(X_train, X_test, vocab_train, vocab_test, vocab_size=30000, filter=True)
216 | word_index = {w: ct for ct, w in enumerate(vocab)}
217 | X_train = [[[word_index[word] for word in sent] for sent in doc] for doc in X_train]
218 | X_test = [[[word_index[word] for word in sent] for sent in doc] for doc in X_test]
219 | logger.info('saving preload_data{}_onehot.pkl'.format(suffix))
220 | res = np.array(X_train), np.array(X_test), np.array(train_ids), np.array(test_ids), id2doc, wv, word_index, nodes
221 | pickle.dump(res, open('preload_data{}_onehot.pkl'.format(suffix), 'wb'))
222 | return res
223 |
224 |
225 | def load_data_nyt(embedding_path, suffix):
226 | if os.path.exists('preload_data{}.pkl'.format(suffix)):
227 | logger.warn('loading from preload_data{}.pkl'.format(suffix))
228 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load(
229 | open('preload_data{}.pkl'.format(suffix), 'rb'))
230 | return X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes
231 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_nyt()
232 | if os.path.exists('text_tokenized{}.pkl'.format(suffix)):
233 | X_train, vocab_train, X_test, vocab_test = pickle.load(open('text_tokenized{}.pkl'.format(suffix), 'rb'))
234 | else:
235 | X_train, vocab_train = tokenize_text_normal(X_train)
236 | X_test, vocab_test = tokenize_text_normal(X_test)
237 | res = X_train, vocab_train, X_test, vocab_test
238 | pickle.dump(res, open('text_tokenized{}.pkl'.format(suffix), 'wb'))
239 | X_train, X_test, vocab = oov2unk(X_train, X_test, vocab_train, vocab_test, count_test=True)
240 | wv, word_index = prepare_word_filter(embedding_path, vocab, biovec=False)
241 | X_train = [[[word_index[word] for word in sent] for sent in doc] for doc in X_train]
242 | X_test = [[[word_index[word] for word in sent] for sent in doc] for doc in X_test]
243 | logger.info('saving preload_data{}.pkl'.format(suffix))
244 | res = np.array(X_train), np.array(X_test), np.array(train_ids), np.array(test_ids), id2doc, wv, word_index, nodes
245 | pickle.dump(res, open('preload_data{}.pkl'.format(suffix), 'wb'))
246 | return res
247 |
248 |
249 | def load_data_yelp(embedding_path, suffix, root, min_reviews=1, max_reviews=10):
250 | """
251 | suffix: used to distinguish different pkl files
252 | root: which subtree to use e.g., root (use all nodes as classes, 1004 - 1 in total), Hotels & Travel
253 | min_reviews: remove businesses that have < min_reviews
254 | max_reviews: use at most max_reviews for a business
255 | """
256 | if os.path.exists('preload_data{}.pkl'.format(suffix)):
257 | logger.warn('loading from preload_data{}.pkl'.format(suffix))
258 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = pickle.load(
259 | open('preload_data{}.pkl'.format(suffix), 'rb'))
260 | return X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes
261 | logger.warn('reviews min_sup={}, max_sup={}'.format(min_reviews, max_reviews))
262 | X_train, X_test, train_ids, test_ids, id2doc, nodes = read_yelp(root, min_reviews, max_reviews)
263 | # TODO need to make sure every time train_ids, test_ids are the same
264 | if os.path.exists('text_tokenized{}.pkl'.format(suffix)):
265 | logger.warn('loading from text_tokenized{}.pkl'.format(suffix))
266 | X_train, vocab_train, X_test, vocab_test = pickle.load(open('text_tokenized{}.pkl'.format(suffix), 'rb'))
267 | else:
268 | X_train, vocab_train = tokenize_text_normal(X_train)
269 | X_test, vocab_test = tokenize_text_normal(X_test)
270 | res = X_train, vocab_train, X_test, vocab_test
271 | pickle.dump(res, open('text_tokenized{}.pkl'.format(suffix), 'wb'))
272 | X_train, X_test, vocab = oov2unk(X_train, X_test, vocab_train, vocab_test, vocab_size=30000)
273 | wv, word_index = prepare_word_filter(embedding_path, vocab, biovec=False)
274 | X_train = [[[word_index[word] for word in sent] for sent in doc] for doc in X_train]
275 | X_test = [[[word_index[word] for word in sent] for sent in doc] for doc in X_test]
276 | logger.info('saving preload_data{}.pkl'.format(suffix))
277 | res = np.array(X_train), np.array(X_test), np.array(train_ids), np.array(test_ids), id2doc, wv, word_index, nodes
278 | pickle.dump(res, open('preload_data{}.pkl'.format(suffix), 'wb'))
279 | return res
280 |
281 |
282 | def filter_ancestors(id2doc, nodes):
283 | logger.info('keep only lowest label in a path...')
284 | id2doc_na = defaultdict(dict)
285 | labels_ct = []
286 | lowest_labels_ct = []
287 | for bid in id2doc:
288 | lowest_labels = []
289 | cat_set = set(id2doc[bid]['categories'])
290 | for label in id2doc[bid]['categories']:
291 | if len(set(nodes[label]['children']) & cat_set) == 0:
292 | lowest_labels.append(label)
293 | labels_ct.append(len(cat_set))
294 | lowest_labels_ct.append(len(lowest_labels))
295 | id2doc_na[bid]['categories'] = lowest_labels
296 | logger.info('#labels')
297 | logger.info(pd.Series(labels_ct).describe(percentiles=[.25, .5, .75, .8, .85, .9, .95, .96, .98]))
298 | logger.info('#lowest labels')
299 | logger.info(pd.Series(lowest_labels_ct).describe(percentiles=[.25, .5, .75, .8, .85, .9, .95, .96, .98]))
300 | return id2doc_na
301 |
302 |
303 | def split_multi(X_train, train_ids, id2doc_train, id2doc):
304 | logger.info('split one sample with m labels to m samples...')
305 | X_train_new = []
306 | train_ids_new = []
307 | id2doc_train_new = defaultdict(dict)
308 | for X, did in zip(X_train, train_ids):
309 | ct = 0
310 | for label in id2doc_train[did]['categories']:
311 | newID = '{}-{}'.format(did, ct)
312 | id2doc[newID] = id2doc[did]
313 | X_train_new.append(X)
314 | train_ids_new.append(newID)
315 | id2doc_train_new[newID]['categories'] = [label]
316 | ct += 1
317 | return np.array(X_train_new), np.array(train_ids_new), id2doc_train_new, id2doc
318 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from datetime import datetime
4 |
5 | import numpy as np
6 | import torch
7 | from sklearn.preprocessing import MultiLabelBinarizer
8 | from tensorboardX import SummaryWriter
9 | from torch.autograd import Variable
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 |
13 | from Linear_Model import Linear_Model
14 | from Logger_morning import myLogger
15 | from conf import conf, print_config
16 | from feature_dataset import featureDataset, my_collate
17 | from readData_fungo import read_fungo
18 |
19 | args = conf()
20 | if args.load_model is None or args.isTrain:
21 | comment = f'_{args.dataset}_{args.base_model}_{args.mode}_{args.remark}'
22 | current_time = datetime.now().strftime('%b%d_%H-%M-%S')
23 | log_dir = os.path.join('runs', current_time + comment)
24 | else:
25 | log_dir = ''.join(args.load_model[:args.load_model.rfind('/')])
26 | print(f'reuse dir: {log_dir}')
27 | logger = myLogger(name='exp', log_path=log_dir + '.log')
28 | # incompatible with logger...
29 | writer = SummaryWriter(log_dir=log_dir)
30 | writer.add_text('Parameters', str(vars(args)))
31 | print_config(args, logger)
32 | logger.setLevel(args.log_level)
33 |
34 | from HMCN import HMCN
35 | from HAN import HAN
36 | from OHCNN_fast import OHCNN_fast
37 | from TextCNN import TextCNN
38 | from loadData import load_data_yelp, filter_ancestors, load_data_rcv1, split_multi, \
39 | load_data_rcv1_onehot, load_data_nyt_onehot, load_data_nyt
40 | from model import Policy
41 | from tree import Tree
42 | from util import get_gpu_memory_map, save_checkpoint, check_doc_size, gen_minibatch_from_cache, gen_minibatch, \
43 | save_minibatch, contains_nan
44 |
45 |
46 | def finish_episode(policy, update=True):
47 | policy_loss = []
48 | all_cum_rewards = []
49 | for i in range(args.n_rollouts):
50 | rewards = []
51 | R = np.zeros(len(policy.rewards[i][0]))
52 | for r in policy.rewards[i][::-1]:
53 | R = r + args.gamma * R
54 | rewards.insert(0, R)
55 | all_cum_rewards.extend(rewards)
56 | rewards = torch.Tensor(rewards) # (length, batch_size)
57 | # logger.warning(f'original {rewards}')
58 | if args.baseline == 'avg':
59 | rewards -= policy.baseline_reward
60 | elif args.baseline == 'greedy':
61 | rewards_greedy = []
62 | R = np.zeros(len(policy.rewards_greedy[0]))
63 | for r in policy.rewards_greedy[::-1]:
64 | R = r + args.gamma * R
65 | rewards_greedy.insert(0, R)
66 | rewards_greedy = torch.Tensor(rewards_greedy)
67 | rewards -= rewards_greedy
68 | # logger.warning(f'after baseline {rewards}')
69 | if args.avg_reward_mode == 'batch':
70 | rewards = Variable((rewards - rewards.mean()) / (rewards.std() + float(np.finfo(np.float32).eps)))
71 | elif args.avg_reward_mode == 'each':
72 | # mean/std is separate for each in the batch
73 | rewards = Variable((rewards - rewards.mean(dim=0)) / (rewards.std(dim=0) + float(np.finfo(np.float32).eps)))
74 | else:
75 | rewards = Variable(rewards)
76 | if args.gpu:
77 | rewards = rewards.cuda()
78 | for log_prob, reward in zip(policy.saved_log_probs[i], rewards):
79 | policy_loss.append(-log_prob * reward)
80 | # logger.warning(f'after mean_std {rewards}')
81 | if update:
82 | tree.n_update += 1
83 | try:
84 | policy_loss = torch.cat(policy_loss).mean()
85 | except Exception as e:
86 | logger.error(e)
87 | entropy = torch.cat(policy.entropy_l).mean()
88 | writer.add_scalar('data/policy_loss', policy_loss, tree.n_update)
89 | writer.add_scalar('data/entropy_loss', policy.beta * entropy.data[0], tree.n_update)
90 | policy_loss += policy.beta * entropy
91 | if args.sl_ratio > 0:
92 | policy_loss += args.sl_ratio * policy.sl_loss
93 | writer.add_scalar('data/sl_loss', args.sl_ratio * policy.sl_loss, tree.n_update)
94 | writer.add_scalar('data/total_loss', policy_loss.data[0], tree.n_update)
95 | optimizer.zero_grad()
96 | policy_loss.backward()
97 | if contains_nan(policy.class_embed.weight.grad):
98 | logger.error('nan in class_embed.weight.grad!')
99 | else:
100 | optimizer.step()
101 | policy.update_baseline(np.mean(np.concatenate(all_cum_rewards)))
102 | policy.finish_episode()
103 |
104 |
105 | def calc_sl_loss(probs, doc_ids, update=True, return_var=False):
106 | mlb = MultiLabelBinarizer(classes=tree.class_idx)
107 | y_l = [tree.id2doc_ancestors[docid]['class_idx'] for docid in doc_ids]
108 | y_true = mlb.fit_transform(y_l)
109 | if args.gpu:
110 | y_true = Variable(torch.from_numpy(y_true)).cuda().float()
111 | else:
112 | y_true = Variable(torch.from_numpy(y_true)).float()
113 | loss = criterion(probs, y_true)
114 | if update:
115 | tree.n_update += 1
116 | optimizer.zero_grad()
117 | loss.backward()
118 | optimizer.step()
119 | if return_var:
120 | return loss
121 | return loss.data[0]
122 |
123 |
124 | def get_cur_size(tokens):
125 | if args.base_model != 'han':
126 | return tokens.size()[0]
127 | else:
128 | return tokens.size()[1]
129 |
130 |
131 | def forward_step_sl(tokens, doc_ids, flat_probs_only=False):
132 | # TODO can reuse logits
133 | if args.global_ratio > 0:
134 | probs = policy.base_model(tokens, True)
135 | global_loss = calc_sl_loss(probs, doc_ids, update=False, return_var=True)
136 | else:
137 | probs = None
138 | global_loss = 0
139 | if flat_probs_only:
140 | policy.sl_loss = global_loss
141 | return global_loss, probs
142 | policy.doc_vec = None
143 | cur_batch_size = get_cur_size(tokens)
144 | cur_class_batch = np.zeros(cur_batch_size, dtype=int)
145 | for t in range(args.n_steps_sl):
146 | next_classes_batch = tree.p2c_batch(cur_class_batch)
147 | next_classes_batch_true, indices, next_class_batch_true, doc_ids = tree.get_next(cur_class_batch,
148 | next_classes_batch,
149 | doc_ids)
150 | policy.step_sl(tokens, cur_class_batch, next_classes_batch, next_classes_batch_true)
151 | cur_class_batch = next_class_batch_true
152 | policy.duplicate_doc_vec(indices)
153 | policy.sl_loss /= args.n_steps
154 | writer.add_scalar('data/step-sl_sl_loss', (1 - args.global_ratio) * policy.sl_loss, tree.n_update)
155 | policy.sl_loss = (1 - args.global_ratio) * policy.sl_loss + args.global_ratio * global_loss
156 | writer.add_scalar('data/flat_sl_loss', args.global_ratio * global_loss, tree.n_update)
157 | return global_loss, probs
158 |
159 |
160 | def train_step_sl():
161 | policy.train()
162 | for i in range(1, args.num_epoch + 1):
163 | g = select_data('train' + args.ohcnn_data, shuffle=True)
164 | loss_total = 0
165 | tree.cur_epoch = i
166 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)):
167 | if 'FUN' in args.dataset or 'GO' in args.dataset:
168 | tokens = Variable(tokens).cuda()
169 | global_loss, flat_probs = forward_step_sl(tokens, doc_ids, flat_probs_only=(args.global_ratio == 1))
170 | optimizer.zero_grad()
171 | policy.sl_loss.backward()
172 | optimizer.step()
173 | tree.n_update += 1
174 | loss_total += policy.sl_loss.data[0]
175 | if ct % args.output_every == 0 and ct != 0:
176 | if args.global_ratio > 0:
177 | logger.info(
178 | f'loss_cur:{policy.sl_loss.data[0]} global_loss:{args.global_ratio * global_loss.data[0]}')
179 | else:
180 | logger.info(f'loss_cur:{policy.sl_loss.data[0]} global_loss:off')
181 | logger.info(f'[{i}:{ct}] loss_avg:{loss_total / ct}')
182 | writer.add_scalar('data/sl_loss', global_loss, tree.n_update)
183 | writer.add_scalar('data/loss_avg', loss_total / ct, tree.n_update)
184 | policy.sl_loss = 0
185 | if i % args.save_every == 0:
186 | eval_save_model(i, datapath='train' + args.ohcnn_data, save=True, output=False)
187 | test_step_sl('test' + args.ohcnn_data, save_prob=False)
188 |
189 |
190 | def test_step_sl(data_path, save_prob=False, output=True):
191 | logger.info('test starts')
192 | policy.eval()
193 | g = select_data(data_path)
194 | pred_l = []
195 | target_l = []
196 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)):
197 | if 'FUN' in args.dataset or 'GO' in args.dataset:
198 | tokens = Variable(tokens).cuda()
199 | real_doc_ids = [i for i in doc_ids]
200 | policy.doc_vec = None
201 | cur_batch_size = get_cur_size(tokens)
202 | cur_class_batch = np.zeros(cur_batch_size, dtype=int)
203 | for _ in range(args.n_steps_sl):
204 | next_classes_batch = tree.p2c_batch(cur_class_batch)
205 | probs = policy.step_sl(tokens, cur_class_batch, next_classes_batch, None, sigmoid=True)
206 | indices, next_class_batch_pred, doc_ids = tree.get_next_by_probs(cur_class_batch, next_classes_batch,
207 | doc_ids, probs, save_prob)
208 | cur_class_batch = next_class_batch_pred
209 | policy.duplicate_doc_vec(indices)
210 | last_did = None
211 | for c, did in zip(cur_class_batch, doc_ids):
212 | if last_did != did:
213 | pred_l.append([])
214 | if c != 0:
215 | pred_l[-1].append(c)
216 | last_did = did
217 | target_l.extend(real_doc_ids)
218 | if save_prob:
219 | logger.info(f'saving {writer.file_writer.get_logdir()}/{data_path}.tree.id2prob2.pkl')
220 | pickle.dump(tree.id2prob, open(f'{writer.file_writer.get_logdir()}/{data_path}.tree.id2prob2.pkl', 'wb'))
221 | tree.id2prob.clear()
222 | return
223 | return evaluate(pred_l, target_l, output=output)
224 |
225 |
226 | def output_log(cur_class_batch, doc_ids, acc, i, ct):
227 | if args.dataset in ['yelp', 'rcv1']:
228 | label_key = 'categories'
229 | try:
230 | logger.info('pred{}{} real{}{} pred_h{} real_h{}'.format(cur_class_batch[:3],
231 | [tree.id2name[tree.idx2id[cur]] for cur in
232 | cur_class_batch[:3]],
233 | [tree.id2doc_ancestors[docid]['class_idx'] for
234 | docid in doc_ids[:3]],
235 | [tree.id2doc_ancestors[docid][label_key] for docid
236 | in doc_ids[:3]],
237 | tree.h_batch(cur_class_batch[:3]),
238 | tree.h_doc_batch(doc_ids[:3])))
239 | except Exception as e:
240 | logger.warning(e)
241 | writer.add_scalar('data/acc', np.mean(acc), tree.n_update)
242 | writer.add_scalar('data/beta', policy.beta, tree.n_update)
243 | logger.info('single-label acc for epoch {} batch {}: {}'.format(i, ct, np.mean(acc)))
244 | if (cur_class_batch == cur_class_batch[0]).all():
245 | logger.error('predictions in a batch are all the same! [{}]'.format(cur_class_batch[0]))
246 | writer.add_text('error', 'predictions in a batch are all the same! [{}]'.format(cur_class_batch[0]),
247 | tree.n_update)
248 | if not args.debug:
249 | exit(1)
250 |
251 |
252 | def train_taxo():
253 | policy.train()
254 | for i in range(1, args.num_epoch + 1):
255 | g = select_data('train' + args.ohcnn_data, shuffle=True)
256 | pred_l = []
257 | target_l = []
258 | tree.cur_epoch = i
259 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)):
260 | if 'FUN' in args.dataset or 'GO' in args.dataset:
261 | tokens = Variable(tokens).cuda()
262 | flat_probs = None
263 | if args.sl_ratio > 0:
264 | _, flat_probs = forward_step_sl(tokens, doc_ids, flat_probs_only=(args.global_ratio == 1))
265 | if not args.mix_flat_probs:
266 | flat_probs = None
267 | policy.doc_vec = None
268 | cur_batch_size = get_cur_size(tokens)
269 |
270 | # greedy baseline
271 | if args.baseline == 'greedy':
272 | tree.taken_actions = [set() for _ in range(cur_batch_size)]
273 | cur_class_batch = np.zeros(cur_batch_size, dtype=int)
274 | next_classes_batch = [set() for _ in range(cur_batch_size)]
275 | for t in range(args.n_steps):
276 | next_classes_batch, next_classes_batch_np, _ = tree.get_next_candidates(cur_class_batch,
277 | next_classes_batch)
278 | choices, m = policy.step(tokens, cur_class_batch, next_classes_batch_np, test=True,
279 | flat_probs=flat_probs)
280 | cur_class_batch = next_classes_batch_np[
281 | np.arange(len(next_classes_batch_np)), choices.data.cpu().numpy()]
282 | tree.update_actions(cur_class_batch)
283 | policy.rewards_greedy.append(tree.calc_reward(t < args.n_steps - 1, cur_class_batch, doc_ids))
284 | tree.last_R = None
285 | for _ in range(args.n_rollouts):
286 | policy.saved_log_probs.append([])
287 | policy.rewards.append([])
288 | tree.taken_actions = [set() for _ in range(cur_batch_size)]
289 | cur_class_batch = np.zeros(cur_batch_size, dtype=int)
290 | next_classes_batch = [set() for _ in range(cur_batch_size)]
291 | for t in range(args.n_steps):
292 | next_classes_batch, next_classes_batch_np, all_stop = tree.get_next_candidates(cur_class_batch,
293 | next_classes_batch)
294 | if args.early_stop and all_stop:
295 | break
296 | choices, m = policy.step(tokens, cur_class_batch, next_classes_batch_np, flat_probs=flat_probs)
297 | cur_class_batch = next_classes_batch_np[
298 | np.arange(len(next_classes_batch_np)), choices.data.cpu().numpy()]
299 | tree.update_actions(cur_class_batch)
300 | policy.saved_log_probs[-1].append(m.log_prob(choices))
301 | policy.rewards[-1].append(tree.calc_reward(t < args.n_steps - 1, cur_class_batch, doc_ids))
302 | tree.last_R = None
303 | tree.remove_stop()
304 | pred_l.extend(tree.taken_actions)
305 | target_l.extend(doc_ids)
306 | finish_episode(policy, update=True)
307 | if ct % args.output_every == 0 and ct != 0:
308 | logger.info(f'epoch {i} batch {ct}')
309 | eval_save_model(i, pred_l, target_l, output=False)
310 | if i % args.save_every == 0:
311 | eval_save_model(i, pred_l, target_l, save=True, output=False)
312 | test_taxo('test' + args.ohcnn_data, save_prob=False)
313 |
314 |
315 | def select_data(data_path, shuffle=False):
316 | if 'FUN' in args.dataset or 'GO' in args.dataset:
317 | isTrain = False
318 | if 'train' in data_path:
319 | isTrain = True
320 | test_dataset = featureDataset(args.dataset, isTrain)
321 | g = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=my_collate, shuffle=shuffle)
322 | elif args.base_model == 'ohcnn-bow-fast':
323 | g = gen_minibatch_from_cache(logger, args, tree, args.batch_size, name=data_path, shuffle=shuffle)
324 | else:
325 | if 'test' in data_path:
326 | g = gen_minibatch(logger, args, word_index, X_test, test_ids, args.batch_size, shuffle=shuffle)
327 | else:
328 | g = gen_minibatch(logger, args, word_index, X_train, train_ids, args.batch_size, shuffle=shuffle)
329 | return g
330 |
331 |
332 | def test_taxo(data_path, save_prob=False):
333 | logger.info('test starts')
334 | policy.eval()
335 | g = select_data(data_path)
336 | pred_l = []
337 | target_l = []
338 | if save_prob:
339 | args.n_steps = tree.n_class - 1
340 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)):
341 | if 'FUN' in args.dataset or 'GO' in args.dataset:
342 | tokens = Variable(tokens).cuda()
343 | flat_probs = None
344 | if args.sl_ratio > 0 and args.mix_flat_probs:
345 | _, flat_probs = forward_step_sl(tokens, doc_ids, flat_probs_only=True)
346 | policy.doc_vec = None
347 | cur_batch_size = get_cur_size(tokens)
348 | tree.taken_actions = [set() for _ in range(cur_batch_size)]
349 | cur_class_batch = np.zeros(cur_batch_size, dtype=int)
350 | next_classes_batch = [set() for _ in range(cur_batch_size)]
351 | for _ in range(args.n_steps):
352 | next_classes_batch, next_classes_batch_np, all_stop = tree.get_next_candidates(cur_class_batch,
353 | next_classes_batch,
354 | save_prob)
355 | if all_stop:
356 | if save_prob:
357 | logger.error('should not enter')
358 | break
359 | choices, m = policy.step(tokens, cur_class_batch, next_classes_batch_np, test=True, flat_probs=flat_probs)
360 | cur_class_batch = next_classes_batch_np[
361 | np.arange(len(next_classes_batch_np)), choices.data.cpu().numpy()]
362 | if save_prob:
363 | for did, idx, p_ in zip(doc_ids, cur_class_batch,
364 | m.probs.gather(-1, choices.unsqueeze(-1)).squeeze(-1).data.cpu().numpy()):
365 | assert 0 < idx < 104, idx
366 | if idx in tree.id2prob[did]:
367 | logger.warning(f'[{did}][{idx}] already existed!')
368 | tree.id2prob[did][idx] = p_
369 | tree.update_actions(cur_class_batch)
370 | tree.remove_stop()
371 | pred_l.extend(tree.taken_actions)
372 | target_l.extend(doc_ids)
373 | if save_prob:
374 | logger.info(f'saving {writer.file_writer.get_logdir()}/{data_path}.tree.id2prob-rl.pkl')
375 | pickle.dump(tree.id2prob, open(f'{writer.file_writer.get_logdir()}/{data_path}.tree.id2prob-rl.pkl', 'wb'))
376 | tree.id2prob.clear()
377 | return evaluate(pred_l, target_l)
378 |
379 |
380 | def eval_save_model(i, pred_l=None, target_l=None, datapath=None, save=False, output=True):
381 | if args.mode == 'hilap-sl':
382 | test_f = test_step_sl
383 | elif args.mode == 'hilap':
384 | test_f = test_taxo
385 | else:
386 | test_f = test_sl
387 | if pred_l:
388 | f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s = evaluate(pred_l, target_l, output=output)
389 | elif datapath:
390 | f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s = test_f(datapath, output=output)
391 | else:
392 | f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s = test_f(X_train, train_ids)
393 | writer.add_scalar('data/micro_train', f1_aa, tree.n_update)
394 | writer.add_scalar('data/macro_train', f1_aa_macro, tree.n_update)
395 | writer.add_scalar('data/samples_train', f1_aa_s, tree.n_update)
396 | if not save:
397 | return
398 | if args.mode in ['hilap', 'hilap-sl']:
399 | save_checkpoint({
400 | 'state_dict': policy.state_dict(),
401 | 'optimizer': optimizer.state_dict(),
402 | }, writer.file_writer.get_logdir(), f'epoch{i}_{f1_aa}_{f1_aa_macro}_{f1_aa_s}.pth.tar', logger, True)
403 | else:
404 | save_checkpoint({
405 | 'state_dict': model.state_dict(),
406 | 'optimizer': optimizer.state_dict(),
407 | }, writer.file_writer.get_logdir(), f'epoch{i}_{f1_aa}_{f1_aa_macro}_{f1_aa_s}.pth.tar', logger, True)
408 |
409 |
410 | def decode(pred_l, target_l, doc_ids, probs):
411 | cur_class_batch = None
412 | target_l.extend(doc_ids)
413 | if args.multi_label:
414 | if args.mode != 'hmcn':
415 | probs = torch.sigmoid(probs)
416 | preds = (probs >= .5).int().data.cpu().numpy()
417 | for pred in preds:
418 | idx = np.nonzero(pred)[0]
419 | if len(idx) == 0:
420 | pred_l.append([])
421 | else:
422 | pred_l.append(idx + 1)
423 | else:
424 | cur_class_batch = torch.max(probs, 1)[1].data.cpu().numpy() + 1
425 | pred_l.extend(cur_class_batch)
426 | return pred_l, target_l, cur_class_batch
427 |
428 |
429 | def train_hmcn():
430 | model.train()
431 | for i in range(1, args.num_epoch + 1):
432 | g = select_data('train' + args.ohcnn_data, shuffle=True)
433 | loss_total = 0
434 | pred_l = []
435 | target_l = []
436 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)):
437 | if 'FUN' in args.dataset or 'GO' in args.dataset:
438 | tokens = Variable(tokens).cuda()
439 | probs = model(tokens)
440 | pred_l, target_l, cur_class_batch = decode(pred_l, target_l, doc_ids, probs)
441 | loss = calc_sl_loss(probs, doc_ids, update=True)
442 | if ct % 50 == 0 and ct != 0:
443 | logger.info('loss: {}'.format(loss))
444 | if args.multi_label:
445 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l))
446 | else:
447 | acc = tree.acc(np.array(pred_l), np.array(target_l))
448 | logger.info('acc for epoch {} batch {}: {}'.format(i, ct, acc))
449 | writer.add_scalar('data/loss', loss, tree.n_update)
450 | writer.add_scalar('data/acc', acc, tree.n_update)
451 | if not args.multi_label and (cur_class_batch == cur_class_batch[0]).all():
452 | logger.error('predictions in a batch are all the same! [{}]'.format(cur_class_batch[0]))
453 | writer.add_text('error', 'predictions in a batch are all the same! [{}]'.format(cur_class_batch[0]),
454 | tree.n_update)
455 | # exit(1)
456 | loss_total += loss
457 | loss_avg = loss_total / (ct + 1)
458 | if args.multi_label:
459 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l))
460 | else:
461 | acc = tree.acc(np.array(pred_l), np.array(target_l))
462 | logger.info('loss_avg:{} acc:{}'.format(loss_avg, acc))
463 | if not args.multi_label:
464 | pred_l = [[label] for label in pred_l]
465 | if i % args.save_every == 0:
466 | eval_save_model(i, pred_l, target_l, save=True)
467 |
468 |
469 | def test_hmcn():
470 | logger.info('testing starts')
471 | model.eval()
472 | g = select_data('test' + args.ohcnn_data)
473 | loss_total = 0
474 | pred_l = []
475 | target_l = []
476 | probs_l = []
477 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)):
478 | if 'FUN' in args.dataset or 'GO' in args.dataset:
479 | tokens = Variable(tokens).cuda()
480 | probs = model(tokens)
481 | probs_l.append(probs.data.cpu().numpy())
482 | pred_l, target_l, cur_class_batch = decode(pred_l, target_l, doc_ids, probs)
483 | loss = calc_sl_loss(probs, doc_ids, update=False)
484 | loss_total += loss
485 | probs_np = np.concatenate(probs_l, axis=0)
486 | logger.info('saving probs to {}/probs_{}.pkl'.format(writer.file_writer.get_logdir(), tree.n_update))
487 | pickle.dump((probs_np, target_l),
488 | open('{}/probs_{}.pkl'.format(writer.file_writer.get_logdir(), tree.n_update), 'wb'))
489 | loss_avg = loss_total / (ct + 1)
490 | if args.multi_label:
491 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l))
492 | else:
493 | acc = tree.acc(np.array(pred_l), np.array(target_l))
494 | logger.info('loss_avg:{} acc:{}'.format(loss_avg, acc))
495 | if not args.multi_label:
496 | pred_l = [[label] for label in pred_l]
497 | inc = 0
498 | for p in pred_l:
499 | p = set(p)
500 | exist = False
501 | for l in p:
502 | cur = tree.c2p_idx[l][0]
503 | while cur != 0:
504 | if cur not in p:
505 | inc += 1
506 | exist = True
507 | break
508 | cur = tree.c2p_idx[cur][0]
509 | if exist:
510 | break
511 | print(inc)
512 | # exit()
513 | return evaluate(pred_l, target_l)
514 |
515 |
516 | def train_sl():
517 | model.train()
518 | for i in range(1, args.num_epoch + 1):
519 | g = select_data('train' + args.ohcnn_data, shuffle=True)
520 | loss_total = 0
521 | pred_l = []
522 | target_l = []
523 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)):
524 | probs = model(tokens, True)
525 | pred_l, target_l, cur_class_batch = decode(pred_l, target_l, doc_ids, probs)
526 | loss = calc_sl_loss(probs, doc_ids, update=True)
527 | if ct % args.output_every == 0 and ct != 0:
528 | logger.info('sl_loss: {}'.format(loss))
529 | if args.multi_label:
530 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l))
531 | else:
532 | acc = tree.acc(np.array(pred_l), np.array(target_l))
533 | logger.info('acc for epoch {} batch {}: {}'.format(i, ct, acc))
534 | writer.add_scalar('data/sl_loss', loss, tree.n_update)
535 | writer.add_scalar('data/acc', acc, tree.n_update)
536 | if not args.multi_label and (cur_class_batch == cur_class_batch[0]).all():
537 | logger.error('predictions in a batch are all the same! [{}]'.format(cur_class_batch[0]))
538 | writer.add_text('error', 'predictions in a batch are all the same! [{}]'.format(cur_class_batch[0]),
539 | tree.n_update)
540 | # exit(1)
541 | loss_total += loss
542 | loss_avg = loss_total / (ct + 1)
543 | if args.multi_label:
544 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l))
545 | else:
546 | acc = tree.acc(np.array(pred_l), np.array(target_l))
547 | logger.info('loss_avg:{} acc:{}'.format(loss_avg, acc))
548 | if not args.multi_label:
549 | pred_l = [[label] for label in pred_l]
550 | if i % args.save_every == 0:
551 | # eval_save_model(i, pred_l, target_l, save=True)
552 | test_sl()
553 |
554 |
555 | def test_sl():
556 | logger.info('testing starts')
557 | model.eval()
558 | g = select_data('test' + args.ohcnn_data)
559 | loss_total = 0
560 | pred_l = []
561 | target_l = []
562 | probs_l = []
563 | for ct, (tokens, doc_ids) in tqdm(enumerate(g)):
564 | probs = model(tokens, True)
565 | probs_l.append(probs.data.cpu().numpy())
566 | pred_l, target_l, cur_class_batch = decode(pred_l, target_l, doc_ids, probs)
567 | loss = calc_sl_loss(probs, doc_ids, update=False)
568 | loss_total += loss
569 | # probs_np = np.concatenate(probs_l, axis=0)
570 | # logger.info('saving probs to {}/probs_{}.pkl'.format(writer.file_writer.get_logdir(), tree.n_update))
571 | # pickle.dump((probs_np, target_l),
572 | # open('{}/probs_{}.pkl'.format(writer.file_writer.get_logdir(), tree.n_update), 'wb'))
573 | loss_avg = loss_total / (ct + 1)
574 | if args.multi_label:
575 | acc = tree.acc_multi(np.array(pred_l), np.array(target_l))
576 | else:
577 | acc = tree.acc(np.array(pred_l), np.array(target_l))
578 | logger.info('loss_avg:{} acc:{}'.format(loss_avg, acc))
579 | if not args.multi_label:
580 | pred_l = [[label] for label in pred_l]
581 | return evaluate(pred_l, target_l)
582 |
583 |
584 | def evaluate(pred_l, test_ids, save_path=None, output=True):
585 | acc = round(tree.acc_multi(pred_l, test_ids), 4)
586 | res = tree.calc_f1(pred_l, test_ids, save_path, output)
587 | if output:
588 | f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s = [round(i, 4) for i in res]
589 | logger.info(
590 | f'acc:{acc} f1_s:{f1_aa_s} micro-f1:{f1} {f1_a} {f1_aa} macro-f1:{f1_macro} {f1_a_macro} {f1_aa_macro}')
591 | if f1_aa > tree.miF[0]:
592 | tree.miF = (f1_aa, f1_aa_macro, tree.cur_epoch)
593 | if f1_aa_macro > tree.maF[1]:
594 | tree.maF = (f1_aa, f1_aa_macro, tree.cur_epoch)
595 | logger.warning(f'best: {tree.miF}, {tree.maF}')
596 | return f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s
597 | else:
598 | f1_a, f1_a_macro, f1_a_s = [round(i, 4) for i in res]
599 | return 0, 0, f1_a, 0, 0, f1_a_macro, f1_a_s
600 |
601 |
602 | if args.dataset == 'rcv1':
603 | if 'oh' in args.base_model:
604 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = load_data_rcv1_onehot('_rcv1_ptAll')
605 | else:
606 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = load_data_rcv1(
607 | '../datasets/glove.6B.50d.txt', '_rcv1_ptAll')
608 | if args.filter_ancestors:
609 | id2doc_train = filter_ancestors(id2doc, nodes)
610 | if args.split_multi:
611 | X_train, train_ids, id2doc_train, id2doc = split_multi(X_train, train_ids, id2doc_train, id2doc)
612 | else:
613 | id2doc_train = id2doc
614 | tree = Tree(args, train_ids, test_ids, id2doc=id2doc_train, id2doc_a=id2doc, nodes=nodes, rootname='Root')
615 |
616 | elif args.dataset == 'yelp':
617 | subtree_name = 'root'
618 | min_reviews = 5
619 | max_reviews = 10
620 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = load_data_yelp('../datasets/glove.6B.50d.txt',
621 | '_yelp_root_100_{}_{}'.format(
622 | min_reviews, max_reviews),
623 | root=subtree_name,
624 | min_reviews=min_reviews,
625 | max_reviews=max_reviews)
626 | logger.warning(f'{len(X_train)} {len(train_ids)} {len(X_test)} {len(test_ids)}')
627 | # save_minibatch(logger, args, word_index, X_train, train_ids, 32, name='train_yelp_root_100_5_10_len256_padded')
628 | # save_minibatch(logger, args, word_index, X_test, test_ids, 32, name='test_yelp_root_100_5_10_len256_padded')
629 | if args.filter_ancestors:
630 | id2doc_train = filter_ancestors(id2doc, nodes)
631 | else:
632 | id2doc_train = id2doc
633 | tree = Tree(args, train_ids, test_ids, id2doc=id2doc_train, id2doc_a=id2doc, nodes=nodes, rootname=subtree_name)
634 | elif args.dataset == 'nyt':
635 | if 'oh' in args.base_model:
636 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = load_data_nyt_onehot('_nyt_ptAll')
637 | else:
638 | X_train, X_test, train_ids, test_ids, id2doc, wv, word_index, nodes = load_data_nyt(
639 | '../datasets/glove.6B.50d.txt', '_nyt_ptAll')
640 | if args.filter_ancestors:
641 | id2doc_train = filter_ancestors(id2doc, nodes)
642 | if args.split_multi:
643 | X_train, train_ids, id2doc_train, id2doc = split_multi(X_train, train_ids, id2doc_train, id2doc)
644 | else:
645 | id2doc_train = id2doc
646 | save_minibatch(logger, args, word_index, X_train, train_ids, 32, name='train_nyt')
647 | save_minibatch(logger, args, word_index, X_test, test_ids, 32, name='test_nyt')
648 | tree = Tree(args, train_ids, test_ids, id2doc=id2doc_train, id2doc_a=id2doc, nodes=nodes, rootname='Top')
649 | elif 'FUN' in args.dataset or 'GO' in args.dataset:
650 | X_train, _, train_ids, test_ids, id2doc, nodes = read_fungo(args.dataset)
651 | if args.filter_ancestors:
652 | id2doc_train = filter_ancestors(id2doc, nodes)
653 | else:
654 | id2doc_train = id2doc
655 | tree = Tree(args, train_ids, test_ids, id2doc=id2doc_train, id2doc_a=id2doc, nodes=nodes, rootname='Top')
656 | else:
657 | logger.error('No such dataset: {}'.format(args.dataset))
658 | exit(1)
659 | if args.stat_check:
660 | check_doc_size(X_train, logger)
661 | check_doc_size(X_test, logger)
662 | if args.gpu:
663 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
664 | sorted_gpu_info = get_gpu_memory_map()
665 | for gpu_id, (mem_left, util) in sorted_gpu_info:
666 | if mem_left >= args.min_mem:
667 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
668 | logger.info('use gpu:{} with {} MB left, util {}%'.format(gpu_id, mem_left, util))
669 | break
670 | else:
671 | logger.warn(f'no gpu has memory left >= {args.min_mem} MB, exiting...')
672 | exit()
673 | else:
674 | torch.set_num_threads(10)
675 | if 'cnn' in args.base_model:
676 | if args.base_model == 'textcnn':
677 | model = TextCNN(args, word_vec=wv, n_classes=tree.n_class - 1)
678 | in_dim = 3000
679 | elif args.base_model == 'ohcnn-bow-fast':
680 | model = OHCNN_fast(word_index['UNK'], n_classes=tree.n_class - 1, vocab_size=len(word_index))
681 | in_dim = 10000
682 | if args.mode == 'hmcn':
683 | local_output_size = tree.get_layer_node_number()
684 | model = HMCN(model, args, local_output_size, tree.n_class - 1, in_dim)
685 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_weight)
686 | if args.gpu:
687 | logger.info(model.cuda())
688 | elif args.base_model == 'han':
689 | model = HAN(args, word_vec=wv, n_classes=tree.n_class - 1)
690 | in_dim = args.sent_gru_hidden_size * 2
691 | if args.mode == 'sl':
692 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_weight)
693 | if args.gpu:
694 | logger.info(model.cuda())
695 | elif args.base_model == 'raw':
696 | local_output_size = tree.get_layer_node_number()
697 | if args.dataset == 'rcv1':
698 | model = HMCN(None, args, local_output_size, tree.n_class - 1, 47236)
699 | elif args.dataset == 'yelp':
700 | model = HMCN(None, args, local_output_size, tree.n_class - 1, 146587)
701 | elif args.dataset == 'nyt':
702 | model = HMCN(None, args, local_output_size, tree.n_class - 1, 102755)
703 | elif 'FUN' in args.dataset or 'GO' in args.dataset:
704 | n_features = len(X_train[0])
705 | logger.info(f'n_features={n_features}')
706 | if args.mode == 'hmcn':
707 | model = HMCN(None, args, local_output_size, tree.n_class - 1, n_features)
708 | else:
709 | model = Linear_Model(args, n_features, tree.n_class - 1)
710 | X_train, X_test, train_ids, test_ids = None, None, None, None
711 | in_dim = args.n_hidden
712 | if args.gpu:
713 | model.cuda()
714 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_weight)
715 |
716 | base_model = model
717 | if args.mode in ['hilap', 'hilap-sl']:
718 | if args.mode == 'hilap':
719 | policy = Policy(args, tree.n_class + 1, base_model, in_dim)
720 | else:
721 | policy = Policy(args, tree.n_class, base_model, in_dim)
722 | if args.gpu:
723 | logger.info(policy.cuda())
724 | optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr, weight_decay=args.l2_weight)
725 | for name, param in policy.named_parameters():
726 | logger.info('{} {} {}'.format(name, type(param.data), param.size()))
727 |
728 | if args.mode == 'hmcn':
729 | criterion = torch.nn.BCELoss()
730 | else:
731 | criterion = torch.nn.BCEWithLogitsLoss()
732 | if args.load_model:
733 | if os.path.isfile(args.load_model):
734 | checkpoint = torch.load(args.load_model)
735 | load_optimizer = True
736 | if args.mode in ['sl', 'hmcn']:
737 | model.load_state_dict(checkpoint['state_dict'])
738 | else:
739 | policy_dict = policy.state_dict()
740 | load_from_sl = False
741 | if load_from_sl:
742 | for i in list(checkpoint['state_dict'].keys()):
743 | checkpoint['state_dict']['base_model.' + i] = checkpoint['state_dict'].pop(i)
744 | checkpoint['state_dict']['class_embed.weight'] = torch.cat(
745 | [policy_dict['class_embed.weight'][-1:], checkpoint['state_dict']['base_model.fc2.weight'],
746 | policy_dict['class_embed.weight'][-1:]])
747 | checkpoint['state_dict']['class_embed_bias.weight'] = torch.cat(
748 | [policy_dict['class_embed_bias.weight'][-1:],
749 | checkpoint['state_dict']['base_model.fc2.bias'].view(-1, 1),
750 | policy_dict['class_embed_bias.weight'][-1:]])
751 | load_optimizer = False
752 | elif checkpoint['state_dict']['class_embed.weight'].size()[0] == \
753 | policy_dict['class_embed.weight'].size()[0] - 1:
754 | logger.warning('try loading pretrained x for rl-taxo. class_embed also loaded.')
755 | load_optimizer = False
756 | # checkpoint['state_dict']['class_embed.weight'] = policy_dict['class_embed.weight']
757 | checkpoint['state_dict']['class_embed.weight'] = torch.cat(
758 | [checkpoint['state_dict']['class_embed.weight'], policy_dict['class_embed.weight'][-1:]])
759 | checkpoint['state_dict']['class_embed_bias.weight'] = torch.cat(
760 | [checkpoint['state_dict']['class_embed_bias.weight'], policy_dict['class_embed_bias.weight'][-1:]])
761 | logger.warning(checkpoint['state_dict']['class_embed.weight'].size())
762 | policy.load_state_dict(checkpoint['state_dict'])
763 | if load_optimizer:
764 | optimizer.load_state_dict(checkpoint['optimizer'])
765 | else:
766 | logger.warning('optimizer not loaded')
767 | logger.info("loaded checkpoint '{}' ".format(args.load_model))
768 | else:
769 | logger.error("no checkpoint found at '{}'".format(args.load_model))
770 | exit(1)
771 | if args.stat_check:
772 | evaluate([[tree.id2idx[id2doc_train[did]['categories'][0]]] for did in train_ids], train_ids)
773 | evaluate([[tree.id2idx[id2doc_train[did]['categories'][0]]] for did in test_ids], test_ids)
774 | if not args.load_model or args.isTrain:
775 | if args.mode == 'hilap-sl':
776 | train_step_sl()
777 | # test_tmp('test' + args.ohcnn_data)
778 | test_step_sl('test' + args.ohcnn_data, save_prob=False)
779 | elif args.mode == 'hilap':
780 | train_taxo()
781 | # test_taxo('train' + args.ohcnn_data, save_prob=False)
782 | # test_taxo('test' + args.ohcnn_data, save_prob=False)
783 | elif args.mode == 'hmcn':
784 | train_hmcn()
785 | else:
786 | train_sl()
787 | test_sl()
788 | else:
789 | if args.mode == 'hilap':
790 | # test_step_sl('train' + args.ohcnn_data, save_prob=False)
791 | # test_step_sl('test' + args.ohcnn_data, save_prob=True)
792 | test_taxo('train' + args.ohcnn_data, save_prob=False)
793 | test_taxo('test' + args.ohcnn_data, save_prob=False)
794 | elif args.mode == 'hilap-sl':
795 | # test_sl(X_test, test_ids) # for testing global_flat with additional local_loss
796 | # test_step_sl('train' + args.ohcnn_data, save_prob=True)
797 | # test_tmp('test' + args.ohcnn_data)
798 | test_step_sl('train' + args.ohcnn_data, save_prob=False)
799 | test_step_sl('test' + args.ohcnn_data, save_prob=False)
800 | elif args.mode == 'hmcn':
801 | test_hmcn()
802 | else:
803 | test_sl()
804 | writer.close()
805 | logger.info(f'log_dir: {log_dir}')
806 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | from torch.distributions import Categorical
9 |
10 |
11 | class Policy(nn.Module):
12 | def __init__(self, args, n_class, base_model, in_dim):
13 | super(Policy, self).__init__()
14 | self.args = args
15 | self.L = 0.02
16 | self.baseline_reward = 0
17 | self.entropy_l = []
18 | self.beta = args.beta
19 | self.beta_decay_rate = .9
20 | self.n_update = 0
21 | self.class_embed = nn.Embedding(n_class, args.class_embed_size)
22 | self.class_embed_bias = nn.Embedding(n_class, 1)
23 |
24 | stdv = 1. / np.sqrt(self.class_embed.weight.size(1))
25 | self.class_embed.weight.data.uniform_(-stdv, stdv)
26 | self.class_embed_bias.weight.data.uniform_(-stdv, stdv)
27 |
28 | self.saved_log_probs = []
29 | self.rewards = []
30 | self.rewards_greedy = []
31 | self.doc_vec = None
32 | self.base_model = base_model
33 |
34 | self.criterion = torch.nn.BCEWithLogitsLoss()
35 | self.sl_loss = 0
36 |
37 | if self.args.use_history:
38 | self.state_hist = None
39 | self.output_hist = None
40 | self.hist_gru = nn.GRU(args.class_embed_size, args.class_embed_size, bidirectional=True)
41 | if self.args.use_cur_class_embed:
42 | in_dim += self.args.class_embed_size
43 | if self.args.use_history:
44 | in_dim += args.hist_embed_size * 2
45 | if self.args.use_l2:
46 | self.l1 = nn.Linear(in_dim, args.l1_size)
47 | self.l2 = nn.Linear(args.l1_size, args.class_embed_size)
48 | elif self.args.use_l1:
49 | self.l1 = nn.Linear(in_dim, args.class_embed_size)
50 |
51 | def update_baseline(self, target):
52 | # a moving average baseline, not used anymore
53 | self.baseline_reward = self.L * target + (1 - self.L) * self.baseline_reward
54 |
55 | def finish_episode(self):
56 | self.sl_loss = 0
57 | self.n_update += 1
58 | if self.n_update % self.args.update_beta_every == 0:
59 | self.beta *= self.beta_decay_rate
60 | self.entropy_l = []
61 | del self.rewards[:]
62 | del self.saved_log_probs[:]
63 | del self.rewards_greedy[:]
64 |
65 | def forward(self, cur_class_batch, next_classes_batch):
66 | cur_class_embed = self.class_embed(cur_class_batch) # (batch, 50)
67 | next_classes_embed = self.class_embed(next_classes_batch) # (batch, max_choices, 50)
68 | nb = self.class_embed_bias(next_classes_batch).squeeze(-1)
69 | states_embed = self.doc_vec
70 | if self.args.use_cur_class_embed:
71 | states_embed = torch.cat((states_embed, cur_class_embed), 1)
72 | if self.args.use_history:
73 | states_embed = torch.cat((states_embed, self.output_hist.squeeze()), 1)
74 | if not self.args.use_l1:
75 | return torch.bmm(next_classes_embed, states_embed.unsqueeze(-1)).squeeze(-1) + nb
76 | if self.args.use_l2:
77 | h1 = F.relu(self.l1(states_embed))
78 | h2 = F.relu(self.l2(h1))
79 | else:
80 | h2 = F.relu(self.l1(states_embed))
81 | h2 = h2.unsqueeze(-1) # (batch, 50, 1)
82 | probs = torch.bmm(next_classes_embed, h2).squeeze(-1) + nb
83 | if self.args.use_history:
84 | self.output_hist, self.state_hist = self.hist_gru(cur_class_embed.unsqueeze(0), self.state_hist)
85 | return probs
86 |
87 | def duplicate_doc_vec(self, indices):
88 | assert self.doc_vec is not None
89 | assert len(indices) > 0
90 | self.doc_vec = self.doc_vec[indices]
91 |
92 | def duplicate_reward(self, indices):
93 | assert len(indices) > 0
94 | self.saved_log_probs[-1] = [[probs[i] for i in indices] for probs in self.saved_log_probs[-1]]
95 | self.rewards[-1] = [[R[i] for i in indices] for R in self.rewards[-1]]
96 |
97 | def generate_doc_vec(self, mini_batch):
98 | self.doc_vec = self.base_model(mini_batch)
99 |
100 | def generate_logits(self, mini_batch, cur_class_batch, next_classes_batch):
101 | if self.doc_vec is None:
102 | self.generate_doc_vec(mini_batch)
103 | if self.args.gpu:
104 | cur_class_batch = Variable(torch.from_numpy(cur_class_batch)).cuda()
105 | next_classes_batch = Variable(torch.from_numpy(next_classes_batch)).cuda()
106 | else:
107 | cur_class_batch = Variable(torch.from_numpy(cur_class_batch))
108 | next_classes_batch = Variable(torch.from_numpy(next_classes_batch))
109 | logits = self(cur_class_batch, next_classes_batch)
110 | # mask padding relations
111 | logits = (next_classes_batch == 0).float() * -99999 + (next_classes_batch != 0).float() * logits
112 | return logits
113 |
114 | def step_sl(self, mini_batch, cur_class_batch, next_classes_batch, next_classes_batch_true, sigmoid=True):
115 | logits = self.generate_logits(mini_batch, cur_class_batch, next_classes_batch)
116 | if not sigmoid:
117 | return logits
118 | if next_classes_batch_true is not None:
119 | if self.args.gpu:
120 | y_true = Variable(torch.from_numpy(next_classes_batch_true)).cuda().float()
121 | else:
122 | y_true = Variable(torch.from_numpy(next_classes_batch_true)).float()
123 | self.sl_loss += self.criterion(logits, y_true)
124 | return F.sigmoid(logits)
125 |
126 | def step(self, mini_batch, cur_class_batch, next_classes_batch, test=False, flat_probs=None):
127 | logits = self.generate_logits(mini_batch, cur_class_batch, next_classes_batch)
128 | if self.args.softmax:
129 | probs = F.softmax(logits, dim=-1)
130 | else:
131 | probs = F.sigmoid(logits)
132 | if not test:
133 | # + epsilon to avoid log(0)
134 | self.entropy_l.append(torch.mean(torch.log(probs + 1e-32) * probs))
135 | next_classes_batch = Variable(torch.from_numpy(next_classes_batch)).cuda()
136 | probs = probs + (next_classes_batch != 0).float() * 1e-16
137 | m = Categorical(probs)
138 | if test or self.args.sample_mode == 'choose_max':
139 | action = torch.max(probs, 1)[1]
140 | elif self.args.sample_mode == 'random':
141 | if random.random() < 1.2:
142 | if self.args.gpu:
143 | action = Variable(torch.zeros(probs.size()[0]).long().random_(0, probs.size()[1])).cuda()
144 | else:
145 | action = Variable(torch.zeros(probs.size()[0]).long().random_(0, probs.size()[1]))
146 | else:
147 | action = m.sample()
148 | else:
149 | action = m.sample()
150 | return action, m
151 |
152 | # not used anymore
153 | def init_hist(self, tokens_size):
154 | if self.args.gpu:
155 | self.state_hist = self.init_hidden(tokens_size, self.args.hist_embed_size).cuda()
156 | first_input = Variable(
157 | torch.from_numpy(np.zeros((1, tokens_size, self.args.class_embed_size)))).cuda().float()
158 | else:
159 | self.state_hist = self.init_hidden(tokens_size, self.args.hist_embed_size)
160 | first_input = Variable(
161 | torch.from_numpy((np.zeros(1, tokens_size, self.args.class_embed_size)))).float()
162 | self.output_hist, self.state_hist = self.hist_gru(first_input, self.state_hist)
163 |
--------------------------------------------------------------------------------
/readData_fungo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import random
4 | from collections import defaultdict
5 |
6 | import numpy as np
7 | from IPython import embed
8 |
9 |
10 | def replace_nan_by_mean(X):
11 | # Obtain mean of columns as you need, nanmean is just convenient.
12 | col_mean = np.nanmean(X, axis=0)
13 | # Find indicies that you need to replace
14 | inds = np.where(np.isnan(X))
15 | # Place column means in the indices. Align the arrays using take
16 | X[inds] = np.take(col_mean, inds[1])
17 | return X
18 |
19 |
20 | def read_fungo_(f_in):
21 | """
22 | X: features
23 | Y: labels
24 | C: for *_FUN: in format ancestor1/ancestor2/.../label
25 | for *_GO: in format parent/child [CAUTION: since it is a DAG, its size is larger than C_slash]
26 | C_slash: unique labels for *_GO [somehow have 3 more labels than HMCN's paper.]
27 |
28 | """
29 | ct = 0
30 | A = 0
31 | C = 0
32 | C_set = set()
33 | flag = False
34 | X = []
35 | Y = []
36 | with open(f_in) as f:
37 | for line in f:
38 | if line.startswith('@ATTRIBUTE'):
39 | if '/' in line:
40 | # print(line.split(','))
41 | C = line.strip().split(',')
42 | C[0] = C[0].split()[-1]
43 | # print([i.split('/')[-1] for i in C][-10:])
44 | C_slash = set([i.split('/')[-1] for i in C])
45 | else:
46 | A += 1
47 | if flag:
48 | ct += 1
49 | data = line.strip().split(',')
50 | classes = data[-1].split('@')
51 | # convert ? to nan
52 | X.append([float(i) if i != '?' else np.nan for i in data[:-1]])
53 | Y.append(classes)
54 | C_set.update(classes)
55 | if line.startswith('@DATA'):
56 | flag = True
57 | X = np.array(X)
58 | X = replace_nan_by_mean(X)
59 | # print(f'[{f_in}] #features={A}, #Classes={len(C)}, #C_slash={len(C_slash)}, #C_appear={len(C_set)}, #samples={ct}')
60 | return X, Y, C, C_slash
61 |
62 |
63 | def read_fungo_all(name):
64 | valid = read_fungo_(
65 | f"./protein_datasets/{name}.valid.arff")
66 | test = read_fungo_(
67 | f"./protein_datasets/{name}.test.arff")
68 | train = read_fungo_(
69 | f"./protein_datasets/{name}.train.arff")
70 | return train, valid, test
71 |
72 |
73 | def construct_hier_dag(name):
74 | hierarchy = defaultdict(set)
75 | train_classes = read_fungo_(
76 | f"./protein_datasets/{name}.train.arff")[2]
77 | valid_classes = read_fungo_(
78 | f"./protein_datasets/{name}.valid.arff")[2]
79 | test_classes = read_fungo_(
80 | f"./protein_datasets/{name}.test.arff")[2]
81 | for t in [train_classes, valid_classes, test_classes]:
82 | for y in t:
83 | parent, child = y.split('/')
84 | if parent == 'root':
85 | parent = 'Top'
86 | hierarchy[parent].add(child)
87 | if not child in hierarchy:
88 | hierarchy[child] = set()
89 | return hierarchy
90 |
91 |
92 | def construct_hierarchy(name):
93 | # Construct hierarchy from train/valid/test.
94 | hierarchy = defaultdict(set)
95 | train_classes = read_fungo_(
96 | f"./protein_datasets/{name}.train.arff")[2]
97 | valid_classes = read_fungo_(
98 | f"./protein_datasets/{name}.valid.arff")[2]
99 | test_classes = read_fungo_(
100 | f"./protein_datasets/{name}.test.arff")[2]
101 | for t in [train_classes, valid_classes, test_classes]:
102 | for y in t:
103 | hier_list = y.split('/')
104 | # Add a pseudo node: Top
105 | if len(hier_list) == 1:
106 | hierarchy['Top'].add(hier_list[0])
107 | hierarchy[hier_list[0]] = set()
108 | continue
109 | for l in range(1, len(hier_list)):
110 | parent = '/'.join(hier_list[:l])
111 | child = '/'.join(hier_list[:l + 1])
112 | if l == 1:
113 | hierarchy['Top'].add(parent)
114 | if not child in hierarchy[parent]:
115 | hierarchy[parent].add(child)
116 | if not child in hierarchy:
117 | hierarchy[child] = set()
118 | return hierarchy
119 |
120 |
121 | def get_all_ancestor_nodes(hierarchy, node):
122 | node_list = set()
123 |
124 | def dfs(node):
125 | if node != 'Top':
126 | node_list.add(node)
127 | parents = hierarchy[node]['parent']
128 | for parent in parents:
129 | dfs(parent)
130 |
131 | dfs(node)
132 | return node_list
133 |
134 |
135 | def read_go(name):
136 | if os.path.exists(f'./protein_datasets/{name}.pkl'):
137 | return pickle.load(open(f'./protein_datasets/{name}.pkl', 'rb'))
138 | p2c = defaultdict(list)
139 | id2doc = defaultdict(lambda: defaultdict(list))
140 | nodes = defaultdict(lambda: defaultdict(list))
141 | random.seed(42)
142 |
143 | train, valid, test = read_fungo_all(name)
144 | hierarchy = construct_hier_dag(name)
145 | for parent in hierarchy:
146 | for child in hierarchy[parent]:
147 | p2c[parent].append(child)
148 | for label in p2c:
149 | for children in p2c[label]:
150 | nodes[label]['children'].append(children)
151 | nodes[children]['parent'].append(label)
152 |
153 | train_data = np.concatenate([train[0], valid[0]])
154 | train[1].extend(valid[1])
155 |
156 | X_train = []
157 | X_test = []
158 | train_ids = []
159 | test_ids = []
160 |
161 | for idx, (feature, classes) in enumerate(zip(train_data, train[1])):
162 | X_train.append(feature)
163 | train_ids.append(idx)
164 | for class_ in classes:
165 | ancestor_nodes = get_all_ancestor_nodes(nodes, class_)
166 | for label in ancestor_nodes:
167 | if not label in id2doc[idx]['categories']:
168 | id2doc[idx]['categories'].append(label)
169 |
170 | for idx, (feature, classes) in enumerate(zip(test[0], test[1])):
171 | X_test.append(feature)
172 | test_ids.append(idx + train_data.shape[0])
173 | for class_ in classes:
174 | ancestor_nodes = get_all_ancestor_nodes(nodes, class_)
175 | for label in ancestor_nodes:
176 | if not label in id2doc[idx + train_data.shape[0]]['categories']:
177 | id2doc[idx + train_data.shape[0]]['categories'].append(label)
178 | res = X_train, X_test, train_ids, test_ids, dict(id2doc), dict(nodes)
179 | pickle.dump(res, open(f'./protein_datasets/{name}.pkl', 'wb'))
180 | return res
181 |
182 |
183 | def read_fun(name):
184 | if os.path.exists(f'./protein_datasets/{name}.pkl'):
185 | return pickle.load(open(f'./protein_datasets/{name}.pkl', 'rb'))
186 | p2c = defaultdict(list)
187 | id2doc = defaultdict(lambda: defaultdict(list))
188 | nodes = defaultdict(lambda: defaultdict(list))
189 | random.seed(42)
190 |
191 | train, valid, test = read_fungo_all(name)
192 | hierarchy = construct_hierarchy(name)
193 | for parent in hierarchy:
194 | for child in hierarchy[parent]:
195 | p2c[parent].append(child)
196 | for label in p2c:
197 | for children in p2c[label]:
198 | nodes[label]['children'].append(children)
199 | nodes[children]['parent'].append(label)
200 |
201 | train_data = np.concatenate([train[0], valid[0]])
202 | train[1].extend(valid[1])
203 |
204 | X_train = []
205 | X_test = []
206 | train_ids = []
207 | test_ids = []
208 |
209 | for idx, (feature, classes) in enumerate(zip(train_data, train[1])):
210 | X_train.append(feature)
211 | train_ids.append(idx)
212 | for class_ in classes:
213 | hier_list = class_.split('/')
214 | for l in range(1, len(hier_list) + 1):
215 | label = '/'.join(hier_list[:l])
216 | if not label in id2doc[idx]['categories']:
217 | id2doc[idx]['categories'].append(label)
218 |
219 | for idx, (feature, classes) in enumerate(zip(test[0], test[1])):
220 | X_test.append(feature)
221 | test_ids.append(idx + train_data.shape[0])
222 | for class_ in classes:
223 | # For each sample, we treat all nodes in the path as labels.
224 | hier_list = class_.split('/')
225 | for l in range(1, len(hier_list) + 1):
226 | label = '/'.join(hier_list[:l])
227 | if not label in id2doc[idx + train_data.shape[0]]['categories']:
228 | id2doc[idx + train_data.shape[0]]['categories'].append(label)
229 | res = X_train, X_test, train_ids, test_ids, dict(id2doc), dict(nodes)
230 | pickle.dump(res, open(f'./protein_datasets/{name}.pkl', 'wb'))
231 | return res
232 |
233 |
234 | def read_fungo(name):
235 | if 'FUN' in name:
236 | return read_fun(name)
237 | return read_go(name)
238 |
239 |
240 | def process_for_cssag(name, train_data, train_labels, test_data, test_labels):
241 | if 'FUN' in name:
242 | hierarchy = construct_hierarchy(name)
243 | else:
244 | hierarchy = construct_hier_dag(name)
245 | nodes = defaultdict(lambda: defaultdict(list))
246 | p2c = defaultdict(list)
247 | for parent in hierarchy:
248 | for child in hierarchy[parent]:
249 | p2c[parent].append(child)
250 | for label in p2c:
251 | for children in p2c[label]:
252 | nodes[label]['children'].append(children)
253 | nodes[children]['parent'].append(label)
254 | label2id = {}
255 | for k in hierarchy:
256 | if k not in label2id:
257 | label2id[k] = len(label2id)
258 | with open('./cssag/' + name + '.hier', 'w') as OUT:
259 | for k in hierarchy:
260 | OUT.write(str(label2id[k]))
261 | for c in hierarchy[k]:
262 | OUT.write(' ' + str(label2id[c]))
263 | OUT.write('\n')
264 | with open('./cssag/' + name + '.train.x', 'w') as OUT:
265 | for l in range(train_data.shape[0]):
266 | x = train_data[l]
267 | OUT.write('0')
268 | for idx, i in enumerate(x):
269 | OUT.write(' ' + str(idx + 1) + ':' + str(i))
270 | OUT.write('\n')
271 | with open('./cssag/' + name + '.test.x', 'w') as OUT:
272 | for l in range(test_data.shape[0]):
273 | x = test_data[l]
274 | OUT.write('0')
275 | for idx, i in enumerate(x):
276 | OUT.write(' ' + str(idx + 1) + ':' + str(i))
277 | OUT.write('\n')
278 | if 'FUN' in name:
279 | with open('./cssag/' + name + '.train.y', 'w') as OUT:
280 | for classes in train_labels:
281 | # OUT.write(str(label2id['Top']))
282 | labels = set()
283 | for class_ in classes:
284 | hier_list = class_.split('/')
285 | for l in range(1, len(hier_list) + 1):
286 | label = '/'.join(hier_list[:l])
287 | labels.add(label)
288 | for i, label in enumerate(labels):
289 | if i > 0:
290 | OUT.write(',')
291 | OUT.write(str(label2id[label]))
292 | OUT.write('\n')
293 | with open('./cssag/' + name + '.test.y', 'w') as OUT:
294 | for classes in test_labels:
295 | # OUT.write(str(label2id['Top']))
296 | labels = set()
297 | for class_ in classes:
298 | hier_list = class_.split('/')
299 | for l in range(1, len(hier_list) + 1):
300 | label = '/'.join(hier_list[:l])
301 | labels.add(label)
302 | for i, label in enumerate(labels):
303 | if i > 0:
304 | OUT.write(',')
305 | OUT.write(str(label2id[label]))
306 | OUT.write('\n')
307 |
308 | elif 'GO' in name:
309 | with open('./cssag/' + name + '.train.y', 'w') as OUT:
310 | for classes in train_labels:
311 | labels = set()
312 | for class_ in classes:
313 | ancestor_nodes = get_all_ancestor_nodes(nodes, class_)
314 | for label in ancestor_nodes:
315 | labels.add(label)
316 | for i, label in enumerate(labels):
317 | if i > 0:
318 | OUT.write(',')
319 | OUT.write(str(label2id[label]))
320 | OUT.write('\n')
321 | with open('./cssag/' + name + '.test.y', 'w') as OUT:
322 | for classes in test_labels:
323 | labels = set()
324 | for class_ in classes:
325 | ancestor_nodes = get_all_ancestor_nodes(nodes, class_)
326 | for label in ancestor_nodes:
327 | labels.add(label)
328 | for i, label in enumerate(labels):
329 | if i > 0:
330 | OUT.write(',')
331 | OUT.write(str(label2id[label]))
332 | OUT.write('\n')
333 |
334 |
335 | if __name__ == '__main__':
336 | # res = read_go()
337 | # embed()
338 | # exit()
339 | name = 'eisen_FUN'
340 | train, valid, test = read_fungo_all(name)
341 | train_data = np.concatenate([train[0], valid[0]])
342 | train[1].extend(valid[1])
343 | train_labels = train[1]
344 | process_for_cssag(name, train_data, train_labels, test[0], test[1])
345 | embed()
346 | exit()
347 |
--------------------------------------------------------------------------------
/readData_nyt.py:
--------------------------------------------------------------------------------
1 | import random
2 | import xml.dom.minidom
3 | from collections import defaultdict
4 |
5 | from tqdm import tqdm
6 |
7 | sample_ratio = 0.02
8 | train_ratio = 0.7
9 | min_per_node = 50
10 |
11 |
12 | def read_nyt_ids(file_path1, file_path2):
13 | ids = [[], []]
14 | random.seed(42)
15 | for cnt, file_path in enumerate([file_path1, file_path2]):
16 | filelist_path = './datasets/NYT_annotated_corpus/data/filelist_' + file_path + '.txt'
17 | with open(filelist_path, 'r') as fin:
18 | for file_cnt, file_name in tqdm(enumerate(fin)):
19 | if random.random() < sample_ratio:
20 | ids[cnt].append(file_name[:-5])
21 | return ids
22 |
23 |
24 | def construct_hierarchy(file_path1, file_path2):
25 | hierarchy = defaultdict(set)
26 | for cnt, file_path in enumerate([file_path1, file_path2]):
27 | print('Processing file %d...' % cnt)
28 | filelist_path = './datasets/NYT_annotated_corpus/data/filelist_' + file_path + '.txt'
29 | with open(filelist_path, 'r') as fin:
30 | for file_cnt, file_name in tqdm(enumerate(fin)):
31 | file_name = file_name.strip('\n')
32 | xml_path = './datasets/NYT_annotated_corpus/data/accum' + file_path + '/' + file_name
33 | try:
34 | dom = xml.dom.minidom.parse(xml_path)
35 | root = dom.documentElement
36 | tags = root.getElementsByTagName('classifier')
37 | for tag in tags:
38 | type = tag.getAttribute('type')
39 | if type != 'taxonomic_classifier':
40 | continue
41 | hier_path = tag.firstChild.data
42 | hier_list = hier_path.split('/')
43 | for l in range(1, len(hier_list)):
44 | parent = '/'.join(hier_list[:l])
45 | child = '/'.join(hier_list[:l + 1])
46 | if not child in hierarchy[parent]:
47 | hierarchy[parent].add(child)
48 | if not child in hierarchy:
49 | hierarchy[child] = set()
50 | except:
51 | print('Something went wrong...')
52 | continue
53 | write_hierarchy_to_file(hierarchy, './datasets/nyt/nyt_hier')
54 |
55 |
56 | def write_hierarchy_to_file(hierarchy, filepath):
57 | with open(filepath, 'w') as fout:
58 | nodes = hierarchy.keys()
59 | for parent in nodes:
60 | for child in hierarchy[parent]:
61 | fout.write(parent + '\t' + child + '\n')
62 |
63 |
64 | def trim_nyt_tree():
65 | ids_1, ids_2 = read_nyt_ids('1987-02', '2003-07')
66 | label_cnt = defaultdict(int)
67 | for id_idx in tqdm(range(len(ids_1) + len(ids_2))):
68 | if id_idx < len(ids_1):
69 | doc_id = ids_1[id_idx]
70 | xml_path = './datasets/NYT_annotated_corpus/data/accum1987-02/' + str(doc_id) + '.xml'
71 | else:
72 | doc_id = ids_2[id_idx - len(ids_1)]
73 | xml_path = './datasets/NYT_annotated_corpus/data/accum2003-07/' + str(doc_id) + '.xml'
74 | try:
75 | dom = xml.dom.minidom.parse(xml_path)
76 | root = dom.documentElement
77 | tags = root.getElementsByTagName('classifier')
78 | for tag in tags:
79 | type = tag.getAttribute('type')
80 | if type != 'taxonomic_classifier':
81 | continue
82 | hier_path = tag.firstChild.data
83 | hier_list = hier_path.split('/')
84 | for l in range(1, len(hier_list) + 1):
85 | label = '/'.join(hier_list[:l])
86 | label_cnt[label] += 1
87 | except:
88 | print('Something went wrong...')
89 | continue
90 | label_cnt = {label: label_cnt[label] for label in label_cnt if label_cnt[label] > min_per_node}
91 | with open('nyt/nyt_trimed_hier', 'w') as fout:
92 | for label in label_cnt:
93 | fout.write(label + '\n')
94 | return label_cnt
95 |
96 |
97 | def read_nyt():
98 | p2c = defaultdict(list)
99 | id2doc = defaultdict(lambda: defaultdict(list))
100 | nodes = defaultdict(lambda: defaultdict(list))
101 | random.seed(42)
102 |
103 | trimmed_nodes = set()
104 | with open('nyt/nyt_trimed_hier', 'r') as fin:
105 | for line in fin:
106 | trimmed_nodes.add(line.strip('\n'))
107 | with open('nyt/nyt_hier', 'r') as f:
108 | for line in f:
109 | line = line.strip('\n')
110 | parent, child = line.split('\t')
111 | if parent in trimmed_nodes and child in trimmed_nodes:
112 | p2c[parent].append(child)
113 | for label in p2c:
114 | for children in p2c[label]:
115 | nodes[label]['children'].append(children)
116 | nodes[children]['parent'].append(label)
117 |
118 | ids_1, ids_2 = read_nyt_ids('1987-02', '2003-07')
119 | X_train = []
120 | X_test = []
121 | train_ids = []
122 | test_ids = []
123 | for id_idx in tqdm(range(len(ids_1) + len(ids_2))):
124 | if id_idx < len(ids_1):
125 | doc_id = ids_1[id_idx]
126 | xml_path = './datasets/NYT_annotated_corpus/data/accum1987-02/' + str(doc_id) + '.xml'
127 | else:
128 | doc_id = ids_2[id_idx - len(ids_1)]
129 | xml_path = './datasets/NYT_annotated_corpus/data/accum2003-07/' + str(doc_id) + '.xml'
130 | try:
131 | dom = xml.dom.minidom.parse(xml_path)
132 | root = dom.documentElement
133 | tags = root.getElementsByTagName('p')
134 | text = ''
135 | for tag in tags[1:]:
136 | text += tag.firstChild.data
137 | if text == '':
138 | print('Something went wrong with text...')
139 | continue
140 | tags = root.getElementsByTagName('classifier')
141 | for tag in tags:
142 | type = tag.getAttribute('type')
143 | if type != 'taxonomic_classifier':
144 | continue
145 | hier_path = tag.firstChild.data
146 | hier_list = hier_path.split('/')
147 | for l in range(1, len(hier_list) + 1):
148 | label = '/'.join(hier_list[:l])
149 | if label in trimmed_nodes and not label in id2doc[doc_id]['categories']:
150 | id2doc[doc_id]['categories'].append(label)
151 | if not doc_id in id2doc:
152 | continue
153 | if random.random() < train_ratio:
154 | train_ids.append(doc_id)
155 | X_train.append(text)
156 | else:
157 | test_ids.append(doc_id)
158 | X_test.append(text)
159 | except:
160 | print('Something went wrong...')
161 | continue
162 |
163 | return X_train, X_test, train_ids, test_ids, dict(id2doc), dict(nodes)
164 |
--------------------------------------------------------------------------------
/readData_rcv1.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from tqdm import tqdm
3 |
4 |
5 | def read_rcv1_ids(filepath):
6 | ids = set()
7 | with open(filepath) as f:
8 | new_doc = True
9 | for line in f:
10 | line_split = line.strip().split()
11 | if new_doc and len(line_split) == 2:
12 | tmp, did = line_split
13 | if tmp == '.I':
14 | ids.add(did)
15 | new_doc = False
16 | else:
17 | print(line_split)
18 | print('maybe error')
19 | elif len(line_split) == 0:
20 | new_doc = True
21 | print('{} samples in {}'.format(len(ids), filepath))
22 | return ids
23 |
24 |
25 | def read_rcv1():
26 | p2c = defaultdict(list)
27 | id2doc = defaultdict(lambda: defaultdict(list))
28 | nodes = defaultdict(lambda: defaultdict(list))
29 | with open('rcv1/rcv1.topics.hier.orig.txt') as f:
30 | for line in f:
31 | start = line.find('parent: ') + len('parent: ')
32 | end = line.find(' ', start)
33 | parent = line[start:end]
34 | start = line.find('child: ') + len('child: ')
35 | end = line.find(' ', start)
36 | child = line[start:end]
37 | start = line.find('child-description: ') + len('child-description: ')
38 | end = line.find('\n', start)
39 | child_desc = line[start:end]
40 | p2c[parent].append(child)
41 | for label in p2c:
42 | if label == 'None':
43 | continue
44 | for children in p2c[label]:
45 | nodes[label]['children'].append(children)
46 | nodes[children]['parent'].append(label)
47 |
48 | with open('rcv1/rcv1-v2.topics.qrels') as f:
49 | for line in f:
50 | cat, doc_id, _ = line.strip().split()
51 | id2doc[doc_id]['categories'].append(cat)
52 | X_train = []
53 | X_test = []
54 | train_ids = []
55 | test_ids = []
56 | train_id_set = read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_train.dat')
57 | test_id_set = read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt0.dat')
58 | test_id_set |= read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt1.dat')
59 | test_id_set |= read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt2.dat')
60 | test_id_set |= read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt3.dat')
61 | print('len(test) total={}'.format(len(test_id_set)))
62 | n_not_found = 0
63 | with open('rcv1/docs.txt') as f:
64 | for line in tqdm(f):
65 | doc_id, text = line.strip().split(maxsplit=1)
66 | if doc_id in train_id_set:
67 | train_ids.append(doc_id)
68 | X_train.append(text)
69 | elif doc_id in test_id_set:
70 | test_ids.append(doc_id)
71 | X_test.append(text)
72 | else:
73 | n_not_found += 1
74 | print('there are {} that cannot be found in official tokenized rcv1'.format(n_not_found))
75 | print('len(train_ids)={} len(test_ids)={}'.format(len(train_ids), len(test_ids)))
76 | return X_train, X_test, train_ids, test_ids, dict(id2doc), dict(nodes)
77 |
78 |
79 | if __name__ == '__main__':
80 | read_rcv1()
81 |
--------------------------------------------------------------------------------
/readData_yelp.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import random
3 |
4 | from tqdm import tqdm
5 |
6 |
7 | def read_yelp_files(root, review_min_sup):
8 | nodes = {}
9 | with open('yelp/Taxonomy_100') as f:
10 | for line in f:
11 | node = ast.literal_eval(line)
12 | nodes[node['title']] = node
13 | subtree_name = root
14 | node_set = set()
15 | q = [subtree_name]
16 | nextq = []
17 | print(q)
18 | while len(q) > 0:
19 | cur_name = q.pop(0)
20 | node_set.add(cur_name)
21 | print(nodes[cur_name]['children'], end=',')
22 | nextq.extend(nodes[cur_name]['children'])
23 | if len(q) == 0:
24 | q = nextq
25 | nextq = []
26 | print()
27 | print('keep {} nodes in total'.format(len(node_set)))
28 | print(node_set)
29 | nodes = {node: nodes[node] for node in node_set}
30 |
31 | business_dict = {}
32 | ct = 0
33 | below_min_sup = 0
34 | with open('yelp/yelp_data_100.csv') as f:
35 | for line in tqdm(f):
36 | business = ast.literal_eval(line)
37 | if len(business['reviews']) < review_min_sup:
38 | below_min_sup += 1
39 | continue
40 | labels = set(business['categories']) & node_set
41 | if len(labels) == 0:
42 | continue
43 | if len(labels) != len(business['categories']):
44 | ct += 1
45 | # print('there are labels on other subtree: {} {}'.format(labels, business['categories']))
46 | business['categories'] = labels
47 | business_dict[business['business_id']] = business
48 | print('keep {} business for tree {}'.format(len(business_dict), subtree_name))
49 | print('there are {} that have labels on other subtrees'.format(ct))
50 | print('there are {} that are filtered because of min_sup'.format(below_min_sup))
51 | return nodes, business_dict
52 |
53 |
54 | def read_yelp(root, min_reviews=1, max_reviews=10):
55 | nodes, business_dict = read_yelp_files(root, min_reviews)
56 | random.seed(42)
57 | train_ids = []
58 | test_ids = []
59 | X_train = []
60 | X_test = []
61 |
62 | for bid in business_dict:
63 | reviews_concat = ' '.join(business_dict[bid]['reviews'][:max_reviews])
64 | if random.random() > 0.3:
65 | train_ids.append(bid)
66 | X_train.append(reviews_concat)
67 | else:
68 | test_ids.append(bid)
69 | X_test.append(reviews_concat)
70 |
71 | return X_train, X_test, train_ids, test_ids, business_dict, nodes
72 |
73 |
74 | if __name__ == '__main__':
75 | read_yelp()
76 |
--------------------------------------------------------------------------------
/tree.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pickle
3 | from collections import defaultdict, Counter
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from scipy import sparse
8 | from sklearn.metrics import f1_score
9 | from sklearn.preprocessing import MultiLabelBinarizer
10 |
11 |
12 | class Tree:
13 | def __init__(self, args, train_ids, test_ids, id2doc=None, id2doc_a=None, nodes=None, X_train=None,
14 | X_test=None, rootname=None):
15 | # child, parent, ancestor
16 | self.c2p = defaultdict(list)
17 | self.p2c = defaultdict(list)
18 | self.p2c_idx = defaultdict(list)
19 | self.c2p_idx = defaultdict(list)
20 | self.c2a = defaultdict(set)
21 |
22 | # real id to auto-increment id
23 | self.id2idx = {}
24 | self.idx2id = {}
25 | # real id to label name
26 | self.id2name = {}
27 | self.name2id = {}
28 | # real id to height
29 | self.id2h = {}
30 | # doc id to doc obj
31 | self.id2doc = id2doc
32 | self.id2doc_ancestors = id2doc_a
33 | self.nodes = nodes # for yelp
34 | self.rootname = rootname
35 | self.id2doc_h = {}
36 | self.train_ids = train_ids
37 | self.test_ids = test_ids
38 | self.X_train = X_train
39 | self.X_test = X_test
40 | self.taken_actions = None
41 | self.n_update = 0 # global update times
42 | self.data_cache = None # loaded from pkl
43 | self.last_R = None # R at last time step
44 | self.id2prob = defaultdict(dict) # [id][class] = p
45 |
46 | self.args = args
47 | self.logger = logging.getLogger('exp')
48 | self.miF = (0, 0)
49 | self.maF = (0, 0)
50 | self.cur_epoch = 0
51 | self.read_mapping()
52 | self.read_edges()
53 | for parent in self.p2c:
54 | for child in self.p2c[parent]:
55 | self.p2c_idx[self.id2idx[parent]].append(self.id2idx[child])
56 | self.c2p_idx[self.id2idx[child]].append(self.id2idx[parent])
57 | self.logger.info('# class = {}'.format(len(self.name2id)))
58 | self.n_class = len(self.name2id)
59 | self.class_idx = list(range(1, self.n_class))
60 |
61 | self.p2c_idx_np = self.pad_p2c_idx()
62 | if args.mode != 'sl' and args.mode != 'hmcn':
63 | self.next_true_bin, self.next_true = self.generate_next_true(keep_cur=True)
64 | self.logger.info('{} terms have more than 1 parent'.format(sum([len(v) > 1 for v in self.c2p.values()])))
65 | leaves = set(self.c2p) - set(self.p2c)
66 | self.logger.info('{} terms are leaves'.format(len(leaves)))
67 | for term in self.c2p:
68 | ancestor_q = [i for i in self.c2p[term]]
69 | while len(ancestor_q) != 0:
70 | cur = ancestor_q.pop(0)
71 | if cur in self.c2p:
72 | ancestor_q.extend(self.c2p[cur])
73 | # exclude root in one's ancestors
74 | self.c2a[term].add(cur)
75 |
76 | # stat info
77 | # note that only the first path is considered
78 | for k in self.id2name:
79 | cur = k
80 | h = 0
81 | while cur in self.c2p:
82 | h += 1
83 | cur = self.c2p[cur][0]
84 | self.id2h[k] = h
85 | if args.stat_check:
86 | self.logger.info('node height description:')
87 | self.logger.info(pd.Series(list(self.id2h.values())).describe())
88 | self.doc_check(self.id2doc, 'all')
89 |
90 | def pad_p2c_idx(self):
91 | col = max([len(c) for c in self.p2c_idx.values()]) + 2
92 | res = np.zeros((len(self.name2id), col), dtype=int)
93 | for row_i in range(len(self.name2id)):
94 | # stay at cur node
95 | if self.args.allow_stay:
96 | res[row_i, len(self.p2c_idx[row_i])] = row_i
97 | if self.args.allow_up:
98 | if row_i in self.c2p_idx:
99 | res[row_i, len(self.p2c_idx[row_i]) + 1] = self.c2p_idx[row_i][0]
100 | # next level node
101 | res[row_i, :len(self.p2c_idx[row_i])] = self.p2c_idx[row_i]
102 | return res
103 |
104 | def get_next_candidates(self, last_selections, cur_candidates, nonstop=False):
105 | all_stop = True
106 | for candi, sel in zip(cur_candidates, last_selections):
107 | if sel != 0: # 0 is not in candi
108 | candi.remove(sel)
109 | if sel == self.n_class: # stop action
110 | candi.clear()
111 | else:
112 | candi.update(self.p2c_idx[sel])
113 | all_stop = False
114 | if not nonstop:
115 | candi.add(self.n_class)
116 | return cur_candidates, self.pad_candidates(cur_candidates), all_stop
117 |
118 | def update_actions(self, cur_class_batch):
119 | for taken, a in zip(self.taken_actions, cur_class_batch):
120 | taken.add(a)
121 |
122 | def pad_candidates(self, cur_candidates):
123 | col = max([len(c) for c in cur_candidates])
124 | res = np.zeros((len(cur_candidates), col), dtype=int)
125 | for row_i, c in enumerate(cur_candidates):
126 | res[row_i, :len(c)] = list(c)
127 | return res
128 |
129 | def doc_check(self, docids, name):
130 | avg_h = []
131 | n_classes = []
132 | hs = []
133 | label_key = 'categories'
134 | classes = []
135 | for docid in docids:
136 | if len(self.id2doc[docid][label_key]) != len(set(self.id2doc[docid][label_key])):
137 | self.logger.error(label_key)
138 | self.logger.error(self.id2doc[docid][label_key])
139 | exit(1)
140 | classes.extend(self.id2doc[docid][label_key])
141 | h = [self.id2h[self.name2id[term]] for term in self.id2doc[docid][label_key]]
142 | hs.append(h)
143 | self.id2doc_h[docid] = h
144 | avg_h.append(np.mean(h))
145 | n_classes.append(len(h))
146 | if self.args.stat_check:
147 | self.logger.info('check info of classes for each doc in {}'.format(name))
148 | self.logger.info('heights of classes')
149 | self.logger.info(hs[:5])
150 | self.logger.info('avg_height of labels')
151 | self.logger.info(pd.Series(avg_h).describe())
152 | self.logger.info('label count of documents')
153 | self.logger.info(pd.Series(n_classes).describe())
154 | self.logger.info('node support')
155 | self.logger.info(Counter(classes))
156 |
157 | def h_doc_batch(self, docids):
158 | return [self.id2doc_h[docid] for docid in docids]
159 |
160 | def h_batch(self, ids):
161 | return [self.id2h[self.idx2id[vid]] for vid in ids]
162 |
163 | def p2c_batch(self, ids):
164 | # ids is virtual
165 | res = self.p2c_idx_np[ids]
166 | # remove columns full of zeros
167 | return res[:, ~np.all(res == 0, axis=0)]
168 |
169 | def generate_next_true(self, keep_cur=False):
170 | self.logger.info('keep_cur={}'.format(keep_cur))
171 | next_true_bin = defaultdict(lambda: defaultdict(list))
172 | next_true = defaultdict(lambda: defaultdict(list))
173 | for did in self.id2doc_ancestors:
174 | class_idx_set = set(self.id2doc_ancestors[did]['class_idx'])
175 | class_idx_set.add(0)
176 | for c in class_idx_set:
177 | for idx, next_c in enumerate(self.p2c_idx[c]):
178 | if next_c in class_idx_set:
179 | next_true_bin[did][c].append(1)
180 | next_true[did][c].append(next_c)
181 | else:
182 | next_true_bin[did][c].append(0)
183 | # if lowest label
184 | if len(next_true[did][c]) == 0:
185 | next_true[did][c].append(c)
186 | if self.args.allow_stay:
187 | next_true_bin[did][c].append(1)
188 | elif keep_cur and c != 0:
189 | # append 1 only for loss calculation
190 | next_true_bin[did][c].append(1)
191 | return next_true_bin, next_true
192 |
193 | def get_next(self, cur_class_batch, next_classes_batch, doc_ids):
194 | assert len(cur_class_batch) == len(doc_ids)
195 | next_classes_batch_true = np.zeros(next_classes_batch.shape)
196 | indices = []
197 | next_class_batch_true = []
198 | for ct, (c, did) in enumerate(zip(cur_class_batch, doc_ids)):
199 | nt = self.next_true_bin[did][c]
200 | if len(self.next_true[did][c]) == 0:
201 | print(ct, did, c)
202 | print(nt, self.next_true[did][c])
203 | print(self.id2doc_ancestors[did])
204 | exit(-1)
205 | next_classes_batch_true[ct][:len(nt)] = nt
206 | for idx in self.next_true[did][c]:
207 | indices.append(ct)
208 | next_class_batch_true.append(idx)
209 | doc_ids = [doc_ids[idx] for idx in indices]
210 | return next_classes_batch_true, indices, np.array(next_class_batch_true), doc_ids
211 |
212 | def get_next_by_probs(self, cur_class_batch, next_classes_batch, doc_ids, probs, save_prob):
213 | assert len(cur_class_batch) == len(doc_ids) == len(doc_ids) == len(probs)
214 | indices = []
215 | next_class_batch_pred = []
216 | if save_prob:
217 | thres = 0
218 | else:
219 | thres = 0.5
220 | preds = (probs > thres).int().data.cpu().numpy()
221 | for ct, (c, next_classes, did, pred, p) in enumerate(
222 | zip(cur_class_batch, next_classes_batch, doc_ids, preds, probs)):
223 | # need allow_stay=True, filter last one (cur) to avoid duplication
224 | next_pred = np.nonzero(pred)[0]
225 | if not self.args.multi_label:
226 | idx_above_thres = np.argsort(p.data.cpu().numpy()[next_pred])
227 | for idx in idx_above_thres[::-1]:
228 | if next_classes[next_pred[idx]] != c:
229 | next_pred = [next_pred[idx]]
230 | break
231 | else:
232 | if len(next_pred) != 0 and next_classes[next_pred[-1]] == c:
233 | next_pred = next_pred[:-1]
234 | # if no next > threshold, stay at current class
235 | if len(next_pred) == 0:
236 | p_selected = []
237 | next_pred = [c]
238 | else:
239 | p_selected = p.data.cpu().numpy()[next_pred]
240 | next_pred = next_classes[next_pred]
241 | # indices remember where one is from; idx is virtual class idx
242 | for idx in next_pred:
243 | indices.append(ct)
244 | next_class_batch_pred.append(idx)
245 | if save_prob:
246 | for idx, p_ in zip(next_pred, p_selected):
247 | if idx in self.id2prob[did]:
248 | self.logger.warning(f'[{did}][{idx}] already existed!')
249 | self.id2prob[did][idx] = p_
250 | doc_ids = [doc_ids[idx] for idx in indices]
251 | return indices, np.array(next_class_batch_pred), doc_ids
252 |
253 | def get_flat_idx_each_layer(self):
254 | flat_idx_each_layer = [[0]]
255 | idx2layer_idx = {}
256 | idx2layer_idx[0] = (0, 0)
257 | for i in range(self.args.n_steps_sl):
258 | flat_idx_each_layer.append([])
259 | current_nodes = flat_idx_each_layer[i]
260 | for current_node in current_nodes:
261 | child_nodes = self.p2c_idx[current_node]
262 | for (in_layer_idx, child_node) in enumerate(child_nodes):
263 | flat_idx_each_layer[i + 1].append(child_node)
264 | idx2layer_idx[child_node] = (i, in_layer_idx)
265 |
266 | self.flat_idx_each_layer = flat_idx_each_layer
267 | self.idx2layer_idx = idx2layer_idx
268 | return flat_idx_each_layer, idx2layer_idx
269 |
270 | def get_layer_node_number(self):
271 | layer_numer = max(self.id2h.values())
272 | local_output_number = [0] * layer_numer
273 | for id in self.id2h:
274 | if self.id2h[id] == 0:
275 | continue
276 | local_output_number[self.id2h[id] - 1] += 1
277 | assert sum(local_output_number) == self.n_class - 1
278 | return local_output_number
279 |
280 | def calc_reward(self, notLast, actions, ids):
281 | if self.args.reward == '01':
282 | return self.calc_reward_01(actions, ids)
283 | elif self.args.reward == 'f1':
284 | return self.calc_reward_f1(actions, ids)
285 | elif self.args.reward == '01-1':
286 | return self.calc_reward_neg(actions, ids)
287 | elif self.args.reward == 'direct':
288 | return self.calc_reward_direct(notLast, actions, ids)
289 | elif self.args.reward == 'taxo':
290 | return self.calc_reward_taxo(actions, ids)
291 | else:
292 | raise NotImplementedError
293 |
294 | def calc_reward_f1(self, actions, ids):
295 | R = []
296 | for a, i, taken in zip(actions, ids, self.taken_actions):
297 | taken.add(a)
298 | y_l_a = self.id2doc_ancestors[i]['class_idx']
299 | correct = set(y_l_a) & taken
300 | p = len(correct) / len(taken)
301 | r = len(correct) / len(y_l_a)
302 | f1 = 2 * p * r / (p + r + 1e-32)
303 | R.append(f1)
304 | R = np.array(R)
305 | res = np.copy(R)
306 | if self.last_R is not None:
307 | res -= self.last_R
308 | self.last_R = R
309 | return res
310 |
311 | def calc_reward_direct(self, notLast, actions, ids):
312 | if notLast:
313 | return [0] * len(actions)
314 | R = []
315 | for a, i, taken in zip(actions, ids, self.taken_actions):
316 | if a in self.id2doc[i]['class_idx']:
317 | R.append(1)
318 | else:
319 | R.append(0)
320 | return R
321 |
322 | def calc_reward_taxo(self, actions, ids):
323 | R = []
324 | for taken, a, i in zip(self.taken_actions, actions, ids):
325 | if a in self.id2doc_ancestors[i]['class_idx']:
326 | R.append(1)
327 | elif a == self.n_class:
328 | R.append(0)
329 | else:
330 | R.append(-1)
331 | return R
332 |
333 | def calc_reward_01(self, actions, ids):
334 | R = []
335 | label_key = 'categories'
336 | for a, i, taken in zip(actions, ids, self.taken_actions):
337 | if a in self.id2doc[i]['class_idx']:
338 | R.append(1)
339 | continue
340 | if a in taken:
341 | R.append(0)
342 | continue
343 | for label in self.id2doc[i][label_key]:
344 | if self.idx2id[a] in self.c2a[self.name2id[label]]:
345 | R.append(1)
346 | break
347 | else:
348 | R.append(0)
349 | taken.add(a)
350 | return R
351 |
352 | def calc_reward_neg(self, actions, ids):
353 | # actions : virtual id
354 | # ids: doc id for mesh, yelp | real labels for review
355 | R = []
356 | label_key = 'categories'
357 | assert len(actions) == len(ids) == len(self.taken_actions)
358 | if self.args.dataset in ['mesh', 'yelp', 'rcv1']:
359 | for a, i, taken in zip(actions, ids, self.taken_actions):
360 | if a in self.id2doc[i]['class_idx']:
361 | R.append(1)
362 | continue
363 | for label in self.id2doc[i][label_key]:
364 | if self.idx2id[a] in self.c2a[self.name2id[label]]:
365 | if a in taken:
366 | R.append(0)
367 | else:
368 | R.append(1)
369 | break
370 | else:
371 | R.append(-1)
372 | taken.add(a)
373 | return R
374 |
375 | def calc_f1(self, pred_l, id_l, save_path=None, output=True):
376 | assert len(pred_l) == len(id_l)
377 | if not output:
378 | mlb = MultiLabelBinarizer(classes=self.class_idx)
379 | y_l_a = [self.id2doc_ancestors[docid]['class_idx'] for docid in id_l]
380 | y_true_a = mlb.fit_transform(y_l_a)
381 | y_pred = mlb.transform(pred_l)
382 | f1_a = f1_score(y_true_a, y_pred, average='micro')
383 | f1_a_macro = f1_score(y_true_a, y_pred, average='macro')
384 | f1_a_s = f1_score(y_true_a, y_pred, average='samples')
385 | self.logger.info(f'micro:{f1_a:.4f} macro:{f1_a_macro:.4f} samples:{f1_a_s:.4f}')
386 | return f1_a, f1_a_macro, f1_a_s
387 | # pred_l_a contains all ancestors (except root)
388 | pred_l_a = []
389 | for pred in pred_l:
390 | pred_l_a.append([])
391 | for i in pred:
392 | pred_l_a[-1].append(i)
393 | # TODO can be refactored
394 | pred_l_a[-1].extend(
395 | [self.id2idx[self.name2id[term]] for term in self.c2a[self.id2name[self.idx2id[i]]]])
396 | y_l = [self.id2doc[docid]['class_idx'] for docid in id_l]
397 | if self.id2doc_ancestors:
398 | y_l_a = [self.id2doc_ancestors[docid]['class_idx'] for docid in id_l]
399 | else:
400 | y_l_a = y_l
401 |
402 | self.logger.info('measuring f1 of {} samples...'.format(len(pred_l)))
403 | avg_len = np.mean([len(p) for p in pred_l])
404 | self.logger.info(f'len(pred): {avg_len}')
405 | # self.logger.info(
406 | # 'docid:{} pred:{} real:{} pred_a:{} real_a:{}'.format(id_l[:5], pred_l[:5], y_l[:5], pred_l_a[:5],
407 | # y_l_a[:5]))
408 | # self.logger.info('Counter(pred_l):{}'.format(Counter(flatten(pred_l))))
409 | # self.logger.info('Counter(y_l):{}'.format(Counter(flatten(y_l))))
410 | # self.logger.info('Counter(y_l_a):{}'.format(Counter(flatten(y_l_a))))
411 | # self.logger.info(self.idx2id)
412 |
413 | calc_leaf_only = False
414 | if calc_leaf_only:
415 | y_true_leaf = [[idx for idx in i if len(self.p2c_idx[idx]) == 0] for i in y_l]
416 | y_pred_leaf = [[idx for idx in i if len(self.p2c_idx[idx]) == 0] for i in pred_l]
417 | mlb = MultiLabelBinarizer()
418 | y_true = sparse.csr_matrix(mlb.fit_transform(y_true_leaf))
419 | print(len(mlb.classes_), 36504)
420 | y_pred = sparse.csr_matrix(mlb.transform(y_pred_leaf))
421 | print(f1_score(y_true, y_pred, average='micro'))
422 | print(f1_score(y_true, y_pred, average='macro'))
423 | print(f1_score(y_true, y_pred, average='samples'))
424 |
425 | mlb = MultiLabelBinarizer(classes=self.class_idx)
426 | y_true = sparse.csr_matrix(mlb.fit_transform(y_l))
427 | y_pred = sparse.csr_matrix(mlb.transform(pred_l))
428 | f1 = f1_score(y_true, y_pred, average='micro')
429 | f1_macro = f1_score(y_true, y_pred, average='macro')
430 |
431 | calc_other_f1 = True
432 | if calc_other_f1:
433 | y_true_a = sparse.csr_matrix(mlb.transform(y_l_a))
434 | y_pred_a = sparse.csr_matrix(mlb.transform(pred_l_a))
435 | if save_path is not None:
436 | self.logger.info('saving to {}/preds.pkl'.format(save_path))
437 | pickle.dump((y_pred, id_l, y_true_a), open('{}/preds.pkl'.format(save_path), 'wb'))
438 | f1_a = f1_score(y_true_a, y_pred, average='micro')
439 | f1_aa = f1_score(y_true_a, y_pred_a, average='micro')
440 | f1_a_macro = f1_score(y_true_a, y_pred, average='macro')
441 | f1_aa_macro = f1_score(y_true_a, y_pred_a, average='macro')
442 | f1_aa_s = f1_score(y_true_a, y_pred_a, average='samples')
443 | # from sklearn.metrics import classification_report
444 | # print(classification_report(y_true, y_pred))
445 | else:
446 | f1_a = 0
447 | f1_aa = 0
448 | f1_a_macro = 0
449 | f1_aa_macro = 0
450 | f1_aa_s = 0
451 | return f1, f1_a, f1_aa, f1_macro, f1_a_macro, f1_aa_macro, f1_aa_s
452 |
453 | def acc(self, pred_l, id_l):
454 | R = []
455 | for pred, i in zip(pred_l, id_l):
456 | R.append(int(pred in self.id2doc[i]['class_idx']))
457 | return np.mean(R)
458 |
459 | def acc_multi(self, pred_l, id_l):
460 | R = []
461 | for preds, i in zip(pred_l, id_l):
462 | for pred in preds:
463 | R.append(int(pred in self.id2doc[i]['class_idx']))
464 | return np.mean(R)
465 |
466 | def read_mapping(self):
467 | for idx, c in enumerate(self.nodes):
468 | if c == self.rootname:
469 | root_idx = idx
470 | self.id2name[c] = c
471 | self.name2id[c] = c
472 | self.id2idx[c] = idx
473 | self.idx2id[idx] = c
474 | # put root to idx 0
475 | self.idx2id[root_idx] = self.idx2id[0]
476 | self.id2idx[self.idx2id[root_idx]] = root_idx
477 | self.id2idx[self.rootname] = 0
478 | self.idx2id[0] = self.rootname
479 | # remove root from labels and add class_idx
480 | for bid in self.id2doc:
481 | self.id2doc[bid]['categories'] = [c for c in self.id2doc[bid]['categories'] if c != self.rootname]
482 | self.id2doc[bid]['class_idx'] = [self.id2idx[c] for c in self.id2doc[bid]['categories']]
483 | if self.id2doc_ancestors is None:
484 | return
485 | for bid in self.id2doc_ancestors:
486 | self.id2doc_ancestors[bid]['categories'] = [c for c in self.id2doc_ancestors[bid]['categories'] if
487 | c != self.rootname]
488 | self.id2doc_ancestors[bid]['class_idx'] = [self.id2idx[c] for c in self.id2doc_ancestors[bid]['categories']]
489 |
490 | def read_edges(self):
491 | for parent in self.nodes:
492 | for child in self.nodes[parent]['children']:
493 | self.c2p[child].append(parent)
494 | self.p2c[parent].append(child)
495 |
496 | def remove_stop(self):
497 | for taken in self.taken_actions:
498 | taken.discard(self.n_class)
499 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import pickle
4 | import subprocess
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import torch
9 | from torch.autograd import Variable
10 | from tqdm import tqdm
11 |
12 |
13 | def isnan(x):
14 | return x != x
15 |
16 |
17 | def contains_nan(x):
18 | return isnan(x).any()
19 |
20 |
21 | def explode(x):
22 | return (x > 10).any()
23 |
24 |
25 | def eu_dist(x):
26 | return sum((x[0] - x[1]) ** 2) / len(x[0])
27 |
28 |
29 | def get_gpu_memory_map():
30 | result = subprocess.check_output(
31 | [
32 | 'nvidia-smi', '--query-gpu=memory.free,utilization.gpu',
33 | '--format=csv,nounits,noheader'
34 | ], encoding='utf-8')
35 | gpu_info = [eval(x) for x in result.strip().split('\n')]
36 | gpu_info = dict(zip(range(len(gpu_info)), gpu_info))
37 | sorted_gpu_info = sorted(gpu_info.items(), key=lambda kv: kv[1][0], reverse=True)
38 | sorted_gpu_info = sorted(sorted_gpu_info, key=lambda kv: kv[1][1])
39 | print(f'gpu_id, (mem_left, util): {sorted_gpu_info}')
40 | return sorted_gpu_info
41 |
42 |
43 | def save_checkpoint(state, modelpath, modelname, logger=None, del_others=True):
44 | if del_others:
45 | for dirpath, dirnames, filenames in os.walk(modelpath):
46 | for filename in filenames:
47 | path = os.path.join(dirpath, filename)
48 | if path.endswith('pth.tar'):
49 | if logger is None:
50 | print(f'rm {path}')
51 | else:
52 | logger.warning(f'rm {path}')
53 | os.system("rm -rf '{}'".format(path))
54 | break
55 | path = os.path.join(modelpath, modelname)
56 | if logger is None:
57 | print('saving model to {}...'.format(path))
58 | else:
59 | logger.warning('saving model to {}...'.format(path))
60 | try:
61 | torch.save(state, path)
62 | except Exception as e:
63 | logger.error(e)
64 |
65 |
66 | def flatten(x):
67 | if isinstance(x, collections.Iterable):
68 | return [a for i in x for a in flatten(i)]
69 | else:
70 | return [x]
71 |
72 |
73 | def check_doc_size(X_train, logger):
74 | n_sent = []
75 | n_words = []
76 | n_words_per_doc = []
77 | for doc in X_train:
78 | n_sent.append(len(doc))
79 | words_per_doc = 0
80 | for sent in doc:
81 | n_words.append(len(sent))
82 | words_per_doc += len(sent)
83 | n_words_per_doc.append(words_per_doc)
84 | logger.info('#sent in a document')
85 | logger.info(pd.Series(n_sent).describe(percentiles=[.25, .5, .75, .8, .85, .9, .95, .96, .98]))
86 | logger.info('#words in a sent')
87 | logger.info(pd.Series(n_words).describe(percentiles=[.25, .5, .75, .8, .85, .9, .95, .96, .98]))
88 | logger.info('#words in a document')
89 | logger.info(pd.Series(n_words_per_doc).describe(percentiles=[.25, .5, .75, .8, .85, .9, .95, .96, .98]))
90 |
91 |
92 | def pad_batch(mini_batch):
93 | mini_batch_size = len(mini_batch)
94 | max_sent_len = min(np.max([len(x) for x in mini_batch]), 10)
95 | max_token_len = min(np.max([len(val) for sublist in mini_batch for val in sublist]), 50)
96 | main_matrix = np.zeros((mini_batch_size, max_sent_len, max_token_len), dtype=np.int)
97 | for i in range(main_matrix.shape[0]):
98 | for j in range(main_matrix.shape[1]):
99 | for k in range(main_matrix.shape[2]):
100 | try:
101 | main_matrix[i, j, k] = mini_batch[i][j][k]
102 | except IndexError:
103 | pass
104 | return Variable(torch.from_numpy(main_matrix).transpose(0, 1))
105 |
106 |
107 | def pad_batch_nosent_fast(args, word_index, mini_batch, region, stride):
108 | mini_batch_size = len(mini_batch)
109 | n_tokens = min(args.max_tokens, max([sum([len(sent) for sent in doc]) for doc in mini_batch]))
110 | main_matrix = np.zeros((mini_batch_size, n_tokens, region), dtype=np.int)
111 | unk_idx = word_index['UNK']
112 | main_matrix.fill(unk_idx)
113 | for i in range(mini_batch_size):
114 | sent_cat = [unk_idx] * (region - 1) + [word for sent in mini_batch[i] for word in sent] # padded
115 | # sent_cat = [word for sent in mini_batch[i] for word in sent]
116 | idx = 0
117 | ct = 0
118 | last_set = set()
119 | while ct < n_tokens and idx < len(sent_cat):
120 | word_set = set() # words in current region
121 | for region_idx, word in enumerate(sent_cat[idx: idx + region]):
122 | if word in word_set:
123 | main_matrix[i][ct][region_idx] = unk_idx
124 | continue
125 | if word != unk_idx:
126 | word_set.add(word)
127 | main_matrix[i][ct][region_idx] = word
128 | if last_set == word_set:
129 | ct -= 1
130 | last_set = word_set
131 | idx += stride
132 | ct += 1
133 |
134 | return main_matrix
135 |
136 |
137 | # region is for bow-cnn. need to covert vectors to multi-hot
138 | def pad_batch_nosent(mini_batch, word_index, onehot=False, region=None, stride=None):
139 | mini_batch_size = len(mini_batch)
140 | n_tokens = min(256, max([sum([len(sent) for sent in doc]) for doc in mini_batch]))
141 | if onehot:
142 | main_matrix = np.zeros((mini_batch_size, n_tokens, 30000), dtype=np.float32)
143 | unk_idx = word_index['UNK']
144 | for i in range(mini_batch_size):
145 | if not region:
146 | ct = 0
147 | for sent in mini_batch[i]:
148 | for word in sent:
149 | if word != unk_idx:
150 | if word > unk_idx:
151 | word -= 1
152 | main_matrix[i][ct][word] = 1
153 | ct += 1
154 | if ct == n_tokens:
155 | break
156 | if ct == n_tokens:
157 | break
158 | else:
159 | sent_cat = [unk_idx] * (region - 1) + [word for sent in mini_batch[i] for word in sent]
160 | idx = 0
161 | ct = 0
162 | last_set = set()
163 | while ct < n_tokens and idx < len(sent_cat):
164 | word_set = set()
165 | for word in sent_cat[idx: idx + region]:
166 | if word != unk_idx:
167 | if word > unk_idx:
168 | word -= 1
169 | word_set.add(word)
170 | main_matrix[i][ct][word] = 1
171 | # variable-stride
172 | if last_set == word_set:
173 | ct -= 1
174 | last_set = word_set
175 | idx += stride
176 | ct += 1
177 | else:
178 | main_matrix = np.zeros((mini_batch_size, n_tokens), dtype=np.int)
179 | for i in range(mini_batch_size):
180 | ct = 0
181 | for sent in mini_batch[i]:
182 | for word in sent:
183 | main_matrix[i][ct] = word
184 | ct += 1
185 | if ct == n_tokens:
186 | break
187 | if ct == n_tokens:
188 | break
189 | return Variable(torch.from_numpy(main_matrix))
190 |
191 |
192 | def iterate_minibatches(args, inputs, targets, batchsize, shuffle):
193 | assert inputs.shape[0] == targets.shape[0]
194 | if args.debug:
195 | for _ in range(300):
196 | yield inputs[:batchsize], targets[:batchsize]
197 | return
198 | if shuffle:
199 | indices = np.arange(inputs.shape[0])
200 | np.random.shuffle(indices)
201 | for start_idx in range(0, inputs.shape[0] - batchsize + 1, batchsize):
202 | if shuffle:
203 | excerpt = indices[start_idx:start_idx + batchsize]
204 | else:
205 | excerpt = slice(start_idx, start_idx + batchsize)
206 | yield inputs[excerpt], targets[excerpt]
207 | if start_idx + batchsize < inputs.shape[0]:
208 | if shuffle:
209 | excerpt = indices[start_idx + batchsize:]
210 | else:
211 | excerpt = slice(start_idx + batchsize, start_idx + batchsize * 2)
212 | yield inputs[excerpt], targets[excerpt]
213 |
214 |
215 | def iterate_minibatches_order(args, inputs, targets, batchsize):
216 | assert inputs.shape[0] == targets.shape[0]
217 | if args.debug:
218 | for _ in range(300):
219 | yield inputs[:batchsize], targets[:batchsize]
220 | return
221 | indices = np.argsort([-len(doc) for doc in inputs])
222 | for start_idx in range(0, inputs.shape[0] - batchsize + 1, batchsize):
223 | excerpt = indices[start_idx:start_idx + batchsize]
224 | yield inputs[excerpt], targets[excerpt]
225 | if start_idx + batchsize < inputs.shape[0]:
226 | excerpt = indices[start_idx + batchsize:]
227 | yield inputs[excerpt], targets[excerpt]
228 |
229 |
230 | def gen_minibatch(logger, args, word_index, tokens, labels, mini_batch_size, shuffle=False):
231 | logger.info('# batches = {}'.format(len(tokens) / mini_batch_size))
232 | # for token, label in iterate_minibatches(tokens, labels, mini_batch_size, shuffle=shuffle):
233 | for token, label in iterate_minibatches_order(args, tokens, labels, mini_batch_size):
234 | if args.base_model == 'textcnn':
235 | token = pad_batch_nosent(token, word_index)
236 | elif args.base_model == 'ohcnn-seq':
237 | token = pad_batch_nosent(token, word_index, onehot=True)
238 | elif args.base_model == 'ohcnn-bow':
239 | token = pad_batch_nosent(token, word_index, onehot=True, region=20, stride=2)
240 | elif args.base_model == 'ohcnn-bow-fast':
241 | main_matrix = pad_batch_nosent_fast(args, word_index, token, region=20, stride=2)
242 | token = Variable(torch.from_numpy(main_matrix))
243 | else:
244 | token = pad_batch(token)
245 | if args.gpu:
246 | yield token.cuda(), label
247 | else:
248 | yield token, label
249 |
250 |
251 | def gen_minibatch_from_cache(logger, args, tree, mini_batch_size, name, shuffle):
252 | pkl_path = '{}_{}.pkl'.format(name, mini_batch_size)
253 | if not os.path.exists(pkl_path):
254 | logger.error('{} NOT FOUND'.format(pkl_path))
255 | exit(-1)
256 | if 'train' in name:
257 | if tree.data_cache is not None:
258 | (token_l, label_l) = tree.data_cache
259 | logger.info('loaded from tree.data_cache')
260 | else:
261 | (token_l, label_l) = pickle.load(open(pkl_path, 'rb'))
262 | tree.data_cache = (token_l, label_l)
263 | else:
264 | (token_l, label_l) = pickle.load(open(pkl_path, 'rb'))
265 | logger.info('loaded {} batches from {}'.format(len(label_l), pkl_path))
266 | if args.debug:
267 | for _ in range(1000):
268 | token = Variable(torch.from_numpy(token_l[0]))
269 | label = label_l[0]
270 | if args.gpu:
271 | yield token.cuda(), label
272 | else:
273 | yield token, label
274 | return
275 | if shuffle:
276 | indices = np.arange(len(token_l))
277 | np.random.shuffle(indices)
278 | for i in indices:
279 | token = token_l[i]
280 | label = label_l[i]
281 | token = Variable(torch.from_numpy(token))
282 | if args.gpu:
283 | yield token.cuda(), label
284 | else:
285 | yield token, label
286 | else:
287 | for token, label in zip(token_l, label_l):
288 | # out of memory
289 | if mini_batch_size > 32:
290 | new_batch_size = mini_batch_size // 2
291 | for i in range(0, mini_batch_size, new_batch_size):
292 | token_v = Variable(torch.from_numpy(token[i:i + new_batch_size]))
293 | label_v = label[i:i + new_batch_size]
294 | if args.gpu:
295 | yield token_v.cuda(), label_v
296 | else:
297 | yield token_v, label_v
298 | else:
299 | token = Variable(torch.from_numpy(token))
300 | if args.gpu:
301 | yield token.cuda(), label
302 | else:
303 | yield token, label
304 |
305 |
306 | def save_minibatch(logger, args, word_index, tokens, labels, mini_batch_size, name=''):
307 | filename = '{}_{}.pkl'.format(name, mini_batch_size)
308 | if os.path.exists(filename):
309 | logger.warning(f'skipped since {filename} existed')
310 | return
311 | token_l = []
312 | label_l = []
313 | for token, label in tqdm(iterate_minibatches_order(args, tokens, labels, mini_batch_size)):
314 | token = pad_batch_nosent_fast(args, word_index, token, region=20, stride=2)
315 | token_l.append(token)
316 | label_l.append(label)
317 | pickle.dump((token_l, label_l), open(filename, 'wb'))
318 |
--------------------------------------------------------------------------------
/yelp/yelp_data_100.csv.sample:
--------------------------------------------------------------------------------
1 | {'reviews': ['They make a plan, for your mouth to get healthy & stay healthy! Wonderful service.', "I have extreme anxiety when it comes to the dentist. I had a really bad experience with my childhood dentist. This made it so when I was an adult I didn't go as often and I should have. I moved to Ahwatukee 9 years ago and a friend recommended this office to me. I've been going regulalry for routine cleaning since. I went last month for a cleaning and found out I needed a crown. Sensing my anxiety he assured me I would be fine. I scheduled my appt. for today. I went in so nervous I wanted to leave. He came in, explaining everything he would be doing. He put me at ease and I've never felt better. I would highly recommend Dr. Kode to anyone looking for a dentist. He's restored my faith in dentists. I will go back in a few weeks for the permanent crown and I'm not a bit nervous about it. Thank you so much for your patience with me and the absolute best experience I've had from a dentist.", "I've only been here once but my husband has been going for a while and I plan to obviously continue. I'm a complete wimp at the dentist and they were pretty great. The staff is friendly and relaxed, which helps. They have been very honest and the fancy office and TV's are just the icing on top. I would recommend any of my family or friends here!!", 'Great experience with the hygenists and pediatric dentist! Clean office, good with kids and timely. The dentist was great. We didnt get the feeling of being "upsold" like other practices. The staff was friendly enough. \n\nWe will definitely be back!', "My husband and I went to this dentist as a recommendation from my sister-in-law. My husband and I went as we were not happy with our current dentsit. What a blessing! We both went and they worked on both of us and did our work that we needed that day! My husband needed a veneer\\/crown with no scheduling down the road because of finances. I had priced the crown and the out of pocket cost and this office was 1\\/2 of what I was quoted from many different offices in the area. We have great dental insurance which has never been an issue but they were concerned about us and not $$$$$. I appreciate that. Everyone in the office was super nice. I went into the office the next day to get 3 fillings and everyone came by to ask how my husband was and how he liked his work. I don't know about anyone else but I will pay extra for great customer service. Here you get the most caring staff and you can understand why when you meet the dentist, Dentist Kobe. He is great at his craft and he sincerely cares about you. God has blessed him and us by introducing us. I highly recommend this dental office to anyone looking for a dentist in the Phoenix area.", 'Recently called Dental By Design because I was recommended by two different people to go and check them out. I called twice and was told they would call me back to schedule my appointment, no call back. Horrible service.', "Dr. Curry has been my dentist for almost 12 years now. She's the definition of awesomeness. AND everyone there always has a smile! Thanks so much", 'Very nice staff Gentle Hygienist. I love the fact that they will call you to schedule an appointment when they have cancellations that might work for you. Off the top of my head, I can\'t recall that lady\'s name that calls from the office all the time, but I just heart her! She is such a gem. \n\nI\'ve been to dental offices that are far more technologically advanced with headphones, satellite TV, computer applications that the staff uses to show and explain what\'s going on with your teeth. Since I\'m an engineer, I love gadgets and tech stuff. \n \nIn addition, the main reason I did not give this place a four is the fact that I have to ask "what is that instrument used to detect?", "what does that tell us?", "what does that mean?". I absolutely detest going to an office where I\'m paying you to do work for me and I leave there not understanding what you did for me. So, every time I"ve been to this office, I ask those questions. I don\'t like it. the information should flow freely. It helps when your patient is informed and can make better decisions. \n\nAll-in-All, I enjoy my visits here.', "This place is dishonest and I would never recommend anyone I know to go there. They tried to make me feel bad about my teeth in order to pressure me into their services. They were rude and tried to take advantage of me. They lied to me and said I had multiple cavities which I don't according to the dentist I go to now. They had a really nice office and equipment but I wasn't going to pay that for them. I'm going to pay someone to be honest with me and to take good care of my teeth, not just look at what's in my wallet.", "I am not giving them one star based on their work but I up front told them that I had 2000 left on my insurance and 3 days to use it but I was limited on money in the future and wanted to be sure that after the 3 days that I was all paid up. I would pay daily for what the insurance didn't cover. They called the insurance everyday and told me what it would cost for the work that day which I paid in full everyday. The final day I paid them and made sure I owed them nothing and they said I'm all paid up. Then months later they claim I owe them over $300 and after talking to the insurance company they claimed the dentist got a quote on one thing but charged them for something different. Well how is that my fault they knew i was not going to be able to pay a bill later yet they screwed up and charged me for it and now they are giving me bad credit over something that I had no control over. Now had the lady at the desk who called my insurance know what to ask the insurance company I would have gladly paid what I needed to that day. You do not dock the customers for your screw ups and that is not how a business keeps customers. I would recommend looking else where unless you too want to be billed for their incompetence. There are plenty of dentists that will not try to pull the wool over your eyes. Tell you one thing but mean another.", "Dr Kahn and Misty were very clear and explained everything for me. It was a super friendly environment. I'll definitely be making this my dentist.", 'WOW, all they want is $$$$$$$. I sure don\'t mind to pay for what I need but trying to squeeze more money from our insurance AND extra from patient? Not cool! Their rates are much more than most places and their dental chair office is like an assembly line. Horrible experiences and when issues were brought up, they talk down on you or act like it was a "mistype" error.\nI would not recommend this place if it was the last dental office on earth.\nExpensive, overpriced, low quality of work and sanitary.', "Going to the dentist office has always been a source of anguish for me, ever since I was young. Perhaps due to some of the abnormal experiences I had gone through up until now. Not having insurance or being covered by my parents plan. As a result of growing up close to the boarder I often made trips to Mexico for my dentistry work.\n\nMy friend who works there suggested I'd come in for a cleaning and after a few months of reluctancy I finally decided to make an appointment. In all honestly I was visibly uncomfortable which was largely due to the lapse of time between visits.\n\nHowever, Dental By Design made every effort to make me feel as calm and comfortable as possible. The office itself is very clean, the front line were all amazingly nice, pleasant and helpful. They worked with my budget because at the time I was uninsured and I was able to fit into a payment plan that worked best for me.\n\nI'm not sure if bed side manner is a term associated with dentistry but Dr. Curry handled me with care. She spoke me all the through my assessment and made sure I understood her, outlining steps to make moving forward.\n\nI can't say I'll look forward to the dentist again, but I will say I look forward to being in a welcoming, knowledgeable, and safe atmosphere again with individuals who do great work.", 'I live 24 miles away but still come back for Dr. Kode and his lovely assistants. This place does it all. Friendly, latest technology, all levels of service and has great office systems\\/processes in place. They even will provide you a rating of your teeth from 1 to 100, but you have to ask for it. Keep up the good work Dental by Design.', 'I have been seeing Dr Kode since he took over the practice. Have never met a more knowledgeable professional with such great bedside or chair-side manner. He is also a great community member.', "Just like most people, I dislike going to the dentist, but I've been going to Dental by Design for over a year and recently got a crown done. It was fair in price and things went very well through the entire experience. The staff is very friendly and professional and they get you in and out at your set appointment time.", "This dental office is one of the best that I have ever been to. When I first moved to the Phoenix area, I probably tried out about 8 different dentists before finally listening to my co-worker's referral and coming here. Dr. Kode is an awesome dentist. He is highly skilled, and uses up-to-date high tech equipment. Marianne is also the best dental hygienists that I have had. She is thorough, gentle, and very knowledgeable about everything!!! She is also a perfectionist, which I can appreciate-being one myself, and never misses anything. Sheryl, Sandy, and Heather are also fantastic, and always friendly! I don't know the rest of the staff as well, but everyone has always been overly nice. \n\nIn addition to being extremely competent, this dental office is also very patient-oriented. They have events (like movie nights) for patients, and always have really cool giveaways. They are also involved in community events and charity work. \n\nMoreover, they are also very compassionate and considerate of patients who do not have dental insurance. They are very reasonably priced, which is rare for this quality of service. \n\nI moved to California for school almost 2 years ago, but I am still driving to Arizona twice a year just to see these guys! They are totally worth it! I tried out a dentist in California who charged me twice as much for a cleaning and did a horrible job. Being in graduate school, I don't have time or money for trial-and-error dentists, and it's good to know that this office is there for me, even though I have to drive about 350 miles to get there. :)\n\nThank you!!!", 'Our family has been going to Dr.Kode even before he formed Dental by Design, over 10 years now. He\'s been there for us through all our family dental "fun"; from the headgear and braces to the wisdom teeth and deep tissue cleaning. He\'s done it all. His staff from the front desk (shout out to Sandy) to the everyone behind the scenes like his best hygienist, MaryAnn take care of our family like we\'re one of their own. Face it, no one likes going to the dentist. But Dental by Design has made the experience a pleasant one. It\'s also very apparent that they appreciate their patients, as they are always holding events and contests to thank them.\nThank you, Dr,Kode and Dental by Design for making my teeth and my family\'s teeth shine :)', 'Been going to Dr. Kode for 7 years now. Very competent and capable dentist, great problem solver and great communicator. Love his staff - Laci is Da Bomb!\n\nAnd... this is coming from the son of a Dentist!', 'Thereza Wright is the best hygienist ive had the pleasure of working with. She is along with the rest of the staff is awesome, understanding, patient and friendly. They have superior service and down to earth vibes in the office putting you at ease when you are sitting in the chair. Going to the dentist has never been easier when you go to these guys and get handled by the team that works there.', 'I have been going here for 5 years now. The staff is always so friendly and there is always a cold water for you to help yourself to. The confirmation process is so convenient because it is texted and emailed to me. Dr. Kode is a licensed Invisalign dentist so I did Invisalign. He did a wonderful job. I have so much more self confidence now. He even threw in an ipad and free teeth whitening. They always do a great job!', "2 bad experiences I thought I would give them the benefit of the doubt but I actually really hated it. I felt like they wanted my money but the work I had done was shoddy. If you want the experience of a used car dealership with the feeling of just being a number like at the DMV. Try it you may like it. Not for me especially since there is a dentist practically on every corner I don't have to settle.", 'Dr. Kode has been my dentist for about 6 years. As well as receiving my dental check ups and care from him, I also did the Invisalign program with him with great success. His staff is very personable, his office is clean and appointments are always handled professionally. Dr Kode is very knowledgable, personable, kind and gentle with all of his treatments. I have recommended him and his staff to several friends and would recommend them to anyone looking for dentist in the Ahwatukee area.'], 'business_id': 'FYWN1wneV18bWNgQjJ2GNg', 'categories': ['Dentists', 'General Dentistry', 'Health & Medical', 'Oral Surgeons', 'Cosmetic Dentists', 'Orthodontists']}
2 | {'reviews': ["I would be grateful to those of you willing to repeat after me slowly: S T E P H E N NEEDS M E D I C A T E D.\nFor every person who loves him let's him go on with himself as he is is just as bad as him. He is an absolute Monster.", 'My family and I have been going to Stephens for many years . I wish I could give them more than 5 stars. We are always treated with the utmost respect , and our stylist is awesome. She is the best we as a family have ever had. Her color, cut, and advice always makes it a pleasant experience. And LOVE the products...', 'Amanda was an absolute peach! So sweet and talented too! I\'d never gotten my hair professionally colored before and she did an amazing job! She gave me her expert opinion on "sombr\xc3\xa9" (subtle ombr\xc3\xa9) and I just love it. Never once did I feel rushed and she made me feel right at home. We chatted the whole time (I probably talked her ear off haha I just felt so comfortable and was having a great time!). My cut and color are fantastic and I\'m super happy with it all :) I\'m from Pittsburgh but I currently live in NYC. Regardless, I plan to come back to Amanda for my next cut and my second coloring experience! Highly recommend!!!', 'I specifically requested Stephen as I thought since he owned the business and had been in business for as long as he has, he would be excellent at what he does. Before washing my hair we discussed what I wanted, which was a simple trim of my current style. I told him it was important to keep as much length as possible. I then had my hair washed and sat down for him to cut my hair. He prepared to make his first cut and had quite a bit of length in his hand. I asked him if that is what he planned on cutting off and he said yes. I reminded him that I wanted a trim only and wanted to keep my length. He then proceeded to tell me he was not going to cut my hair because I would not let him have space to be creative. He said "I\'ll dry your hair for you though." Really? How nice of him. As I left I asked the two women at the front desk if he does that often and they said "sometimes".', "Paid $115.00 for about 8 highlights that dont even look nice. Then she didn't even blow dry my hair nicely. So I spent over a hundred dollars and had to go home and fix my own hair. I will never go back there again.", 'My daughter recently went to have her hair done at Stephen\'s at South Hills Village. During her visit, one of the owners struck up a conversation with her and proceeded to express his unsolicited opinions on her college choices, making comments that were rude, insulting and totally unprofessional. I called the salon to explain what had happened and let them know how upset we were. A manager called me back to pass along that the owner "said he was sorry". He did not even have the courtesy to respond to us personally. I\'ve been a Stephen\'s customer for almost 15 years and I know some great stylists there. BUT if this is the kind of person that is benefitting from my patronage there - I won\'t be back.', "I have been going to Stephen Szabo Salon for 15+ years and have nothing but great things to say about them. I have gone to several of the stylists and have enjoyed the Brazilian blowout, different color variations, straightening, etc. There are many salon options to choose from in the area, but Stephen's is by far the BEST in the South Hills! Thank you for always taking care of your clients :-)", 'I went here to get my hair done for a wedding with the bridal party. There were 7 girls including the bride...so there were ALOT of us. You would think they would have anticipated that we would need attention and time, but it was utter chaos. It\'s like they hadn\'t known for months we were coming on that particular Saturday morning to get our hair done. It was so rushed and chaotic that none of us, including the bride, enjoyed the experience. \n\nI had even helped them a great deal by straightening my hair beforehand. My hair is really curly and I figured it would be easier to manage and style if I straightened it. I did warn them that I straightened it and not to put any water on it b\\/c the curls would spring up like perennials in Spring. Now to their credit, the woman who was styling my hair didn\'t put any water on it, but she was using hairspray and a curling iron (which I told her may not be necessary) and she made a comment, "wow you weren\'t kidding when you said your hair was curly". b\\/c as she was spraying the hairspray on the "death by fumes" setting and curling the hair (fire hazard anyone?) my hair was holding on to the curls made by the curling iron with such intensity that the hair spray wasn\'t even needed. Umm, did I stutter when I said that my hair is very curly. \n\nPerhaps I\'m being too hard on them, but I had so wanted and expected it to be a good experience (not even great) and I was so disappointed. They didn\'t do anything that I couldn\'t have managed on my own with a curling iron and a large can of Aquanet and a few bobby pins.', "I'm a regular here for many years. Totally reliable and affordable. Many good stylists.", "I have been a client here for almost 12 years now. My stylist is Leigh Ann. She is the best you will find. She accommodates my hectic schedule as best as she can. She also takes care of my daughter and my son who are in their 20's. Leigh Ann has been in the business for more than 30 years. She keeps up to speed with the latest trends, cuts and coloring. Her personality is awesome. They are currently remodeling now and it looks fabulous. I would recommend them very highly.", 'Ok: I have been going to Stephen on\\/off for 10 years. I have lived through the insults and even staff dragging a comb through my shampoo drenched hair. The latest visit revealed that he is now requesting staff not use conditioner after shampooing. what does he expect when detergent is added to color treated hair? "If only I would have use his personalized shampoo!" Mary T, I will gladly and slowly repeat after you: "STEPHEN NEEDS MEDICATED." whew! I feel much better. On a good note: I do like the environment: staff and hairdressers are very friendly.'], 'business_id': 'He-G7vWjzVUysIKrfNbPUQ', 'categories': ['Hair Stylists', 'Hair Salons', "Men's Hair Salons", 'Blow Dry/Out Services', 'Hair Extensions', 'Beauty & Spas']}
3 |
--------------------------------------------------------------------------------