├── README.md ├── tf_data_utils.py ├── tf_sentimentmain.py ├── tf_seq_lstm.py ├── tf_tree_lstm.py └── tf_treenode.py /README.md: -------------------------------------------------------------------------------- 1 | Tensorflow implementation of Recursive Neural Networks using LSTM units as 2 | described in "Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks" by Kai Sheng Tai, Richard Socher, and Christopher D. Manning. 3 | 4 | Please download the relevant data before running this code i.e. Standford Sentiment Treebank 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /tf_data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from tf_treenode import tNode,processTree 4 | import numpy as np 5 | import os 6 | import random 7 | 8 | class Vocab(object): 9 | 10 | def __init__(self,path): 11 | self.words = [] 12 | self.word2idx={} 13 | self.idx2word={} 14 | 15 | self.load(path) 16 | 17 | 18 | def load(self,path): 19 | 20 | with open(path,'r') as f: 21 | for line in f: 22 | w=line.strip() 23 | assert w not in self.words 24 | self.words.append(w) 25 | self.word2idx[w] = len(self.words) -1 # 0 based index 26 | self.idx2word[self.word2idx[w]]=w 27 | 28 | def __len__(self): 29 | return len(self.words) 30 | 31 | def encode(self,word): 32 | #if word not in self.words: 33 | #word = self.unk_word 34 | return self.word2idx[word] 35 | 36 | def decode(self,idx): 37 | assert idx >= len(self.words) 38 | return self.idx2word[idx] 39 | def size(self): 40 | return len(self.words) 41 | 42 | 43 | def load_sentiment_treebank(data_dir,fine_grained): 44 | voc=Vocab(os.path.join(data_dir,'vocab-cased.txt')) 45 | 46 | split_paths={} 47 | for split in ['train','test','dev']: 48 | split_paths[split]=os.path.join(data_dir,split) 49 | 50 | fnlist=[tNode.encodetokens,tNode.relabel] 51 | arglist=[voc.encode,fine_grained] 52 | #fnlist,arglist=[tNode.relabel],[fine_grained] 53 | 54 | data={} 55 | for split,path in split_paths.iteritems(): 56 | sentencepath=os.path.join(path,'sents.txt') 57 | treepath=os.path.join(path,'parents.txt') 58 | labelpath=os.path.join(path,'labels.txt') 59 | trees=parse_trees(sentencepath,treepath,labelpath) 60 | if not fine_grained: 61 | trees=[tree for tree in trees if tree.label != 0] 62 | trees = [(processTree(tree,fnlist,arglist),tree.label) for tree in trees] 63 | data[split]=trees 64 | 65 | return data,voc 66 | 67 | 68 | def parse_trees(sentencepath, treepath,labelpath): 69 | trees=[] 70 | with open(treepath,'r') as ft, open (labelpath) as fl, open( 71 | sentencepath,'r') as f: 72 | while True: 73 | parentidxs = ft.readline() 74 | labels = fl.readline() 75 | sentence=f.readline() 76 | if not parentidxs or not labels or not sentence: 77 | break 78 | parentidxs=[int(p) for p in parentidxs.strip().split() ] 79 | labels=[int(l) if l != '#' else None for l in labels.strip().split()] 80 | 81 | tree=parse_tree(sentence,parentidxs,labels) 82 | trees.append(tree) 83 | return trees 84 | 85 | 86 | 87 | def parse_tree(sentence,parents,labels): 88 | nodes = {} 89 | parents = [p - 1 for p in parents] #change to zero based 90 | sentence=[w for w in sentence.strip().split()] 91 | for i in xrange(len(parents)): 92 | if i not in nodes: 93 | idx = i 94 | prev = None 95 | while True: 96 | node = tNode(idx) 97 | if prev is not None: 98 | assert prev.idx != node.idx 99 | node.add_child(prev) 100 | 101 | node.label = labels[idx] 102 | nodes[idx] = node 103 | 104 | if idx < len(sentence): 105 | node.word = sentence[idx] 106 | 107 | 108 | parent = parents[idx] 109 | if parent in nodes: 110 | assert len(nodes[parent].children) < 2 111 | nodes[parent].add_child(node) 112 | break 113 | elif parent == -1: 114 | root = node 115 | break 116 | 117 | prev = node 118 | idx = parent 119 | 120 | return root 121 | 122 | def BFStree(root): 123 | from collections import deque 124 | node=root 125 | leaves=[] 126 | inodes=[] 127 | queue=deque([node]) 128 | func=lambda node:node.children==[] 129 | 130 | while queue: 131 | node=queue.popleft() 132 | if func(node): 133 | leaves.append(node) 134 | else: 135 | inodes.append(node) 136 | if node.children: 137 | queue.extend(node.children) 138 | 139 | return leaves,inodes 140 | 141 | def extract_tree_data(tree,max_degree=2,only_leaves_have_vals=True,with_labels=False): 142 | #processTree(tree) 143 | #fnlist=[tree.encodetokens,tree.relabel] 144 | #arglist=[voc.encode,fine_grained] 145 | #processTree(tree,fnlist,arglist) 146 | leaves,inodes=BFStree(tree) 147 | labels=[] 148 | leaf_emb=[] 149 | tree_str=[] 150 | i=0 151 | for leaf in reversed(leaves): 152 | leaf.idx = i 153 | i+=1 154 | labels.append(leaf.label) 155 | leaf_emb.append(leaf.word) 156 | for node in reversed(inodes): 157 | node.idx=i 158 | c=[child.idx for child in node.children] 159 | tree_str.append(c) 160 | labels.append(node.label) 161 | if not only_leaves_have_vals: 162 | leaf_emb.append(-1) 163 | i+=1 164 | if with_labels: 165 | labels_exist = [l is not None for l in labels] 166 | labels = [l or 0 for l in labels] 167 | return (np.array(leaf_emb,dtype='int32'), 168 | np.array(tree_str,dtype='int32'), 169 | np.array(labels,dtype=float), 170 | np.array(labels_exist,dtype=float)) 171 | else: 172 | print leaf_emb,'asas' 173 | return (np.array(leaf_emb,dtype='int32'), 174 | np.array(tree_str,dtype='int32')) 175 | 176 | def extract_batch_tree_data(batchdata,fillnum=120): 177 | 178 | dim1,dim2=len(batchdata),fillnum 179 | #leaf_emb_arr,treestr_arr,labels_arr=[],[],[] 180 | leaf_emb_arr = np.empty([dim1,dim2],dtype='int32') 181 | leaf_emb_arr.fill(-1) 182 | treestr_arr = np.empty([dim1,dim2,2],dtype='int32') 183 | treestr_arr.fill(-1) 184 | labels_arr = np.empty([dim1,dim2],dtype=float) 185 | labels_arr.fill(-1) 186 | for i,(tree,_) in enumerate(batchdata): 187 | input_,treestr,labels,_=extract_tree_data(tree, 188 | max_degree=2, 189 | only_leaves_have_vals=False, 190 | with_labels = True) 191 | leaf_emb_arr[i,0:len(input_)]=input_ 192 | treestr_arr[i,0:len(treestr),0:2]=treestr 193 | labels_arr[i,0:len(labels)]=labels 194 | 195 | return leaf_emb_arr,treestr_arr,labels_arr 196 | 197 | def extract_seq_data(data,numsamples=0,fillnum=100): 198 | seqdata=[] 199 | seqlabels=[] 200 | for tree,_ in data: 201 | seq,seqlbls=extract_seq_from_tree(tree,numsamples) 202 | seqdata.extend(seq) 203 | seqlabels.extend(seqlbls) 204 | 205 | seqlngths=[len(s) for s in seqdata] 206 | maxl=max(seqlngths) 207 | assert fillnum >=maxl 208 | if 1: 209 | seqarr=np.empty([len(seqdata),fillnum],dtype='int32') 210 | seqarr.fill(-1) 211 | for i,s in enumerate(seqdata): 212 | seqarr[i,0:len(s)]=np.array(s,dtype='int32') 213 | seqdata=seqarr 214 | return seqdata,seqlabels,seqlngths,maxl 215 | 216 | def extract_seq_from_tree(tree,numsamples=0): 217 | 218 | if tree.span is None: 219 | tree.postOrder(tree,tree.get_spans) 220 | 221 | seq,lbl=[],[] 222 | s,l=tree.span,tree.label 223 | seq.append(s) 224 | lbl.append(l) 225 | 226 | if not numsamples: 227 | return seq,lbl 228 | 229 | 230 | num_nodes = tree.idx 231 | if numsamples==-1: 232 | numsamples=num_nodes 233 | #numsamples=min(numsamples,num_nodes) 234 | #sampled_idxs = random.sample(range(num_nodes),numsamples) 235 | #sampled_idxs=range(num_nodes) 236 | #print sampled_idxs,num_nodes 237 | 238 | subtrees={} 239 | #subtrees[tree.idx]= 240 | #func=lambda tr,su:su.update([(tr.idx,tr)]) 241 | def func_(self,su): 242 | su.update([(self.idx,self)]) 243 | 244 | tree.postOrder(tree,func_,subtrees) 245 | 246 | for j in xrange(numsamples):#sampled_idxs: 247 | i=random.randint(0,num_nodes) 248 | root = subtrees[i] 249 | s,l=root.span,root.label 250 | seq.append(s) 251 | lbl.append(l) 252 | 253 | return seq,lbl 254 | 255 | def get_max_len_data(datadic): 256 | maxlen=0 257 | for data in datadic.values(): 258 | for tree,_ in data: 259 | tree.postOrder(tree,tree.get_numleaves) 260 | assert tree.num_leaves > 1 261 | if tree.num_leaves > maxlen: 262 | maxlen=tree.num_leaves 263 | 264 | return maxlen 265 | 266 | def get_max_node_size(datadic): 267 | maxsize=0 268 | for data in datadic.values(): 269 | for tree,_ in data: 270 | tree.postOrder(tree,tree.get_size) 271 | assert tree.size > 1 272 | if tree.size > maxsize: 273 | maxsize=tree.size 274 | 275 | return maxsize 276 | 277 | def test_fn(): 278 | data_dir='./stanford_lstm/data/sst' 279 | fine_grained=0 280 | data,_=load_sentiment_treebank(data_dir,fine_grained) 281 | for d in data.itervalues(): 282 | print len(d) 283 | 284 | d=data['dev'] 285 | a,b,c,_=extract_seq_data(d[0:1],5) 286 | print a,b,c 287 | 288 | print get_max_len_data(data) 289 | return data 290 | if __name__=='__main__': 291 | test_fn() 292 | -------------------------------------------------------------------------------- /tf_sentimentmain.py: -------------------------------------------------------------------------------- 1 | import tf_data_utils as utils 2 | 3 | import os 4 | import sys 5 | import numpy as np 6 | import tensorflow as tf 7 | import random 8 | import pickle 9 | 10 | import tf_seq_lstm 11 | import tf_tree_lstm 12 | 13 | DIR = './project_data/sst/' 14 | GLOVE_DIR ='./' 15 | 16 | 17 | 18 | 19 | import pdb 20 | import time 21 | 22 | #from tf_data_utils import extract_tree_data,load_sentiment_treebank 23 | 24 | class Config(object): 25 | 26 | num_emb=None 27 | 28 | emb_dim = 300 29 | hidden_dim = 150 30 | output_dim=None 31 | degree = 2 32 | 33 | num_epochs = 1 34 | early_stopping = 2 35 | dropout = 0.5 36 | lr = 0.05 37 | emb_lr = 0.1 38 | reg=0.0001 39 | 40 | batch_size = 5 41 | #num_steps = 10 42 | maxseqlen = None 43 | maxnodesize = None 44 | fine_grained=False 45 | trainable_embeddings=True 46 | nonroot_labels=True 47 | #dependency=True not supported 48 | 49 | def train(restore=False): 50 | 51 | config=Config() 52 | 53 | 54 | data,vocab = utils.load_sentiment_treebank(DIR,config.fine_grained) 55 | 56 | train_set, dev_set, test_set = data['train'], data['dev'], data['test'] 57 | print 'train', len(train_set) 58 | print 'dev', len(dev_set) 59 | print 'test', len(test_set) 60 | 61 | num_emb = len(vocab) 62 | num_labels = 5 if config.fine_grained else 3 63 | for _, dataset in data.items(): 64 | labels = [label for _, label in dataset] 65 | assert set(labels) <= set(xrange(num_labels)), set(labels) 66 | print 'num emb', num_emb 67 | print 'num labels', num_labels 68 | 69 | config.num_emb=num_emb 70 | config.output_dim = num_labels 71 | 72 | config.maxseqlen=utils.get_max_len_data(data) 73 | config.maxnodesize=utils.get_max_node_size(data) 74 | 75 | print config.maxnodesize,config.maxseqlen ," maxsize" 76 | #return 77 | random.seed() 78 | np.random.seed() 79 | 80 | 81 | with tf.Graph().as_default(): 82 | 83 | #model = tf_seq_lstm.tf_seqLSTM(config) 84 | model = tf_tree_lstm.tf_NarytreeLSTM(config) 85 | 86 | init=tf.initialize_all_variables() 87 | saver = tf.train.Saver() 88 | best_valid_score=0.0 89 | best_valid_epoch=0 90 | dev_score=0.0 91 | test_score=0.0 92 | with tf.Session() as sess: 93 | 94 | sess.run(init) 95 | start_time=time.time() 96 | 97 | if restore:saver.restore(sess,'./ckpt/tree_rnn_weights') 98 | for epoch in range(config.num_epochs): 99 | print 'epoch', epoch 100 | avg_loss=0.0 101 | avg_loss = train_epoch(model, train_set,sess) 102 | print 'avg loss', avg_loss 103 | 104 | dev_score=evaluate(model,dev_set,sess) 105 | print 'dev-scoer', dev_score 106 | 107 | if dev_score > best_valid_score: 108 | best_valid_score=dev_score 109 | best_valid_epoch=epoch 110 | saver.save(sess,'./ckpt/tree_rnn_weights') 111 | 112 | if epoch -best_valid_epoch > config.early_stopping: 113 | break 114 | 115 | print "time per epochis {0}".format( 116 | time.time()-start_time) 117 | test_score = evaluate(model,test_set,sess) 118 | print test_score,'test_score' 119 | 120 | def train_epoch(model,data,sess): 121 | 122 | loss=model.train(data,sess) 123 | return loss 124 | 125 | def evaluate(model,data,sess): 126 | acc=model.evaluate(data,sess) 127 | return acc 128 | 129 | if __name__ == '__main__': 130 | if len(sys.argv) > 1: 131 | restore=True 132 | else:restore=False 133 | train(restore) 134 | 135 | -------------------------------------------------------------------------------- /tf_seq_lstm.py: -------------------------------------------------------------------------------- 1 | 2 | #from __future__ import print_function 3 | import numpy as np 4 | import tensorflow as tf 5 | import os 6 | import sys 7 | 8 | from tensorflow.python.ops import rnn_cell,rnn 9 | from tf_data_utils import extract_seq_data 10 | 11 | 12 | class tf_seqLSTM(object): 13 | 14 | def add_placeholders(self): 15 | 16 | self.batch_len = tf.placeholder(tf.int32,name="batch_len") 17 | 18 | self.max_time = tf.placeholder(tf.int32,name="max_time") 19 | dim1=self.config.batch_size*(1+self.internal) 20 | self.input = tf.placeholder(tf.int32,shape=[None,self.config.maxseqlen],name="input") 21 | 22 | self.labels = tf.placeholder(tf.int32,shape=None 23 | ,name="labels") 24 | 25 | self.dropout = tf.placeholder(tf.float32,name="dropout") 26 | 27 | self.lngths = tf.placeholder(tf.int32,shape=None 28 | ,name="lnghts") 29 | 30 | 31 | def __init__(self,config 32 | ): 33 | self.emb_dim = config.emb_dim 34 | self.hidden_dim = config.hidden_dim 35 | self.num_emb = config.num_emb 36 | self.output_dim = config.output_dim 37 | self.config=config 38 | self.batch_size=config.batch_size 39 | self.reg=self.config.reg 40 | self.internal=4 #paramter for sampling sequences coresponding to subtrees 41 | assert self.emb_dim > 1 and self.hidden_dim > 1 42 | 43 | self.add_placeholders() 44 | 45 | #self.cell = rnn_cell.LSTMCell(self.hidden_dim) 46 | 47 | emb_input = self.add_embedding() 48 | 49 | #self.add_model_variables() 50 | 51 | output_states = self.compute_states(emb_input) 52 | 53 | logits = self.create_output(output_states) 54 | 55 | self.pred = tf.nn.softmax(logits) 56 | 57 | self.loss,self.total_loss = self.calc_loss(logits) 58 | 59 | self.train_op1,self.train_op2 = self.add_training_op() 60 | 61 | def add_embedding(self): 62 | #embed=np.load('glove{0}_uniform.npy'.format(self.emb_dim)) 63 | 64 | with tf.device('/cpu:0'): 65 | with tf.variable_scope("Embed"): 66 | embedding=tf.get_variable('embedding',[self.num_emb, 67 | self.emb_dim] 68 | ,initializer= 69 | tf.random_uniform_initializer(-0.05,0.05),trainable=True, 70 | regularizer=tf.contrib.layers.l2_regularizer(0.0)) 71 | ix=tf.to_int32(tf.not_equal(self.input,-1))*self.input 72 | emb = tf.nn.embedding_lookup(embedding,ix) 73 | emb = emb * tf.to_float(tf.not_equal(tf.expand_dims(self.input,2),-1)) 74 | return emb 75 | 76 | def compute_states(self,emb): 77 | 78 | def unpack_sequence(tensor): 79 | return tf.unpack(tf.transpose(tensor, perm=[1, 0, 2])) 80 | 81 | 82 | with tf.variable_scope("Composition",initializer= 83 | tf.contrib.layers.xavier_initializer(),regularizer= 84 | tf.contrib.layers.l2_regularizer(self.reg)): 85 | cell = rnn_cell.LSTMCell(self.hidden_dim) 86 | #tf.cond(tf.less(self.dropout 87 | #if tf.less(self.dropout, tf.constant(1.0)): 88 | cell = rnn_cell.DropoutWrapper(cell, 89 | output_keep_prob=self.dropout,input_keep_prob=self.dropout) 90 | #output, state = rnn.dynamic_rnn(cell,emb,sequence_length=self.lngths,dtype=tf.float32) 91 | outputs,_=rnn.rnn(cell,unpack_sequence(emb),sequence_length=self.lngths,dtype=tf.float32) 92 | #output = pack_sequence(outputs) 93 | 94 | sum_out=tf.reduce_sum(tf.pack(outputs),[0]) 95 | sent_rep = tf.div(sum_out,tf.expand_dims(tf.to_float(self.lngths),1)) 96 | final_state=sent_rep 97 | return final_state 98 | 99 | def create_output(self,rnn_out): 100 | 101 | with tf.variable_scope("Projection",regularizer= 102 | tf.contrib.layers.l2_regularizer(self.reg)): 103 | U = tf.get_variable("U",[self.output_dim,self.hidden_dim], 104 | initializer=tf.random_uniform_initializer( 105 | -0.05,0.05)) 106 | bu = tf.get_variable("bu",[self.output_dim],initializer= 107 | tf.constant_initializer(0.0), 108 | regularizer=tf.contrib.layers.l2_regularizer(0.0)) 109 | 110 | logits=tf.matmul(rnn_out,U,transpose_b=True)+bu 111 | 112 | return logits 113 | 114 | def calc_loss(self,logits): 115 | 116 | l1=tf.nn.sparse_softmax_cross_entropy_with_logits( 117 | logits,self.labels) 118 | loss=tf.reduce_sum(l1,[0]) 119 | reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 120 | regpart=tf.add_n(reg_losses) 121 | total_loss=loss+0.5*regpart 122 | return loss,total_loss 123 | 124 | def add_training_op_old(self): 125 | 126 | opt = tf.train.AdagradOptimizer(self.config.lr) 127 | train_op = opt.minimize(self.total_loss) 128 | return train_op 129 | 130 | 131 | def add_training_op(self): 132 | loss=self.total_loss 133 | opt1=tf.train.AdagradOptimizer(self.config.lr) 134 | opt2=tf.train.AdagradOptimizer(self.config.emb_lr) 135 | 136 | ts=tf.trainable_variables() 137 | gs=tf.gradients(loss,ts) 138 | gs_ts=zip(gs,ts) 139 | 140 | gt_emb,gt_nn=[],[] 141 | for g,t in gs_ts: 142 | if "embedding" in t.name: 143 | gt_emb.append((g,t)) 144 | else: 145 | gt_nn.append((g,t)) 146 | 147 | train_op2=opt2.apply_gradients(gt_emb) 148 | train_op1=opt1.apply_gradients(gt_nn) 149 | 150 | train_op=[train_op1,train_op2] 151 | 152 | return train_op 153 | 154 | def train(self,data,sess,isTree=True): 155 | 156 | from random import shuffle 157 | shuffle(data) 158 | losses=[] 159 | for i in range(0,len(data),self.batch_size): 160 | batch_size = min(i+self.batch_size,len(data))-i 161 | batch_data=data[i:i+batch_size] 162 | 163 | seqdata,seqlabels,seqlngths,max_len=extract_seq_data(batch_data 164 | ,self.internal,self.config.maxseqlen) 165 | feed={self.input:seqdata,self.labels:seqlabels, 166 | self.dropout:self.config.dropout,self.lngths: 167 | seqlngths,self.batch_len:len(seqdata),self.max_time:max_len} 168 | #loss,_=sess.run([self.loss,self.train_op],feed_dict=feed) 169 | loss,_,_=sess.run([self.loss,self.train_op1,self.train_op2],feed_dict=feed) 170 | #sess.run(self.train_op,feed_dict=feed) 171 | 172 | losses.append(loss) 173 | avg_loss=np.mean(losses) 174 | sstr='avg loss %.2f at example %d of %d\r' % (avg_loss, i, len(data)) 175 | sys.stdout.write(sstr) 176 | sys.stdout.flush() 177 | #if i>100: break 178 | return np.mean(losses) 179 | 180 | def evaluate(self,data,sess): 181 | num_correct=0 182 | total_data=0 183 | for i in range(0,len(data),self.batch_size): 184 | batch_size = min(i+self.batch_size,len(data))-i 185 | batch_data=data[i:i+batch_size] 186 | 187 | seqdata,seqlabels,seqlngths,max_len=extract_seq_data(batch_data 188 | ,0,self.config.maxseqlen) 189 | feed={self.input:seqdata,self.labels:seqlabels, 190 | self.dropout:1.0,self.lngths: 191 | seqlngths,self.batch_len:len(seqdata),self.max_time:max_len} 192 | pred=sess.run(self.pred,feed_dict=feed) 193 | y=np.argmax(pred,axis=1) 194 | #print y,seqlabels,pred 195 | #print y,seqlabels,pred 196 | for i,v in enumerate(y): 197 | if seqlabels[i]==v: 198 | num_correct+=1 199 | total_data+=1 200 | acc=float(num_correct)/float(total_data) 201 | return acc 202 | 203 | 204 | 205 | 206 | 207 | class tf_seqbiLSTM(tf_seqLSTM): 208 | 209 | def add_training_op(self,loss): 210 | 211 | opt = tf.train.AdagradOptimizer(self.config.lr) 212 | train_op = opt.minimize(loss) 213 | return train_op 214 | 215 | def compute_states(self,emb): 216 | def unpack_sequence(tensor): 217 | return tf.unpack(tf.transpose(tensor, perm=[1, 0, 2])) 218 | 219 | 220 | 221 | with tf.variable_scope("Composition",initializer= 222 | tf.contrib.layers.xavier_initializer(),regularizer= 223 | tf.contrib.layers.l2_regularizer(self.reg)): 224 | cell_fw = rnn_cell.LSTMCell(self.hidden_dim) 225 | cell_bw = rnn_cell.LSTMCell(self.hidden_dim) 226 | #tf.cond(tf.less(self.dropout 227 | #if tf.less(self.dropout, tf.constant(1.0)): 228 | cell_fw = rnn_cell.DropoutWrapper(cell_fw, 229 | output_keep_prob=self.dropout,input_keep_prob=self.dropout) 230 | cell_bw=rnn_cell.DropoutWrapper(cell_bw, output_keep_prob=self.dropout,input_keep_prob=self.dropout) 231 | 232 | #output, state = rnn.dynamic_rnn(cell,emb,sequence_length=self.lngths,dtype=tf.float32) 233 | outputs,_,_=rnn.bidirectional_rnn(cell_fw,cell_bw,unpack_sequence(emb),sequence_length=self.lngths,dtype=tf.float32) 234 | #output = pack_sequence(outputs) 235 | sum_out=tf.reduce_sum(tf.pack(outputs),[0]) 236 | sent_rep = tf.div(sum_out,tf.expand_dims(tf.to_float(self.lngths),1)) 237 | 238 | 239 | 240 | final_state=sent_rep 241 | return final_state 242 | 243 | 244 | 245 | 246 | def create_output(self,rnn_out): 247 | 248 | with tf.variable_scope("Projection",regularizer= 249 | tf.contrib.layers.l2_regularizer(self.reg)): 250 | U = tf.get_variable("U",[self.output_dim,2*self.hidden_dim], 251 | initializer=tf.random_uniform_initializer( 252 | -0.05,0.05)) 253 | bu = tf.get_variable("bu",[self.output_dim],initializer= 254 | tf.constant_initializer(0.0), 255 | regularizer=tf.contrib.layers.l2_regularizer(0.0)) 256 | 257 | logits=tf.matmul(rnn_out,U,transpose_b=True)+bu 258 | 259 | return logits 260 | 261 | 262 | -------------------------------------------------------------------------------- /tf_tree_lstm.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import tensorflow as tf 4 | import os 5 | import sys 6 | 7 | from tf_data_utils import extract_tree_data,extract_batch_tree_data 8 | 9 | 10 | 11 | 12 | class tf_NarytreeLSTM(object): 13 | 14 | def __init__(self,config): 15 | self.emb_dim = config.emb_dim 16 | self.hidden_dim = config.hidden_dim 17 | self.num_emb = config.num_emb 18 | self.output_dim = config.output_dim 19 | self.config=config 20 | self.batch_size=config.batch_size 21 | self.reg=self.config.reg 22 | self.degree=config.degree 23 | assert self.emb_dim > 1 and self.hidden_dim > 1 24 | 25 | self.add_placeholders() 26 | 27 | emb_leaves = self.add_embedding() 28 | 29 | self.add_model_variables() 30 | 31 | batch_loss = self.compute_loss(emb_leaves) 32 | 33 | self.loss,self.total_loss=self.calc_batch_loss(batch_loss) 34 | 35 | self.train_op1,self.train_op2 = self.add_training_op() 36 | #self.train_op=tf.no_op() 37 | 38 | def add_embedding(self): 39 | 40 | #embed=np.load('glove{0}_uniform.npy'.format(self.emb_dim)) 41 | with tf.variable_scope("Embed",regularizer=None): 42 | embedding=tf.get_variable('embedding',[self.num_emb, 43 | self.emb_dim] 44 | ,initializer=tf.random_uniform_initializer(-0.05,0.05),trainable=True,regularizer=None) 45 | ix=tf.to_int32(tf.not_equal(self.input,-1))*self.input 46 | emb_tree=tf.nn.embedding_lookup(embedding,ix) 47 | emb_tree=emb_tree*(tf.expand_dims( 48 | tf.to_float(tf.not_equal(self.input,-1)),2)) 49 | 50 | return emb_tree 51 | 52 | 53 | def add_placeholders(self): 54 | dim2=self.config.maxnodesize 55 | dim1=self.config.batch_size 56 | self.input = tf.placeholder(tf.int32,[dim1,dim2],name='input') 57 | self.treestr = tf.placeholder(tf.int32,[dim1,dim2,2],name='tree') 58 | self.labels = tf.placeholder(tf.int32,[dim1,dim2],name='labels') 59 | self.dropout = tf.placeholder(tf.float32,name='dropout') 60 | 61 | self.n_inodes = tf.reduce_sum(tf.to_int32(tf.not_equal(self.treestr,-1)),[1,2]) 62 | self.n_inodes = self.n_inodes/2 63 | 64 | self.num_leaves = tf.reduce_sum(tf.to_int32(tf.not_equal(self.input,-1)),[1]) 65 | self.batch_len = tf.placeholder(tf.int32,name="batch_len") 66 | 67 | def calc_wt_init(self,fan_in=300): 68 | eps=1.0/np.sqrt(fan_in) 69 | return eps 70 | 71 | def add_model_variables(self): 72 | 73 | with tf.variable_scope("Composition", 74 | initializer= 75 | tf.contrib.layers.xavier_initializer(), 76 | regularizer= 77 | tf.contrib.layers.l2_regularizer(self.config.reg 78 | )): 79 | 80 | cU = tf.get_variable("cU",[self.emb_dim,2*self.hidden_dim],initializer=tf.random_uniform_initializer(-self.calc_wt_init(),self.calc_wt_init())) 81 | cW = tf.get_variable("cW",[self.degree*self.hidden_dim,(self.degree+3)*self.hidden_dim],initializer=tf.random_uniform_initializer(-self.calc_wt_init(self.hidden_dim),self.calc_wt_init(self.hidden_dim))) 82 | cb = tf.get_variable("cb",[4*self.hidden_dim],initializer=tf.constant_initializer(0.0),regularizer=tf.contrib.layers.l2_regularizer(0.0)) 83 | #cU = tf.get_variable("cU",[self.emb_dim,2*self.hidden_dim]) 84 | #cW = tf.get_variable("cW",[self.degree*self.hidden_dim,(self.degree+3)*self.hidden_dim]) 85 | #cb = tf.get_variable("cb",[4*self.hidden_dim],initializer=tf.constant_initializer(0.0),regularizer=tf.contrib.layers.l2_regularizer(0.0)) 86 | with tf.variable_scope("Projection",regularizer=tf.contrib.layers.l2_regularizer(self.config.reg)): 87 | 88 | U = tf.get_variable("U",[self.output_dim,self.hidden_dim], 89 | initializer=tf.random_uniform_initializer(self.calc_wt_init(self.hidden_dim),self.calc_wt_init(self.hidden_dim)) 90 | ) 91 | bu = tf.get_variable("bu",[self.output_dim],initializer= 92 | tf.constant_initializer(0.0),regularizer=tf.contrib.layers.l2_regularizer(0.0)) 93 | 94 | def process_leafs(self,emb): 95 | 96 | with tf.variable_scope("Composition",reuse=True): 97 | cU = tf.get_variable("cU",[self.emb_dim,2*self.hidden_dim]) 98 | cb = tf.get_variable("cb",[4*self.hidden_dim]) 99 | b = tf.slice(cb,[0],[2*self.hidden_dim]) 100 | def _recurseleaf(x): 101 | 102 | concat_uo = tf.matmul(tf.expand_dims(x,0),cU) + b 103 | u,o = tf.split(1,2,concat_uo) 104 | o=tf.nn.sigmoid(o) 105 | u=tf.nn.tanh(u) 106 | 107 | c = u#tf.squeeze(u) 108 | h = o * tf.nn.tanh(c) 109 | 110 | 111 | hc = tf.concat(1,[h,c]) 112 | hc=tf.squeeze(hc) 113 | return hc 114 | 115 | hc = tf.map_fn(_recurseleaf,emb) 116 | return hc 117 | 118 | 119 | def compute_loss(self,emb_batch,curr_batch_size=None): 120 | outloss=[] 121 | prediction=[] 122 | for idx_batch in range(self.config.batch_size): 123 | 124 | tree_states=self.compute_states(emb_batch,idx_batch) 125 | logits = self.create_output(tree_states) 126 | 127 | labels1=tf.gather(self.labels,idx_batch) 128 | labels2=tf.reduce_sum(tf.to_int32(tf.not_equal(labels1,-1))) 129 | labels=tf.gather(labels1,tf.range(labels2)) 130 | loss = self.calc_loss(logits,labels) 131 | 132 | 133 | pred = tf.nn.softmax(logits) 134 | 135 | pred_root=tf.gather(pred,labels2-1) 136 | 137 | 138 | prediction.append(pred_root) 139 | outloss.append(loss) 140 | 141 | batch_loss=tf.pack(outloss) 142 | self.pred = tf.pack(prediction) 143 | 144 | return batch_loss 145 | 146 | 147 | def compute_states(self,emb,idx_batch=0): 148 | 149 | 150 | num_leaves = tf.squeeze(tf.gather(self.num_leaves,idx_batch)) 151 | #num_leaves=tf.Print(num_leaves,[num_leaves]) 152 | n_inodes = tf.gather(self.n_inodes,idx_batch) 153 | #embx=tf.gather(emb,tf.range(num_leaves)) 154 | embx=tf.gather(tf.gather(emb,idx_batch),tf.range(num_leaves)) 155 | #treestr=self.treestr#tf.gather(self.treestr,tf.range(self.n_inodes)) 156 | treestr=tf.gather(tf.gather(self.treestr,idx_batch),tf.range(n_inodes)) 157 | leaf_hc = self.process_leafs(embx) 158 | leaf_h,leaf_c=tf.split(1,2,leaf_hc) 159 | 160 | 161 | node_h=tf.identity(leaf_h) 162 | node_c=tf.identity(leaf_c) 163 | 164 | idx_var=tf.constant(0) #tf.Variable(0,trainable=False) 165 | 166 | with tf.variable_scope("Composition",reuse=True): 167 | 168 | cW = tf.get_variable("cW",[self.degree*self.hidden_dim,(self.degree+3)*self.hidden_dim]) 169 | cb = tf.get_variable("cb",[4*self.hidden_dim]) 170 | bu,bo,bi,bf=tf.split(0,4,cb) 171 | 172 | def _recurrence(node_h,node_c,idx_var): 173 | node_info=tf.gather(treestr,idx_var) 174 | 175 | child_h=tf.gather(node_h,node_info) 176 | child_c=tf.gather(node_c,node_info) 177 | 178 | flat_ = tf.reshape(child_h,[-1]) 179 | tmp=tf.matmul(tf.expand_dims(flat_,0),cW) 180 | u,o,i,fl,fr=tf.split(1,5,tmp) 181 | 182 | i=tf.nn.sigmoid(i+bi) 183 | o=tf.nn.sigmoid(o+bo) 184 | u=tf.nn.tanh(u+bu) 185 | fl=tf.nn.sigmoid(fl+bf) 186 | fr=tf.nn.sigmoid(fr+bf) 187 | 188 | f=tf.concat(0,[fl,fr]) 189 | c = i * u + tf.reduce_sum(f*child_c,[0]) 190 | h = o * tf.nn.tanh(c) 191 | 192 | node_h = tf.concat(0,[node_h,h]) 193 | 194 | node_c = tf.concat(0,[node_c,c]) 195 | 196 | idx_var=tf.add(idx_var,1) 197 | 198 | return node_h,node_c,idx_var 199 | loop_cond = lambda a1,b1,idx_var: tf.less(idx_var,n_inodes) 200 | 201 | loop_vars=[node_h,node_c,idx_var] 202 | node_h,node_c,idx_var=tf.while_loop(loop_cond, _recurrence, 203 | loop_vars,parallel_iterations=10) 204 | 205 | return node_h 206 | 207 | 208 | def create_output(self,tree_states): 209 | 210 | with tf.variable_scope("Projection",reuse=True): 211 | 212 | U = tf.get_variable("U",[self.output_dim,self.hidden_dim], 213 | ) 214 | bu = tf.get_variable("bu",[self.output_dim]) 215 | 216 | h=tf.matmul(tree_states,U,transpose_b=True)+bu 217 | return h 218 | 219 | 220 | 221 | def calc_loss(self,logits,labels): 222 | 223 | l1=tf.nn.sparse_softmax_cross_entropy_with_logits( 224 | logits,labels) 225 | loss=tf.reduce_sum(l1,[0]) 226 | return loss 227 | 228 | def calc_batch_loss(self,batch_loss): 229 | reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 230 | regpart=tf.add_n(reg_losses) 231 | loss=tf.reduce_mean(batch_loss) 232 | total_loss=loss+0.5*regpart 233 | return loss,total_loss 234 | 235 | def add_training_op_old(self): 236 | 237 | opt = tf.train.AdagradOptimizer(self.config.lr) 238 | train_op = opt.minimize(self.total_loss) 239 | return train_op 240 | 241 | def add_training_op(self): 242 | loss=self.total_loss 243 | opt1=tf.train.AdagradOptimizer(self.config.lr) 244 | opt2=tf.train.AdagradOptimizer(self.config.emb_lr) 245 | 246 | ts=tf.trainable_variables() 247 | gs=tf.gradients(loss,ts) 248 | gs_ts=zip(gs,ts) 249 | 250 | gt_emb,gt_nn=[],[] 251 | for g,t in gs_ts: 252 | #print t.name,g.name 253 | if "Embed/embedding:0" in t.name: 254 | #g=tf.Print(g,[g.get_shape(),t.get_shape()]) 255 | gt_emb.append((g,t)) 256 | #print t.name 257 | else: 258 | gt_nn.append((g,t)) 259 | #print t.name 260 | 261 | train_op1=opt1.apply_gradients(gt_nn) 262 | train_op2=opt2.apply_gradients(gt_emb) 263 | train_op=[train_op1,train_op2] 264 | 265 | return train_op 266 | 267 | 268 | 269 | def train(self,data,sess): 270 | from random import shuffle 271 | data_idxs=range(len(data)) 272 | shuffle(data_idxs) 273 | losses=[] 274 | for i in range(0,len(data),self.batch_size): 275 | batch_size = min(i+self.batch_size,len(data))-i 276 | if batch_size < self.batch_size:break 277 | 278 | batch_idxs=data_idxs[i:i+batch_size] 279 | batch_data=[data[ix] for ix in batch_idxs]#[i:i+batch_size] 280 | 281 | input_b,treestr_b,labels_b=extract_batch_tree_data(batch_data,self.config.maxnodesize) 282 | 283 | feed={self.input:input_b,self.treestr:treestr_b,self.labels:labels_b,self.dropout:self.config.dropout,self.batch_len:len(input_b)} 284 | 285 | loss,_,_=sess.run([self.loss,self.train_op1,self.train_op2],feed_dict=feed) 286 | #sess.run(self.train_op,feed_dict=feed) 287 | 288 | losses.append(loss) 289 | avg_loss=np.mean(losses) 290 | sstr='avg loss %.2f at example %d of %d\r' % (avg_loss, i, len(data)) 291 | sys.stdout.write(sstr) 292 | sys.stdout.flush() 293 | 294 | #if i>1000: break 295 | return np.mean(losses) 296 | 297 | 298 | def evaluate(self,data,sess): 299 | num_correct=0 300 | total_data=0 301 | data_idxs=range(len(data)) 302 | test_batch_size=self.config.batch_size 303 | losses=[] 304 | for i in range(0,len(data),test_batch_size): 305 | batch_size = min(i+test_batch_size,len(data))-i 306 | if batch_size < test_batch_size:break 307 | batch_idxs=data_idxs[i:i+batch_size] 308 | batch_data=[data[ix] for ix in batch_idxs]#[i:i+batch_size] 309 | labels_root=[l for _,l in batch_data] 310 | input_b,treestr_b,labels_b=extract_batch_tree_data(batch_data,self.config.maxnodesize) 311 | 312 | feed={self.input:input_b,self.treestr:treestr_b,self.labels:labels_b,self.dropout:1.0,self.batch_len:len(input_b)} 313 | 314 | pred_y=sess.run(self.pred,feed_dict=feed) 315 | #print pred_y,labels_root 316 | y=np.argmax(pred_y,axis=1) 317 | #num_correct+=np.sum(y==np.array(labels_root)) 318 | for i,v in enumerate(labels_root): 319 | if y[i]==v:num_correct+=1 320 | total_data+=1 321 | #break 322 | 323 | acc=float(num_correct)/float(total_data) 324 | return acc 325 | 326 | 327 | class tf_ChildsumtreeLSTM(tf_NarytreeLSTM): 328 | 329 | 330 | def add_model_variables(self): 331 | with tf.variable_scope("Composition", 332 | initializer= 333 | tf.contrib.layers.xavier_initializer(), 334 | regularizer= 335 | tf.contrib.layers.l2_regularizer(self.config.reg 336 | )): 337 | 338 | cUW = tf.get_variable("cUW",[self.emb_dim+self.hidden_dim,4*self.hidden_dim]) 339 | cb = tf.get_variable("cb",[4*self.hidden_dim],initializer=tf.constant_initializer(0.0),regularizer=tf.contrib.layers.l2_regularizer(0.0)) 340 | 341 | with tf.variable_scope("Projection",regularizer=tf.contrib.layers.l2_regularizer(self.config.reg)): 342 | 343 | U = tf.get_variable("U",[self.output_dim,self.hidden_dim], 344 | initializer=tf.random_uniform_initializer( 345 | -0.05,0.05)) 346 | bu = tf.get_variable("bu",[self.output_dim],initializer= 347 | tf.constant_initializer(0.0),regularizer=tf.contrib.layers.l2_regularizer(0.0)) 348 | 349 | def process_leafs(self,emb): 350 | 351 | with tf.variable_scope("Composition",reuse=True): 352 | cUW = tf.get_variable("cUW") 353 | cb = tf.get_variable("cb") 354 | U = tf.slice(cUW,[0,0],[self.emb_dim,2*self.hidden_dim]) 355 | b = tf.slice(cb,[0],[2*self.hidden_dim]) 356 | def _recurseleaf(x): 357 | 358 | concat_uo = tf.matmul(tf.expand_dims(x,0),U) + b 359 | u,o = tf.split(1,2,concat_uo) 360 | o=tf.nn.sigmoid(o) 361 | u=tf.nn.tanh(u) 362 | 363 | c = u#tf.squeeze(u) 364 | h = o * tf.nn.tanh(c) 365 | 366 | 367 | hc = tf.concat(1,[h,c]) 368 | hc=tf.squeeze(hc) 369 | return hc 370 | 371 | hc = tf.map_fn(_recurseleaf,emb) 372 | return hc 373 | 374 | 375 | def compute_states(self,emb,idx_batch=0): 376 | 377 | #if num_leaves is None: 378 | #num_leaves = self.n_nodes - self.n_inodes 379 | num_leaves = tf.squeeze(tf.gather(self.num_leaves,idx_batch)) 380 | #num_leaves=tf.Print(num_leaves,[num_leaves]) 381 | n_inodes = tf.gather(self.n_inodes,idx_batch) 382 | #embx=tf.gather(emb,tf.range(num_leaves)) 383 | emb_tree=tf.gather(emb,idx_batch) 384 | emb_leaf=tf.gather(emb_tree,tf.range(num_leaves)) 385 | #treestr=self.treestr#tf.gather(self.treestr,tf.range(self.n_inodes)) 386 | treestr=tf.gather(tf.gather(self.treestr,idx_batch),tf.range(n_inodes)) 387 | leaf_hc = self.process_leafs(emb_leaf) 388 | leaf_h,leaf_c=tf.split(1,2,leaf_hc) 389 | 390 | node_h=tf.identity(leaf_h) 391 | node_c=tf.identity(leaf_c) 392 | 393 | idx_var=tf.constant(0) #tf.Variable(0,trainable=False) 394 | 395 | with tf.variable_scope("Composition",reuse=True): 396 | 397 | cUW = tf.get_variable("cUW",[self.emb_dim+self.hidden_dim,4*self.hidden_dim]) 398 | cb = tf.get_variable("cb",[4*self.hidden_dim]) 399 | bu,bo,bi,bf=tf.split(0,4,cb) 400 | 401 | UW = tf.slice(cUW,[0,0],[-1,3*self.hidden_dim]) 402 | 403 | U_fW_f=tf.slice(cUW,[0,3*self.hidden_dim],[-1,-1]) 404 | 405 | def _recurrence(emb_tree,node_h,node_c,idx_var): 406 | node_x=tf.gather(emb_tree,num_leaves+idx_var) 407 | #node_x=tf.zeros([self.emb_dim]) 408 | node_info=tf.gather(treestr,idx_var) 409 | 410 | child_h=tf.gather(node_h,node_info) 411 | child_c=tf.gather(node_c,node_info) 412 | 413 | concat_xh=tf.concat(0,[node_x,tf.reduce_sum(node_h,[0])]) 414 | 415 | tmp=tf.matmul(tf.expand_dims(concat_xh,0),UW) 416 | u,o,i=tf.split(1,3,tmp) 417 | #node_x=tf.Print(node_x,[tf.shape(node_x),node_x.get_shape()]) 418 | hl,hr=tf.split(0,2,child_h) 419 | x_hl=tf.concat(0,[node_x,tf.squeeze(hl)]) 420 | x_hr=tf.concat(0,[node_x,tf.squeeze(hr)]) 421 | fl=tf.matmul(tf.expand_dims(x_hl,0),U_fW_f) 422 | fr=tf.matmul(tf.expand_dims(x_hr,0),U_fW_f) 423 | 424 | i=tf.nn.sigmoid(i+bi) 425 | o=tf.nn.sigmoid(o+bo) 426 | u=tf.nn.tanh(u+bu) 427 | fl=tf.nn.sigmoid(fl+bf) 428 | fr=tf.nn.sigmoid(fr+bf) 429 | 430 | f=tf.concat(0,[fl,fr]) 431 | c = i * u + tf.reduce_sum(f*child_c,[0]) 432 | h = o * tf.nn.tanh(c) 433 | 434 | node_h = tf.concat(0,[node_h,h]) 435 | 436 | node_c = tf.concat(0,[node_c,c]) 437 | 438 | idx_var=tf.add(idx_var,1) 439 | 440 | return emb_tree,node_h,node_c,idx_var 441 | loop_cond = lambda a1,b1,c1,idx_var: tf.less(idx_var,n_inodes) 442 | 443 | loop_vars=[emb_tree,node_h,node_c,idx_var] 444 | emb_tree,node_h,node_c,idx_var=tf.while_loop(loop_cond, _recurrence, loop_vars, parallel_iterations=1) 445 | 446 | return node_h 447 | -------------------------------------------------------------------------------- /tf_treenode.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class tNode(object): 4 | def __init__(self,idx=-1,word=None): 5 | self.left = None 6 | self.right = None 7 | self.word = word 8 | self.size = 0 9 | self.height = 1 10 | self.parent = None 11 | self.label = None 12 | self.children = [] 13 | self.idx=idx 14 | self.span=None 15 | 16 | def add_parent(self,parent): 17 | self.parent=parent 18 | def add_child(self,node): 19 | assert len(self.children) < 2 20 | self.children.append(node) 21 | def add_children(self,children): 22 | self.children.extend(children) 23 | 24 | def get_left(self): 25 | left = None 26 | if self.children: 27 | left=self.children[0] 28 | return left 29 | def get_right(self): 30 | right = None 31 | if len(self.children) == 2: 32 | right=self.children[1] 33 | return right 34 | 35 | @staticmethod 36 | def get_height(root): 37 | if root.children: 38 | root.height = max(root.get_left().height,root.get_right().height)+1 39 | else: 40 | root.height=1 41 | print root.idx,root.height,'asa' 42 | 43 | @staticmethod 44 | def get_size(root): 45 | if root.children: 46 | root.size = root.get_left().size+root.get_right().size+1 47 | else: 48 | root.size=1 49 | 50 | @staticmethod 51 | def get_spans(root): 52 | if root.children: 53 | root.span=root.get_left().span+root.get_right().span 54 | else: 55 | root.span=[root.word] 56 | 57 | @staticmethod 58 | def get_numleaves(self): 59 | if self.children: 60 | self.num_leaves=self.get_left().num_leaves+self.get_right().num_leaves 61 | else: 62 | self.num_leaves=1 63 | 64 | @staticmethod 65 | def postOrder(root,func=None,args=None): 66 | 67 | if root is None: 68 | return 69 | tNode.postOrder(root.get_left(),func,args) 70 | tNode.postOrder(root.get_right(),func,args) 71 | 72 | if args is not None: 73 | func(root,args) 74 | else: 75 | func(root) 76 | 77 | @staticmethod 78 | def encodetokens(root,func): 79 | if root is None: 80 | return 81 | if root.word is None: 82 | return 83 | else: 84 | root.word=func(root.word) 85 | 86 | @staticmethod 87 | def relabel(root,fine_grained): 88 | if root is None: 89 | return 90 | if root.label is not None: 91 | if fine_grained: 92 | root.label += 2 93 | else: 94 | if root.label < 0: 95 | root.label = 0 96 | elif root.label == 0: 97 | root.label = 1 98 | else: 99 | root.label = 2 100 | 101 | 102 | def processTree(root,funclist=None,argslist=None): 103 | if funclist is None: 104 | root.postOrder(root,root.get_height) 105 | root.postOrder(root,root.get_num_leaves) 106 | root.postOrder(root,root.get_size) 107 | else: 108 | #print funclist,argslist 109 | for func,args in zip(funclist,argslist): 110 | root.postOrder(root,func,args) 111 | 112 | return root 113 | 114 | def test_tNode(): 115 | 116 | nodes={} 117 | for i in range(7): 118 | nodes[i]=tNode(i) 119 | if i < 4:nodes[i].word=i+10 120 | nodes[0].parent = nodes[1].parent = nodes[4] 121 | nodes[2].parent = nodes[3].parent = nodes[5] 122 | nodes[5].parent = nodes[6].parent = nodes[6] 123 | nodes[6].add_child(nodes[4]) 124 | nodes[6].add_child(nodes[5]) 125 | nodes[4].add_children([nodes[0],nodes[1]]) 126 | nodes[5].add_children([nodes[2],nodes[3]]) 127 | root=nodes[6] 128 | postOrder=root.postOrder 129 | postOrder(root,tNode.get_height,None) 130 | postOrder(root,tNode.get_numleaves,None) 131 | postOrder(root,root.get_spans,None) 132 | print root.height,root.num_leaves 133 | for n in nodes.itervalues():print n.span 134 | 135 | if __name__=='__main__': 136 | test_tNode() 137 | 138 | 139 | --------------------------------------------------------------------------------