├── ID_LSTM ├── parser.py ├── actor.py ├── datamanager.py ├── test.py ├── LSTM_critic.py └── main.py ├── README.md └── HS_LSTM ├── parser.py ├── actor.py ├── test.py ├── datamanager.py ├── main.py └── LSTM_critic.py /ID_LSTM/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | class Parser(object): 4 | def getParser(self): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--name', type=str, default='test') 7 | parser.add_argument('--fasttest', type=int, default=0, choices=[0, 1]) 8 | parser.add_argument('--seed', type=int, default=int(1000*time.time())) 9 | parser.add_argument('--dataset', type=str, default='../TrainData/MR') 10 | parser.add_argument('--maxlenth', type=int, default=70) 11 | parser.add_argument('--grained', type=int, default=2) 12 | parser.add_argument('--optimizer', type=str, default='Adam', \ 13 | choices=['SGD', 'Adagrad', 'Adadelta', 'Adam', 'Nadam']) 14 | parser.add_argument('--lr', type=float, default=0.0005) 15 | parser.add_argument('--epoch', type=int, default=5) 16 | parser.add_argument('--batchsize', type=int, default=5) 17 | parser.add_argument('--word_vector', type=str, default='../WordVector/vector.300dim') 18 | parser.add_argument('--dim', type=int, default=300) 19 | parser.add_argument('--tau', type=float, default=0.1) 20 | parser.add_argument('--dropout', type=float, default=0.5) 21 | parser.add_argument('--alpha', type=float, default=0.1) 22 | parser.add_argument('--epsilon', type=float, default=0.05) 23 | parser.add_argument('--sample_cnt', type=int ,default=5) 24 | parser.add_argument('--LSTMpretrain', type=str, default='') 25 | parser.add_argument('--RLpretrain', type=str, default='') 26 | return parser 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Structured Representation for Text Classification via Reinforcement Learning 2 | Tianyang Zhang*, Minlie Huang, Li Zhao 3 | 4 | Representation learning is a fundamental problem in natural language processing. This paper studies how to learn a structured representation for text classification. Unlike most existing representation models that either use no structure or rely on pre-specified structures, we propose a reinforcement learning (RL) method to learn sentence representation by discovering optimized structures automatically. We demonstrate two attempts to build structured representation: Information Distilled LSTM (ID-LSTM) and Hierarchically Structured LSTM (HS-LSTM). ID-LSTM selects only important, task-relevant words, and HS-LSTM discovers phrase structures in a sentence. Structure discovery in the two representation models is formulated as a sequential decision problem: current decision of structure discovery affects following decisions, which can be addressed by policy gradient RL. Results show that our method can learn task-friendly representations by identifying important words or task-relevant structures without explicit structure annotations, and thus yields competitive performance. 5 | 6 | @inproceedings{zhang2018learning, 7 | 8 | title={Learning Structured Representation for Text Classification via Reinforcement Learning}, 9 | 10 | author={Zhang, Tianyang and Huang, Minlie and Zhao, Li}, 11 | 12 | booktitle={AAAI}, 13 | 14 | year={2018} 15 | 16 | } 17 | 18 | AGnews dataset used in the experiment: 19 | https://drive.google.com/open?id=1becf7pzfuLL7qgWqv4q-TyDYjSzodWfR 20 | -------------------------------------------------------------------------------- /HS_LSTM/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | class Parser(object): 4 | def getParser(self): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--name', type=str, default='test') 7 | parser.add_argument('--fasttest', type=int, default=0, choices=[0, 1]) 8 | parser.add_argument('--seed', type=int, default=int(1000*time.time())) 9 | parser.add_argument('--dataset', type=str, default='../TrainData/MR') 10 | parser.add_argument('--maxlenth', type=int, default=70) 11 | parser.add_argument('--grained', type=int, default=2) 12 | parser.add_argument('--optimizer', type=str, default='Adam', \ 13 | choices=['SGD', 'Adagrad', 'Adadelta', 'Adam', 'Nadam']) 14 | parser.add_argument('--lr', type=float, default=0.0005) 15 | parser.add_argument('--epoch', type=int, default=5) 16 | parser.add_argument('--batchsize', type=int, default=5) 17 | parser.add_argument('--samplecnt', type=int, default=5) 18 | parser.add_argument('--attention', type=int, default=0, choices=[0, 1]) 19 | parser.add_argument('--word_vector', type=str, default='../WordVector/vector.300dim') 20 | parser.add_argument('--dim', type=int, default=300) 21 | parser.add_argument('--tau', type=float, default=0.1) 22 | parser.add_argument('--dropout', type=float, default=0.5) 23 | parser.add_argument('--alpha', type=float, default=0.1) 24 | parser.add_argument('--LSTMpretrain', type=str, default='') 25 | parser.add_argument('--RLpretrain', type=str, default='') 26 | parser.add_argument('--pretype', type=str, default='N1', choices=['N1', 'ALL1', 'ALL0', 'RANDOM']) 27 | return parser 28 | -------------------------------------------------------------------------------- /HS_LSTM/actor.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import tflearn 4 | 5 | class ActorNetwork(object): 6 | """ 7 | action network 8 | use the state 9 | sample the action 10 | """ 11 | 12 | def __init__(self, sess, dim, optimizer, learning_rate, tau, num_critic_vars): 13 | self.global_step = tf.Variable(0, trainable=False, name="ActorStep") 14 | self.sess = sess 15 | self.dim = dim 16 | self.learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 10000, 0.95, staircase=True) 17 | self.tau = tau 18 | if optimizer == 'Adam': 19 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate) 20 | elif optimizer == 'Adagrad': 21 | self.optimizer = tf.train.AdagradOptimizer(self.learning_rate) 22 | elif optimizer == 'Adadelta': 23 | self.optimizer = tf.train.AdadeltaOptimizer(self.learning_rate) 24 | print self.optimizer 25 | self.num_other_variables = len(tf.trainable_variables()) 26 | 27 | #actor network(updating) 28 | self.input_l, self.input_d, self.scaled_out = self.create_actor_network() 29 | self.network_params = tf.trainable_variables()[self.num_other_variables:] 30 | 31 | #actor network(delayed updating) 32 | self.target_input_l, self.target_input_d, self.target_scaled_out = self.create_actor_network() 33 | self.target_network_params = tf.trainable_variables()[self.num_other_variables + len(self.network_params):] 34 | 35 | #delayed updaing actor network 36 | self.update_target_network_params = \ 37 | [self.target_network_params[i].assign(\ 38 | tf.multiply(self.network_params[i], self.tau) +\ 39 | tf.multiply(self.target_network_params[i], 1 - self.tau))\ 40 | for i in range(len(self.target_network_params))] 41 | 42 | self.assign_active_network_params = \ 43 | [self.network_params[i].assign(\ 44 | self.target_network_params[i]) for i in range(len(self.network_params))] 45 | 46 | #gradient provided by critic network 47 | self.action_gradient = tf.placeholder(tf.float32, [2]) 48 | self.log_target_scaled_out = tf.log(self.target_scaled_out) 49 | 50 | self.actor_gradients = tf.gradients(self.log_target_scaled_out, self.target_network_params, self.action_gradient) 51 | 52 | self.grads = [tf.placeholder(tf.float32, [600, 1]), 53 | tf.placeholder(tf.float32, [1,]), 54 | tf.placeholder(tf.float32, [600, 1])] 55 | self.optimize = self.optimizer.apply_gradients(zip(self.grads, self.network_params[:-1]), global_step=self.global_step) 56 | 57 | def create_actor_network(self): 58 | input_l = tf.placeholder(tf.float32, shape=[1, self.dim*2]) 59 | input_d = tf.placeholder(tf.float32, shape=[1, self.dim*2]) 60 | 61 | t1 = tflearn.fully_connected(input_l, 1) 62 | t2 = tflearn.fully_connected(input_d, 1) 63 | 64 | scaled_out = tflearn.activation(\ 65 | tf.matmul(input_l,t1.W) + tf.matmul(input_d,t2.W) + t1.b,\ 66 | activation = 'sigmoid') 67 | 68 | scaled_out = tf.stack([1.0 - scaled_out[0][0], scaled_out[0][0]]) 69 | return input_l, input_d, scaled_out 70 | 71 | def train(self, grad): 72 | self.sess.run(self.optimize, feed_dict={ 73 | self.grads[0]: grad[0], 74 | self.grads[1]: grad[1], 75 | self.grads[2]: grad[2]}) 76 | 77 | def predict_target(self, input_l, input_d): 78 | return self.sess.run(self.target_scaled_out, feed_dict={ 79 | self.target_input_l: input_l, 80 | self.target_input_d: input_d}) 81 | 82 | def get_gradient(self, input_l, input_d, a_gradient): 83 | return self.sess.run(self.actor_gradients[:-1], feed_dict={ 84 | self.target_input_l: input_l, 85 | self.target_input_d: input_d, 86 | self.action_gradient: a_gradient}) 87 | 88 | def update_target_network(self): 89 | self.sess.run(self.update_target_network_params) 90 | 91 | def assign_active_network(self): 92 | self.sess.run(self.assign_active_network_params) 93 | -------------------------------------------------------------------------------- /ID_LSTM/actor.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import tflearn 4 | 5 | class ActorNetwork(object): 6 | """ 7 | action network 8 | use the state 9 | sample the action 10 | """ 11 | 12 | def __init__(self, sess, dim, optimizer, learning_rate, tau): 13 | self.global_step = tf.Variable(0, trainable=False, name="ActorStep") 14 | self.sess = sess 15 | self.dim = dim 16 | self.learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 10000, 0.95, staircase=True) 17 | self.tau = tau 18 | if optimizer == 'Adam': 19 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate) 20 | elif optimizer == 'Adagrad': 21 | self.optimizer = tf.train.AdagradOptimizer(self.learning_rate) 22 | elif optimizer == 'Adadelta': 23 | self.optimizer = tf.train.AdadeltaOptimizer(self.learning_rate) 24 | self.num_other_variables = len(tf.trainable_variables()) 25 | #actor network(updating) 26 | self.input_l, self.input_d, self.scaled_out = self.create_actor_network() 27 | self.network_params = tf.trainable_variables()[self.num_other_variables:] 28 | 29 | #actor network(delayed updating) 30 | self.target_input_l, self.target_input_d, self.target_scaled_out = self.create_actor_network() 31 | self.target_network_params = tf.trainable_variables()[self.num_other_variables + len(self.network_params):] 32 | 33 | #delayed updaing actor network 34 | self.update_target_network_params = \ 35 | [self.target_network_params[i].assign(\ 36 | tf.multiply(self.network_params[i], self.tau) +\ 37 | tf.multiply(self.target_network_params[i], 1 - self.tau))\ 38 | for i in range(len(self.target_network_params))] 39 | 40 | self.assign_active_network_params = \ 41 | [self.network_params[i].assign(\ 42 | self.target_network_params[i]) for i in range(len(self.network_params))] 43 | 44 | #gradient provided by critic network 45 | self.action_gradient = tf.placeholder(tf.float32, [2]) 46 | self.log_target_scaled_out = tf.log(self.target_scaled_out) 47 | 48 | self.actor_gradients = tf.gradients(self.log_target_scaled_out, self.target_network_params, self.action_gradient) 49 | print self.actor_gradients 50 | 51 | self.grads = [tf.placeholder(tf.float32, [600,1]), 52 | tf.placeholder(tf.float32, [1,]), 53 | tf.placeholder(tf.float32, [300, 1])] 54 | self.optimize = self.optimizer.apply_gradients(zip(self.grads, self.network_params[:-1]), global_step=self.global_step) 55 | 56 | def create_actor_network(self): 57 | input_l = tf.placeholder(tf.float32, shape=[1, self.dim*2]) 58 | input_d = tf.placeholder(tf.float32, shape=[1, self.dim]) 59 | 60 | t1 = tflearn.fully_connected(input_l, 1) 61 | t2 = tflearn.fully_connected(input_d, 1) 62 | 63 | scaled_out = tflearn.activation(\ 64 | tf.matmul(input_l,t1.W) + tf.matmul(input_d,t2.W) + t1.b,\ 65 | activation = 'sigmoid') 66 | 67 | s_out = tf.clip_by_value(scaled_out[0][0], 1e-5, 1 - 1e-5) 68 | 69 | scaled_out = tf.stack([1.0 - s_out, s_out]) 70 | return input_l, input_d, scaled_out 71 | 72 | def train(self, grad): 73 | self.sess.run(self.optimize, feed_dict={ 74 | self.grads[0]: grad[0], 75 | self.grads[1]: grad[1], 76 | self.grads[2]: grad[2]}) 77 | 78 | def predict_target(self, input_l, input_d): 79 | return self.sess.run(self.target_scaled_out, feed_dict={ 80 | self.target_input_l: input_l, 81 | self.target_input_d: input_d}) 82 | 83 | def get_gradient(self, input_l, input_d, a_gradient): 84 | return self.sess.run(self.actor_gradients[:-1], feed_dict={ 85 | self.target_input_l: input_l, 86 | self.target_input_d: input_d, 87 | self.action_gradient: a_gradient}) 88 | 89 | def update_target_network(self): 90 | self.sess.run(self.update_target_network_params) 91 | 92 | def assign_active_network(self): 93 | self.sess.run(self.assign_active_network_params) 94 | -------------------------------------------------------------------------------- /ID_LSTM/datamanager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import json, random 4 | 5 | class DataManager(object): 6 | def __init__(self, dataset): 7 | ''' 8 | Read the data from dir "dataset" 9 | ''' 10 | self.origin = {} 11 | for fname in ['train', 'dev', 'test']: 12 | data = [] 13 | for line in open('%s/%s.res' % (dataset, fname)): 14 | s = json.loads(line.strip()) 15 | if len(s) > 0: 16 | data.append(s) 17 | self.origin[fname] = data 18 | def getword(self): 19 | ''' 20 | Get the words that appear in the data. 21 | Sorted by the times it appears. 22 | {'ok': 1, 'how': 2, ...} 23 | Never run this function twice. 24 | ''' 25 | wordcount = {} 26 | def dfs(node): 27 | if node.has_key('children'): 28 | dfs(node['children'][0]) 29 | dfs(node['children'][1]) 30 | else: 31 | word = node['word'].lower() 32 | wordcount[word] = wordcount.get(word, 0) + 1 33 | for fname in ['train', 'dev', 'test']: 34 | for sent in self.origin[fname]: 35 | dfs(sent) 36 | words = wordcount.items() 37 | words.sort(key = lambda x : x[1], reverse = True) 38 | self.words = words 39 | self.wordlist = {item[0]: index+1 for index, item in enumerate(words)} 40 | return self.wordlist 41 | 42 | def getdata(self, grained, maxlenth): 43 | ''' 44 | Get all the data, divided into (train,dev,test). 45 | For every sentence, {'words':[1,3,5,...], 'solution': [0,1,0,0,0]} 46 | For each data, [sentence1, sentence2, ...] 47 | Never run this function twice. 48 | ''' 49 | def one_hot_vector(r): 50 | s = np.zeros(grained, dtype=np.float32) 51 | s[r] += 1.0 52 | return s 53 | def dfs(node, words): 54 | if node.has_key('children'): 55 | dfs(node['children'][0], words) 56 | dfs(node['children'][1], words) 57 | else: 58 | word = self.wordlist[node['word'].lower()] 59 | words.append(word) 60 | self.getword() 61 | self.data = {} 62 | for fname in ['train', 'dev', 'test']: 63 | self.data[fname] = [] 64 | for sent in self.origin[fname]: 65 | words = [] 66 | dfs(sent, words) 67 | lens = len(words) 68 | if maxlenth < lens: 69 | print lens 70 | words += [0] * (maxlenth - lens) 71 | solution = one_hot_vector(int(sent['rating'])) 72 | now = {'words': np.array(words), \ 73 | 'solution': solution,\ 74 | 'lenth': lens} 75 | self.data[fname].append(now) 76 | return self.data['train'], self.data['dev'], self.data['test'] 77 | 78 | def get_wordvector(self, name): 79 | fr = open(name) 80 | n, dim = map(int, fr.readline().split()) 81 | self.wv = {} 82 | for i in range(n - 1): 83 | vec = fr.readline().split() 84 | word = vec[0].lower() 85 | vec = map(float, vec[1:]) 86 | if self.wordlist.has_key(word): 87 | self.wv[self.wordlist[word]] = vec 88 | self.wordvector = [] 89 | losscnt = 0 90 | for i in range(len(self.wordlist) + 1): 91 | if self.wv.has_key(i): 92 | self.wordvector.append(self.wv[i]) 93 | else: 94 | losscnt += 1 95 | self.wordvector.append(np.random.uniform(-0.1,0.1,[dim])) 96 | self.wordvector = np.array(self.wordvector, dtype=np.float32) 97 | print losscnt, "words not find in wordvector" 98 | print len(self.wordvector), "words in total" 99 | return self.wordvector 100 | 101 | #datamanager = DataManager("../TrainData/MR") 102 | #train_data, test_data, dev_data = datamanager.getdata(2, 200) 103 | #wv = datamanager.get_wordvector("../WordVector/vector.25dim") 104 | #mxlen = 0 105 | #for item in train_data: 106 | # print item['lenth'] 107 | # if item['lenth'] > mxlen: 108 | # mxlen =item['lenth'] 109 | #print mxlen 110 | -------------------------------------------------------------------------------- /HS_LSTM/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import random 4 | import sys, os 5 | import json 6 | import argparse 7 | from parser import Parser 8 | from datamanager import DataManager 9 | from actor import ActorNetwork 10 | from LSTM_critic import LSTM_CriticNetwork 11 | tf.logging.set_verbosity(tf.logging.ERROR) 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | #get parse 14 | argv = sys.argv[1:] 15 | parser = Parser().getParser() 16 | args, _ = parser.parse_known_args(argv) 17 | random.seed(args.seed) 18 | 19 | #get data 20 | dataManager = DataManager(args.dataset) 21 | train_data, dev_data, test_data = dataManager.getdata(args.grained, args.maxlenth) 22 | word_vector = dataManager.get_wordvector(args.word_vector) 23 | 24 | print "train_data ", len(train_data) 25 | print "dev_data", len(dev_data) 26 | print "test_data", len(test_data) 27 | if args.fasttest == 1: 28 | train_data = train_data[:100] 29 | dev_data = dev_data[:20] 30 | test_data = test_data[:20] 31 | 32 | def sampling_RL(sess, actor, inputs, lenth, Random=True): 33 | current_lower_state = np.zeros((1, state_size), dtype=np.float32) 34 | current_upper_state = np.zeros((1, state_size), dtype=np.float32) 35 | actions = [] 36 | states = [] 37 | #sampling actions 38 | 39 | for pos in range(lenth): 40 | out_d, current_lower_state = critic.lower_LSTM_target(current_lower_state, [[inputs[pos]]]) 41 | predicted = actor.predict_target(current_upper_state, current_lower_state) 42 | #print predicted 43 | states.append([current_upper_state, current_lower_state]) 44 | if Random: 45 | action = (0 if random.random() < predicted[0] else 1) 46 | else: 47 | action = np.argmax(predicted) 48 | actions.append(action) 49 | if action == 1: 50 | current_upper_state = critic.upper_LSTM_target(current_upper_state, out_d) 51 | current_lower_state = np.zeros_like(current_lower_state) 52 | 53 | #pad zeros 54 | actions += [0] * (args.maxlenth - lenth) 55 | actions[lenth-1] = 1 56 | #get the position of action 1 57 | action_pos = [] 58 | for (i, j) in enumerate(actions): 59 | if j == 1: 60 | action_pos.append(i) 61 | return actions, states, action_pos 62 | 63 | 64 | def test(sess, actor, critic, test_data, Random=False): 65 | acc = 0 66 | total_lenth = 0 67 | total_phrase_count = 0 68 | for i in range(len(test_data)): 69 | #prepare 70 | data = test_data[i] 71 | inputs, solution, lenth, paction = data['words'], data['solution'], data['lenth'], data['action'] 72 | #get sampling 73 | if Random == False: 74 | actions, _, action_pos = sampling_RL(sess, actor, inputs, lenth, Random=False) 75 | else: 76 | actions, action_pos = sampling_random(lenth, postag, paction) 77 | 78 | #predict 79 | out = critic.predict_target([inputs], [actions], [action_pos], [lenth], [len(action_pos)]) 80 | if np.argmax(out) == np.argmax(solution): 81 | acc += 1 82 | 83 | print json.dumps(actions[:lenth]) 84 | print json.dumps([dataManager.words[i-1][0] for i in inputs][:lenth]) 85 | #print out, solution 86 | 87 | total_lenth += lenth 88 | total_phrase_count += len(action_pos) 89 | 90 | avelenth = float(total_lenth) / float(len(test_data)) 91 | avephrase= float(total_phrase_count) / float(len(test_data)) 92 | avephraselenth = avelenth / avephrase 93 | 94 | #print "average length :", avelenth 95 | #print "average phrase number :", avephrase 96 | #print "average phrase length :", avephraselenth 97 | 98 | return float(acc) / len(test_data) 99 | 100 | config = tf.ConfigProto() 101 | config.gpu_options.allow_growth = True 102 | with tf.Session(config = config) as sess: 103 | #model 104 | critic = LSTM_CriticNetwork(sess, args.dim, args.optimizer, args.lr, args.tau, args.grained, args.attention, args.maxlenth, args.dropout, word_vector) 105 | actor = ActorNetwork(sess, args.dim, args.optimizer, args.lr, args.tau, critic.get_num_trainable_vars()) 106 | state_size = critic.state_size 107 | 108 | #print variables 109 | for item in tf.trainable_variables(): 110 | print (item.name, item.get_shape()) 111 | 112 | saver = tf.train.Saver() 113 | 114 | saver.restore(sess, "checkpoints/best821") 115 | test(sess, actor, critic, dev_data) 116 | -------------------------------------------------------------------------------- /ID_LSTM/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import random 4 | import sys, os 5 | import json 6 | import argparse 7 | from parser import Parser 8 | from datamanager import DataManager 9 | from actor import ActorNetwork 10 | from LSTM_critic import LSTM_CriticNetwork 11 | tf.logging.set_verbosity(tf.logging.ERROR) 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | #get parse 14 | argv = sys.argv[1:] 15 | parser = Parser().getParser() 16 | args, _ = parser.parse_known_args(argv) 17 | random.seed(args.seed) 18 | 19 | #get data 20 | dataManager = DataManager(args.dataset) 21 | train_data, dev_data, test_data = dataManager.getdata(args.grained, args.maxlenth) 22 | word_vector = dataManager.get_wordvector(args.word_vector) 23 | 24 | if args.fasttest == 1: 25 | train_data = train_data[:100] 26 | dev_data = dev_data[:20] 27 | test_data = test_data[:20] 28 | print "train_data ", len(train_data) 29 | print "dev_data", len(dev_data) 30 | print "test_data", len(test_data) 31 | 32 | def sampling_RL(sess, actor, inputs, vec, lenth, Random=True): 33 | current_lower_state = np.zeros((1, 2*args.dim), dtype=np.float32) 34 | actions = [] 35 | states = [] 36 | #sampling actions 37 | 38 | for pos in range(lenth): 39 | predicted = actor.predict_target(current_lower_state, [vec[0][pos]]) 40 | #print predicted 41 | states.append([current_lower_state, [vec[0][pos]]]) 42 | if Random: 43 | action = (0 if random.random() < predicted[0] else 1) 44 | else: 45 | action = np.argmax(predicted) 46 | actions.append(action) 47 | if action == 1: 48 | out_d, current_lower_state = critic.lower_LSTM_target(current_lower_state, [[inputs[pos]]]) 49 | 50 | Rinput = [] 51 | for (i, a) in enumerate(actions): 52 | if a == 1: 53 | Rinput.append(inputs[i]) 54 | Rlenth = len(Rinput) 55 | if Rlenth == 0: 56 | actions[lenth-2] = 1 57 | Rinput.append(inputs[lenth-2]) 58 | Rlenth = 1 59 | Rinput += [0] * (args.maxlenth - Rlenth) 60 | return actions, states, Rinput, Rlenth 61 | 62 | def test(sess, actor, critic, test_data, noRL=False): 63 | acc = 0 64 | total_lenth = 0 65 | total_dis = 0 66 | rwords = {} 67 | owords = {} 68 | for i in range(len(test_data)): 69 | #prepare 70 | data = test_data[i] 71 | inputs, solution, lenth = data['words'], data['solution'], data['lenth'] 72 | 73 | #predict 74 | if noRL: 75 | out = critic.predict_target([inputs], [lenth]) 76 | else: 77 | actions, states, Rinput, Rlenth = sampling_RL(sess, actor, inputs, critic.wordvector_find([inputs]), lenth, Random=False) 78 | out = critic.predict_target([Rinput], [Rlenth]) 79 | #print json.dumps(actions) 80 | #print Rinput 81 | #print json.dumps([dataManager.words[i-1][0] for i in inputs][:lenth]) 82 | #print [dataManager.words[i-1][0] for i in Rinput][:Rlenth] 83 | #print out, solution 84 | #print (float(Rlenth)/lenth) * 0.05 * args.grained 85 | 86 | if np.argmax(out) == np.argmax(solution): 87 | acc += 1 88 | 89 | total_lenth += lenth 90 | total_dis += Rlenth 91 | for i in range(lenth): 92 | wd = dataManager.words[inputs[i]-1][0] 93 | if owords.has_key(wd): 94 | owords[wd] = owords[wd] + 1 95 | else: 96 | owords[wd] = 1 97 | if actions[i] == 0: 98 | if rwords.has_key(wd): 99 | rwords[wd] = rwords[wd] + 1 100 | else: 101 | rwords[wd] = 1 102 | ratewords = {} 103 | for (key, value) in rwords.items(): 104 | ratewords[key] = float(value) / owords[key] 105 | rdwords = ratewords.items() 106 | rdwords.sort(key = lambda x : x[1], reverse = True) 107 | outcnt = 0 108 | for i in range(len(rdwords)): 109 | if owords[rdwords[i][0]] > 20: 110 | print rdwords[i], owords[rdwords[i][0]] 111 | outcnt += 1 112 | if outcnt > 20: 113 | break; 114 | avelenth = float(total_lenth) / float(len(test_data)) 115 | avedis = float(total_dis) / float(len(test_data)) 116 | #print "average length", avelenth 117 | #print "average distilled length", avedis 118 | return float(acc) / len(test_data) 119 | 120 | config = tf.ConfigProto() 121 | config.gpu_options.allow_growth = True 122 | with tf.Session(config = config) as sess: 123 | #model 124 | critic = LSTM_CriticNetwork(sess, args.dim, args.optimizer, args.lr, args.tau, args.grained, args.maxlenth, args.dropout, word_vector) 125 | actor = ActorNetwork(sess, args.dim, args.optimizer, args.lr, args.tau) 126 | #print variables 127 | for item in tf.trainable_variables(): 128 | print (item.name, item.get_shape()) 129 | 130 | saver = tf.train.Saver() 131 | 132 | saver.restore(sess, "checkpoints/best816") 133 | 134 | print test(sess, actor, critic, dev_data) 135 | 136 | -------------------------------------------------------------------------------- /HS_LSTM/datamanager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import json, random 4 | 5 | class DataManager(object): 6 | def __init__(self, dataset): 7 | ''' 8 | Read the data from dir "dataset" 9 | ''' 10 | self.origin = {} 11 | for fname in ['train', 'dev', 'test']: 12 | data = [] 13 | for line in open('%s/%s.res' % (dataset, fname)): 14 | s = json.loads(line.strip()) 15 | if len(s) > 0: 16 | data.append(s) 17 | self.origin[fname] = data 18 | def getword(self): 19 | ''' 20 | Get the words that appear in the data. 21 | Sorted by the times it appears. 22 | {'ok': 1, 'how': 2, ...} 23 | Never run this function twice. 24 | ''' 25 | wordcount = {} 26 | def dfs(node): 27 | if node.has_key('children'): 28 | dfs(node['children'][0]) 29 | dfs(node['children'][1]) 30 | else: 31 | word = node['word'].lower() 32 | wordcount[word] = wordcount.get(word, 0) + 1 33 | for fname in ['train', 'dev', 'test']: 34 | for sent in self.origin[fname]: 35 | dfs(sent) 36 | words = wordcount.items() 37 | words.sort(key = lambda x : x[1], reverse = True) 38 | self.words = words 39 | self.wordlist = {item[0]: index+1 for index, item in enumerate(words)} 40 | return self.wordlist 41 | 42 | def getdata(self, grained, maxlenth): 43 | ''' 44 | Get all the data, divided into (train,dev,test). 45 | For every sentence, {'words':[1,3,5,...], 'solution': [0,1,0,0,0]} 46 | For each data, [sentence1, sentence2, ...] 47 | Never run this function twice. 48 | ''' 49 | def one_hot_vector(r): 50 | s = np.zeros(grained, dtype=np.float32) 51 | s[r] += 1.0 52 | return s 53 | def dfs(node, words): 54 | if node.has_key('children'): 55 | dfs(node['children'][0], words) 56 | dfs(node['children'][1], words) 57 | node['size'] = node['children'][0]['size'] + node['children'][1]['size'] 58 | else: 59 | word = self.wordlist[node['word'].lower()] 60 | words.append(word) 61 | node['size'] = 1 62 | def look_action(node, action, ulen): 63 | if node['size'] <= ulen: 64 | action += [0] * (node['size'] - 1) 65 | action.append(1) 66 | elif node.has_key('children'): 67 | look_action(node['children'][0], action, ulen) 68 | look_action(node['children'][1], action, ulen) 69 | self.getword() 70 | self.data = {} 71 | for fname in ['train', 'dev', 'test']: 72 | self.data[fname] = [] 73 | for sent in self.origin[fname]: 74 | words, action = [], [] 75 | dfs(sent, words) 76 | lens = len(words) 77 | words += [0] * (maxlenth - lens) 78 | solution = one_hot_vector(int(sent['rating'])) 79 | look_action(sent, action, int(np.sqrt(lens) + 0.5)) 80 | now = {'words': np.array(words), \ 81 | 'solution': solution,\ 82 | 'lenth': lens, \ 83 | 'action': action} 84 | self.data[fname].append(now) 85 | return self.data['train'], self.data['dev'], self.data['test'] 86 | 87 | def get_wordvector(self, name): 88 | fr = open(name) 89 | n, dim = map(int, fr.readline().split()) 90 | self.wv = {} 91 | for i in range(n - 1): 92 | vec = fr.readline().split() 93 | word = vec[0].lower() 94 | vec = map(float, vec[1:]) 95 | if self.wordlist.has_key(word): 96 | self.wv[self.wordlist[word]] = vec 97 | self.wordvector = [] 98 | losscnt = 0 99 | for i in range(len(self.wordlist) + 1): 100 | if self.wv.has_key(i): 101 | self.wordvector.append(self.wv[i]) 102 | else: 103 | losscnt += 1 104 | self.wordvector.append(np.random.uniform(-0.1,0.1,[dim])) 105 | self.wordvector = np.array(self.wordvector, dtype=np.float32) 106 | print losscnt, "words not find in wordvector" 107 | print len(self.wordvector), "words in total" 108 | return self.wordvector 109 | 110 | #datamanager = DataManager("../TrainData/SUBJ") 111 | #train_data, test_data, dev_data = datamanager.getdata(2, 200) 112 | #wv = datamanager.get_wordvector("../WordVector/vector.25dim") 113 | #mxlen = 0 114 | #for item in test_data: 115 | # print item['action'], item['lenth'] 116 | # if item['lenth'] > mxlen: 117 | # mxlen = item['lenth'] 118 | #print mxlen 119 | 120 | #datamanager = DataManager("../TrainData/MR") 121 | #train_data, dev_data, test_data = datamanager.getdata(2,70); 122 | #for item in dev_data: 123 | # print json.dumps(item['action'][:item['lenth']]) 124 | # print json.dumps([datamanager.words[i-1][0] for i in item['words']][:item['lenth']]) 125 | 126 | -------------------------------------------------------------------------------- /ID_LSTM/LSTM_critic.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.rnn import LSTMCell 3 | import tflearn 4 | import numpy as np 5 | 6 | class LSTM_CriticNetwork(object): 7 | """ 8 | predict network. 9 | use the word vector and actions(sampled from actor network) 10 | get the final prediction. 11 | """ 12 | def __init__(self, sess, dim, optimizer, learning_rate, tau, grained, max_lenth, dropout, wordvector): 13 | self.global_step = tf.Variable(0, trainable=False, name="LSTMStep") 14 | self.sess = sess 15 | self.max_lenth = max_lenth 16 | self.dim = dim 17 | self.learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 10000, 0.95, staircase=True) 18 | self.tau = tau 19 | self.grained = grained 20 | self.dropout = dropout 21 | self.init = tf.random_uniform_initializer(-0.05, 0.05, dtype=tf.float32) 22 | self.L2regular = 0.00001 # add to parser 23 | print "optimizer: ", optimizer 24 | if optimizer == 'Adam': 25 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate) 26 | elif optimizer == 'Adagrad': 27 | self.optimizer = tf.train.AdagradOptimizer(self.learning_rate) 28 | elif optimizer == 'Adadelta': 29 | self.optimizer = tf.train.AdadeltaOptimizer(self.learning_rate) 30 | self.keep_prob = tf.placeholder(tf.float32, name="keepprob") 31 | self.num_other_variables = len(tf.trainable_variables()) 32 | self.wordvector = tf.get_variable('wordvector', dtype=tf.float32, initializer=wordvector, trainable=True) 33 | 34 | #lstm cells 35 | self.lower_cell_state, self.lower_cell_input, self.lower_cell_output, self.lower_cell_state1 = self.create_LSTM_cell('Lower/Active') 36 | 37 | #critic network (updating) 38 | self.inputs, self.lenth, self.out = self.create_critic_network("Active") 39 | self.network_params = tf.trainable_variables()[self.num_other_variables:] 40 | 41 | self.target_wordvector = tf.get_variable('wordvector_target', dtype=tf.float32, initializer=wordvector, trainable=True) 42 | 43 | #lstm cells 44 | self.target_lower_cell_state, self.target_lower_cell_input, self.target_lower_cell_output, self.target_lower_cell_state1 = self.create_LSTM_cell('Lower/Target') 45 | 46 | #critic network (delayed updating) 47 | self.target_inputs, self.target_lenth, self.target_out = self.create_critic_network("Target") 48 | self.target_network_params = tf.trainable_variables()[len(self.network_params)+self.num_other_variables:] 49 | 50 | #delayed updating critic network ops 51 | self.update_target_network_params = \ 52 | [self.target_network_params[i].assign(\ 53 | tf.multiply(self.network_params[i], self.tau)+\ 54 | tf.multiply(self.target_network_params[i], 1 - self.tau))\ 55 | for i in range(len(self.target_network_params))] 56 | 57 | self.assign_target_network_params = \ 58 | [self.target_network_params[i].assign(\ 59 | self.network_params[i]) for i in range(len(self.target_network_params))] 60 | self.assign_active_network_params = \ 61 | [self.network_params[i].assign(\ 62 | self.target_network_params[i]) for i in range(len(self.network_params))] 63 | 64 | self.ground_truth = tf.placeholder(tf.float32, [1,self.grained], name="ground_truth") 65 | 66 | 67 | #calculate loss 68 | self.loss_target = tf.nn.softmax_cross_entropy_with_logits(labels=self.ground_truth, logits=self.target_out) 69 | self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.ground_truth, logits=self.out) 70 | self.loss2 = 0 71 | with tf.variable_scope("Lower/Active", reuse=True): 72 | self.loss2+= tf.nn.l2_loss(tf.get_variable('lstm_cell/kernel')) 73 | with tf.variable_scope("Active/pred", reuse=True): 74 | self.loss2+= tf.nn.l2_loss(tf.get_variable('W')) 75 | self.loss += self.loss2 * self.L2regular 76 | self.loss_target += self.loss2 * self.L2regular 77 | self.gradients = tf.gradients(self.loss_target, self.target_network_params) 78 | self.optimize = self.optimizer.apply_gradients(zip(self.gradients, self.network_params), global_step = self.global_step) 79 | 80 | #total variables 81 | self.num_trainable_vars = len(self.network_params) + len(self.target_network_params) 82 | 83 | #wordvector look for 84 | self.WVinput, self.WVvec = self.create_wordvector_find() 85 | 86 | 87 | def create_critic_network(self, Scope): 88 | inputs = tf.placeholder(shape=[1, self.max_lenth], dtype=tf.int32, name="inputs") 89 | lenth = tf.placeholder(shape=[1], dtype=tf.int32, name="lenth") 90 | 91 | #Lower network 92 | if Scope[-1] == 'e': 93 | vec = tf.nn.embedding_lookup(self.wordvector, inputs) 94 | else: 95 | vec = tf.nn.embedding_lookup(self.target_wordvector, inputs) 96 | cell = LSTMCell(self.dim, initializer=self.init, state_is_tuple=False) 97 | 98 | with tf.variable_scope("Lower", reuse=True): 99 | out, _ = tf.nn.dynamic_rnn(cell, vec, lenth, dtype=tf.float32, scope=Scope) 100 | out = tf.gather(out[0], lenth-1) 101 | 102 | out = tflearn.dropout(out, self.keep_prob) 103 | out = tflearn.fully_connected(out, self.grained, scope=Scope+"/pred", name="get_pred") 104 | return inputs, lenth, out 105 | 106 | def create_LSTM_cell(self,Scope): 107 | cell = LSTMCell(self.dim, initializer=self.init, state_is_tuple=False) 108 | state = tf.placeholder(tf.float32, shape = [1, cell.state_size], name="cell_state") 109 | inputs = tf.placeholder(tf.int32, shape = [1, 1], name="cell_input") 110 | if Scope[-1] == 'e': 111 | vec = tf.nn.embedding_lookup(self.wordvector, inputs) 112 | else: 113 | vec = tf.nn.embedding_lookup(self.target_wordvector, inputs) 114 | with tf.variable_scope(Scope, reuse=False): 115 | out, state1 = cell(vec[:,0,:], state) 116 | return state, inputs, out, state1 117 | 118 | def create_wordvector_find(self): 119 | inputs = tf.placeholder(tf.int32, shape=[1, self.max_lenth], name="WVtofind") 120 | vec = tf.nn.embedding_lookup(self.target_wordvector, inputs) 121 | return inputs, vec 122 | 123 | def getloss(self, inputs, lenth, ground_truth): 124 | return self.sess.run([self.target_out, self.loss_target], feed_dict={ 125 | self.target_inputs: inputs, 126 | self.target_lenth: lenth, 127 | self.ground_truth: ground_truth, 128 | self.keep_prob: 1.0}) 129 | 130 | def train(self, inputs, lenth, ground_truth): 131 | return self.sess.run([self.target_out, self.loss_target, self.optimize], feed_dict={ 132 | self.target_inputs: inputs, 133 | self.target_lenth: lenth, 134 | self.ground_truth: ground_truth, 135 | self.keep_prob: self.dropout}) 136 | 137 | def predict_target(self, inputs, lenth): 138 | return self.sess.run(self.target_out, feed_dict={ 139 | self.target_inputs: inputs, 140 | self.target_lenth: lenth, 141 | self.keep_prob: 1.0}) 142 | 143 | def update_target_network(self): 144 | self.sess.run(self.update_target_network_params) 145 | 146 | def assign_target_network(self): 147 | self.sess.run(self.assign_target_network_params) 148 | 149 | def assign_active_network(self): 150 | self.sess.run(self.assign_active_network_params) 151 | 152 | def get_num_trainable_vars(self): 153 | return self.num_trainable_vars 154 | 155 | def lower_LSTM_target(self, state, inputs): 156 | return self.sess.run([self.target_lower_cell_output, self.target_lower_cell_state1], feed_dict={ 157 | self.target_lower_cell_state: state, 158 | self.target_lower_cell_input: inputs}) 159 | 160 | def wordvector_find(self, inputs): 161 | return self.sess.run(self.WVvec, feed_dict={ 162 | self.WVinput :inputs}) 163 | -------------------------------------------------------------------------------- /ID_LSTM/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import random 4 | import sys, os 5 | import json 6 | import argparse 7 | from parser import Parser 8 | from datamanager import DataManager 9 | from actor import ActorNetwork 10 | from LSTM_critic import LSTM_CriticNetwork 11 | tf.logging.set_verbosity(tf.logging.ERROR) 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | #get parse 14 | argv = sys.argv[1:] 15 | parser = Parser().getParser() 16 | args, _ = parser.parse_known_args(argv) 17 | random.seed(args.seed) 18 | 19 | #get data 20 | dataManager = DataManager(args.dataset) 21 | train_data, dev_data, test_data = dataManager.getdata(args.grained, args.maxlenth) 22 | word_vector = dataManager.get_wordvector(args.word_vector) 23 | 24 | if args.fasttest == 1: 25 | train_data = train_data[:100] 26 | dev_data = dev_data[:20] 27 | test_data = test_data[:20] 28 | print "train_data ", len(train_data) 29 | print "dev_data", len(dev_data) 30 | print "test_data", len(test_data) 31 | 32 | def sampling_RL(sess, actor, inputs, vec, lenth, epsilon=0., Random=True): 33 | #print epsilon 34 | current_lower_state = np.zeros((1, 2*args.dim), dtype=np.float32) 35 | actions = [] 36 | states = [] 37 | #sampling actions 38 | 39 | for pos in range(lenth): 40 | predicted = actor.predict_target(current_lower_state, [vec[0][pos]]) 41 | 42 | states.append([current_lower_state, [vec[0][pos]]]) 43 | if Random: 44 | if random.random() > epsilon: 45 | action = (0 if random.random() < predicted[0] else 1) 46 | else: 47 | action = (1 if random.random() < predicted[0] else 0) 48 | else: 49 | action = np.argmax(predicted) 50 | actions.append(action) 51 | if action == 1: 52 | out_d, current_lower_state = critic.lower_LSTM_target(current_lower_state, [[inputs[pos]]]) 53 | 54 | Rinput = [] 55 | for (i, a) in enumerate(actions): 56 | if a == 1: 57 | Rinput.append(inputs[i]) 58 | Rlenth = len(Rinput) 59 | if Rlenth == 0: 60 | actions[lenth-2] = 1 61 | Rinput.append(inputs[lenth-2]) 62 | Rlenth = 1 63 | Rinput += [0] * (args.maxlenth - Rlenth) 64 | return actions, states, Rinput, Rlenth 65 | 66 | def train(sess, actor, critic, train_data, batchsize, samplecnt=5, LSTM_trainable=True, RL_trainable=True): 67 | print "training : total ", len(train_data), "nodes." 68 | random.shuffle(train_data) 69 | for b in range(len(train_data) / batchsize): 70 | datas = train_data[b * batchsize: (b+1) * batchsize] 71 | totloss = 0. 72 | critic.assign_active_network() 73 | actor.assign_active_network() 74 | for j in range(batchsize): 75 | #prepare 76 | data = datas[j] 77 | inputs, solution, lenth = data['words'], data['solution'], data['lenth'] 78 | #train the predict network 79 | if RL_trainable: 80 | actionlist, statelist, losslist = [], [], [] 81 | aveloss = 0. 82 | for i in range(samplecnt): 83 | actions, states, Rinput, Rlenth = sampling_RL(sess, actor, inputs, critic.wordvector_find([inputs]), lenth, args.epsilon, Random=True) 84 | actionlist.append(actions) 85 | statelist.append(states) 86 | out, loss = critic.getloss([Rinput], [Rlenth], [solution]) 87 | loss += (float(Rlenth) / lenth) **2 *0.15 88 | aveloss += loss 89 | losslist.append(loss) 90 | 91 | aveloss /= samplecnt 92 | totloss += aveloss 93 | grad = None 94 | if LSTM_trainable: 95 | out, loss, _ = critic.train([Rinput], [Rlenth], [solution]) 96 | for i in range(samplecnt): 97 | for pos in range(len(actionlist[i])): 98 | rr = [0., 0.] 99 | rr[actionlist[i][pos]] = (losslist[i] - aveloss) * args.alpha 100 | g = actor.get_gradient(statelist[i][pos][0], statelist[i][pos][1], rr) 101 | if grad == None: 102 | grad = g 103 | else: 104 | grad[0] += g[0] 105 | grad[1] += g[1] 106 | grad[2] += g[2] 107 | actor.train(grad) 108 | else: 109 | out, loss, _ = critic.train([inputs], [lenth], [solution]) 110 | totloss += loss 111 | if RL_trainable: 112 | actor.update_target_network() 113 | if LSTM_trainable: 114 | critic.update_target_network() 115 | else: 116 | critic.assign_target_network() 117 | if (b + 1) % 500 == 0: 118 | acc_test = test(sess, actor, critic, test_data, noRL= not RL_trainable) 119 | acc_dev = test(sess, actor, critic, dev_data, noRL= not RL_trainable) 120 | print "batch ",b , "total loss ", totloss, "----test: ", acc_test, "| dev: ", acc_dev 121 | 122 | 123 | def test(sess, actor, critic, test_data, noRL=False): 124 | acc = 0 125 | for i in range(len(test_data)): 126 | #prepare 127 | data = test_data[i] 128 | inputs, solution, lenth = data['words'], data['solution'], data['lenth'] 129 | 130 | #predict 131 | if noRL: 132 | out = critic.predict_target([inputs], [lenth]) 133 | else: 134 | actions, states, Rinput, Rlenth = sampling_RL(sess, actor, inputs, critic.wordvector_find([inputs]), lenth, Random=False) 135 | out = critic.predict_target([Rinput], [Rlenth]) 136 | if np.argmax(out) == np.argmax(solution): 137 | acc += 1 138 | return float(acc) / len(test_data) 139 | 140 | config = tf.ConfigProto() 141 | config.gpu_options.allow_growth = True 142 | with tf.Session(config = config) as sess: 143 | #model 144 | critic = LSTM_CriticNetwork(sess, args.dim, args.optimizer, args.lr, args.tau, args.grained, args.maxlenth, args.dropout, word_vector) 145 | actor = ActorNetwork(sess, args.dim, args.optimizer, args.lr, args.tau) 146 | #print variables 147 | for item in tf.trainable_variables(): 148 | print (item.name, item.get_shape()) 149 | 150 | saver = tf.train.Saver() 151 | 152 | #LSTM pretrain 153 | if args.RLpretrain != '': 154 | pass 155 | elif args.LSTMpretrain == '': 156 | sess.run(tf.global_variables_initializer()) 157 | for i in range(0, 2): 158 | train(sess, actor, critic, train_data, args.batchsize, args.sample_cnt, RL_trainable=False) 159 | critic.assign_target_network() 160 | acc_test = test(sess, actor, critic, test_data, True) 161 | acc_dev = test(sess, actor, critic, dev_data, True) 162 | print "LSTM_only ",i, "----test: ", acc_test, "| dev: ", acc_dev 163 | saver.save(sess, "checkpoints/"+args.name+"_base", global_step=i) 164 | print "LSTM pretrain OK" 165 | else: 166 | print "Load LSTM from ", args.LSTMpretrain 167 | saver.restore(sess, args.LSTMpretrain) 168 | 169 | print "epsilon", args.epsilon 170 | 171 | if args.RLpretrain == '': 172 | for i in range(0, 5): 173 | train(sess, actor, critic, train_data, args.batchsize, args.sample_cnt, LSTM_trainable=False) 174 | acc_test = test(sess, actor, critic, test_data) 175 | acc_dev = test(sess, actor, critic, dev_data) 176 | print "RL pretrain ", i, "----test: ", acc_test, "| dev: ", acc_dev 177 | saver.save(sess, "checkpoints/"+args.name+"_RLpre", global_step=i) 178 | print "RL pretrain OK" 179 | else: 180 | print "Load RL from", args.RLpretrain 181 | saver.restore(sess, args.RLpretrain) 182 | 183 | for e in range(args.epoch): 184 | train(sess, actor, critic, train_data, args.batchsize, args.sample_cnt) 185 | acc_test = test(sess, actor, critic, test_data) 186 | acc_dev = test(sess, actor, critic, dev_data) 187 | print "epoch ", e, "----test: ", acc_test, "| dev: ", acc_dev 188 | saver.save(sess, "checkpoints/"+args.name, global_step=e) 189 | 190 | 191 | -------------------------------------------------------------------------------- /HS_LSTM/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import random 4 | import sys, os 5 | import json 6 | import argparse 7 | from parser import Parser 8 | from datamanager import DataManager 9 | from actor import ActorNetwork 10 | from LSTM_critic import LSTM_CriticNetwork 11 | tf.logging.set_verbosity(tf.logging.ERROR) 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | #get parse 14 | argv = sys.argv[1:] 15 | parser = Parser().getParser() 16 | args, _ = parser.parse_known_args(argv) 17 | random.seed(args.seed) 18 | 19 | #get data 20 | dataManager = DataManager(args.dataset) 21 | train_data, dev_data, test_data = dataManager.getdata(args.grained, args.maxlenth) 22 | word_vector = dataManager.get_wordvector(args.word_vector) 23 | 24 | print "train_data ", len(train_data) 25 | print "dev_data", len(dev_data) 26 | print "test_data", len(test_data) 27 | if args.fasttest == 1: 28 | train_data = train_data[:100] 29 | dev_data = dev_data[:20] 30 | test_data = test_data[:20] 31 | 32 | def sampling_RL(sess, actor, inputs, lenth, Random=True): 33 | current_lower_state = np.zeros((1, state_size), dtype=np.float32) 34 | current_upper_state = np.zeros((1, state_size), dtype=np.float32) 35 | actions = [] 36 | states = [] 37 | #sampling actions 38 | 39 | for pos in range(lenth): 40 | out_d, current_lower_state = critic.lower_LSTM_target(current_lower_state, [[inputs[pos]]]) 41 | predicted = actor.predict_target(current_upper_state, current_lower_state) 42 | #print predicted 43 | states.append([current_upper_state, current_lower_state]) 44 | if Random: 45 | action = (0 if random.random() < predicted[0] else 1) 46 | else: 47 | action = np.argmax(predicted) 48 | actions.append(action) 49 | if action == 1: 50 | current_upper_state = critic.upper_LSTM_target(current_upper_state, out_d) 51 | current_lower_state = np.zeros_like(current_lower_state) 52 | 53 | #pad zeros 54 | actions += [0] * (args.maxlenth - lenth) 55 | actions[lenth-1] = 1 56 | #get the position of action 1 57 | action_pos = [] 58 | for (i, j) in enumerate(actions): 59 | if j == 1: 60 | action_pos.append(i) 61 | return actions, states, action_pos 62 | 63 | def sampling_random(lenth, p_action = None): 64 | actions = [] 65 | typ = args.pretype 66 | actions = np.copy(p_action).tolist() 67 | actions += [0] * (args.maxlenth - lenth) 68 | action_pos = [] 69 | for (i, j) in enumerate(actions): 70 | if j == 1: 71 | action_pos.append(i) 72 | if len(action_pos) == 0: 73 | actions[lenth-1] = 1 74 | action_pos.append(lenth-1) 75 | if len(actions) != args.maxlenth: 76 | print lenth, p_action 77 | return actions, action_pos 78 | 79 | def train(sess, actor, critic, train_data, batch_size, samplecnt, LSTM_trainable=True, RL_trainable=True): 80 | print "training : total ", len(train_data), "nodes. ", len(train_data)/batch_size, " batchs." 81 | random.shuffle(train_data) 82 | for b in range(len(train_data)/batch_size): 83 | datas = train_data[b * batch_size: (b+1) * batch_size] 84 | totloss = 0. 85 | actor.assign_active_network() 86 | critic.assign_active_network() 87 | for i in range(batch_size): 88 | #prepare 89 | data = datas[i] 90 | inputs, solution, lenth, p_action = data['words'], data['solution'], data['lenth'], data['action'] 91 | aveloss = 0. 92 | statelist, actionlist, losslist = [], [], [] 93 | #get sampling 94 | if RL_trainable: 95 | for sp in range(samplecnt): 96 | actions, states, action_pos = sampling_RL(sess, actor, inputs, lenth) 97 | statelist.append(states) 98 | actionlist.append(actions) 99 | out, loss = critic.getloss([inputs], [actions], [action_pos], [lenth], [len(action_pos)], [solution]) 100 | # control loss of lenth 101 | _x = float(len(action_pos)) / lenth 102 | loss += (1 * _x + 0.1 / _x - 0.6) * 0.1 * args.grained 103 | # 104 | aveloss += loss 105 | losslist.append(loss) 106 | else : 107 | actions, action_pos = sampling_random(lenth, p_action) 108 | #train the predict network 109 | if LSTM_trainable: 110 | out, loss, _ = critic.train([inputs], [actions], [action_pos], [lenth], [len(action_pos)], [solution]) 111 | if not RL_trainable: 112 | totloss += loss 113 | #train the actor network 114 | if RL_trainable: 115 | aveloss /= samplecnt 116 | totloss += aveloss 117 | grad = None 118 | for sp in range(samplecnt): 119 | for pos in range(lenth): 120 | rr = [0.,0.] 121 | rr[actionlist[sp][pos]] = (losslist[sp] - aveloss) * args.alpha 122 | 123 | g = actor.get_gradient(statelist[sp][pos][0], statelist[sp][pos][1], rr) 124 | if grad == None: 125 | grad = g 126 | else: 127 | grad[0] += g[0] 128 | grad[1] += g[1] 129 | grad[2] += g[2] 130 | actor.train(grad) 131 | 132 | if RL_trainable: 133 | actor.update_target_network() 134 | if LSTM_trainable: 135 | if RL_trainable: 136 | critic.update_target_network() 137 | else: 138 | critic.assign_target_network() 139 | if (b + 1) % 500 == 0: 140 | acc_test = test(sess, actor, critic, test_data, not RL_trainable) 141 | acc_dev = test(sess, actor, critic, dev_data, not RL_trainable) 142 | print "batch ",b , "total loss ", totloss, "----test: ", acc_test, "| dev: ", acc_dev 143 | 144 | def test(sess, actor, critic, test_data, Random=False): 145 | acc = 0 146 | for i in range(len(test_data)): 147 | #prepare 148 | data = test_data[i] 149 | inputs, solution, lenth, paction = data['words'], data['solution'], data['lenth'], data['action'] 150 | #get sampling 151 | if Random == False: 152 | actions, _, action_pos = sampling_RL(sess, actor, inputs, lenth, Random=False) 153 | else: 154 | actions, action_pos = sampling_random(lenth, paction) 155 | 156 | if len(actions) != args.maxlenth: 157 | print inputs 158 | #predict 159 | out = critic.predict_target([inputs], [actions], [action_pos], [lenth], [len(action_pos)]) 160 | if np.argmax(out) == np.argmax(solution): 161 | acc += 1 162 | return float(acc) / len(test_data) 163 | 164 | config = tf.ConfigProto() 165 | config.gpu_options.allow_growth = True 166 | with tf.Session(config = config) as sess: 167 | #model 168 | critic = LSTM_CriticNetwork(sess, args.dim, args.optimizer, args.lr, args.tau, args.grained, args.attention, args.maxlenth, args.dropout, word_vector) 169 | actor = ActorNetwork(sess, args.dim, args.optimizer, args.lr, args.tau, critic.get_num_trainable_vars()) 170 | state_size = critic.state_size 171 | 172 | #print variables 173 | for item in tf.trainable_variables(): 174 | print (item.name, item.get_shape()) 175 | 176 | saver = tf.train.Saver() 177 | 178 | #LSTM pretrain 179 | if args.RLpretrain != '': 180 | pass 181 | elif args.LSTMpretrain == '': 182 | sess.run(tf.global_variables_initializer()) 183 | for i in range(0,2): 184 | train(sess, actor, critic, train_data, args.batchsize, args.samplecnt, RL_trainable=False) 185 | critic.assign_target_network() 186 | acc_test = test(sess, actor, critic, test_data, True) 187 | acc_dev = test(sess, actor, critic, dev_data, True) 188 | print "LSTM_only ",i, "----test: ", acc_test, "| dev: ", acc_dev 189 | saver.save(sess, "checkpoints/"+args.name+"_base", global_step=i) 190 | print "LSTM pretrain OK" 191 | else: 192 | print "Load LSTM from ", args.LSTMpretrain 193 | saver.restore(sess, args.LSTMpretrain) 194 | pass 195 | #RL pretrain 196 | if args.RLpretrain == '': 197 | for i in range(0,5): 198 | train(sess, actor, critic, train_data, args.batchsize, args.samplecnt, LSTM_trainable=False) 199 | acc_test = test(sess, actor, critic, test_data) 200 | acc_dev = test(sess, actor, critic, dev_data) 201 | print "RL pretrain ", i, "----test: ", acc_test, "| dev: ", acc_dev 202 | saver.save(sess, "checkpoints/"+args.name+"_RL", global_step=i) 203 | print "RL pretrain OK" 204 | else: 205 | print "Load RL from ", args.RLpretrain 206 | saver.restore(sess, args.RLpretrain) 207 | #train 208 | results = [] 209 | for e in range(args.epoch): 210 | train(sess, actor, critic, train_data, args.batchsize, args.samplecnt) 211 | acc_test = test(sess, actor, critic, test_data) 212 | acc_dev = test(sess, actor, critic, dev_data) 213 | print "epoch ", e, "---- test: ", acc_test, "| dev: ", acc_dev 214 | saver.save(sess, "checkpoints/"+args.name, global_step=e) 215 | 216 | 217 | -------------------------------------------------------------------------------- /HS_LSTM/LSTM_critic.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.rnn import LSTMCell 3 | import tflearn 4 | import numpy as np 5 | 6 | class LSTM_CriticNetwork(object): 7 | """ 8 | predict network. 9 | use the word vector and actions(sampled from actor network) 10 | get the final prediction. 11 | """ 12 | def __init__(self, sess, dim, optimizer, learning_rate, tau, grained, isAttention, max_lenth, dropout, wordvector): 13 | self.global_step = tf.Variable(0, trainable=False, name="LSTMStep") 14 | self.sess = sess 15 | self.dim = dim 16 | self.learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 10000, 0.95, staircase=True) 17 | self.tau = tau 18 | self.grained = grained 19 | self.isAttention = isAttention 20 | self.max_lenth = max_lenth 21 | self.dropout = dropout 22 | self.init = tf.random_uniform_initializer(-0.05, 0.05, dtype=tf.float32) 23 | self.L2regular = 0.00001 # add to parser 24 | print "optimizer: ", optimizer 25 | if optimizer == 'Adam': 26 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate) 27 | elif optimizer == 'Adagrad': 28 | self.optimizer = tf.train.AdagradOptimizer(self.learning_rate) 29 | elif optimizer == 'Adadelta': 30 | self.optimizer = tf.train.AdadeltaOptimizer(self.learning_rate) 31 | self.keep_prob = tf.placeholder(tf.float32, name="keepprob") 32 | self.num_other_variables = len(tf.trainable_variables()) 33 | self.wordvector = tf.get_variable('wordvector', dtype=tf.float32, initializer=wordvector, trainable=True) 34 | 35 | #lstm cells 36 | self.upper_cell_state, self.upper_cell_input, self.upper_cell_output = self.create_Upper_LSTM_cell('Upper/Active') 37 | self.lower_cell_state, self.lower_cell_input, self.lower_cell_output, self.lower_cell_state1 = self.create_Lower_LSTM_cell('Lower/Active') 38 | 39 | #critic network (updating) 40 | self.inputs, self.action, self.action_pos, self.lenth, self.lenth_up, self.out = self.create_critic_network("Active") 41 | self.network_params = tf.trainable_variables()[self.num_other_variables:] 42 | 43 | self.target_wordvector = tf.get_variable('wordvector_target', dtype=tf.float32, initializer=wordvector, trainable=True) 44 | 45 | #lstm cells 46 | self.target_upper_cell_state, self.target_upper_cell_input, self.target_upper_cell_output = self.create_Upper_LSTM_cell('Upper/Target') 47 | self.target_lower_cell_state, self.target_lower_cell_input, self.target_lower_cell_output, self.target_lower_cell_state1 = self.create_Lower_LSTM_cell('Lower/Target') 48 | 49 | #critic network (delayed updating) 50 | self.target_inputs, self.target_action, self.target_action_pos, self.target_lenth, self.target_lenth_up, self.target_out = self.create_critic_network("Target") 51 | self.target_network_params = tf.trainable_variables()[len(self.network_params)+self.num_other_variables:] 52 | 53 | #delayed updating critic network ops 54 | self.update_target_network_params = \ 55 | [self.target_network_params[i].assign(\ 56 | tf.multiply(self.network_params[i], self.tau)+\ 57 | tf.multiply(self.target_network_params[i], 1 - self.tau))\ 58 | for i in range(len(self.target_network_params))] 59 | 60 | self.assign_target_network_params = \ 61 | [self.target_network_params[i].assign(\ 62 | self.network_params[i]) for i in range(len(self.target_network_params))] 63 | self.assign_active_network_params = \ 64 | [self.network_params[i].assign(\ 65 | self.target_network_params[i]) for i in range(len(self.network_params))] 66 | 67 | self.ground_truth = tf.placeholder(tf.float32, [1,self.grained], name="ground_truth") 68 | 69 | 70 | #calculate loss 71 | self.loss_target = tf.nn.softmax_cross_entropy_with_logits(labels=self.ground_truth, logits=self.target_out) 72 | self.loss = tf.nn.softmax_cross_entropy_with_logits(labels=self.ground_truth, logits=self.out) 73 | #with tf.variable_scope("Upper/Active", reuse=True): 74 | # self.loss2 = tf.nn.l2_loss(tf.get_variable('lstm_cell/kernel')) 75 | #with tf.variable_scope("Lower/Active", reuse=True): 76 | # self.loss2+= tf.nn.l2_loss(tf.get_variable('lstm_cell/kernel')) 77 | #with tf.variable_scope("Active/pred", reuse=True): 78 | # self.loss2+= tf.nn.l2_loss(tf.get_variable('W')) 79 | #self.loss += self.loss2 * self.L2regular 80 | #self.loss_target += self.loss2 * self.L2regular 81 | self.gradients = tf.gradients(self.loss_target, self.target_network_params) 82 | self.optimize = self.optimizer.apply_gradients(zip(self.gradients, self.network_params), global_step = self.global_step) 83 | #self.optimize = self.optimizer.minimize(self.loss) 84 | 85 | #total variables 86 | self.num_trainable_vars = len(self.network_params) + len(self.target_network_params) 87 | 88 | def create_critic_network(self, Scope): 89 | inputs = tf.placeholder(shape=[1, self.max_lenth], dtype=tf.int32, name="inputs") 90 | action = tf.placeholder(shape=[1, self.max_lenth], dtype=tf.int32, name="action") 91 | action_pos = tf.placeholder(shape=[1, None], dtype=tf.int32, name="action_pos") 92 | lenth = tf.placeholder(shape=[1], dtype=tf.int32, name="lenth") 93 | lenth_up = tf.placeholder(shape=[1], dtype=tf.int32, name="lenth_up") 94 | 95 | #Lower network 96 | if Scope[-1] == 'e': 97 | vec = tf.nn.embedding_lookup(self.wordvector, inputs) 98 | print "active" 99 | else: 100 | vec = tf.nn.embedding_lookup(self.target_wordvector, inputs) 101 | print "target" 102 | cell = LSTMCell(self.dim, initializer=self.init, state_is_tuple=False) 103 | self.state_size = cell.state_size 104 | actions = tf.to_float(action) 105 | h = cell.zero_state(1, tf.float32) 106 | embedding = [] 107 | for step in range(self.max_lenth): 108 | with tf.variable_scope("Lower/"+Scope, reuse=True): 109 | o, h = cell(vec[:,step,:], h) 110 | embedding.append(o[0]) 111 | h = h *(1.0 - actions[0,step]) 112 | 113 | #Upper network 114 | embedding = tf.stack(embedding) 115 | embedding = tf.gather(embedding, action_pos, name="Upper_input") 116 | with tf.variable_scope("Upper", reuse=True): 117 | out, _ = tf.nn.bidirectional_dynamic_rnn(cell, cell, embedding, lenth_up, dtype=tf.float32, scope=Scope) 118 | 119 | if self.isAttention: 120 | out = tf.concat(out, 2) 121 | out = out[0,:,:] 122 | tmp = tflearn.fully_connected(out, self.dim, scope=Scope, name="att") 123 | tmp = tflearn.tanh(tmp) 124 | with tf.variable_scope(Scope): 125 | v_T = tf.get_variable("v_T", dtype=tf.float32, shape=[self.dim, 1], trainable=True) 126 | a = tflearn.softmax(tf.matmul(tmp,v_T)) 127 | out = tf.reduce_sum(out * a, 0) 128 | out = tf.expand_dims(out, 0) 129 | else: 130 | #out = embedding[:, -1, :] 131 | out = tf.concat((out[0][:,-1,:], out[1][:,0,:]), 1) 132 | 133 | out = tflearn.dropout(out, self.keep_prob) 134 | out = tflearn.fully_connected(out, self.grained, scope=Scope+"/pred", name="get_pred") 135 | return inputs, action, action_pos, lenth, lenth_up, out 136 | 137 | def create_Lower_LSTM_cell(self,Scope): 138 | cell = LSTMCell(self.dim, initializer=self.init, state_is_tuple=False) 139 | state = tf.placeholder(tf.float32, shape = [1, cell.state_size], name="cell_state") 140 | inputs = tf.placeholder(tf.int32, shape = [1, 1], name="cell_input") 141 | if Scope[-1] == 'e': 142 | vec = tf.nn.embedding_lookup(self.wordvector, inputs) 143 | else: 144 | vec = tf.nn.embedding_lookup(self.target_wordvector, inputs) 145 | with tf.variable_scope(Scope, reuse=False): 146 | out, state1 = cell(vec[:,0,:], state) 147 | return state, inputs, out, state1 148 | 149 | def create_Upper_LSTM_cell(self, Scope): 150 | cell = LSTMCell(self.dim, initializer=self.init, state_is_tuple=False) 151 | state_l = tf.placeholder(tf.float32, shape = [1, cell.state_size], name="cell_state_l") 152 | state_d = tf.placeholder(tf.float32, shape = [1, self.dim], name="cell_state_d") 153 | with tf.variable_scope(Scope, reuse=False): 154 | _, out = cell(state_d, state_l) 155 | return state_l, state_d, out 156 | 157 | def getloss(self, inputs, action, action_pos, lenth, lenth_up, ground_truth): 158 | return self.sess.run([self.target_out, self.loss_target], feed_dict={ 159 | self.target_inputs: inputs, 160 | self.target_action: action, 161 | self.target_action_pos: action_pos, 162 | self.target_lenth: lenth, 163 | self.target_lenth_up: lenth_up, 164 | self.ground_truth: ground_truth, 165 | self.keep_prob: 1.0}) 166 | 167 | def train(self, inputs, action, action_pos, lenth, lenth_up, ground_truth): 168 | return self.sess.run([self.target_out, self.loss_target, self.optimize], feed_dict={ 169 | self.target_inputs: inputs, 170 | self.target_action: action, 171 | self.target_action_pos: action_pos, 172 | self.target_lenth: lenth, 173 | self.target_lenth_up: lenth_up, 174 | self.ground_truth: ground_truth, 175 | self.keep_prob: self.dropout}) 176 | 177 | def predict_target(self, inputs, action, action_pos, lenth, lenth_up): 178 | return self.sess.run(self.target_out, feed_dict={ 179 | self.target_inputs: inputs, 180 | self.target_action: action, 181 | self.target_action_pos: action_pos, 182 | self.target_lenth: lenth, 183 | self.target_lenth_up: lenth_up, 184 | self.keep_prob: 1.0}) 185 | 186 | def update_target_network(self): 187 | self.sess.run(self.update_target_network_params) 188 | 189 | def assign_target_network(self): 190 | self.sess.run(self.assign_target_network_params) 191 | 192 | def assign_active_network(self): 193 | self.sess.run(self.assign_active_network_params) 194 | 195 | def get_num_trainable_vars(self): 196 | return self.num_trainable_vars 197 | 198 | def upper_LSTM_target(self, state, inputs): 199 | return self.sess.run(self.target_upper_cell_output, feed_dict={ 200 | self.target_upper_cell_state: state, 201 | self.target_upper_cell_input: inputs}) 202 | 203 | def lower_LSTM_target(self, state, inputs): 204 | return self.sess.run([self.target_lower_cell_output, self.target_lower_cell_state1], feed_dict={ 205 | self.target_lower_cell_state: state, 206 | self.target_lower_cell_input: inputs}) 207 | --------------------------------------------------------------------------------