├── Data.py ├── Nets.py ├── README.md ├── main.py └── prepare_data.py /Data.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import random 3 | import pickle as pkl 4 | import numpy as np 5 | import torch 6 | import itertools 7 | 8 | import torch.utils.data as data 9 | import torch.nn.functional as fn 10 | 11 | from torch.autograd import Variable 12 | from collections import Counter 13 | from tqdm import tqdm 14 | from itertools import groupby 15 | from operator import itemgetter 16 | from collections import OrderedDict 17 | import pickle as pkl 18 | from random import choice, random 19 | from torch.utils.data import DataLoader, Dataset 20 | from torch.utils.data.sampler import Sampler 21 | 22 | 23 | class TuplesListDataset(Dataset): 24 | 25 | def __init__(self, tuplelist): 26 | super(TuplesListDataset, self).__init__() 27 | self.tuplelist = tuplelist 28 | self.data2class = None 29 | self.class_field = None 30 | 31 | def __len__(self): 32 | return len(self.tuplelist) 33 | 34 | def __getitem__(self,index): 35 | if self.data2class is None: 36 | return self.tuplelist[index] 37 | else: 38 | t = list(self.tuplelist[index]) 39 | t[self.class_field] = self.data2class[t[self.class_field]] 40 | return tuple(t) 41 | 42 | def field_iter(self,field): 43 | 44 | def field_iterator(): 45 | for i in range(len(self)): 46 | yield self[i][field] 47 | 48 | return field_iterator 49 | 50 | 51 | def get_stats(self,field): 52 | d = dict(Counter(self.field_iter(field)())) 53 | sumv = sum([v for k,v in d.items()]) 54 | class_per = {k:(v/sumv) for k,v in d.items()} 55 | 56 | return d,class_per 57 | 58 | def get_class_dict(self,field): 59 | self.class_field = field 60 | self.class2data = {i:c for i,c in enumerate(sorted(list(set(self.field_iter(field)()))))} 61 | self.data2class = {c:i for i,c in self.class2data.items()} 62 | return self.class2data 63 | 64 | def set_class_mapping(self,field,mapping): 65 | self.class_field = field 66 | self.class2data = mapping 67 | self.data2class = {c:i for i,c in self.class2data.items()} 68 | 69 | @staticmethod 70 | def build_train_test(datatuples,splits,split_num=0): 71 | train,test = [],[] 72 | 73 | for split,data in tqdm(zip(splits,datatuples),total=len(datatuples),desc="Building train/test of split #{}".format(split_num)): 74 | if split == split_num: 75 | test.append(data) 76 | else: 77 | train.append(data) 78 | return TuplesListDataset(train),TuplesListDataset(test) 79 | 80 | 81 | 82 | class BucketSampler(Sampler): 83 | """ 84 | Evenly sample from bucket for datalen 85 | """ 86 | 87 | def __init__(self, dataset): 88 | self.dataset = dataset 89 | self.index_buckets = self._build_index_buckets() 90 | self.len = min([len(x) for x in self.index_buckets.values()]) 91 | 92 | def __iter__(self): 93 | 94 | return iter(self.bucket_iterator()) 95 | 96 | def __len__(self): 97 | 98 | return self.len 99 | 100 | def bucket_iterator(self): 101 | cl = list(self.index_buckets.keys()) 102 | 103 | for x in range(len(self)): 104 | yield choice(self.index_buckets[choice(cl)]) 105 | 106 | 107 | def _build_index_buckets(self): 108 | class_index = {} 109 | for ind,cl in enumerate(self.dataset.field_iter(1)()): 110 | if cl not in class_index: 111 | class_index[cl] = [ind] 112 | else: 113 | class_index[cl].append(ind) 114 | return class_index 115 | 116 | 117 | 118 | 119 | class Vectorizer(): 120 | 121 | def __init__(self,word_dict=None,max_sent_len=8,max_word_len=32): 122 | self.word_dict = word_dict 123 | self.nlp = spacy.load('en') 124 | self.max_sent_len = max_sent_len 125 | self.max_word_len = max_word_len 126 | 127 | 128 | def _get_words_dict(self,data,max_words): 129 | word_counter = Counter(w.lower_ for d in self.nlp.tokenizer.pipe((doc for doc in tqdm(data(),desc="Tokenizing data"))) for w in d) 130 | dict_w = {w: i for i,(w,_) in tqdm(enumerate(word_counter.most_common(max_words),start=2),desc="building word dict",total=max_words)} 131 | dict_w["_padding_"] = 0 132 | dict_w["_unk_word_"] = 1 133 | print("Dictionnary has {} words".format(len(dict_w))) 134 | return dict_w 135 | 136 | def build_dict(self,text_iterator,max_f): 137 | self.word_dict = self._get_words_dict(text_iterator,max_f) 138 | 139 | def vectorize_batch(self,t,trim=True): 140 | return self._vect_dict(t,trim) 141 | 142 | def _vect_dict(self,t,trim): 143 | 144 | if self.word_dict is None: 145 | print("No dictionnary to vectorize text \n-> call method build_dict \n-> or set a word_dict attribute \n first") 146 | raise Exception 147 | 148 | revs = [] 149 | for rev in t: 150 | review = [] 151 | for j,sent in enumerate(self.nlp(rev).sents): 152 | 153 | if trim and j>= self.max_sent_len: 154 | break 155 | s = [] 156 | for k,word in enumerate(sent): 157 | word = word.lower_ 158 | 159 | if trim and k >= self.max_word_len: 160 | break 161 | 162 | if word in self.word_dict: 163 | s.append(self.word_dict[word]) 164 | else: 165 | s.append(self.word_dict["_unk_word_"]) #_unk_word_ 166 | 167 | if len(s) >= 1: 168 | review.append(torch.LongTensor(s)) 169 | if len(review) == 0: 170 | review = [torch.LongTensor([self.word_dict["_unk_word_"]])] 171 | revs.append(review) 172 | 173 | return revs 174 | 175 | -------------------------------------------------------------------------------- /Nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from operator import itemgetter 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | 9 | 10 | class AttentionalBiGRU(nn.Module): 11 | 12 | def __init__(self, inp_size, hid_size, dropout=0): 13 | super(AttentionalBiGRU, self).__init__() 14 | self.register_buffer("mask",torch.FloatTensor()) 15 | 16 | natt = hid_size*2 17 | 18 | self.gru = nn.GRU(input_size=inp_size,hidden_size=hid_size,num_layers=1,bias=True,batch_first=True,dropout=dropout,bidirectional=True) 19 | self.lin = nn.Linear(hid_size*2,natt) 20 | self.att_w = nn.Linear(natt,1,bias=False) 21 | self.tanh = nn.Tanh() 22 | 23 | 24 | 25 | def forward(self, packed_batch): 26 | 27 | rnn_sents,_ = self.gru(packed_batch) 28 | enc_sents,len_s = torch.nn.utils.rnn.pad_packed_sequence(rnn_sents) 29 | emb_h = self.tanh(self.lin(enc_sents.view(enc_sents.size(0)*enc_sents.size(1),-1))) # Nwords * Emb 30 | 31 | attend = self.att_w(emb_h).view(enc_sents.size(0),enc_sents.size(1)).transpose(0,1) 32 | all_att = self._masked_softmax(attend,self._list_to_bytemask(list(len_s))).transpose(0,1) # attW,sent 33 | attended = all_att.unsqueeze(2).expand_as(enc_sents) * enc_sents 34 | 35 | return attended.sum(0,True).squeeze(0) 36 | 37 | def forward_att(self, packed_batch): 38 | 39 | rnn_sents,_ = self.gru(packed_batch) 40 | enc_sents,len_s = torch.nn.utils.rnn.pad_packed_sequence(rnn_sents) 41 | 42 | emb_h = self.tanh(self.lin(enc_sents.view(enc_sents.size(0)*enc_sents.size(1),-1))) # Nwords * Emb 43 | attend = self.att_w(emb_h).view(enc_sents.size(0),enc_sents.size(1)).transpose(0,1) 44 | all_att = self._masked_softmax(attend,self._list_to_bytemask(list(len_s))).transpose(0,1) # attW,sent 45 | attended = all_att.unsqueeze(2).expand_as(enc_sents) * enc_sents 46 | return attended.sum(0,True).squeeze(0), all_att 47 | 48 | def _list_to_bytemask(self,l): 49 | mask = self._buffers['mask'].resize_(len(l),l[0]).fill_(1) 50 | 51 | for i,j in enumerate(l): 52 | if j != l[0]: 53 | mask[i,j:l[0]] = 0 54 | 55 | return mask 56 | 57 | def _masked_softmax(self,mat,mask): 58 | exp = torch.exp(mat) * Variable(mask,requires_grad=False) 59 | sum_exp = exp.sum(1,True)+0.0001 60 | 61 | return exp/sum_exp.expand_as(exp) 62 | 63 | 64 | 65 | class HierarchicalDoc(nn.Module): 66 | 67 | def __init__(self, ntoken, num_class, emb_size=200, hid_size=50): 68 | super(HierarchicalDoc, self).__init__() 69 | 70 | self.embed = nn.Embedding(ntoken, emb_size,padding_idx=0) 71 | self.word = AttentionalBiGRU(emb_size, hid_size) 72 | self.sent = AttentionalBiGRU(hid_size*2, hid_size) 73 | 74 | self.emb_size = emb_size 75 | self.lin_out = nn.Linear(hid_size*2,num_class) 76 | self.register_buffer("reviews",torch.Tensor()) 77 | 78 | 79 | def set_emb_tensor(self,emb_tensor): 80 | self.emb_size = emb_tensor.size(-1) 81 | self.embed.weight.data = emb_tensor 82 | 83 | 84 | def _reorder_sent(self,sents,stats): 85 | 86 | sort_r = sorted([(l,r,s,i) for i,(l,r,s) in enumerate(stats)], key=itemgetter(0,1,2)) #(len(r),r#,s#) 87 | builder = OrderedDict() 88 | 89 | for (l,r,s,i) in sort_r: 90 | if r not in builder: 91 | builder[r] = [i] 92 | else: 93 | builder[r].append(i) 94 | 95 | list_r = list(reversed(builder)) 96 | 97 | revs = Variable(self._buffers["reviews"].resize_(len(builder),len(builder[list_r[0]]),sents.size(1)).fill_(0), requires_grad=False) 98 | lens = [] 99 | real_order = [] 100 | 101 | for i,x in enumerate(list_r): 102 | revs[i,0:len(builder[x]),:] = sents[builder[x],:] 103 | lens.append(len(builder[x])) 104 | real_order.append(x) 105 | 106 | real_order = sorted(range(len(real_order)), key=lambda k: real_order[k]) 107 | 108 | return revs,lens,real_order 109 | 110 | 111 | def forward(self, batch_reviews,stats): 112 | ls,lr,rn,sn = zip(*stats) 113 | emb_w = F.dropout(self.embed(batch_reviews),training=self.training) 114 | 115 | packed_sents = torch.nn.utils.rnn.pack_padded_sequence(emb_w, ls,batch_first=True) 116 | sent_embs = self.word(packed_sents) 117 | 118 | rev_embs,lens,real_order = self._reorder_sent(sent_embs,zip(lr,rn,sn)) 119 | 120 | packed_rev = torch.nn.utils.rnn.pack_padded_sequence(rev_embs, lens,batch_first=True) 121 | doc_embs = self.sent(packed_rev) 122 | 123 | final_emb = doc_embs[real_order,:] 124 | out = self.lin_out(final_emb) 125 | 126 | return out 127 | 128 | 129 | def forward_visu(self, batch_reviews,stats): 130 | ls,lr,rn,sn = zip(*stats) 131 | emb_w = self.embed(batch_reviews) 132 | 133 | packed_sents = torch.nn.utils.rnn.pack_padded_sequence(emb_w, ls,batch_first=True) 134 | sent_embs,att_w = self.word.forward_att(packed_sents) 135 | 136 | rev_embs,lens,real_order = self._reorder_sent(sent_embs,zip(lr,rn,sn)) 137 | 138 | packed_rev = torch.nn.utils.rnn.pack_padded_sequence(rev_embs, lens,batch_first=True) 139 | doc_embs,att_s = self.sent.forward_att(packed_rev) 140 | 141 | final_emb = doc_embs[real_order,:] 142 | att_s = att_s[:,real_order] 143 | 144 | out = self.lin_out(final_emb) 145 | 146 | return out,att_s 147 | 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ----------------- 2 | 3 | ## Deprecated code 4 | A faster and up to date implementation is [in my other repo](https://github.com/cedias/Hierarchical-Sentiment) 5 | 6 | ---------------- 7 | 8 | # HAN-pytorch 9 | Batched implementation of [Hierarchical Attention Networks for Document Classification paper](https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf) 10 | 11 | ## Requirements 12 | - Pytorch (>= 0.2) 13 | - Spacy (for tokenizing) 14 | - Gensim (for building word vectors) 15 | - tqdm (for fancy graphics) 16 | 17 | ## Scripts: 18 | - `prepare_data.py` transforms gzip files as found on [Julian McAuley Amazon product data page](http://jmcauley.ucsd.edu/data/amazon/) to lists of `(user,item,review,rating)` tuples and builds word vectors if `--create-emb` option is specified. 19 | - `main.py` trains a Hierarchical Model. 20 | - `Data.py` holds data managing objects. 21 | - `Nets.py` holds networks. 22 | - `beer2json.py` is an helper script if you happen to have the ratebeer/beeradvocate datasets. 23 | 24 | ## Note: 25 | The whole dataset is used to create word embeddings which can be an issue. 26 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import argparse 4 | import torch 5 | import spacy 6 | import pickle as pkl 7 | import numpy as np 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as fn 11 | 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader, Dataset 14 | from torch.utils.data.sampler import Sampler 15 | from collections import Counter 16 | from tqdm import tqdm 17 | from operator import itemgetter 18 | from random import choice 19 | from collections import OrderedDict,Counter 20 | from Nets import HierarchicalDoc 21 | from Data import TuplesListDataset, Vectorizer, BucketSampler 22 | import sys 23 | 24 | 25 | 26 | def checkpoint(epoch,net,output): 27 | model_out_path = output+"_epoch_{}.pth".format(epoch) 28 | torch.save(net, model_out_path) 29 | print("Checkpoint saved to {}".format(model_out_path)) 30 | 31 | def check_memory(emb_size,max_sents,max_words,b_size,cuda): 32 | try: 33 | e_size = (2,b_size,max_sents,max_words,emb_size) 34 | d_size = (b_size,max_sents,max_words) 35 | t = torch.rand(*e_size) 36 | db = torch.rand(*d_size) 37 | 38 | if cuda: 39 | db = db.cuda() 40 | t = t.cuda() 41 | 42 | print("-> Quick memory check : OK\n") 43 | 44 | except Exception as e: 45 | print(e) 46 | print("Not enough memory to handle current settings {} ".format(e_size)) 47 | print("Try lowering sentence size and length.") 48 | sys.exit() 49 | 50 | 51 | 52 | 53 | def load_embeddings(file): 54 | emb_file = open(file).readlines() 55 | first = emb_file[0] 56 | word, vec = int(first.split()[0]),int(first.split()[1]) 57 | size = (word,vec) 58 | print("--> Got {} words of {} dimensions".format(size[0],size[1])) 59 | tensor = np.zeros((size[0]+2,size[1]),dtype=np.float32) ## adding padding + unknown 60 | word_d = {} 61 | word_d["_padding_"] = 0 62 | word_d["_unk_word_"] = 1 63 | 64 | print("--> Shape with padding and unk_token:") 65 | print(tensor.shape) 66 | 67 | for i,line in tqdm(enumerate(emb_file,1),desc="Creating embedding tensor",total=len(emb_file)): 68 | if i==1: #skipping header (-1 to the enumeration to take it into account) 69 | continue 70 | 71 | spl = line.strip().split(" ") 72 | 73 | if len(spl[1:]) == size[1]: #word is most probably whitespace or junk if badly parsed 74 | word_d[spl[0]] = i 75 | tensor[i] = np.array(spl[1:],dtype=np.float32) 76 | else: 77 | print("WARNING: MALFORMED EMBEDDING DICTIONNARY:\n {} \n line isn't parsed correctly".format(line)) 78 | 79 | try: 80 | assert(len(word_d)==size[0]+2) 81 | except: 82 | print("Final dictionnary length differs from number of embeddings - some lines were malformed.") 83 | 84 | return tensor, word_d 85 | 86 | def save(net,dic,path): 87 | dict_m = net.state_dict() 88 | dict_m["word_dic"] = dic 89 | dict_m["reviews"] = torch.Tensor() 90 | dict_m["word.mask"] = torch.Tensor() 91 | dict_m["sent.mask"] = torch.Tensor() 92 | 93 | torch.save(dict_m,path) 94 | 95 | 96 | def tuple_batcher_builder(vectorizer, trim=True): 97 | 98 | def tuple_batch(l): 99 | review,rating = zip(*l) 100 | r_t = torch.Tensor(rating).long() 101 | list_rev = vectorizer.vectorize_batch(review,trim) 102 | 103 | # sorting by sentence-review length 104 | stat = sorted([(len(s),len(r),r_n,s_n,s) for r_n,r in enumerate(list_rev) for s_n,s in enumerate(r)],reverse=True) 105 | 106 | max_len = stat[0][0] 107 | batch_t = torch.zeros(len(stat),max_len).long() 108 | 109 | for i,s in enumerate(stat): 110 | for j,w in enumerate(s[-1]): # s[-1] is sentence in stat tuple 111 | batch_t[i,j] = w 112 | 113 | stat = [(ls,lr,r_n,s_n) for ls,lr,r_n,s_n,_ in stat] 114 | 115 | return batch_t,r_t, stat,review 116 | 117 | return tuple_batch 118 | 119 | 120 | def tuple2var(tensors,data): 121 | def copy2tensor(t,data): 122 | t.resize_(data.size()).copy_(data) 123 | return Variable(t) 124 | return tuple(map(copy2tensor,tensors,data)) 125 | 126 | 127 | def new_tensors(n,cuda,types={}): 128 | def new_tensor(t_type,cuda): 129 | x = torch.Tensor() 130 | 131 | if t_type: 132 | x = x.type(t_type) 133 | if cuda: 134 | x = x.cuda() 135 | return x 136 | 137 | return tuple([new_tensor(types.setdefault(i,None),cuda) for i in range(0,n)]) 138 | 139 | def train(epoch,net,optimizer,dataset,criterion,cuda): 140 | epoch_loss = 0 141 | ok_all = 0 142 | data_tensors = new_tensors(2,cuda,types={0:torch.LongTensor,1:torch.LongTensor}) #data-tensors 143 | 144 | with tqdm(total=len(dataset),desc="Training") as pbar: 145 | for iteration, (batch_t,r_t,stat,rev) in enumerate(dataset): 146 | 147 | data = tuple2var(data_tensors,(batch_t,r_t)) 148 | optimizer.zero_grad() 149 | out = net(data[0],stat) 150 | ok,per = accuracy(out,data[1]) 151 | loss = criterion(out, data[1]) 152 | epoch_loss += loss.data[0] 153 | loss.backward() 154 | 155 | optimizer.step() 156 | 157 | ok_all += per.data[0] 158 | 159 | pbar.update(1) 160 | pbar.set_postfix({"acc":ok_all/(iteration+1),"CE":epoch_loss/(iteration+1)}) 161 | 162 | print("===> Epoch {} Complete: Avg. Loss: {:.4f}, {}% accuracy".format(epoch, epoch_loss /len(dataset),ok_all/len(dataset))) 163 | 164 | 165 | 166 | def test(epoch,net,dataset,cuda): 167 | epoch_loss = 0 168 | ok_all = 0 169 | pred = 0 170 | skipped = 0 171 | data_tensors = new_tensors(2,cuda,types={0:torch.LongTensor,1:torch.LongTensor}) #data-tensors 172 | with tqdm(total=len(dataset),desc="Evaluating") as pbar: 173 | for iteration, (batch_t,r_t, stat,rev) in enumerate(dataset): 174 | data = tuple2var(data_tensors,(batch_t,r_t)) 175 | out,att = net.forward_visu(data[0],stat) 176 | 177 | ok,per = accuracy(out,data[1]) 178 | ok_all += per.data[0] 179 | pred+=1 180 | 181 | pbar.update(1) 182 | pbar.set_postfix({"acc":ok_all/pred, "skipped":skipped}) 183 | 184 | 185 | print("===> TEST Complete: {}% accuracy".format(ok_all/pred)) 186 | 187 | def accuracy(out,truth): 188 | def sm(mat): 189 | exp = torch.exp(mat) 190 | sum_exp = exp.sum(1,True)+0.0001 191 | return exp/sum_exp.expand_as(exp) 192 | 193 | _,max_i = torch.max(sm(out),1) 194 | 195 | eq = torch.eq(max_i,truth).float() 196 | all_eq = torch.sum(eq) 197 | 198 | return all_eq, all_eq/truth.size(0)*100 199 | 200 | def main(args): 201 | print(32*"-"+"\nHierarchical Attention Network:\n" + 32*"-") 202 | 203 | print("\nLoading Data:\n" + 25*"-") 204 | 205 | max_features = args.max_feat 206 | datadict = pkl.load(open(args.filename,"rb")) 207 | tuples = datadict["data"] 208 | splits = datadict["splits"] 209 | split_keys = set(x for x in splits) 210 | 211 | if args.split not in split_keys: 212 | print("Chosen split (#{}) not in split set {}".format(args.split,split_keys)) 213 | else: 214 | print("Split #{} chosen".format(args.split)) 215 | 216 | train_set,test_set = TuplesListDataset.build_train_test(tuples,splits,args.split) 217 | 218 | print("Train set length:",len(train_set)) 219 | print("Test set length:",len(test_set)) 220 | 221 | classes = train_set.get_class_dict(1) #create class mapping 222 | test_set.set_class_mapping(1,classes) #set same class mapping 223 | num_class = len(classes) 224 | 225 | print(classes) 226 | 227 | 228 | print(25*"-"+"\nClass stats:\n" + 25*"-") 229 | print("Train set:\n" + 10*"-") 230 | 231 | class_stats,class_per = train_set.get_stats(1) 232 | print(class_stats) 233 | print(class_per) 234 | 235 | if args.weight_classes: 236 | class_weight = torch.zeros(num_class) 237 | for c,p in class_per.items(): 238 | class_weight[c] = 1-p 239 | 240 | print(class_weight) 241 | 242 | if args.cuda: 243 | class_weight = class_weight.cuda() 244 | 245 | print(10*"-" + "\n Test set:\n" + 10*"-") 246 | 247 | test_stats,test_per = test_set.get_stats(1) 248 | print(test_stats) 249 | print(test_per) 250 | 251 | 252 | vectorizer = Vectorizer(max_word_len=args.max_words,max_sent_len=args.max_sents) 253 | 254 | if args.load: 255 | state = torch.load(args.load) 256 | vectorizer.word_dict = state["word_dic"] 257 | net = HierarchicalDoc(ntoken=len(state["word_dic"]),emb_size=state["embed.weight"].size(1),hid_size=state["sent.gru.weight_hh_l0"].size(1),num_class=state["lin_out.weight"].size(0)) 258 | del state["word_dic"] 259 | net.load_state_dict(state) 260 | else: 261 | 262 | if args.emb: 263 | tensor,dic = load_embeddings(args.emb) 264 | print(len(dic)) 265 | net = HierarchicalDoc(ntoken=len(dic),emb_size=len(tensor[1]),hid_size=args.hid_size,num_class=num_class) 266 | net.set_emb_tensor(torch.FloatTensor(tensor)) 267 | vectorizer.word_dict = dic 268 | else: 269 | print(25*"-" + "\nBuilding word vectors: \n"+"-"*25) 270 | vectorizer.build_dict(train_set.field_iter(0),args.max_feat) 271 | net = HierarchicalDoc(ntoken=len(vectorizer.word_dict), emb_size=args.emb_size,hid_size=args.hid_size, num_class=num_class) 272 | 273 | 274 | tuple_batch = tuple_batcher_builder(vectorizer,trim=True) 275 | tuple_batch_test = tuple_batcher_builder(vectorizer,trim=True) 276 | 277 | 278 | 279 | sampler = None 280 | if args.balance: 281 | sampler = BucketSampler(train_set) 282 | sampler_t = BucketSampler(test_set) 283 | 284 | 285 | dataloader = DataLoader(train_set, batch_size=args.b_size, shuffle=False, sampler=sampler, num_workers=2, collate_fn=tuple_batch,pin_memory=True) 286 | dataloader_test = DataLoader(test_set, batch_size=args.b_size, shuffle=False, num_workers=2, collate_fn=tuple_batch_test) 287 | else: 288 | dataloader = DataLoader(train_set, batch_size=args.b_size, shuffle=True, num_workers=2, collate_fn=tuple_batch,pin_memory=True) 289 | dataloader_test = DataLoader(test_set, batch_size=args.b_size, shuffle=True, num_workers=2, collate_fn=tuple_batch_test) 290 | 291 | 292 | if args.weight_classes: 293 | criterion = torch.nn.CrossEntropyLoss(weight=class_weight) 294 | else: 295 | criterion = torch.nn.CrossEntropyLoss() 296 | 297 | 298 | 299 | if args.cuda: 300 | net.cuda() 301 | 302 | print("-"*20) 303 | 304 | 305 | 306 | check_memory(args.max_sents,args.max_words,net.emb_size,args.b_size,args.cuda) 307 | 308 | optimizer = optim.Adam(net.parameters())#,lr=args.lr,momentum=args.momentum) 309 | torch.nn.utils.clip_grad_norm(net.parameters(), args.clip_grad) 310 | 311 | 312 | for epoch in range(1, args.epochs + 1): 313 | train(epoch,net,optimizer,dataloader,criterion,args.cuda) 314 | 315 | 316 | if args.snapshot: 317 | print("snapshot of model saved as {}".format(args.save+"_snapshot")) 318 | save(net,vectorizer.word_dict,args.save+"_snapshot") 319 | 320 | test(epoch,net,dataloader_test,args.cuda) 321 | 322 | if args.save: 323 | print("model saved to {}".format(args.save)) 324 | save(net,vectorizer.word_dict,args.save) 325 | 326 | 327 | if __name__ == '__main__': 328 | 329 | parser = argparse.ArgumentParser(description='Hierarchical Attention Networks for Document Classification') 330 | parser.add_argument("--split", type=int, default=0) 331 | parser.add_argument("--emb-size",type=int,default=200) 332 | parser.add_argument("--hid-size",type=int,default=50) 333 | parser.add_argument("--b-size", type=int, default=32) 334 | parser.add_argument("--max-feat", type=int,default=10000) 335 | parser.add_argument("--epochs", type=int,default=10) 336 | parser.add_argument("--clip-grad", type=float,default=1) 337 | parser.add_argument("--lr", type=float, default=0.01) 338 | parser.add_argument("--max-words", type=int,default=32) 339 | parser.add_argument("--max-sents",type=int,default=16) 340 | parser.add_argument("--momentum",type=float,default=0.9) 341 | parser.add_argument("--emb", type=str) 342 | parser.add_argument("--load", type=str) 343 | parser.add_argument("--save", type=str) 344 | parser.add_argument("--snapshot", action='store_true') 345 | parser.add_argument("--weight-classes", action='store_true') 346 | parser.add_argument("--output", type=str) 347 | parser.add_argument('--cuda', action='store_true', 348 | help='use CUDA') 349 | parser.add_argument('--balance', action='store_true', 350 | help='balance class in batches') 351 | parser.add_argument('filename', type=str) 352 | args = parser.parse_args() 353 | 354 | 355 | main(args) 356 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import pickle as pkl 3 | import argparse 4 | from tqdm import tqdm 5 | from random import randint 6 | from collections import Counter 7 | import sys 8 | import collections 9 | import gensim 10 | import logging 11 | import spacy 12 | import itertools 13 | import re 14 | import json 15 | import torch 16 | 17 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) #gensim logging 18 | 19 | 20 | def count_lines(file): 21 | count = 0 22 | for _ in file: 23 | count += 1 24 | file.seek(0) 25 | return count 26 | 27 | 28 | def build_dataset(args): 29 | 30 | def preprocess(datas): 31 | for data in datas: 32 | yield (data['reviewText'],max(1,int(round(float(data["overall"]))))-1) #zero is useless, classes between 0-4 for 1-5 reviews 33 | 34 | def preprocess_rescale(datas): 35 | for data in datas: 36 | rating = max(1,int(round(float(data["overall"]))))-1 37 | 38 | if rating > 3: 39 | rating = 1 40 | elif rating == 3: 41 | yield None 42 | continue 43 | else: 44 | rating = 0 45 | yield (data['reviewText'],rating) #zero is useless 46 | 47 | def data_generator(data): 48 | with gzip.open(args.input,"r") as f: 49 | for x in tqdm(f,desc="Reviews",total=count_lines(f)): 50 | yield json.loads(x) 51 | 52 | class TokIt(collections.Iterator): 53 | def __init__(self, tokenized): 54 | self.tok = tokenized 55 | self.x = 0 56 | self.stop = len(tokenized) 57 | 58 | def __iter__(self): 59 | return self 60 | 61 | def next(self): 62 | if self.x < self.stop: 63 | self.x += 1 64 | return list(w.orth_ for w in self.tok[self.x-1] if len(w.orth_.strip()) >= 1 ) #whitespace shouldn't be a word. 65 | else: 66 | self.x = 0 67 | raise StopIteration 68 | __next__ = next 69 | 70 | 71 | 72 | 73 | print("Building dataset from : {}".format(args.input)) 74 | print("-> Building {} random splits".format(args.nb_splits)) 75 | 76 | nlp = spacy.load('en') 77 | 78 | tokenized = [tok for tok in tqdm(nlp.tokenizer.pipe((x["reviewText"] for x in data_generator(args.input)),batch_size=10000, n_threads=8),desc="Tokenizing")] 79 | 80 | 81 | 82 | if args.create_emb: 83 | w2vmodel = gensim.models.Word2Vec(TokIt(tokenized), size=args.emb_size, window=5, min_count=5, iter=args.epochs, max_vocab_size=args.dic_size, workers=4) 84 | print(len(w2vmodel.wv.vocab)) 85 | w2vmodel.wv.save_word2vec_format(args.emb_file,total_vec=len(w2vmodel.wv.vocab)) 86 | 87 | if args.rescale: 88 | print("-> Rescaling data to 0-1 (3's are discarded)") 89 | data = [dt for dt in tqdm(preprocess_rescale(data_generator(args.input)),desc="Processing") if dt is not None] 90 | else: 91 | data = [dt for dt in tqdm(preprocess(data_generator(args.input)),desc="Processing")] 92 | 93 | 94 | splits = [randint(0,args.nb_splits-1) for _ in range(0,len(data))] 95 | count = Counter(splits) 96 | 97 | print("Split distribution is the following:") 98 | print(count) 99 | 100 | return {"data":data,"splits":splits,"rows":("review","rating")} 101 | 102 | 103 | def main(args): 104 | ds = build_dataset(args) 105 | pkl.dump(ds,open(args.output,"wb")) 106 | 107 | if __name__ == '__main__': 108 | 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument("input", type=str) 111 | parser.add_argument("output", type=str, default="sentences.pkl") 112 | parser.add_argument("--rescale",action="store_true") 113 | parser.add_argument("--nb_splits",type=int, default=5) 114 | 115 | parser.add_argument("--create-emb",action="store_true") 116 | parser.add_argument("--emb-file", type=str, default=None) 117 | parser.add_argument("--emb-size",type=int, default=100) 118 | parser.add_argument("--dic-size", type=int,default=10000000) 119 | parser.add_argument("--epochs", type=int,default=1) 120 | args = parser.parse_args() 121 | 122 | if args.emb_file is None: 123 | args.emb_file = args.output + "_emb.txt" 124 | 125 | main(args) 126 | --------------------------------------------------------------------------------