├── agents ├── __init__.py └── Train.py ├── functionApproximation ├── __init__.py ├── FA.py └── base.py ├── envs ├── __init__.py └── env.py ├── model_train.sh ├── util.py ├── launch.py └── logger.py /agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .Train import Train 2 | -------------------------------------------------------------------------------- /functionApproximation/__init__.py: -------------------------------------------------------------------------------- 1 | from .FA import FA 2 | -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | 5 | #================================== 6 | #@file name: __init__.py 7 | #@author: Lixin Zou 8 | #@contact: zoulixin15@gmail.com 9 | #@time:2019/10/28,11:37 PM 10 | #================================== 11 | from .env import env -------------------------------------------------------------------------------- /model_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ~/pyenv3/bin/python ./launch.py -data_path ./data/data/ -environment env -T 40 -ST [5,10,20,40] -agent Train -FA FA -latent_factor 50 \ 3 | -learning_rate 0.001 -training_epoch 3000 -seed 145 -gpu_no 0 -inner_epoch 50 -rnn_layer 2 -gamma 0.8 -batch 50 -restore_model False 4 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import numpy as np 5 | import ipdb 6 | import inspect 7 | import random 8 | import tensorflow as tf 9 | import os 10 | def arg_parser(): 11 | """ 12 | Create an empty argparse.ArgumentParser. 13 | """ 14 | import argparse 15 | return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | 17 | def get_objects(name_space): 18 | res = {} 19 | for name, obj in inspect.getmembers(name_space): 20 | if inspect.isclass(obj): 21 | res[name] = obj 22 | return res 23 | 24 | def set_global_seeds(i): 25 | tf.set_random_seed(i) 26 | np.random.seed(i) 27 | random.seed(i) 28 | 29 | def softmax(x): 30 | z = x - max(x) 31 | numerator = np.exp(z) 32 | denominator = np.sum(numerator) 33 | softmax = numerator / denominator 34 | return softmax 35 | 36 | def path_join(a,b): 37 | return os.path.join(a,b) 38 | 39 | save4float = lambda x:str(round(x,4)) 40 | 41 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import numpy as np 5 | import ipdb 6 | import logger 7 | from datetime import datetime 8 | import sys 9 | from util import get_objects,set_global_seeds,arg_parser 10 | import envs as all_envs 11 | import agents as all_agents 12 | import functionApproximation as all_FA 13 | import os 14 | def str2bool(str=""): 15 | str = str.lower() 16 | if str.__contains__("yes") or str.__contains__("true") or str.__contains__("y") or str.__contains__("t"): 17 | return True 18 | else: 19 | return False 20 | 21 | def common_arg_parser(): 22 | """ 23 | Create an argparse.ArgumentParser for run.py. 24 | """ 25 | parser = arg_parser() 26 | parser.add_argument('-seed',type=int, default=123) 27 | parser.add_argument('-environment', type=str, default="Env") 28 | parser.add_argument('-data_path',type=str,default="./data/m100k") 29 | parser.add_argument('-agent',type=str,default="training methods") 30 | parser.add_argument('-FA',type=str,default="function approximation") 31 | parser.add_argument('-T', dest='T', type=int, default=3, help="time_step") 32 | parser.add_argument('-ST', dest='ST', type=eval, default="[10,30,60,120]", help="evaluation_time_step") 33 | parser.add_argument('-gpu_no', dest='gpu_no', type=str, default="0", help='which gpu for usage') 34 | parser.add_argument('-latent_factor', dest='latent_factor', type=int, default=10, help="latent factor") 35 | parser.add_argument('-learning_rate', dest='learning_rate', type=float, default=0.01, help="learning rate") 36 | parser.add_argument('-training_epoch', dest='training_epoch', type=int, default=30000, help="training epoch") 37 | parser.add_argument('-rnn_layer', dest='rnn_layer', type=int, default=1, help="rnn_layer") 38 | parser.add_argument('-inner_epoch', dest='inner_epoch', type=int, default=50, help="rnn_layer") 39 | parser.add_argument('-batch', dest='batch', type=int, default=128, help="batch_size") 40 | parser.add_argument('-gamma', dest='gamma', type=float, default=0.0, help="gamma") 41 | parser.add_argument('-clip_param', dest='clip_param', type=float, default=0.2, help="clip_param") 42 | parser.add_argument('-restore_model', dest='restore_model', type=str2bool, default="False", help="") 43 | parser.add_argument('-num_blocks', dest='num_blocks', type=int, default=2, help="") 44 | parser.add_argument('-num_heads', dest='num_heads', type=int, default=1, help="") 45 | parser.add_argument('-dropout_rate', dest='dropout_rate', type=float, default=0.1, help="") 46 | return parser 47 | 48 | def main(args): 49 | # arguments 50 | arg_parser = common_arg_parser() 51 | args, unknown_args = arg_parser.parse_known_args(args) 52 | args.model = "_".join([args.agent,args.FA,str(args.T)]) 53 | # initialization 54 | set_global_seeds(args.seed) 55 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_no) 56 | # logger 57 | logger.configure("./log"+args.data_path.split("/")[-2]+"/"+"_".join([args.model,datetime.now().strftime("%Y%m%d_%H%M%S"),args.data_path.split("/")[-2],str(args.learning_rate),str(args.T),str(args.ST),str(args.gamma)])) 58 | logger.log("Training Model: "+args.model) 59 | # environments 60 | envs = get_objects(all_envs) 61 | env = envs[args.environment](args) 62 | # ipdb.set_trace() 63 | # policy network 64 | args.user_num = env.user_num 65 | args.item_num = env.item_num 66 | args.utype_num = env.utype_num 67 | # ipdb.set_trace() 68 | args.saved_path = os.path.join(os.path.abspath("./"),"saved_path_"+args.data_path.split("/")[-2]+"_"+str(args.FA)+"_"+str(args.learning_rate)+"_"+str(args.agent)+"_"+str(args.seed)) 69 | 70 | 71 | nets = get_objects(all_FA) 72 | fa = nets[args.FA].create_model_without_distributed(args) 73 | 74 | logger.log("Hype-Parameters: "+str(args)) 75 | # # agents 76 | agents = get_objects(all_agents) 77 | agents[args.agent](env,fa,args).train() 78 | 79 | 80 | 81 | 82 | if __name__ == '__main__': 83 | main(sys.argv) -------------------------------------------------------------------------------- /functionApproximation/FA.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import numpy as np 5 | import ipdb 6 | from .base import basic_model 7 | import tensorflow as tf 8 | from .base import * 9 | 10 | class FA(basic_model): 11 | def _create_placeholders(self): 12 | self.utype = tf.placeholder(tf.int32, (None,), name='uid') 13 | self.p_rec = [tf.placeholder(tf.int32,(None,None,), name="p"+str(i)+"_rec") for i in range(6)] 14 | self.pt = [tf.placeholder(tf.int32,(None,2),"p"+str(i)+"t") for i in range(6)] 15 | self.rec = tf.placeholder(tf.int32, (None,), name='iid') 16 | self.target = tf.placeholder(tf.float32,(None,),name="target") 17 | 18 | def _update_placehoders(self): 19 | self.placeholders["all"] = {"uid":self.utype, 20 | "iid":self.rec, 21 | "goal":self.target} 22 | for i in range(6): 23 | self.placeholders["all"]["p"+str(i)+"_rec"] = self.p_rec[i] 24 | self.placeholders["all"]["p"+str(i)+"t"] = self.pt[i] 25 | self.placeholders["predict"] = {item: self.placeholders["all"][item] for item in ["uid"] + ["p"+str(i)+"_rec" for i in range(6)] + ["p"+str(i)+"t" for i in range(6)]} 26 | self.placeholders["optimize"] = self.placeholders["all"] 27 | 28 | def _create_inference(self): 29 | p_f = [tf.Variable(np.random.uniform(-0.01, 0.01,(self.args.item_num,self.args.latent_factor)), 30 | dtype=tf.float32, trainable=True, name='item'+str(i)+'_feature') for i in range(6)] 31 | u_f = tf.Variable(np.random.uniform(-0.01, 0.01,(self.args.utype_num,self.args.latent_factor)), 32 | dtype=tf.float32, trainable=True, name='user_feature') 33 | u_emb = tf.nn.embedding_lookup(u_f, self.utype) 34 | self.p_rec = [tf.transpose(item,[1,0]) for item in self.p_rec] 35 | i_p_mask = [tf.expand_dims(tf.to_float(tf.not_equal(item, 0)), -1) for item in self.p_rec] 36 | 37 | 38 | self.p_seq = [tf.nn.embedding_lookup(p_f[i],self.p_rec[i]) for i in range(6)] 39 | for iii,item in enumerate(self.p_seq): 40 | for i in range(self.args.num_blocks): 41 | with tf.variable_scope("rate_"+str(iii)+"_num_blocks_"+str(i)): 42 | item = multihead_attention(queries=normalize(item), 43 | keys=item, 44 | num_units=self.args.latent_factor, 45 | num_heads=self.args.num_heads, 46 | dropout_rate=self.args.dropout_rate, 47 | is_training=True, 48 | causality=True, 49 | scope="self_attention_pos_"+str(i)) 50 | 51 | item = feedforward(normalize(item), num_units=[self.args.latent_factor, self.args.latent_factor], 52 | dropout_rate=self.args.dropout_rate, is_training=True,scope="feed_forward_pos_"+str(i)) 53 | item *= i_p_mask[iii] 54 | self.p_seq = [normalize(item) for item in self.p_seq] 55 | 56 | p_out = [tf.gather_nd(tf.transpose(self.p_seq[i],[1,0,2]), self.pt[i]) for i in range(6)] 57 | context = tf.concat(p_out,1) 58 | hidden = tf.layers.dense(context,self.args.latent_factor,activation=tf.nn.relu) 59 | self.pi = tf.layers.dense(hidden, self.args.item_num, trainable=True) 60 | 61 | def _build_actor(self,context,name,trainable): 62 | with tf.variable_scope(name): 63 | a_prob = tf.layers.dense(context, self.args.item_num, trainable=trainable) 64 | params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name) 65 | return a_prob, params 66 | 67 | def _create_optimizer(self): 68 | a_indices = tf.stack([tf.range(tf.shape(self.rec)[0], dtype=tf.int32), self.rec], axis=1) 69 | self.npi = tf.gather_nd(params=self.pi, indices=a_indices) 70 | self.loss = tf.losses.mean_squared_error(self.npi,self.target) 71 | self.optimizer = tf.train.AdamOptimizer(self.args.learning_rate).minimize(self.loss) 72 | 73 | def optimize_model(self,sess,data): 74 | feed_dicts = self._get_feed_dict("optimize",data) 75 | return sess.run([self.loss,self.npi,self.optimizer],feed_dicts)[:2] 76 | 77 | def predict(self,sess,data): 78 | feed_dicts = self._get_feed_dict("predict", data) 79 | return sess.run(self.pi, feed_dicts) 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /envs/env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import numpy as np 5 | import ipdb 6 | import numpy as np 7 | import ipdb 8 | import logger 9 | from sklearn.model_selection import train_test_split 10 | from collections import OrderedDict 11 | from util import * 12 | import math 13 | from collections import Counter 14 | import copy as cp 15 | 16 | # TODO: 17 | class env(object): 18 | def __init__(self, args): 19 | logger.log("initialize environment") 20 | self.T = args.T 21 | self.rates = {} 22 | self.items = {} 23 | self.users = {} 24 | self.utypes = {} 25 | self.utype_kind = {} 26 | self.ideal_list = {} 27 | self.args = args 28 | with open(path_join(self.args.data_path, "env.dat"), "r") as f: 29 | for line in f: 30 | line = line.strip("\n").split("\t") 31 | iid = list(map(lambda x:x.split(":"),line[1:])) 32 | self.rates[int(line[0])] = {int(i[0]):int(i[1]) for i in iid} 33 | for i in iid: self.items[int(i[0])]=int(i[1]) 34 | logger.log("user number: " + str(len(self.rates) + 1)) 35 | logger.log("item number: " + str(len(self.items) + 1)) 36 | logger.log("user type" 37 | " number: " + str(len(self.utype_kind) + 1)) 38 | self.setup_train_test() 39 | 40 | @property 41 | def user_num(self): 42 | return len(self.rates) + 1 43 | 44 | @property 45 | def item_num(self): 46 | return len(self.items) + 1 47 | 48 | @property 49 | def utype_num(self): 50 | return len(self.utypes) + 1 51 | 52 | def setup_train_test(self): 53 | users = list(range(1, self.user_num)) 54 | np.random.shuffle(users) 55 | self.training, self.validation, self.evaluation = np.split(np.asarray(users), [int(.85 * self.user_num - 1), 56 | int(.9 * self.user_num - 1)]) 57 | 58 | def reset(self): 59 | self.reset_with_users(np.random.choice(self.training)) 60 | 61 | def reset_with_users(self, uid): 62 | self.state = [(uid,1), []] 63 | self.short = {} 64 | return self.state 65 | 66 | def step(self, action): 67 | if action in self.rates[self.state[0][0]] and (not action in self.short): 68 | rate = self.rates[self.state[0][0]][action] 69 | if rate>=4: 70 | reward = 1 71 | else: 72 | reward = 0 73 | else: 74 | rate = 0 75 | reward = 0 76 | 77 | if len(self.state[1]) < self.T - 1: 78 | done = False 79 | else: 80 | done = True 81 | self.short[action] = 1 82 | t = self.state[1] + [[action, reward, done]] 83 | info = {"precision": self.precision(t), 84 | "recall": self.recall(t, self.state[0][0]), 85 | "rate":rate} 86 | self.state[1].append([action, reward, done, info]) 87 | return self.state, reward, done, info 88 | 89 | def step_policy(self,policy): 90 | policy = policy[:self.args.T] 91 | rewards = [] 92 | for action in policy: 93 | if action in self.rates[self.state[0][0]]: 94 | rewards.append(self.rates[self.state[0][0]][action]) 95 | else: 96 | rewards.append(0) 97 | t = [[a,rewards[i],False] for i,a in enumerate(policy)] 98 | info = {"precision": self.precision(t), 99 | "recall": self.recall(t, self.state[0][0])} 100 | self.state[1].extend(t) 101 | return self.state,rewards,True,info 102 | 103 | 104 | def ndcg(self, episode, uid): 105 | if len(self.rates[uid]) > len(episode): 106 | return self.dcg_at_k(list(map(lambda x: x[1], episode)), 107 | len(episode), 108 | method=1) / self.dcg_at_k(sorted(list(self.rates[uid].values()),reverse=True), 109 | len(episode), 110 | method=1) 111 | else: 112 | return self.dcg_at_k(list(map(lambda x: x[1], episode)), 113 | len(episode), 114 | method=1) / self.dcg_at_k( 115 | list(self.rates[uid].values()) + [0] * (len(episode) - len(self.rates[uid])), 116 | len(episode), method=1) 117 | 118 | def dcg_at_k(self, r, k, method=1): 119 | r = np.asfarray(r)[:k] 120 | if r.size: 121 | if method == 0: 122 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 123 | elif method == 1: 124 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 125 | else: 126 | raise ValueError('method must be 0 or 1.') 127 | 128 | def alpha_dcg(self, item_list, k=10, alpha=0.5, *args): 129 | items = [] 130 | G = [] 131 | for i, item in enumerate(item_list[:k]): 132 | items += item 133 | G.append(sum(map(lambda x: math.pow(alpha, x - 1), dict(Counter(items)).values())) / math.log(i + 2, 2)) 134 | return sum(G) 135 | 136 | def precision(self, episode): 137 | return sum([i[1] for i in episode]) 138 | 139 | def recall(self, episode, uid): 140 | return sum([i[1] for i in episode]) / len(self.rates[uid]) 141 | -------------------------------------------------------------------------------- /agents/Train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import numpy as np 5 | import ipdb 6 | import copy as cp 7 | import logger 8 | from util import * 9 | import time 10 | from collections import Counter 11 | import copy as cp 12 | 13 | MEMORYSIZE = 50000 14 | BATCHSIZE = 128 15 | THRESHOLD = 300 16 | 17 | start = 0 18 | end = 3000 19 | 20 | def decay_function1(x): 21 | x = 50+x 22 | return max(2.0/(1+np.power(x,0.2)),0.001) 23 | 24 | START = decay_function1(start) 25 | END = decay_function1(end) 26 | 27 | def decay_function(x): 28 | x = max(min(end,x),start) 29 | return (decay_function1(x)-END)/(START-END+0.0000001) 30 | 31 | 32 | class Train(object): 33 | def __init__(self,env,fa,args): 34 | self.env = env 35 | self.fa = fa 36 | self.args = args 37 | self.tau = 0 38 | self.memory = [] 39 | 40 | def train(self): 41 | for epoch in range(self.args.training_epoch): 42 | logger.log(epoch) 43 | self.collecting_data_update_model("training", epoch) 44 | if epoch % 100 == 0 and epoch>=300: 45 | self.collecting_data_update_model("validation", epoch) 46 | self.collecting_data_update_model("evaluation", epoch) 47 | 48 | def collecting_data_update_model(self, type="training", epoch=0): 49 | if type=="training": 50 | selected_users = np.random.choice(self.env.training,(self.args.inner_epoch,)) 51 | elif type=="validation": 52 | selected_users = self.env.validation 53 | elif type=="evaluation": 54 | selected_users = self.env.evaluation 55 | elif type=="verified": 56 | selected_users = self.env.training 57 | else: 58 | selected_users = range(1,3) 59 | infos = {item:[] for item in self.args.ST} 60 | used_actions = [] 61 | for uuid in selected_users: 62 | actions = {} 63 | rwds = 0 64 | done = False 65 | state = self.env.reset_with_users(uuid) 66 | while not done: 67 | data = {"uid": [state[0][1]]} 68 | for i in range(6): 69 | p_r,pnt = self.convert_item_seq2matrix([[0]+[item[0] for item in state[1] if item[3]["rate"] == i]]) 70 | data["p"+str(i)+"_rec"] = p_r 71 | data["p"+str(i)+"t"] = pnt 72 | policy = self.fa["model"].predict(self.fa["sess"],data)[0] 73 | if type == "training": 74 | if np.random.random()<5*THRESHOLD/(THRESHOLD+self.tau): policy = np.random.uniform(0,1,(self.args.item_num,)) 75 | for item in actions: policy[item] = -np.inf 76 | action = np.argmax(policy[1:]) + 1 77 | else: 78 | for item in actions: policy[item] = -np.inf 79 | action = np.argmax(policy[1:]) + 1 80 | s_pre = cp.deepcopy(state) 81 | state_next, rwd, done, info = self.env.step(action) 82 | if type == "training": 83 | self.memory.append([s_pre,action,rwd,done,cp.deepcopy(state_next)]) 84 | actions[action] = 1 85 | rwds += rwd 86 | state = state_next 87 | if len(state[1]) in self.args.ST: 88 | infos[len(state[1])].append(info) 89 | used_actions.extend(list(actions.keys())) 90 | if type == "training": 91 | if len(self.memory)>=BATCHSIZE: 92 | self.memory = self.memory[-MEMORYSIZE:] 93 | batch = [self.memory[item] for item in np.random.choice(range(len(self.memory)),(BATCHSIZE,))] 94 | data = self.convert_batch2dict(batch,epoch) 95 | loss,_ = self.fa["model"].optimize_model(self.fa["sess"], data) 96 | logger.record_tabular("loss ", "|".join([str(round(loss,4))])) 97 | self.tau += 5 98 | for item in self.args.ST: 99 | logger.record_tabular(str(item)+"precision",round(np.mean([i["precision"] for i in infos[item]]),4)) 100 | logger.record_tabular(str(item)+"recall",round(np.mean([i["recall"] for i in infos[item]]),4)) 101 | logger.log(str(item)+" precision: ",round(np.mean([i["precision"] for i in infos[item]]),4)) 102 | logger.record_tabular("epoch", epoch) 103 | logger.record_tabular("type", type) 104 | logger.dump_tabular() 105 | 106 | def convert_batch2dict(self,batch,epoch): 107 | uids = [] 108 | pos_recs = {i:[] for i in range(6)} 109 | next_pos = {i:[] for i in range(6)} 110 | iids = [] 111 | goals = [] 112 | dones = [] 113 | for item in batch: 114 | uids.append(item[0][0][1]) 115 | ep = item[0][1] 116 | for xxx in range(6): 117 | pos_recs[xxx].append([0] + [j[0] for j in ep if j[3]["rate"]==xxx]) 118 | iids.append(item[1]) 119 | goals.append(item[2]) 120 | if item[3]:dones.append(0.0) 121 | else:dones.append(1.0) 122 | ep = item[4][1] 123 | for xxx in range(6): 124 | next_pos[xxx].append([0] + [j[0] for j in ep if j[3]["rate"] == xxx]) 125 | data = {"uid":uids} 126 | for xxx in range(6): 127 | p_r, pnt = self.convert_item_seq2matrix(next_pos[xxx]) 128 | data["p" + str(xxx) + "_rec"] = p_r 129 | data["p" + str(xxx) + "t"] = pnt 130 | value = self.fa["model"].predict(self.fa["sess"], data) 131 | value[:,0] = -500 132 | goals = np.max(value,axis=-1)*np.asarray(dones)*min(self.args.gamma,decay_function(max(end-epoch,0)+1)) + goals 133 | data = {"uid":uids,"iid":iids,"goal":goals} 134 | for i in range(6): 135 | p_r, pnt = self.convert_item_seq2matrix(pos_recs[i]) 136 | data["p" + str(i) + "_rec"] = p_r 137 | data["p" + str(i) + "t"] = pnt 138 | return data 139 | 140 | def convert_item_seq2matrix(self, item_seq): 141 | max_length = max([len(item) for item in item_seq]) 142 | matrix = np.zeros((max_length, len(item_seq)),dtype=np.int32) 143 | for x, xx in enumerate(item_seq): 144 | for y, yy in enumerate(xx): 145 | matrix[y, x] = yy 146 | target_index = list(zip([len(i) - 1 for i in item_seq], range(len(item_seq)))) 147 | return matrix, target_index 148 | -------------------------------------------------------------------------------- /functionApproximation/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | 5 | import numpy as np 6 | import ipdb 7 | import tensorflow as tf 8 | import logger 9 | 10 | class basic_model(object): 11 | GRAPHS = {} 12 | SESS = {} 13 | SAVER = {} 14 | 15 | def c_opt(self,learning_rate,name): 16 | if str(name).__contains__("adam"): 17 | print("adam") 18 | optimizer = tf.train.AdamOptimizer(learning_rate) 19 | elif str(name).__contains__("adagrad"): 20 | print("adagrad") 21 | optimizer = tf.train.AdagradOptimizer(learning_rate) 22 | elif str(name).__contains__("sgd"): 23 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 24 | elif str(name).__contains__("rms"): 25 | optimizer = tf.train.RMSPropOptimizer(learning_rate) 26 | elif str(name).__contains__("moment"): 27 | optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.1) 28 | return optimizer 29 | 30 | @classmethod 31 | def create_model(cls, config, variable_scope = "target", trainable = True, graph_name="DEFAULT",task_index=0): 32 | jobs = config.jobs 33 | job = list(jobs.keys())[0] 34 | logger.info("CREATE MODEL", config.model, "GRAPH", graph_name, "VARIABLE SCOPE", variable_scope,"jobs",jobs,"job",job,"task_index",task_index) 35 | cls.CLUSTER = tf.train.ClusterSpec(jobs) 36 | cls.SERVER = tf.train.Server(cls.CLUSTER, job_name=job, task_index=task_index,config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) 37 | if not graph_name in cls.GRAPHS: 38 | logger.info("Adding a new tensorflow graph:",graph_name) 39 | cls.GRAPHS[graph_name] = tf.Graph() 40 | with cls.GRAPHS[graph_name].as_default(): 41 | model = cls(config, variable_scope=variable_scope, trainable=trainable) 42 | if not graph_name in cls.SESS: 43 | cls.SESS[graph_name] = tf.Session(cls.SERVER.target) 44 | cls.SAVER[graph_name] = tf.train.Saver(max_to_keep=50) 45 | cls.SESS[graph_name].run(model.init) 46 | return {"graph": cls.GRAPHS[graph_name], 47 | "sess": cls.SESS[graph_name], 48 | "saver": cls.SAVER[graph_name], 49 | "model": model,"cluster":cls.CLUSTER,"server":cls.SERVER} 50 | 51 | @classmethod 52 | def create_model_without_distributed(cls, config, variable_scope = "target", trainable = True, graph_name="DEFAULT"): 53 | logger.info("CREATE MODEL", config.model, "GRAPH", graph_name, "VARIABLE SCOPE", variable_scope) 54 | if not graph_name in cls.GRAPHS: 55 | logger.info("Adding a new tensorflow graph:",graph_name) 56 | cls.GRAPHS[graph_name] = tf.Graph() 57 | with cls.GRAPHS[graph_name].as_default(): 58 | model = cls(config, variable_scope=variable_scope, trainable=trainable) 59 | if not graph_name in cls.SESS: 60 | cls.SESS[graph_name] = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))) 61 | cls.SAVER[graph_name] = tf.train.Saver(max_to_keep=50) 62 | cls.SESS[graph_name].run(model.init) 63 | return {"graph": cls.GRAPHS[graph_name], 64 | "sess": cls.SESS[graph_name], 65 | "saver": cls.SAVER[graph_name], 66 | "model": model} 67 | 68 | def _update_placehoders(self): 69 | self.placeholders = {"none":{}} 70 | raise NotImplemented 71 | 72 | def _get_feed_dict(self,task,data_dicts): 73 | place_holders = self.placeholders[task] 74 | res = {} 75 | for key, value in place_holders.items(): 76 | res[value] = data_dicts[key] 77 | return res 78 | 79 | def __init__(self, args, variable_scope = "target", trainable = True): 80 | print(self.__class__) 81 | self.args = args 82 | self.variable_scope = variable_scope 83 | self.trainable = trainable 84 | self.placeholders = {} 85 | self._build_model() 86 | 87 | def _build_model(self): 88 | with tf.variable_scope(self.variable_scope): 89 | self._create_placeholders() 90 | self._create_global_step() 91 | self._update_placehoders() 92 | self._create_inference() 93 | if self.trainable: 94 | self._create_optimizer() 95 | self._create_intializer() 96 | 97 | def _create_global_step(self): 98 | self.global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step') 99 | 100 | def _create_intializer(self): 101 | with tf.name_scope("initlializer"): 102 | self.init = tf.global_variables_initializer() 103 | 104 | def _create_placeholders(self): 105 | raise NotImplementedError 106 | 107 | def _create_inference(self): 108 | raise NotImplementedError 109 | 110 | def _create_optimizer(self): 111 | raise NotImplementedError 112 | 113 | def chose_action(self, state, sess): 114 | raise NotImplementedError 115 | pass 116 | 117 | def build_cell(self,rnn_type,initializer,hidden,input_data,initial_state): 118 | cell = tf.contrib.rnn.MultiRNNCell([tf.nn.rnn_cell.GRUCell(hidden, 119 | kernel_initializer=initializer, 120 | bias_initializer=tf.zeros_initializer)] * self.args.rnn_layer) 121 | return tf.nn.dynamic_rnn(cell, input_data, initial_state=initial_state, 122 | dtype=tf.float32, 123 | time_major=True) 124 | 125 | 126 | def positional_encoding(dim, sentence_length, dtype=tf.float32): 127 | encoded_vec = np.array([pos / np.power(10000, 2 * i / dim) for pos in range(sentence_length) for i in range(dim)]) 128 | encoded_vec[::2] = np.sin(encoded_vec[::2]) 129 | encoded_vec[1::2] = np.cos(encoded_vec[1::2]) 130 | 131 | return tf.convert_to_tensor(encoded_vec.reshape([sentence_length, dim]), dtype=dtype) 132 | 133 | 134 | def normalize(inputs, 135 | epsilon=1e-8, 136 | scope="ln", 137 | reuse=None): 138 | '''Applies layer normalization. 139 | 140 | Args: 141 | inputs: A tensor with 2 or more dimensions, where the first dimension has 142 | `batch_size`. 143 | epsilon: A floating number. A very small number for preventing ZeroDivision Error. 144 | scope: Optional scope for `variable_scope`. 145 | reuse: Boolean, whether to reuse the weights of a previous layer 146 | by the same name. 147 | 148 | Returns: 149 | A tensor with the same shape and data dtype as `inputs`. 150 | ''' 151 | with tf.variable_scope(scope, reuse=reuse): 152 | inputs_shape = inputs.get_shape() 153 | params_shape = inputs_shape[-1:] 154 | 155 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True) 156 | beta = tf.Variable(tf.zeros(params_shape)) 157 | gamma = tf.Variable(tf.ones(params_shape)) 158 | normalized = (inputs - mean) / ((variance + epsilon) ** (.5)) 159 | outputs = gamma * normalized + beta 160 | 161 | return outputs 162 | 163 | 164 | def embedding(inputs, 165 | vocab_size, 166 | num_units, 167 | zero_pad=True, 168 | scale=True, 169 | l2_reg=0.0, 170 | scope="embedding", 171 | with_t=False, 172 | reuse=None): 173 | '''Embeds a given tensor. 174 | 175 | Args: 176 | inputs: A `Tensor` with type `int32` or `int64` containing the ids 177 | to be looked up in `lookup table`. 178 | vocab_size: An int. Vocabulary size. 179 | num_units: An int. Number of embedding hidden units. 180 | zero_pad: A boolean. If True, all the values of the fist row (id 0) 181 | should be constant zeros. 182 | scale: A boolean. If True. the outputs is multiplied by sqrt num_units. 183 | scope: Optional scope for `variable_scope`. 184 | reuse: Boolean, whether to reuse the weights of a previous layer 185 | by the same name. 186 | 187 | Returns: 188 | A `Tensor` with one more rank than inputs's. The last dimensionality 189 | should be `num_units`. 190 | 191 | For example, 192 | 193 | ``` 194 | import tensorflow as tf 195 | 196 | inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3))) 197 | outputs = embedding(inputs, 6, 2, zero_pad=True) 198 | with tf.Session() as sess: 199 | sess.run(tf.global_variables_initializer()) 200 | print sess.run(outputs) 201 | >> 202 | [[[ 0. 0. ] 203 | [ 0.09754146 0.67385566] 204 | [ 0.37864095 -0.35689294]] 205 | 206 | [[-1.01329422 -1.09939694] 207 | [ 0.7521342 0.38203377] 208 | [-0.04973143 -0.06210355]]] 209 | ``` 210 | 211 | ``` 212 | import tensorflow as tf 213 | 214 | inputs = tf.to_int32(tf.reshape(tf.range(2*3), (2, 3))) 215 | outputs = embedding(inputs, 6, 2, zero_pad=False) 216 | with tf.Session() as sess: 217 | sess.run(tf.global_variables_initializer()) 218 | print sess.run(outputs) 219 | >> 220 | [[[-0.19172323 -0.39159766] 221 | [-0.43212751 -0.66207761] 222 | [ 1.03452027 -0.26704335]] 223 | 224 | [[-0.11634696 -0.35983452] 225 | [ 0.50208133 0.53509563] 226 | [ 1.22204471 -0.96587461]]] 227 | ``` 228 | ''' 229 | with tf.variable_scope(scope, reuse=reuse): 230 | lookup_table = tf.get_variable('lookup_table', 231 | dtype=tf.float32, 232 | shape=[vocab_size, num_units], 233 | # initializer=tf.contrib.layers.xavier_initializer(), 234 | regularizer=tf.contrib.layers.l2_regularizer(l2_reg)) 235 | if zero_pad: 236 | lookup_table = tf.concat((tf.zeros(shape=[1, num_units]), 237 | lookup_table[1:, :]), 0) 238 | outputs = tf.nn.embedding_lookup(lookup_table, inputs) 239 | 240 | if scale: 241 | outputs = outputs * (num_units ** 0.5) 242 | if with_t: 243 | return outputs, lookup_table 244 | else: 245 | return outputs 246 | 247 | 248 | def multihead_attention(queries, 249 | keys, 250 | num_units=None, 251 | num_heads=8, 252 | dropout_rate=0, 253 | is_training=True, 254 | causality=False, 255 | scope="multihead_attention", 256 | reuse=None, 257 | with_qk=False): 258 | '''Applies multihead attention. 259 | 260 | Args: 261 | queries: A 3d tensor with shape of [N, T_q, C_q]. 262 | keys: A 3d tensor with shape of [N, T_k, C_k]. 263 | num_units: A scalar. Attention size. 264 | dropout_rate: A floating point number. 265 | is_training: Boolean. Controller of mechanism for dropout. 266 | causality: Boolean. If true, units that reference the future are masked. 267 | num_heads: An int. Number of heads. 268 | scope: Optional scope for `variable_scope`. 269 | reuse: Boolean, whether to reuse the weights of a previous layer 270 | by the same name. 271 | 272 | Returns 273 | A 3d tensor with shape of (N, T_q, C) 274 | ''' 275 | with tf.variable_scope(scope, reuse=reuse): 276 | # Set the fall back option for num_units 277 | if num_units is None: 278 | num_units = queries.get_shape().as_list[-1] 279 | 280 | # Linear projections 281 | # Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu) # (N, T_q, C) 282 | # K = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C) 283 | # V = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C) 284 | Q = tf.layers.dense(queries, num_units, activation=None) # (N, T_q, C) 285 | K = tf.layers.dense(keys, num_units, activation=None) # (N, T_k, C) 286 | V = tf.layers.dense(keys, num_units, activation=None) # (N, T_k, C) 287 | 288 | # Split and concat 289 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, C/h) 290 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) 291 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) 292 | 293 | # Multiplication 294 | outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k) 295 | 296 | # Scale 297 | outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5) 298 | 299 | # Key Masking 300 | key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k) 301 | key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k) 302 | key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k) 303 | 304 | paddings = tf.ones_like(outputs) * (-2 ** 32 + 1) 305 | outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k) 306 | 307 | # Causality = Future blinding 308 | if causality: 309 | diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k) 310 | tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k) 311 | masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k) 312 | 313 | paddings = tf.ones_like(masks) * (-2 ** 32 + 1) 314 | outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k) 315 | 316 | # Activation 317 | outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k) 318 | 319 | # Query Masking 320 | query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q) 321 | query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q) 322 | query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k) 323 | outputs *= query_masks # broadcasting. (N, T_q, C) 324 | 325 | # Dropouts 326 | outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training)) 327 | 328 | # Weighted sum 329 | outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h) 330 | 331 | # Restore shape 332 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) # (N, T_q, C) 333 | 334 | # Residual connection 335 | outputs += queries 336 | 337 | # Normalize 338 | # outputs = normalize(outputs) # (N, T_q, C) 339 | 340 | if with_qk: 341 | return Q, K 342 | else: 343 | return outputs 344 | 345 | 346 | def feedforward(inputs, 347 | num_units=[2048, 512], 348 | scope="multihead_attention", 349 | dropout_rate=0.2, 350 | is_training=True, 351 | reuse=None): 352 | '''Point-wise feed forward net. 353 | 354 | Args: 355 | inputs: A 3d tensor with shape of [N, T, C]. 356 | num_units: A list of two integers. 357 | scope: Optional scope for `variable_scope`. 358 | reuse: Boolean, whether to reuse the weights of a previous layer 359 | by the same name. 360 | 361 | Returns: 362 | A 3d tensor with the same shape and dtype as inputs 363 | ''' 364 | with tf.variable_scope(scope, reuse=reuse): 365 | # Inner layer 366 | params = {"inputs": inputs, "filters": num_units[0], "kernel_size": 1, 367 | "activation": tf.nn.relu, "use_bias": True} 368 | outputs = tf.layers.conv1d(**params) 369 | outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training)) 370 | # Readout layer 371 | params = {"inputs": outputs, "filters": num_units[1], "kernel_size": 1, 372 | "activation": None, "use_bias": True} 373 | outputs = tf.layers.conv1d(**params) 374 | outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training)) 375 | 376 | # Residual connection 377 | outputs += inputs 378 | 379 | # Normalize 380 | # outputs = normalize(outputs) 381 | 382 | return outputs 383 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import os.path as osp 5 | import json 6 | import time 7 | import datetime 8 | import tempfile 9 | from collections import defaultdict 10 | from contextlib import contextmanager 11 | 12 | DEBUG = 10 13 | INFO = 20 14 | WARN = 30 15 | ERROR = 40 16 | 17 | DISABLED = 50 18 | 19 | class KVWriter(object): 20 | def writekvs(self, kvs): 21 | raise NotImplementedError 22 | 23 | class SeqWriter(object): 24 | def writeseq(self, seq): 25 | raise NotImplementedError 26 | 27 | class HumanOutputFormat(KVWriter, SeqWriter): 28 | def __init__(self, filename_or_file): 29 | if isinstance(filename_or_file, str): 30 | self.file = open(filename_or_file, 'wt') 31 | self.own_file = True 32 | else: 33 | assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s'%filename_or_file 34 | self.file = filename_or_file 35 | self.own_file = False 36 | 37 | def writekvs(self, kvs): 38 | # Create strings for printing 39 | key2str = {} 40 | for (key, val) in sorted(kvs.items()): 41 | if hasattr(val, '__float__'): 42 | valstr = '%-8.3g' % val 43 | else: 44 | valstr = str(val) 45 | key2str[self._truncate(key)] = self._truncate(valstr) 46 | 47 | # Find max widths 48 | if len(key2str) == 0: 49 | print('WARNING: tried to write empty key-value dict') 50 | return 51 | else: 52 | keywidth = max(map(len, key2str.keys())) 53 | valwidth = max(map(len, key2str.values())) 54 | 55 | # Write out the data 56 | dashes = '-' * (keywidth + valwidth + 7) 57 | lines = [dashes] 58 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 59 | lines.append('| %s%s | %s%s |' % ( 60 | key, 61 | ' ' * (keywidth - len(key)), 62 | val, 63 | ' ' * (valwidth - len(val)), 64 | )) 65 | lines.append(dashes) 66 | self.file.write('\n'.join(lines) + '\n') 67 | 68 | # Flush the output to the file 69 | self.file.flush() 70 | 71 | def _truncate(self, s): 72 | maxlen = 30 73 | return s[:maxlen-3] + '...' if len(s) > maxlen else s 74 | 75 | def writeseq(self, seq): 76 | seq = list(seq) 77 | for (i, elem) in enumerate(seq): 78 | self.file.write(elem) 79 | if i < len(seq) - 1: # add space unless this is the last one 80 | self.file.write(' ') 81 | self.file.write('\n') 82 | self.file.flush() 83 | 84 | def close(self): 85 | if self.own_file: 86 | self.file.close() 87 | 88 | class JSONOutputFormat(KVWriter): 89 | def __init__(self, filename): 90 | self.file = open(filename, 'wt') 91 | 92 | def writekvs(self, kvs): 93 | for k, v in sorted(kvs.items()): 94 | if hasattr(v, 'dtype'): 95 | kvs[k] = float(v) 96 | self.file.write(json.dumps(kvs) + '\n') 97 | self.file.flush() 98 | 99 | def close(self): 100 | self.file.close() 101 | 102 | class CSVOutputFormat(KVWriter): 103 | def __init__(self, filename): 104 | self.file = open(filename, 'w+t') 105 | self.keys = [] 106 | self.sep = ',' 107 | 108 | def writekvs(self, kvs): 109 | # Add our current row to the history 110 | extra_keys = list(kvs.keys() - self.keys) 111 | extra_keys.sort() 112 | if extra_keys: 113 | self.keys.extend(extra_keys) 114 | self.file.seek(0) 115 | lines = self.file.readlines() 116 | self.file.seek(0) 117 | for (i, k) in enumerate(self.keys): 118 | if i > 0: 119 | self.file.write(',') 120 | self.file.write(k) 121 | self.file.write('\n') 122 | for line in lines[1:]: 123 | self.file.write(line[:-1]) 124 | self.file.write(self.sep * len(extra_keys)) 125 | self.file.write('\n') 126 | for (i, k) in enumerate(self.keys): 127 | if i > 0: 128 | self.file.write(',') 129 | v = kvs.get(k) 130 | if v is not None: 131 | self.file.write(str(v)) 132 | self.file.write('\n') 133 | self.file.flush() 134 | 135 | def close(self): 136 | self.file.close() 137 | 138 | 139 | class TensorBoardOutputFormat(KVWriter): 140 | """ 141 | Dumps key/value pairs into TensorBoard's numeric format. 142 | """ 143 | def __init__(self, dir): 144 | os.makedirs(dir, exist_ok=True) 145 | self.dir = dir 146 | self.step = 1 147 | prefix = 'events' 148 | path = osp.join(osp.abspath(dir), prefix) 149 | import tensorflow as tf 150 | from tensorflow.python import pywrap_tensorflow 151 | from tensorflow.core.util import event_pb2 152 | from tensorflow.python.util import compat 153 | self.tf = tf 154 | self.event_pb2 = event_pb2 155 | self.pywrap_tensorflow = pywrap_tensorflow 156 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 157 | 158 | def writekvs(self, kvs): 159 | def summary_val(k, v): 160 | kwargs = {'tag': k, 'simple_value': float(v)} 161 | return self.tf.Summary.Value(**kwargs) 162 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 163 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 164 | event.step = self.step # is there any reason why you'd want to specify the step? 165 | self.writer.WriteEvent(event) 166 | self.writer.Flush() 167 | self.step += 1 168 | 169 | def close(self): 170 | if self.writer: 171 | self.writer.Close() 172 | self.writer = None 173 | 174 | def make_output_format(format, ev_dir, log_suffix=''): 175 | os.makedirs(ev_dir, exist_ok=True) 176 | if format == 'stdout': 177 | return HumanOutputFormat(sys.stdout) 178 | elif format == 'log': 179 | return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix)) 180 | elif format == 'json': 181 | return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix)) 182 | elif format == 'csv': 183 | return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix)) 184 | elif format == 'tensorboard': 185 | return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix)) 186 | else: 187 | raise ValueError('Unknown format specified: %s' % (format,)) 188 | 189 | # ================================================================ 190 | # API 191 | # ================================================================ 192 | 193 | def logkv(key, val): 194 | """ 195 | Log a value of some diagnostic 196 | Call this once for each diagnostic quantity, each iteration 197 | If called many times, last value will be used. 198 | """ 199 | get_current().logkv(key, val) 200 | 201 | def logkv_mean(key, val): 202 | """ 203 | The same as logkv(), but if called many times, values averaged. 204 | """ 205 | get_current().logkv_mean(key, val) 206 | 207 | def logkvs(d): 208 | """ 209 | Log a dictionary of key-value pairs 210 | """ 211 | for (k, v) in d.items(): 212 | logkv(k, v) 213 | 214 | def dumpkvs(): 215 | """ 216 | Write all of the diagnostics from the current iteration 217 | """ 218 | return get_current().dumpkvs() 219 | 220 | def getkvs(): 221 | return get_current().name2val 222 | 223 | 224 | def log(*args, level=INFO): 225 | """ 226 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 227 | """ 228 | get_current().log(*args, level=level) 229 | 230 | def debug(*args): 231 | log(*args, level=DEBUG) 232 | 233 | def info(*args): 234 | log(*args, level=INFO) 235 | 236 | def warn(*args): 237 | log(*args, level=WARN) 238 | 239 | def error(*args): 240 | log(*args, level=ERROR) 241 | 242 | 243 | def set_level(level): 244 | """ 245 | Set logging threshold on current logger. 246 | """ 247 | get_current().set_level(level) 248 | 249 | def set_comm(comm): 250 | get_current().set_comm(comm) 251 | 252 | def get_dir(): 253 | """ 254 | Get directory that log files are being written to. 255 | will be None if there is no output directory (i.e., if you didn't call start) 256 | """ 257 | return get_current().get_dir() 258 | 259 | record_tabular = logkv 260 | dump_tabular = dumpkvs 261 | 262 | @contextmanager 263 | def profile_kv(scopename): 264 | logkey = 'wait_' + scopename 265 | tstart = time.time() 266 | try: 267 | yield 268 | finally: 269 | get_current().name2val[logkey] += time.time() - tstart 270 | 271 | def profile(n): 272 | """ 273 | Usage: 274 | @profile("my_func") 275 | def my_func(): code 276 | """ 277 | def decorator_with_name(func): 278 | def func_wrapper(*args, **kwargs): 279 | with profile_kv(n): 280 | return func(*args, **kwargs) 281 | return func_wrapper 282 | return decorator_with_name 283 | 284 | 285 | # ================================================================ 286 | # Backend 287 | # ================================================================ 288 | 289 | def get_current(): 290 | if Logger.CURRENT is None: 291 | _configure_default_logger() 292 | 293 | return Logger.CURRENT 294 | 295 | 296 | class Logger(object): 297 | DEFAULT = None # A logger with no output files. (See right below class definition) 298 | # So that you can still log to the terminal without setting up any output files 299 | CURRENT = None # Current logger being used by the free functions above 300 | 301 | def __init__(self, dir, output_formats, comm=None): 302 | self.name2val = defaultdict(float) # values this iteration 303 | self.name2cnt = defaultdict(int) 304 | self.level = INFO 305 | self.dir = dir 306 | self.output_formats = output_formats 307 | self.comm = comm 308 | 309 | # Logging API, forwarded 310 | # ---------------------------------------- 311 | def logkv(self, key, val): 312 | self.name2val[key] = val 313 | 314 | def logkv_mean(self, key, val): 315 | oldval, cnt = self.name2val[key], self.name2cnt[key] 316 | self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1) 317 | self.name2cnt[key] = cnt + 1 318 | 319 | def dumpkvs(self): 320 | if self.comm is None: 321 | d = self.name2val 322 | else: 323 | from baselines.common import mpi_util 324 | d = mpi_util.mpi_weighted_mean(self.comm, 325 | {name : (val, self.name2cnt.get(name, 1)) 326 | for (name, val) in self.name2val.items()}) 327 | if self.comm.rank != 0: 328 | d['dummy'] = 1 # so we don't get a warning about empty dict 329 | out = d.copy() # Return the dict for unit testing purposes 330 | for fmt in self.output_formats: 331 | if isinstance(fmt, KVWriter): 332 | fmt.writekvs(d) 333 | self.name2val.clear() 334 | self.name2cnt.clear() 335 | return out 336 | 337 | def log(self, *args, level=INFO): 338 | if self.level <= level: 339 | self._do_log(args) 340 | 341 | # Configuration 342 | # ---------------------------------------- 343 | def set_level(self, level): 344 | self.level = level 345 | 346 | def set_comm(self, comm): 347 | self.comm = comm 348 | 349 | def get_dir(self): 350 | return self.dir 351 | 352 | def close(self): 353 | for fmt in self.output_formats: 354 | fmt.close() 355 | 356 | # Misc 357 | # ---------------------------------------- 358 | def _do_log(self, args): 359 | for fmt in self.output_formats: 360 | if isinstance(fmt, SeqWriter): 361 | fmt.writeseq(map(str, args)) 362 | 363 | def get_rank_without_mpi_import(): 364 | # check environment variables here instead of importing mpi4py 365 | # to avoid calling MPI_Init() when this module is imported 366 | for varname in ['PMI_RANK', 'OMPI_COMM_WORLD_RANK']: 367 | if varname in os.environ: 368 | return int(os.environ[varname]) 369 | return 0 370 | 371 | 372 | def configure(dir=None, format_strs=None, comm=None, log_suffix=''): 373 | """ 374 | If comm is provided, average all numerical stats across that comm 375 | """ 376 | if dir is None: 377 | dir = os.getenv('OPENAI_LOGDIR') 378 | if dir is None: 379 | dir = osp.join(tempfile.gettempdir(), 380 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f")) 381 | assert isinstance(dir, str) 382 | dir = os.path.expanduser(dir) 383 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 384 | 385 | rank = get_rank_without_mpi_import() 386 | if rank > 0: 387 | log_suffix = log_suffix + "-rank%03i" % rank 388 | 389 | if format_strs is None: 390 | if rank == 0: 391 | format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',') 392 | else: 393 | format_strs = os.getenv('OPENAI_LOG_FORMAT_MPI', 'log').split(',') 394 | format_strs = filter(None, format_strs) 395 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 396 | 397 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 398 | if output_formats: 399 | log('Logging to %s'%dir) 400 | 401 | def _configure_default_logger(): 402 | configure() 403 | Logger.DEFAULT = Logger.CURRENT 404 | 405 | def reset(): 406 | if Logger.CURRENT is not Logger.DEFAULT: 407 | Logger.CURRENT.close() 408 | Logger.CURRENT = Logger.DEFAULT 409 | log('Reset logger') 410 | 411 | @contextmanager 412 | def scoped_configure(dir=None, format_strs=None, comm=None): 413 | prevlogger = Logger.CURRENT 414 | configure(dir=dir, format_strs=format_strs, comm=comm) 415 | try: 416 | yield 417 | finally: 418 | Logger.CURRENT.close() 419 | Logger.CURRENT = prevlogger 420 | 421 | # ================================================================ 422 | 423 | def _demo(): 424 | info("hi") 425 | debug("shouldn't appear") 426 | set_level(DEBUG) 427 | debug("should appear") 428 | dir = "/tmp/testlogging" 429 | if os.path.exists(dir): 430 | shutil.rmtree(dir) 431 | configure(dir=dir) 432 | logkv("a", 3) 433 | logkv("b", 2.5) 434 | dumpkvs() 435 | logkv("b", -2.5) 436 | logkv("a", 5.5) 437 | dumpkvs() 438 | info("^^^ should see a = 5.5") 439 | logkv_mean("b", -22.5) 440 | logkv_mean("b", -44.4) 441 | logkv("a", 5.5) 442 | dumpkvs() 443 | info("^^^ should see b = -33.3") 444 | 445 | logkv("b", -2.5) 446 | dumpkvs() 447 | 448 | logkv("a", "longasslongasslongasslongasslongasslongassvalue") 449 | dumpkvs() 450 | 451 | 452 | # ================================================================ 453 | # Readers 454 | # ================================================================ 455 | 456 | def read_json(fname): 457 | import pandas 458 | ds = [] 459 | with open(fname, 'rt') as fh: 460 | for line in fh: 461 | ds.append(json.loads(line)) 462 | return pandas.DataFrame(ds) 463 | 464 | def read_csv(fname): 465 | import pandas 466 | return pandas.read_csv(fname, index_col=None, comment='#') 467 | 468 | def read_tb(path): 469 | """ 470 | path : a tensorboard file OR a directory, where we will find all TB files 471 | of the form events.* 472 | """ 473 | import pandas 474 | import numpy as np 475 | from glob import glob 476 | import tensorflow as tf 477 | if osp.isdir(path): 478 | fnames = glob(osp.join(path, "events.*")) 479 | elif osp.basename(path).startswith("events."): 480 | fnames = [path] 481 | else: 482 | raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s"%path) 483 | tag2pairs = defaultdict(list) 484 | maxstep = 0 485 | for fname in fnames: 486 | for summary in tf.train.summary_iterator(fname): 487 | if summary.step > 0: 488 | for v in summary.summary.value: 489 | pair = (summary.step, v.simple_value) 490 | tag2pairs[v.tag].append(pair) 491 | maxstep = max(summary.step, maxstep) 492 | data = np.empty((maxstep, len(tag2pairs))) 493 | data[:] = np.nan 494 | tags = sorted(tag2pairs.keys()) 495 | for (colidx,tag) in enumerate(tags): 496 | pairs = tag2pairs[tag] 497 | for (step, value) in pairs: 498 | data[step-1, colidx] = value 499 | return pandas.DataFrame(data, columns=tags) 500 | 501 | if __name__ == "__main__": 502 | _demo() 503 | --------------------------------------------------------------------------------