├── README.md ├── aclImdb └── download.sh ├── main.py ├── attention.py ├── config.py ├── model.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # tf-hierarchical-rnn 2 | 3 | Document claddification use hierarchical rnn 4 | -------------------------------------------------------------------------------- /aclImdb/download.sh: -------------------------------------------------------------------------------- 1 | wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz 2 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from config import get_args 2 | from model import hier_rnn 3 | 4 | def main(args): 5 | network=hier_rnn(args) 6 | network.train() 7 | if __name__=='__main__': 8 | args=get_args() 9 | main(args) -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def attention(inputs,attention_size): 4 | inputs=tf.concat(inputs,2) 5 | 6 | v=tf.layers.dense(inputs,attention_size,activation=tf.nn.tanh) 7 | vu=tf.layers.dense(v,1,use_bias=False) 8 | alphas=tf.nn.softmax(vu) 9 | 10 | output=tf.reduce_mean(alphas*inputs,axis=1) 11 | return output -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser=argparse.ArgumentParser() 5 | parser.add_argument('--batch_size',default=8,type=int) 6 | parser.add_argument('--attention_size',default=128,type=int) 7 | parser.add_argument('--hidden_layers',default=3,type=int) 8 | parser.add_argument('--hidden_units',default=128,type=int) 9 | return parser.parse_args() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import tensorflow as tf 3 | from attention import attention 4 | class hier_rnn(): 5 | def __init__(self,args): 6 | self.args=args 7 | self.sentence=tf.placeholder(tf.int32,[self.args.batch_size,None,None]) 8 | self.target=tf.placeholder(tf.int64,[self.args.batch_size]) 9 | self.seq_len=tf.placeholder(tf.int32,[None]) 10 | self.max_len=tf.placeholder(tf.int32,shape=()) 11 | 12 | def word_embedding(self,input): 13 | def cell(): 14 | return tf.nn.rnn_cell.GRUCell(128) 15 | 16 | cell_bw=cell_fw=tf.nn.rnn_cell.MultiRNNCell([cell() for _ in range(3)]) 17 | outputs,_=tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,input, 18 | sequence_length=self.seq_len,dtype=tf.float32, 19 | scope='word_embedding') 20 | return attention(outputs,128) 21 | 22 | def sentence_embedding(self,input): 23 | def cell(): 24 | return tf.nn.rnn_cell.GRUCell(128) 25 | 26 | cell_bw=cell_fw=tf.nn.rnn_cell.MultiRNNCell([cell() for _ in range(3)]) 27 | cell_fw_initial=cell_fw.zero_state(self.args.batch_size,tf.float32) 28 | cell_bw_initial=cell_bw.zero_state(self.args.batch_size,tf.float32) 29 | outputs,_=tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,input, 30 | initial_state_fw=cell_fw_initial, 31 | initial_state_bw=cell_bw_initial, 32 | scope='sentence_embedding') 33 | return attention(outputs,128) 34 | 35 | def forward(self): 36 | # time_step=self.sentence.shape[2].value 37 | sen_in=tf.reshape(self.sentence,[self.args.batch_size*self.max_len,-1]) 38 | with tf.device("/cpu:0"): 39 | embedding=tf.get_variable('embedding',shape=[89526,256]) 40 | inputs=tf.nn.embedding_lookup(embedding,sen_in) 41 | word_embedding=self.word_embedding(inputs) 42 | word_embedding=tf.reshape(word_embedding,[self.args.batch_size,-1,256]) 43 | sen_embedding=self.sentence_embedding(word_embedding) 44 | logits=tf.layers.dense(sen_embedding,2) 45 | cross_entropy=tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 46 | labels=self.target,logits=logits)) 47 | optimizer=tf.train.AdamOptimizer().minimize(cross_entropy) 48 | correct=tf.equal(self.target,tf.argmax(tf.nn.softmax(logits),axis=1)) 49 | accuracy=tf.reduce_mean(tf.cast(correct,tf.float32)) 50 | return cross_entropy,optimizer,accuracy 51 | 52 | def train(self): 53 | cross_entropy,optimizer,accuracy=self.forward() 54 | with tf.Session() as sess: 55 | sess.run(tf.global_variables_initializer()) 56 | for epoch in range(10): 57 | x_batch,y_batch,seq_len,max_len=next_batch(self.args.batch_size) 58 | for step in range(len(x_batch)): 59 | # print(y_batch[step]) 60 | # print(y_batch[step]) 61 | loss,_,acc=sess.run([cross_entropy,optimizer,accuracy], 62 | feed_dict={self.sentence:x_batch[step], 63 | self.target:y_batch[step], 64 | self.seq_len:seq_len[step], 65 | self.max_len:max_len[step]}) 66 | if step%10==0: 67 | print("Epoch %d,Step %d,loss is %f"%(epoch,step,loss)) 68 | print("Epoch %d,Step %d,accuracy is %f"%(epoch,step,acc)) 69 | x_batch,y_batch,seq_len,max_len=next_batch(self.args.batch_size,mode='test') 70 | test_accuracy=0 71 | for step in range(len(x_batch)): 72 | acc=sess.run(accuracy,feed_dict={self.sentence:x_batch[step], 73 | self.target:y_batch[step], 74 | self.seq_len:seq_len[step], 75 | self.max_len:max_len[step]}) 76 | test_accuracy+=acc 77 | print('test accuracy is %f'%(test_accuracy/len(x_batch))) 78 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nltk 3 | import random 4 | import collections 5 | import numpy as np 6 | from nltk.tokenize import WordPunctTokenizer 7 | 8 | def load_file(folder): 9 | pos_folder=os.path.join(folder,'pos') 10 | neg_folder=os.path.join(folder,'neg') 11 | sentences=[] 12 | labels=[] 13 | for name in os.listdir(pos_folder): 14 | filename=os.path.join(pos_folder,name) 15 | with open(filename) as f: 16 | sentence=f.read() 17 | sentences.append(sentence) 18 | labels.append(0) 19 | for name in os.listdir(neg_folder): 20 | filename=os.path.join(neg_folder,name) 21 | with open(filename) as f: 22 | sentence=f.readline() 23 | sentences.append(sentence) 24 | labels.append(1) 25 | return np.array(sentences),np.array(labels) 26 | 27 | def build_dict(): 28 | word_num={} 29 | num_word={} 30 | with open('aclImdb/imdb.vocab') as f: 31 | for idx,line in enumerate(f): 32 | word_num[line[:-1]]=idx 33 | num_word[idx]=line[:-1] 34 | # sentences,_=load_file('aclImdb/imdb.vocab') 35 | # words=[] 36 | # for sentence in sentences: 37 | # for word in WordPunctTokenizer().tokenize(sentence): 38 | # words.append(word) 39 | # # print(words) 40 | # word_time=collections.Counter(words) 41 | # for idx,word in enumerate(word_time.keys()): 42 | # word_num[word]=idx 43 | # num_word[idx]=word 44 | # print(len(word_num.keys())) 45 | # print(idx) 46 | return word_num,num_word 47 | 48 | def split_sentences(sentences): 49 | sentence=[] 50 | tokenizer=nltk.data.load('tokenizers/punkt/english.pickle') 51 | sentence.append(tokenizer.tokenize(sentences)) 52 | return sentence 53 | 54 | def word2vector(sentence,word_num): 55 | vector=[] 56 | for word in WordPunctTokenizer().tokenize(sentence): 57 | try: 58 | vector.append(word_num[word]) 59 | except: 60 | vector.append(0) 61 | return vector,len(vector) 62 | 63 | def prepare_data(sentences,word_num): 64 | tokenizer=nltk.data.load('tokenizers/punkt/english.pickle') 65 | each_sentences=[] 66 | for sentence in sentences: 67 | each_sentences.append(tokenizer.tokenize(sentence)) 68 | 69 | vectors=[] 70 | word_index=[] 71 | for sen in each_sentences: 72 | each_vector=[] 73 | for word_sen in sen: 74 | vec,length=word2vector(word_sen,word_num) 75 | # print(length) 76 | if length<60: 77 | each_vector.append(vec) 78 | word_index.append(length) 79 | else: 80 | each_vector.append([0]) 81 | word_index.append(1) 82 | vectors.append(each_vector) 83 | 84 | sen_index=[len(sen) for sen in each_sentences] 85 | max_sen_len=np.max(sen_index) 86 | 87 | max_word_len=np.max(word_index) 88 | 89 | # print(max_word_len) 90 | seq_len=np.zeros([len(sentences),max_sen_len]) 91 | data=np.zeros((len(sentences),max_sen_len,max_word_len)) 92 | for i,vector in enumerate(vectors): 93 | for j,each_vector in enumerate(vector): 94 | 95 | seq_len[i,j]=len(each_vector) 96 | data[i,j,:len(each_vector)]=each_vector 97 | 98 | return data,np.reshape(seq_len,-1),max_sen_len 99 | 100 | 101 | def next_batch(batch_size,mode='train'): 102 | if mode=='train': 103 | sentences,labels=load_file('aclImdb/train') 104 | # print(labels) 105 | else: 106 | sentences,labels=load_file('aclImdb/test') 107 | length=len(sentences) 108 | idx=np.arange(0,length) 109 | np.random.shuffle(idx) 110 | sentences=sentences[idx] 111 | labels=labels[idx] 112 | word_num,_=build_dict() 113 | # print(len(sen_label)) 114 | # for i in range(len(sen_label)): 115 | # sentences,labels=sen_label[i] 116 | # print(labels) 117 | # print(len(sentences)) 118 | sentence_batch=[] 119 | label_batch=[] 120 | seq_len=[] 121 | max_len=[] 122 | length=len(sentences) 123 | start_index=0 124 | while(1): 125 | end_index = start_index + batch_size 126 | if end_index>=length: 127 | break 128 | sen,sen_num,sen_len=prepare_data(sentences[start_index:end_index],word_num) 129 | seq_len.append(sen_num) 130 | max_len.append(sen_len) 131 | sentence_batch.append(sen) 132 | label_batch.append(labels[start_index:end_index]) 133 | start_index=end_index 134 | return sentence_batch,label_batch,seq_len,max_len 135 | 136 | # sen,label,seq_len=next_batch(128) 137 | # print(sen[1].shape) 138 | # print(label) 139 | # print(seq_len) --------------------------------------------------------------------------------