├── README.md ├── dataset.py ├── layers.py ├── models.py ├── notebooks ├── check_result.ipynb └── example.ipynb ├── parser ├── README.md └── parser.jar ├── requirements.txt ├── retrain.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Attention-based Tree-to-Sequence Code Summarization Model 2 | 3 | The TensorFlow Eager Execution implementation of [Source Code Summarization with Extended Tree-LSTM](https://arxiv.org/abs/1906.08094) (Shido+, 2019) 4 | 5 | including: 6 | 7 | - **Multi-way Tree-LSTM model (Ours)** 8 | - Child-sum Tree-LSTM model 9 | - N-ary Tree-LSTM model 10 | - DeepCom (Hu et al.) 11 | - CODE-NN (Iyer et al.) 12 | 13 | ## Dataset 14 | 15 | 1. Download raw dataset from [https://github.com/xing-hu/DeepCom] 16 | 2. Parse them with parser.jar 17 | 18 | ## Usage 19 | 20 | 1. Prepare tree-structured data with `dataset.py` 21 | - Run `$ python dataset.py [dir]` 22 | 2. Train and evaluate model with `train.py` 23 | - See `$ python train.py -h` 24 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from tqdm import tqdm 3 | from glob import glob 4 | from utils import Node, traverse_label, traverse 5 | import pickle 6 | import os 7 | from joblib import Parallel, delayed 8 | from collections import Counter 9 | import re 10 | from os.path import abspath 11 | import nltk 12 | 13 | 14 | def parse(path): 15 | with open(path, "r") as f: 16 | num_objects = f.readline() 17 | nodes = [Node(num=i, children=[]) for i in range(int(num_objects))] 18 | for i in range(int(num_objects)): 19 | label = " ".join(f.readline().split(" ")[1:])[:-1] 20 | nodes[i].label = label 21 | while 1: 22 | line = f.readline() 23 | if line == "\n": 24 | break 25 | p, c = map(int, line.split(" ")) 26 | nodes[p].children.append(nodes[c]) 27 | nodes[c].parent = nodes[p] 28 | nl = f.readline()[:-1] 29 | return nodes[0], nl 30 | 31 | 32 | def is_invalid_com(s): 33 | return s[:2] == "/*" and len(s) > 1 34 | 35 | 36 | def is_invalid_seq(s): 37 | return len(s) < 4 38 | 39 | 40 | def get_method_name(root): 41 | for c in root.children: 42 | if c.label == "name (SimpleName)": 43 | return c.children[0].label[12:-1] 44 | 45 | 46 | def is_invalid_tree(root): 47 | labels = traverse_label(root) 48 | if root.label == 'root (ConstructorDeclaration)': 49 | return True 50 | if len(labels) >= 100: 51 | return True 52 | method_name = get_method_name(root) 53 | for word in ["test", "Test", "set", "Set", "get", "Get"]: 54 | if method_name[:len(word)] == word: 55 | return True 56 | return False 57 | 58 | 59 | def clean_nl(s): 60 | if s[-1] == ".": 61 | s = s[:-1] 62 | s = s.split(". ")[0] 63 | s = re.sub("[<].+?[>]", "", s) 64 | s = re.sub("[\[\]\%]", "", s) 65 | s = s[0:1].lower() + s[1:] 66 | return s 67 | 68 | 69 | def tokenize(s): 70 | return [""] + nltk.word_tokenize(s) + [""] 71 | 72 | 73 | def parse_dir(path_to_dir): 74 | files = sorted(glob(path_to_dir + "/*")) 75 | set_name = path_to_dir.split("/")[-1] 76 | 77 | nls = {} 78 | skip = 0 79 | 80 | for file in tqdm(files, "parsing {}".format(path_to_dir)): 81 | tree, nl = parse(file) 82 | nl = clean_nl(nl) 83 | if is_invalid_com(nl): 84 | skip += 1 85 | continue 86 | if is_invalid_tree(tree): 87 | skip += 1 88 | continue 89 | number = int(file.split("/")[-1]) 90 | seq = tokenize(nl) 91 | if is_invalid_seq(seq): 92 | skip += 1 93 | continue 94 | nls[abspath("./dataset/tree/" + set_name + "/" + str(number))] = seq 95 | with open("./dataset/tree_raw/" + set_name + "/" + str(number), "wb", 1) as f: 96 | pickle.dump(tree, f) 97 | 98 | print("{} files skipped".format(skip)) 99 | 100 | if set_name == "train": 101 | vocab = Counter([x for l in nls.values() for x in l]) 102 | nl_i2w = {i: w for i, w in enumerate( 103 | ["", ""] + sorted([x[0] for x in vocab.most_common(30000)]))} 104 | nl_w2i = {w: i for i, w in enumerate( 105 | ["", ""] + sorted([x[0] for x in vocab.most_common(30000)]))} 106 | pickle.dump(nl_i2w, open("./dataset/nl_i2w.pkl", "wb")) 107 | pickle.dump(nl_w2i, open("./dataset/nl_w2i.pkl", "wb")) 108 | 109 | return nls 110 | 111 | 112 | def pickling(): 113 | args = sys.argv 114 | 115 | if len(args) <= 1: 116 | raise Exception("(usage) $ python dataset.py [dir]") 117 | 118 | data_dir = args[1] 119 | 120 | dirs = [ 121 | "dataset", 122 | "dataset/tree_raw", 123 | "dataset/tree_raw/train", 124 | "dataset/tree_raw/valid", 125 | "dataset/tree_raw/test", 126 | "dataset/nl" 127 | ] 128 | for d in dirs: 129 | if not os.path.exists(d): 130 | os.mkdir(d) 131 | 132 | for path in [data_dir + "/" + s for s in ["train", "valid", "test"]]: 133 | set_name = path.split("/")[-1] 134 | nl = parse_dir(path) 135 | with open("./dataset/nl/" + set_name + ".pkl", "wb", 1) as f: 136 | pickle.dump(nl, f) 137 | 138 | 139 | def isnum(s): 140 | try: 141 | float(s) 142 | except ValueError: 143 | return False 144 | else: 145 | return True 146 | 147 | 148 | def get_labels(path): 149 | tree = pickle.load(open(path, "rb")) 150 | return traverse_label(tree) 151 | 152 | 153 | def get_bracket(s): 154 | if "value=" == s[:6] or "identifier=" in s[:11]: 155 | return None 156 | p = "\(.+?\)" 157 | res = re.findall(p, s) 158 | if len(res) == 1: 159 | return res[0] 160 | return s 161 | 162 | 163 | def get_identifier(s): 164 | if "identifier=" == s[:11]: 165 | return "SimpleName_" + s[11:] 166 | else: 167 | return None 168 | 169 | 170 | def is_SimpleName(s): 171 | return "SimpleName_" == s[:11] 172 | 173 | 174 | def get_values(s): 175 | if "value=" == s[:6]: 176 | return "Value_" + s[6:] 177 | else: 178 | return None 179 | 180 | 181 | def is_value(s): 182 | return "Value_" == s[:6] 183 | 184 | 185 | def make_dict(): 186 | labels = Parallel(n_jobs=-1)(delayed(get_labels)(p) for p in tqdm( 187 | glob("./dataset/tree_raw/train/*"), "reading all labels")) 188 | labels = [l for s in labels for l in s] 189 | 190 | non_terminals = set( 191 | [get_bracket(x) for x in tqdm( 192 | list(set(labels)), "collect non-tarminals")]) - set([None, "(SimpleName)"]) 193 | non_terminals = sorted(list(non_terminals)) 194 | 195 | ids = Counter( 196 | [y for y in [get_identifier(x) for x in tqdm( 197 | labels, "collect identifiers")] if y is not None]) 198 | ids_list = [x[0] for x in ids.most_common(30000)] 199 | 200 | values = Counter( 201 | [y for y in [get_values(x) for x in tqdm( 202 | labels, "collect values")] if y is not None]) 203 | values_list = [x[0] for x in values.most_common(1000)] 204 | 205 | vocab = ["", "SimpleName_", "Value_", "Value_"] 206 | vocab += non_terminals + ids_list + values_list + ["(", ")"] 207 | 208 | code_i2w = {i: w for i, w in enumerate(vocab)} 209 | code_w2i = {w: i for i, w in enumerate(vocab)} 210 | 211 | pickle.dump(code_i2w, open("./dataset/code_i2w.pkl", "wb")) 212 | pickle.dump(code_w2i, open("./dataset/code_w2i.pkl", "wb")) 213 | 214 | 215 | def remove_SimpleName(root): 216 | for node in traverse(root): 217 | if "=" not in node.label and "(SimpleName)" in node.label: 218 | if node.children[0].label[:11] != "identifier=": 219 | raise Exception("ERROR!") 220 | node.label = "SimpleName_" + node.children[0].label[11:] 221 | node.children = [] 222 | elif node.label[:11] == "identifier=": 223 | node.label = "SimpleName_" + node.label[11:] 224 | elif node.label[:6] == "value=": 225 | node.label = "Value_" + node.label[6:] 226 | 227 | return root 228 | 229 | 230 | def modifier(root, dic): 231 | for node in traverse(root): 232 | if is_SimpleName(node.label): 233 | if node.label not in dic: 234 | node.label = "SimpleName_" 235 | elif is_value(node.label): 236 | if node.label not in dic: 237 | if isnum(node.label): 238 | node.label = "Value_" 239 | else: 240 | node.label = "Value_" 241 | else: 242 | node.label = get_bracket(node.label) 243 | if node.label not in dic: 244 | raise Exception("Unknown word", node.label) 245 | 246 | return root 247 | 248 | 249 | def rebuild_tree(path, dst, dic): 250 | root = pickle.load(open(path, "rb")) 251 | root = remove_SimpleName(root) 252 | root = modifier(root, dic) 253 | pickle.dump(root, open(dst, "wb"), 1) 254 | 255 | 256 | def preprocess_trees(): 257 | 258 | dirs = [ 259 | "./dataset", 260 | "./dataset/tree", 261 | "./dataset/tree/train", 262 | "./dataset/tree/valid", 263 | "./dataset/tree/test", 264 | "./dataset/nl" 265 | ] 266 | for d in dirs: 267 | if not os.path.exists(d): 268 | os.mkdir(d) 269 | 270 | sets_name = [ 271 | "./dataset/tree_raw/train/*", 272 | "./dataset/tree_raw/valid/*", 273 | "./dataset/tree_raw/test/*" 274 | ] 275 | 276 | dic = set(pickle.load(open("./dataset/code_i2w.pkl", "rb")).values()) 277 | 278 | for sets in sets_name: 279 | files = sorted(list(glob(sets))) 280 | dst = [x.replace("tree_raw", "tree") for x in files] 281 | Parallel(n_jobs=-1)( 282 | delayed(rebuild_tree)(p, d, dic) for p, d in tqdm( 283 | list(zip(files, dst)), "preprocessing {}".format(sets))) 284 | 285 | 286 | if __name__ == "__main__": 287 | nltk.download('punkt') 288 | sys.setrecursionlimit(10000) 289 | pickling() 290 | make_dict() 291 | preprocess_trees() 292 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | """layers""" 2 | 3 | import tensorflow as tf 4 | from utils import * 5 | tfe = tf.contrib.eager 6 | 7 | 8 | class TreeEmbeddingLayer(tf.keras.Model): 9 | def __init__(self, dim_E, in_vocab): 10 | super(TreeEmbeddingLayer, self).__init__() 11 | self.E = tf.get_variable("E", [in_vocab, dim_E], tf.float32, 12 | initializer=tf.keras.initializers.RandomUniform()) 13 | 14 | def call(self, x): 15 | '''x: list of [1,]''' 16 | x_len = [xx.shape[0] for xx in x] 17 | ex = tf.nn.embedding_lookup(self.E, tf.concat(x, axis=0)) 18 | exs = tf.split(ex, x_len, 0) 19 | return exs 20 | 21 | 22 | class TreeEmbeddingLayerTreeBase(tf.keras.Model): 23 | def __init__(self, dim_E, in_vocab): 24 | super(TreeEmbeddingLayerTreeBase, self).__init__() 25 | self.E = tf.get_variable("E", [in_vocab, dim_E], tf.float32, 26 | initializer=tf.keras.initializers.RandomUniform()) 27 | 28 | def call(self, roots): 29 | return [self.apply_single(root) for root in roots] 30 | 31 | def apply_single(self, root): 32 | labels = traverse_label(root) 33 | embedded = tf.nn.embedding_lookup(self.E, labels) 34 | new_nodes = self.Node2TreeLSTMNode(root, parent=None) 35 | for rep, node in zip(embedded, traverse(new_nodes)): 36 | node.h = rep 37 | return new_nodes 38 | 39 | def Node2TreeLSTMNode(self, node, parent): 40 | children = [self.Node2TreeLSTMNode(c, node) for c in node.children] 41 | return TreeLSTMNode(node.label, parent=parent, children=children, num=node.num) 42 | 43 | 44 | class ChildSumLSTMLayerWithEmbedding(tf.keras.Model): 45 | def __init__(self, in_vocab, dim_in, dim_out): 46 | super(ChildSumLSTMLayerWithEmbedding, self).__init__() 47 | self.dim_in = dim_in 48 | self.dim_out = dim_out 49 | self.E = tf.get_variable("E", [in_vocab, dim_in], tf.float32, 50 | initializer=tf.keras.initializers.RandomUniform()) 51 | self.U_f = tf.keras.layers.Dense(dim_out, use_bias=False) 52 | self.U_iuo = tf.keras.layers.Dense(dim_out * 3, use_bias=False) 53 | self.W = tf.keras.layers.Dense(dim_out * 4) 54 | # self.h_init = tfe.Variable( 55 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal())) 56 | # self.c_init = tfe.Variable( 57 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal())) 58 | self.h_init = tf.zeros([1, dim_out], tf.float32) 59 | self.c_init = tf.zeros([1, dim_out], tf.float32) 60 | 61 | @staticmethod 62 | def get_nums(roots): 63 | res = [[x.num for x in n.children] if n.children != [] else [0] for n in roots] 64 | max_len = max([len(x) for x in res]) 65 | res = tf.keras.preprocessing.sequence.pad_sequences( 66 | res, max_len, padding="post", value=-1.) 67 | return tf.constant(res, tf.int32) 68 | 69 | def call(self, roots): 70 | depthes = [x[1] for x in sorted(depth_split_batch2( 71 | roots).items(), key=lambda x:-x[0])] # list of list of Nodes 72 | indices = [self.get_nums(nodes) for nodes in depthes] 73 | 74 | h_tensor = self.h_init 75 | c_tensor = self.c_init 76 | for indice, nodes in zip(indices, depthes): 77 | x = tf.nn.embedding_lookup(self.E, [node.label for node in nodes]) # [nodes, dim_in] 78 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice, nodes) 79 | h_tensor = tf.concat([self.h_init, h_tensor], 0) 80 | c_tensor = tf.concat([self.c_init, c_tensor], 0) 81 | return depthes[-1] 82 | 83 | def apply(self, x, h_tensor, c_tensor, indice, nodes): 84 | 85 | mask_bool = tf.not_equal(indice, -1.) 86 | mask = tf.cast(mask_bool, tf.float32) # [batch, child] 87 | 88 | h = tf.gather(h_tensor, tf.where(mask_bool, 89 | indice, tf.zeros_like(indice))) # [nodes, child, dim] 90 | c = tf.gather(c_tensor, tf.where(mask_bool, 91 | indice, tf.zeros_like(indice))) 92 | h_sum = tf.reduce_sum(h * tf.expand_dims(mask, -1), 1) # [nodes, dim_out] 93 | 94 | W_x = self.W(x) # [nodes, dim_out * 4] 95 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out] 96 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2] 97 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3] 98 | W_o_x = W_x[:, self.dim_out * 3:] 99 | 100 | branch_f_k = tf.reshape(self.U_f(tf.reshape(h, [-1, h.shape[-1]])), h.shape) 101 | branch_f_k = tf.sigmoid(tf.expand_dims(W_f_x, 1) + branch_f_k) 102 | branch_f = tf.reduce_sum(branch_f_k * c * tf.expand_dims(mask, -1), 1) # [node, dim_out] 103 | 104 | branch_iuo = self.U_iuo(h_sum) # [nodes, dim_out * 3] 105 | branch_i = tf.sigmoid(branch_iuo[:, :self.dim_out * 1] + W_i_x) # [nodes, dim_out] 106 | branch_u = tf.tanh(branch_iuo[:, self.dim_out * 1:self.dim_out * 2] + W_u_x) 107 | branch_o = tf.sigmoid(branch_iuo[:, self.dim_out * 2:] + W_o_x) 108 | 109 | new_c = branch_i * branch_u + branch_f # [node, dim_out] 110 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out] 111 | 112 | for n, c, h in zip(nodes, new_c, new_h): 113 | n.c = c 114 | n.h = h 115 | 116 | return new_h, new_c 117 | 118 | 119 | class ChildSumLSTMLayer(tf.keras.Model): 120 | def __init__(self, dim_in, dim_out): 121 | super(ChildSumLSTMLayer, self).__init__() 122 | self.dim_in = dim_in 123 | self.dim_out = dim_out 124 | self.U_f = tf.keras.layers.Dense(dim_out, use_bias=False) 125 | self.U_iuo = tf.keras.layers.Dense(dim_out * 3, use_bias=False) 126 | self.W = tf.keras.layers.Dense(dim_out * 4) 127 | # self.h_init = tfe.Variable( 128 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal())) 129 | # self.c_init = tfe.Variable( 130 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal())) 131 | self.h_init = tf.zeros([1, dim_out], tf.float32) 132 | self.c_init = tf.zeros([1, dim_out], tf.float32) 133 | 134 | def call(self, tensor, indices): 135 | h_tensor = self.h_init 136 | c_tensor = self.c_init 137 | res_h, res_c = [], [] 138 | for indice, x in zip(indices, tensor): 139 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice) 140 | h_tensor = tf.concat([self.h_init, h_tensor], 0) 141 | c_tensor = tf.concat([self.c_init, c_tensor], 0) 142 | res_h.append(h_tensor[1:, :]) 143 | res_c.append(c_tensor[1:, :]) 144 | return res_h, res_c 145 | 146 | def apply(self, x, h_tensor, c_tensor, indice): 147 | 148 | mask_bool = tf.not_equal(indice, -1.) 149 | mask = tf.cast(mask_bool, tf.float32) # [batch, child] 150 | 151 | h = tf.gather(h_tensor, tf.where(mask_bool, 152 | indice, tf.zeros_like(indice))) # [nodes, child, dim] 153 | c = tf.gather(c_tensor, tf.where(mask_bool, 154 | indice, tf.zeros_like(indice))) 155 | h_sum = tf.reduce_sum(h * tf.expand_dims(mask, -1), 1) # [nodes, dim_out] 156 | 157 | W_x = self.W(x) # [nodes, dim_out * 4] 158 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out] 159 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2] 160 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3] 161 | W_o_x = W_x[:, self.dim_out * 3:] 162 | 163 | branch_f_k = tf.reshape(self.U_f(tf.reshape(h, [-1, h.shape[-1]])), h.shape) 164 | branch_f_k = tf.sigmoid(tf.expand_dims(W_f_x, 1) + branch_f_k) 165 | branch_f = tf.reduce_sum(branch_f_k * c * tf.expand_dims(mask, -1), 1) # [node, dim_out] 166 | 167 | branch_iuo = self.U_iuo(h_sum) # [nodes, dim_out * 3] 168 | branch_i = tf.sigmoid(branch_iuo[:, :self.dim_out * 1] + W_i_x) # [nodes, dim_out] 169 | branch_u = tf.tanh(branch_iuo[:, self.dim_out * 1:self.dim_out * 2] + W_u_x) 170 | branch_o = tf.sigmoid(branch_iuo[:, self.dim_out * 2:] + W_o_x) 171 | 172 | new_c = branch_i * branch_u + branch_f # [node, dim_out] 173 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out] 174 | 175 | return new_h, new_c 176 | 177 | 178 | class ChildSumLSTMLayerTreeBase(tf.keras.Model): 179 | def __init__(self, dim_in, dim_out): 180 | super(ChildSumLSTMLayerTreeBase, self).__init__() 181 | self.dim_in = dim_in 182 | self.dim_out = dim_out 183 | self.U_f = tf.keras.layers.Dense(dim_out, use_bias=False) 184 | self.U_iuo = tf.keras.layers.Dense(dim_out * 3, use_bias=False) 185 | self.W = tf.keras.layers.Dense(dim_out * 4) 186 | # self.h_init = tfe.Variable( 187 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal())) 188 | # self.c_init = tfe.Variable( 189 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal())) 190 | self.h_init = tf.zeros([1, dim_out], tf.float32) 191 | self.c_init = tf.zeros([1, dim_out], tf.float32) 192 | 193 | @staticmethod 194 | def get_nums(roots): 195 | res = [[x.num for x in n.children] if n.children != [] else [0] for n in roots] 196 | max_len = max([len(x) for x in res]) 197 | res = tf.keras.preprocessing.sequence.pad_sequences( 198 | res, max_len, padding="post", value=-1.) 199 | return tf.constant(res, tf.int32) 200 | 201 | def call(self, roots): 202 | depthes = [x[1] for x in sorted(depth_split_batch2( 203 | roots).items(), key=lambda x:-x[0])] # list of list of Nodes 204 | indices = [self.get_nums(nodes) for nodes in depthes] 205 | 206 | h_tensor = self.h_init 207 | c_tensor = self.c_init 208 | for indice, nodes in zip(indices, depthes): 209 | x = tf.stack([node.h for node in nodes]) # [nodes, dim_in] 210 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice, nodes) 211 | h_tensor = tf.concat([self.h_init, h_tensor], 0) 212 | c_tensor = tf.concat([self.c_init, c_tensor], 0) 213 | return depthes[-1] 214 | 215 | def apply(self, x, h_tensor, c_tensor, indice, nodes): 216 | 217 | mask_bool = tf.not_equal(indice, -1.) 218 | mask = tf.cast(mask_bool, tf.float32) # [batch, child] 219 | 220 | h = tf.gather(h_tensor, tf.where(mask_bool, 221 | indice, tf.zeros_like(indice))) # [nodes, child, dim] 222 | c = tf.gather(c_tensor, tf.where(mask_bool, 223 | indice, tf.zeros_like(indice))) 224 | h_sum = tf.reduce_sum(h * tf.expand_dims(mask, -1), 1) # [nodes, dim_out] 225 | 226 | W_x = self.W(x) # [nodes, dim_out * 4] 227 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out] 228 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2] 229 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3] 230 | W_o_x = W_x[:, self.dim_out * 3:] 231 | 232 | branch_f_k = tf.reshape(self.U_f(tf.reshape(h, [-1, h.shape[-1]])), h.shape) 233 | branch_f_k = tf.sigmoid(tf.expand_dims(W_f_x, 1) + branch_f_k) 234 | branch_f = tf.reduce_sum(branch_f_k * c * tf.expand_dims(mask, -1), 1) # [node, dim_out] 235 | 236 | branch_iuo = self.U_iuo(h_sum) # [nodes, dim_out * 3] 237 | branch_i = tf.sigmoid(branch_iuo[:, :self.dim_out * 1] + W_i_x) # [nodes, dim_out] 238 | branch_u = tf.tanh(branch_iuo[:, self.dim_out * 1:self.dim_out * 2] + W_u_x) 239 | branch_o = tf.sigmoid(branch_iuo[:, self.dim_out * 2:] + W_o_x) 240 | 241 | new_c = branch_i * branch_u + branch_f # [node, dim_out] 242 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out] 243 | 244 | for n, c, h in zip(nodes, new_c, new_h): 245 | n.c = c 246 | n.h = h 247 | 248 | return new_h, new_c 249 | 250 | 251 | class NaryLSTMLayer(tf.keras.Model): 252 | def __init__(self, dim_in, dim_out): 253 | super(NaryLSTMLayer, self).__init__() 254 | self.dim_in = dim_in 255 | self.dim_out = dim_out 256 | self.U_f1 = tf.keras.layers.Dense(dim_out, use_bias=False) 257 | self.U_f2 = tf.keras.layers.Dense(dim_out, use_bias=False) 258 | self.U_iuo = tf.keras.layers.Dense(dim_out * 3, use_bias=False) 259 | self.W = tf.keras.layers.Dense(dim_out * 4) 260 | # self.h_init = tfe.Variable( 261 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal())) 262 | # self.c_init = tfe.Variable( 263 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal())) 264 | self.h_init = tf.zeros([1, dim_out], tf.float32) 265 | self.c_init = tf.zeros([1, dim_out], tf.float32) 266 | 267 | def call(self, tensor, indices): 268 | h_tensor = self.h_init 269 | c_tensor = self.c_init 270 | res_h, res_c = [], [] 271 | for indice, x in zip(indices, tensor): 272 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice) 273 | h_tensor = tf.concat([self.h_init, h_tensor], 0) 274 | c_tensor = tf.concat([self.c_init, c_tensor], 0) 275 | res_h.append(h_tensor[1:, :]) 276 | res_c.append(c_tensor[1:, :]) 277 | return res_h, res_c 278 | 279 | def apply(self, x, h_tensor, c_tensor, indice): 280 | 281 | mask_bool = tf.not_equal(indice, -1.) 282 | 283 | h = tf.gather(h_tensor, tf.where(mask_bool, 284 | indice, tf.zeros_like(indice))) # [nodes, child, dim] 285 | c = tf.gather(c_tensor, tf.where(mask_bool, 286 | indice, tf.zeros_like(indice))) 287 | 288 | W_x = self.W(x) # [nodes, dim_out * 4] 289 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out] 290 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2] 291 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3] 292 | W_o_x = W_x[:, self.dim_out * 3:] 293 | 294 | if h.shape[1] <= 1: 295 | h = tf.concat([h, tf.zeros_like(h)], 1) # [nodes, 2, dim] 296 | c = tf.concat([c, tf.zeros_like(c)], 1) 297 | 298 | h_concat = tf.reshape(h, [h.shape[0], -1]) 299 | 300 | branch_f1 = self.U_f1(h_concat) 301 | branch_f1 = tf.sigmoid(W_f_x + branch_f1) 302 | branch_f2 = self.U_f2(h_concat) 303 | branch_f2 = tf.sigmoid(W_f_x + branch_f2) 304 | branch_f = branch_f1 * c[:, 0] + branch_f2 * c[:, 1] 305 | 306 | branch_iuo = self.U_iuo(h_concat) # [nodes, dim_out * 3] 307 | branch_i = tf.sigmoid(branch_iuo[:, :self.dim_out * 1] + W_i_x) # [nodes, dim_out] 308 | branch_u = tf.tanh(branch_iuo[:, self.dim_out * 1:self.dim_out * 2] + W_u_x) 309 | branch_o = tf.sigmoid(branch_iuo[:, self.dim_out * 2:] + W_o_x) 310 | 311 | new_c = branch_i * branch_u + branch_f # [node, dim_out] 312 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out] 313 | 314 | return new_h, new_c 315 | 316 | 317 | class BiLSTM_(tf.keras.Model): 318 | def __init__(self, dim, return_seq=False): 319 | super(BiLSTM_, self).__init__() 320 | self.dim = dim 321 | # self.c_init_f = tfe.Variable(tf.get_variable("c_init_f", [1, dim], tf.float32, 322 | # initializer=he_normal())) 323 | # self.h_init_f = tfe.Variable(tf.get_variable("h_initf", [1, dim], tf.float32, 324 | # initializer=he_normal())) 325 | # self.c_init_b = tfe.Variable(tf.get_variable("c_init_b", [1, dim], tf.float32, 326 | # initializer=he_normal())) 327 | # self.h_init_b = tfe.Variable(tf.get_variable("h_init_b", [1, dim], tf.float32, 328 | # initializer=he_normal())) 329 | self.c_init_f = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32)) 330 | self.h_init_f = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32)) 331 | self.c_init_b = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32)) 332 | self.h_init_b = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32)) 333 | self.Cell_f = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(dim) 334 | self.Cell_b = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(dim) 335 | self.fc = tf.keras.layers.Dense(dim, use_bias=False) 336 | self.return_seq = return_seq 337 | 338 | def call(self, x, length): 339 | '''x: [batch, length, dim]''' 340 | batch = x.shape[0] 341 | ys, states = tf.nn.bidirectional_dynamic_rnn(self.Cell_f, self.Cell_b, x, 342 | length, 343 | tf.nn.rnn_cell.LSTMStateTuple( 344 | tf.tile(self.c_init_f, [batch, 1]), 345 | tf.tile(self.h_init_f, [batch, 1])), 346 | tf.nn.rnn_cell.LSTMStateTuple( 347 | tf.tile(self.c_init_b, [batch, 1]), 348 | tf.tile(self.h_init_b, [batch, 1]))) 349 | if self.return_seq: 350 | return self.fc(tf.concat(ys, -1)) 351 | else: 352 | state_f, state_b = states 353 | state_concat = tf.concat([state_f.h, state_b.h], -1) 354 | return self.fc(state_concat) 355 | 356 | 357 | class BiLSTM(tf.keras.Model): 358 | def __init__(self, dim, return_seq=False): 359 | super(BiLSTM, self).__init__() 360 | self.dim = dim 361 | self.c_init_f = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32)) 362 | self.h_init_f = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32)) 363 | self.c_init_b = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32)) 364 | self.h_init_b = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32)) 365 | self.lay_f = tf.keras.layers.CuDNNLSTM(dim, return_sequences=True, return_state=True) 366 | self.lay_b = tf.keras.layers.CuDNNLSTM(dim, return_sequences=True, return_state=True) 367 | self.fc = tf.keras.layers.Dense(dim, use_bias=False) 368 | self.return_seq = return_seq 369 | 370 | def call(self, x, length): 371 | '''x: [batch, length, dim]''' 372 | batch = x.shape[0] 373 | x_back = tf.reverse_sequence(x, length, 1) 374 | 375 | init_state_f = (tf.tile(self.h_init_f, [batch, 1]), tf.tile(self.c_init_f, [batch, 1])) 376 | init_state_b = (tf.tile(self.h_init_b, [batch, 1]), tf.tile(self.c_init_b, [batch, 1])) 377 | 378 | y_f, h_f, c_f = self.lay_f(x, init_state_f) 379 | y_b, h_b, c_b = self.lay_b(x_back, init_state_b) 380 | 381 | y = tf.concat([y_f, y_b], -1) 382 | 383 | if self.return_seq: 384 | return self.fc(y) 385 | else: 386 | y_last = tf.gather_nd(y, tf.stack([tf.range(batch), length - 1], 1)) 387 | return self.fc(y_last) 388 | 389 | 390 | class ShidoTreeLSTMLayer(tf.keras.Model): 391 | def __init__(self, dim_in, dim_out): 392 | super(ShidoTreeLSTMLayer, self).__init__() 393 | self.dim_in = dim_in 394 | self.dim_out = dim_out 395 | self.U_f = BiLSTM(dim_out, return_seq=True) 396 | self.U_i = BiLSTM(dim_out) 397 | self.U_u = BiLSTM(dim_out) 398 | self.U_o = BiLSTM(dim_out) 399 | self.W = tf.keras.layers.Dense(dim_out * 4) 400 | # self.h_init = tfe.Variable( 401 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal())) 402 | # self.c_init = tfe.Variable( 403 | # tf.get_variable("c_init", [1, dim_out], tf.float32, initializer=he_normal())) 404 | self.h_init = tf.zeros([1, dim_out], tf.float32) 405 | self.c_init = tf.zeros([1, dim_out], tf.float32) 406 | 407 | def call(self, tensor, indices): 408 | h_tensor = self.h_init 409 | c_tensor = self.c_init 410 | res_h, res_c = [], [] 411 | for indice, x in zip(indices, tensor): 412 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice) 413 | res_h.append(h_tensor[:, :]) 414 | res_c.append(c_tensor[:, :]) 415 | h_tensor = tf.concat([self.h_init, h_tensor], 0) 416 | c_tensor = tf.concat([self.c_init, c_tensor], 0) 417 | return res_h, res_c 418 | 419 | def apply(self, x, h_tensor, c_tensor, indice): 420 | 421 | mask_bool = tf.not_equal(indice, -1.) 422 | mask = tf.cast(mask_bool, tf.float32) # [nodes, child] 423 | length = tf.cast(tf.reduce_sum(mask, 1), tf.int32) 424 | 425 | h = tf.gather(h_tensor, tf.where(mask_bool, 426 | indice, tf.zeros_like(indice))) # [nodes, child, dim] 427 | c = tf.gather(c_tensor, tf.where(mask_bool, 428 | indice, tf.zeros_like(indice))) 429 | 430 | W_x = self.W(x) # [nodes, dim_out * 4] 431 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out] 432 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2] 433 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3] 434 | W_o_x = W_x[:, self.dim_out * 3:] 435 | 436 | branch_f_k = self.U_f(h, length) 437 | branch_f_k = tf.sigmoid(tf.expand_dims(W_f_x, 1) + branch_f_k) 438 | branch_f = tf.reduce_sum(branch_f_k * c * tf.expand_dims(mask, -1), 1) # [node, dim_out] 439 | 440 | branch_i = self.U_i(h, length) # [nodes, dim_out] 441 | branch_i = tf.sigmoid(branch_i + W_i_x) # [nodes, dim_out] 442 | branch_u = self.U_u(h, length) # [nodes, dim_out] 443 | branch_u = tf.tanh(branch_u + W_u_x) 444 | branch_o = self.U_o(h, length) # [nodes, dim_out] 445 | branch_o = tf.sigmoid(branch_o + W_o_x) 446 | 447 | new_c = branch_i * branch_u + branch_f # [node, dim_out] 448 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out] 449 | 450 | return new_h, new_c 451 | 452 | 453 | class ShidoTreeLSTMLayerTreeBase(tf.keras.Model): 454 | def __init__(self, dim_in, dim_out): 455 | super(ShidoTreeLSTMLayerTreeBase, self).__init__() 456 | self.dim_in = dim_in 457 | self.dim_out = dim_out 458 | self.U_f = BiLSTM(dim_out, return_seq=True) 459 | self.U_i = BiLSTM(dim_out) 460 | self.U_u = BiLSTM(dim_out) 461 | self.U_o = BiLSTM(dim_out) 462 | self.W = tf.keras.layers.Dense(dim_out * 4) 463 | # self.h_init = tfe.Variable( 464 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal())) 465 | # self.c_init = tfe.Variable( 466 | # tf.get_variable("c_init", [1, dim_out], tf.float32, initializer=he_normal())) 467 | self.h_init = tf.zeros([1, dim_out], tf.float32) 468 | self.c_init = tf.zeros([1, dim_out], tf.float32) 469 | 470 | @staticmethod 471 | def get_nums(roots): 472 | res = [[x.num for x in n.children] if n.children != [] else [0] for n in roots] 473 | max_len = max([len(x) for x in res]) 474 | res = tf.keras.preprocessing.sequence.pad_sequences( 475 | res, max_len, padding="post", value=-1.) 476 | return tf.constant(res, tf.int32) 477 | 478 | def call(self, roots): 479 | depthes = [x[1] for x in sorted(depth_split_batch2( 480 | roots).items(), key=lambda x:-x[0])] # list of list of Nodes 481 | indices = [self.get_nums(nodes) for nodes in depthes] 482 | 483 | h_tensor = self.h_init 484 | c_tensor = self.c_init 485 | for indice, nodes in zip(indices, depthes): 486 | x = tf.stack([node.h for node in nodes]) # [nodes, dim_in] 487 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice, nodes) 488 | h_tensor = tf.concat([self.h_init, h_tensor], 0) 489 | c_tensor = tf.concat([self.c_init, c_tensor], 0) 490 | return depthes[-1] 491 | 492 | def apply(self, x, h_tensor, c_tensor, indice, nodes): 493 | 494 | mask_bool = tf.not_equal(indice, -1.) 495 | mask = tf.cast(mask_bool, tf.float32) # [nodes, child] 496 | length = tf.cast(tf.reduce_sum(mask, 1), tf.int32) 497 | 498 | h = tf.gather(h_tensor, tf.where(mask_bool, 499 | indice, tf.zeros_like(indice))) # [nodes, child, dim] 500 | c = tf.gather(c_tensor, tf.where(mask_bool, 501 | indice, tf.zeros_like(indice))) 502 | 503 | W_x = self.W(x) # [nodes, dim_out * 4] 504 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out] 505 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2] 506 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3] 507 | W_o_x = W_x[:, self.dim_out * 3:] 508 | 509 | branch_f_k = self.U_f(h, length) 510 | branch_f_k = tf.sigmoid(tf.expand_dims(W_f_x, 1) + branch_f_k) 511 | branch_f = tf.reduce_sum(branch_f_k * c * tf.expand_dims(mask, -1), 1) # [node, dim_out] 512 | 513 | branch_i = self.U_i(h, length) # [nodes, dim_out] 514 | branch_i = tf.sigmoid(branch_i + W_i_x) # [nodes, dim_out] 515 | branch_u = self.U_u(h, length) # [nodes, dim_out] 516 | branch_u = tf.tanh(branch_u + W_u_x) 517 | branch_o = self.U_o(h, length) # [nodes, dim_out] 518 | branch_o = tf.sigmoid(branch_o + W_o_x) 519 | 520 | new_c = branch_i * branch_u + branch_f # [node, dim_out] 521 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out] 522 | 523 | for n, c, h in zip(nodes, new_c, new_h): 524 | n.c = c 525 | n.h = h 526 | 527 | return new_h, new_c 528 | 529 | 530 | class ShidoTreeLSTMWithEmbedding(ShidoTreeLSTMLayer): 531 | def __init__(self, in_vocab, dim_in, dim_out): 532 | super(ShidoTreeLSTMWithEmbedding, self).__init__(dim_in, dim_out) 533 | self.E = tf.get_variable("E", [in_vocab, dim_in], tf.float32, 534 | initializer=tf.keras.initializers.RandomUniform()) 535 | self.dim_in = dim_in 536 | self.dim_out = dim_out 537 | self.U_f = BiLSTM(dim_out, return_seq=True) 538 | self.U_i = BiLSTM(dim_out) 539 | self.U_u = BiLSTM(dim_out) 540 | self.U_o = BiLSTM(dim_out) 541 | self.W = tf.keras.layers.Dense(dim_out * 4) 542 | # self.h_init = tfe.Variable( 543 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal())) 544 | # self.c_init = tfe.Variable( 545 | # tf.get_variable("c_init", [1, dim_out], tf.float32, initializer=he_normal())) 546 | self.h_init = tf.zeros([1, dim_out], tf.float32) 547 | self.c_init = tf.zeros([1, dim_out], tf.float32) 548 | 549 | def call(self, roots): 550 | depthes = [x[1] for x in sorted(depth_split_batch2( 551 | roots).items(), key=lambda x:-x[0])] # list of list of Nodes 552 | indices = [self.get_nums(nodes) for nodes in depthes] 553 | 554 | h_tensor = self.h_init 555 | c_tensor = self.c_init 556 | for indice, nodes in zip(indices, depthes): 557 | x = tf.nn.embedding_lookup(self.E, [node.label for node in nodes]) # [nodes, dim_in] 558 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice, nodes) 559 | h_tensor = tf.concat([self.h_init, h_tensor], 0) 560 | c_tensor = tf.concat([self.c_init, c_tensor], 0) 561 | return depthes[-1] 562 | 563 | 564 | class TreeDropout(tf.keras.Model): 565 | def __init__(self, rate): 566 | super(TreeDropout, self).__init__() 567 | self.dropout_layer = tf.keras.layers.Dropout(rate) 568 | 569 | def call(self, roots): 570 | nodes = [node for root in roots for node in traverse(root)] 571 | ys = [node.h for node in nodes] 572 | tensor = tf.stack(ys) 573 | dropped = self.dropout_layer(tensor) 574 | for e, v in enumerate(tf.split(dropped, len(ys))): 575 | nodes[e].h = tf.squeeze(v) 576 | return roots 577 | 578 | 579 | class SetEmbeddingLayer(tf.keras.Model): 580 | def __init__(self, dim_E, in_vocab): 581 | super(SetEmbeddingLayer, self).__init__() 582 | self.E = tf.get_variable("E", [in_vocab, dim_E], tf.float32, 583 | initializer=tf.keras.initializers.RandomUniform()) 584 | 585 | def call(self, sets): 586 | length = [len(s) for s in sets] 587 | concatenated = tf.concat(sets, 0) 588 | embedded = tf.nn.embedding_lookup(self.E, concatenated) 589 | y = tf.split(embedded, length) 590 | return y 591 | 592 | 593 | class LSTMEncoder(tf.keras.Model): 594 | def __init__(self, dim, layer=1): 595 | super(LSTMEncoder, self).__init__() 596 | self.dim = dim 597 | # self.c_init_f = tfe.Variable(tf.get_variable("c_init_f", [1, dim], tf.float32, 598 | # initializer=he_normal())) 599 | # self.h_init_f = tfe.Variable(tf.get_variable("h_initf", [1, dim], tf.float32, 600 | # initializer=he_normal())) 601 | self.Cell_f = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(dim) 602 | self.h_init_f = tf.zeros([1, dim], tf.float32) 603 | self.c_init_f = tf.zeros([1, dim], tf.float32) 604 | 605 | def call(self, x, length): 606 | '''x: [batch, length, dim]''' 607 | batch = x.shape[0] 608 | ys, states = tf.nn.dynamic_rnn(self.Cell_f, x, 609 | length, 610 | tf.nn.rnn_cell.LSTMStateTuple( 611 | tf.tile(self.c_init_f, [batch, 1]), 612 | tf.tile(self.h_init_f, [batch, 1]))) 613 | return ys, states 614 | 615 | 616 | class SequenceEmbeddingLayer(tf.keras.Model): 617 | def __init__(self, dim_E, in_vocab): 618 | super(SequenceEmbeddingLayer, self).__init__() 619 | self.E = tf.keras.layers.Embedding(in_vocab, dim_E) 620 | 621 | def call(self, y): 622 | y = self.E(y) 623 | return y 624 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import pad_tensor 3 | from layers import * 4 | import numpy as np 5 | 6 | 7 | class AttentionDecoder(tf.keras.Model): 8 | def __init__(self, dim_F, dim_rep, vocab_size, layer=1): 9 | super(AttentionDecoder, self).__init__() 10 | self.layer = layer 11 | self.dim_rep = dim_rep 12 | self.F = tf.keras.layers.Embedding(vocab_size, dim_F) 13 | for i in range(layer): 14 | self.__setattr__("layer{}".format(i), 15 | tf.keras.layers.CuDNNLSTM(dim_rep, 16 | return_sequences=True, 17 | return_state=True, 18 | recurrent_initializer='glorot_uniform')) 19 | self.fc = tf.keras.layers.Dense(vocab_size) 20 | 21 | # used for attention 22 | self.W1 = tf.keras.layers.Dense(self.dim_rep) 23 | self.W2 = tf.keras.layers.Dense(self.dim_rep) 24 | self.V = tf.keras.layers.Dense(1) 25 | print("I am Decoder, dim is {} and {} layered".format(str(self.dim_rep), str(self.layer))) 26 | 27 | @staticmethod 28 | def loss_function(real, pred): 29 | loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) 30 | return tf.reduce_sum(loss_) 31 | 32 | def get_loss(self, enc_y, states, target, dropout=0.0): 33 | ''' 34 | enc_y: batch_size([seq_len, dim]) 35 | states: ([batch, dim], [batch, dim]) 36 | target: [batch, max_len] (padded with -1.) 37 | ''' 38 | mask = tf.not_equal(target, -1.) 39 | h, c = states 40 | enc_y, _ = pad_tensor(enc_y) 41 | enc_y = tf.nn.dropout(enc_y, 1. - dropout) 42 | dec_hidden = tf.nn.dropout(h, 1. - dropout) 43 | dec_cell = tf.nn.dropout(c, 1. - dropout) 44 | 45 | l_states = [(dec_hidden, dec_cell) for _ in range(self.layer)] 46 | target = tf.nn.relu(target) 47 | dec_input = target[:, 0] 48 | loss = 0 49 | for t in range(1, target.shape[1]): 50 | # passing enc_output to the decoder 51 | predictions, l_states, att = self.call( 52 | dec_input, l_states, enc_y) 53 | real = tf.boolean_mask(target[:, t], mask[:, t]) 54 | pred = tf.boolean_mask(predictions, mask[:, t]) 55 | loss += self.loss_function(real, pred) 56 | # using teacher forcing 57 | dec_input = target[:, t] 58 | 59 | return loss / tf.reduce_sum(tf.cast(mask, tf.float32)) 60 | 61 | def translate(self, y_enc, states, max_length, start_token, end_token): 62 | ''' 63 | enc_y: [seq_len, dim] 64 | states: ([dim,], [dim,]) 65 | ''' 66 | attention_plot = np.zeros((max_length, y_enc.shape[0])) 67 | 68 | h, c = states 69 | y_enc = tf.expand_dims(y_enc, 0) 70 | dec_hidden = tf.expand_dims(h, 0) 71 | dec_cell = tf.expand_dims(c, 0) 72 | dec_input = tf.constant(start_token, tf.int32, [1]) 73 | result = [] 74 | 75 | l_states = [(dec_hidden, dec_cell) for _ in range(self.layer)] 76 | 77 | for t in range(max_length): 78 | predictions, l_states, attention_weights = self.call( 79 | dec_input, l_states, y_enc) 80 | 81 | attention_weights = tf.reshape(attention_weights, (-1,)) 82 | attention_plot[t] = attention_weights.numpy() 83 | 84 | predicted_id = tf.argmax(predictions[0]).numpy() 85 | result.append(predicted_id) 86 | 87 | if predicted_id == end_token: 88 | return result[:-1], attention_plot[:t] 89 | 90 | # the predicted ID is fed back into the model 91 | dec_input = tf.expand_dims(predicted_id, 0) 92 | 93 | return result, attention_plot 94 | 95 | def call(self, x, l_states, enc_y): 96 | # enc_y shape == (batch_size, max_length, hidden_size) 97 | 98 | # hidden shape == (batch_size, hidden size) 99 | # hidden_with_time_axis shape == (batch_size, 1, hidden size) 100 | # we are doing this to perform addition to calculate the score 101 | hidden_with_time_axis = tf.expand_dims(l_states[-1][0], 1) 102 | 103 | # score shape == (batch_size, max_length, hidden_size) 104 | score = tf.nn.tanh(self.W1(enc_y) + self.W2(hidden_with_time_axis)) 105 | 106 | # attention_weights shape == (batch_size, max_length, 1) 107 | # we get 1 at the last axis because we are applying score to self.V 108 | attention_weights = tf.nn.softmax(self.V(score), axis=1) 109 | 110 | # context_vector shape after sum == (batch_size, hidden_size) 111 | context_vector = attention_weights * enc_y 112 | context_vector = tf.reduce_sum(context_vector, axis=1) 113 | 114 | # x shape after passing through embedding == (batch_size, 1, embedding_dim) 115 | x = tf.expand_dims(x, 1) 116 | x = self.F(x) 117 | 118 | # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size) 119 | # x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1) 120 | 121 | # passing the concatenated vector to the GRU 122 | new_l_states = [] 123 | for i, states in zip(range(self.layer), l_states): 124 | if i < self.layer - 1: 125 | skip = x 126 | x, h, c = getattr(self, "layer{}".format(i))(x, states) 127 | x += skip 128 | else: 129 | x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1) 130 | x, h, c = getattr(self, "layer{}".format(i))(x, states) 131 | n_states = (h, c) 132 | new_l_states.append(n_states) 133 | 134 | # output shape == (batch_size * 1, hidden_size) 135 | x = tf.reshape(x, (-1, x.shape[2])) 136 | 137 | # output shape == (batch_size * 1, vocab) 138 | x = self.fc(x) 139 | 140 | return x, new_l_states, attention_weights 141 | 142 | 143 | class BaseModel(tf.keras.Model): 144 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0., lr=1e-3): 145 | super(BaseModel, self).__init__() 146 | self.dim_E = dim_E 147 | self.dim_F = dim_F 148 | self.dim_rep = dim_rep 149 | self.in_vocab = in_vocab 150 | self.out_vocab = out_vocab 151 | self.dropout = dropout 152 | self.decoder = AttentionDecoder(dim_F, dim_rep, out_vocab, layer) 153 | self.optimizer = tf.train.AdamOptimizer(lr) 154 | 155 | def encode(self, trees): 156 | ''' 157 | ys: list of [seq_len, dim] 158 | hx, cx: [batch, dim] 159 | return: ys, [hx, cx] 160 | ''' 161 | 162 | def train_on_batch(self, x, y): 163 | with tf.GradientTape() as tape: 164 | y_enc, (c, h) = self.encode(x) 165 | loss = self.decoder.get_loss(y_enc, (c, h), y, dropout=self.dropout) 166 | variables = self.variables 167 | gradients = tape.gradient(loss, variables) 168 | self.optimizer.apply_gradients(zip(gradients, variables)) 169 | return loss.numpy() 170 | 171 | def translate(self, x, nl_i2w, nl_w2i, max_length=100): 172 | res = [] 173 | y_enc, (c, h) = self.encode(x) 174 | batch_size = len(y_enc) 175 | for i in range(batch_size): 176 | nl, _ = self.decoder.translate( 177 | y_enc[i], (c[i], h[i]), max_length, nl_w2i[""], nl_w2i[""]) 178 | res.append([nl_i2w[n] for n in nl]) 179 | return res 180 | 181 | def evaluate_on_batch(self, x, y): 182 | y_enc, (c, h) = self.encode(x) 183 | loss = self.decoder.get_loss(y_enc, (c, h), y) 184 | return loss.numpy() 185 | 186 | 187 | class CodennModel(BaseModel): 188 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.5, lr=1e-3): 189 | super(CodennModel, self).__init__(dim_E, dim_F, dim_rep, in_vocab, 190 | out_vocab, layer, dropout, lr) 191 | self.dropout = dropout 192 | self.E = SetEmbeddingLayer(dim_E, in_vocab) 193 | print("I am CodeNNModel, dim is {} and {} layered".format( 194 | str(self.dim_rep), "0")) 195 | 196 | def encode(self, sets): 197 | sets = self.E(sets) 198 | # sets = [tf.nn.dropout(t, 1. - self.dropout) for t in sets] 199 | 200 | hx = tf.zeros([len(sets), self.dim_rep]) 201 | cx = tf.zeros([len(sets), self.dim_rep]) 202 | ys = sets 203 | 204 | return ys, [hx, cx] 205 | 206 | 207 | class Seq2seqModel(BaseModel): 208 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.5, lr=1e-3): 209 | super(Seq2seqModel, self).__init__(dim_E, dim_F, 210 | dim_rep, in_vocab, out_vocab, layer, dropout, lr) 211 | self.layer = layer 212 | self.dropout = dropout 213 | self.E = tf.keras.layers.Embedding(in_vocab + 1, dim_E, mask_zero=True) 214 | for i in range(layer): 215 | self.__setattr__("layer{}".format(i), 216 | tf.keras.layers.CuDNNLSTM(dim_rep, 217 | return_sequences=True, 218 | return_state=True)) 219 | print("I am seq2seq model, dim is {} and {} layered".format( 220 | str(self.dim_rep), str(self.layer))) 221 | 222 | def encode(self, seq): 223 | length = get_length(seq) 224 | tensor = self.E(seq + 1) 225 | # tensor = tf.nn.dropout(tensor, 1. - self.dropout) 226 | for i in range(self.layer): 227 | skip = tensor 228 | tensor, h, c = getattr(self, "layer{}".format(i))(tensor) 229 | tensor += skip 230 | 231 | cx = c 232 | hx = h 233 | ys = [y[:i] for y, i in zip(tf.unstack(tensor, axis=0), length.numpy())] 234 | 235 | return ys, [hx, cx] 236 | 237 | 238 | class ChildsumModel(BaseModel): 239 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.5, lr=1e-4): 240 | super(ChildsumModel, self).__init__(dim_E, dim_F, 241 | dim_rep, in_vocab, out_vocab, layer, dropout, lr) 242 | self.layer = layer 243 | self.dropout = dropout 244 | self.E = TreeEmbeddingLayer(dim_E, in_vocab) 245 | for i in range(layer): 246 | self.__setattr__("layer{}".format(i), ChildSumLSTMLayer(dim_E, dim_rep)) 247 | print("I am Child-sum model, dim is {} and {} layered".format( 248 | str(self.dim_rep), str(self.layer))) 249 | 250 | def encode(self, x): 251 | tensor, indice, tree_num = x 252 | tensor = self.E(tensor) 253 | # tensor = [tf.nn.dropout(t, 1. - self.dropout) for t in tensor] 254 | for i in range(self.layer): 255 | skip = tensor 256 | tensor, c = getattr(self, "layer{}".format(i))(tensor, indice) 257 | tensor = [t + s for t, s in zip(tensor, skip)] 258 | 259 | hx = tensor[-1] 260 | cx = c[-1] 261 | ys = [] 262 | batch_size = tensor[-1].shape[0] 263 | tensor = tf.concat(tensor, 0) 264 | tree_num = tf.concat(tree_num, 0) 265 | for batch in range(batch_size): 266 | ys.append(tf.boolean_mask(tensor, tf.equal(tree_num, batch))) 267 | return ys, [hx, cx] 268 | 269 | 270 | class NaryModel(BaseModel): 271 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.5, lr=1e-4): 272 | super(NaryModel, self).__init__(dim_E, dim_F, 273 | dim_rep, in_vocab, out_vocab, layer, dropout, lr) 274 | self.layer = layer 275 | self.dropout = dropout 276 | self.E = TreeEmbeddingLayer(dim_E, in_vocab) 277 | for i in range(layer): 278 | self.__setattr__("layer{}".format(i), NaryLSTMLayer(dim_E, dim_rep)) 279 | print("I am N-ary model, dim is {} and {} layered".format( 280 | str(self.dim_rep), str(self.layer))) 281 | 282 | def encode(self, x): 283 | tensor, indice, tree_num = x 284 | tensor = self.E(tensor) 285 | # tensor = [tf.nn.dropout(t, 1. - self.dropout) for t in tensor] 286 | for i in range(self.layer): 287 | skip = tensor 288 | tensor, c = getattr(self, "layer{}".format(i))(tensor, indice) 289 | tensor = [t + s for t, s in zip(tensor, skip)] 290 | 291 | hx = tensor[-1] 292 | cx = c[-1] 293 | ys = [] 294 | batch_size = tensor[-1].shape[0] 295 | tensor = tf.concat(tensor, 0) 296 | tree_num = tf.concat(tree_num, 0) 297 | for batch in range(batch_size): 298 | ys.append(tf.boolean_mask(tensor, tf.equal(tree_num, batch))) 299 | return ys, [hx, cx] 300 | 301 | 302 | class MultiwayModel(BaseModel): 303 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.0, lr=1e-4): 304 | super(MultiwayModel, self).__init__(dim_E, dim_F, 305 | dim_rep, in_vocab, out_vocab, layer, dropout, lr) 306 | self.layer = layer 307 | self.dropout = dropout 308 | self.E = TreeEmbeddingLayer(dim_E, in_vocab) 309 | for i in range(layer): 310 | self.__setattr__("layer{}".format(i), ShidoTreeLSTMLayer(dim_E, dim_rep)) 311 | print("I am Multi-way model, dim is {} and {} layered".format( 312 | str(self.dim_rep), str(self.layer))) 313 | 314 | def encode(self, x): 315 | tensor, indice, tree_num = x 316 | tensor = self.E(tensor) 317 | # tensor = [tf.nn.dropout(t, 1. - self.dropout) for t in tensor] 318 | for i in range(self.layer): 319 | skip = tensor 320 | tensor, c = getattr(self, "layer{}".format(i))(tensor, indice) 321 | tensor = [t + s for t, s in zip(tensor, skip)] 322 | 323 | hx = tensor[-1] 324 | cx = c[-1] 325 | ys = [] 326 | batch_size = tensor[-1].shape[0] 327 | tensor = tf.concat(tensor, 0) 328 | tree_num = tf.concat(tree_num, 0) 329 | for batch in range(batch_size): 330 | ys.append(tf.boolean_mask(tensor, tf.equal(tree_num, batch))) 331 | return ys, [hx, cx] 332 | -------------------------------------------------------------------------------- /notebooks/check_result.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "import sys\n", 11 | "sys.path.append(\"../\")\n", 12 | "from matplotlib import pylab as plt\n", 13 | "import numpy as np\n", 14 | "from glob import glob\n", 15 | "import json\n", 16 | "from utils import *\n", 17 | "from tqdm import tqdm\n", 18 | "import pandas as pd" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "!CUDA_VISIBLE_DEVICE=" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "codes = [json.loads(s)['code'] for s in open(\"/home/shido/summarization_java/test.json\").readlines()]" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "trn_data = read_pickle(\"../dataset/nl/train.pkl\")\n", 46 | "vld_data = read_pickle(\"../dataset/nl/valid.pkl\")\n", 47 | "tst_data = read_pickle(\"../dataset/nl/test.pkl\")\n", 48 | "code_i2w = read_pickle(\"../dataset/code_i2w.pkl\")\n", 49 | "code_w2i = read_pickle(\"../dataset/code_w2i.pkl\")\n", 50 | "nl_i2w = read_pickle(\"../dataset/nl_i2w.pkl\")\n", 51 | "nl_w2i = read_pickle(\"../dataset/nl_w2i.pkl\")\n", 52 | "\n", 53 | "trn_x, trn_y_raw = zip(*trn_data.items())\n", 54 | "vld_x, vld_y_raw = zip(*vld_data.items())\n", 55 | "tst_x, tst_y_raw = zip(*tst_data.items())\n", 56 | "\n", 57 | "trn_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in trn_y_raw]\n", 58 | "vld_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in vld_y_raw]\n", 59 | "tst_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in tst_y_raw]" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "sorted([len(x) for x in trn_y])[::-1]" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "print(len(trn_y))\n", 78 | "print(len(vld_y))\n", 79 | "print(len(tst_y))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "lengthes = [len(traverse_label(read_pickle(x))) for x in tst_x]" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "files = sorted(glob(\"../models/*/history.json\"))\n", 98 | "dirs = [x.split(\"/\")[-2] for x in files]\n", 99 | "histories = {name: json.load(open(x)) for name, x in zip(dirs, files)}\n", 100 | "dirs" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "names = [\n", 110 | " \"multiway_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\",\n", 111 | " \"childsum_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\",\n", 112 | " \"deepcom_dim256_embed256_drop0.5_lr0.001_batch64_epochs30_layer1\",\n", 113 | " \"codenn_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\",\n", 114 | "]" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "plt.figure(figsize=(15, 10))\n", 124 | "# for name, his in [(view, histories[name]) for view, name in zip([\"Ours\", \"Child-Sum\", \"[Hu+, 18]\", \"[Iyer+, 16]\"], names)]:\n", 125 | "for name, his in histories.items():\n", 126 | " if \"drop0.5_lr0.001_batch160_epochs50_layer1\" in name:\n", 127 | " plt.plot(his[\"bleu_val\"], \"-\", label=name)\n", 128 | "# plt.plot(his[\"loss_val\"], \"-x\", label=name + \"_valid\")\n", 129 | "plt.grid()\n", 130 | "plt.legend()" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "for name, his in histories.items():\n", 140 | " if \"drop0.5_lr0.001_batch160_epochs50_layer1\" in name:\n", 141 | " print(name, \":\", np.mean(his[\"bleus\"]))" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "from nltk.translate.gleu_score import sentence_gleu\n", 151 | "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n", 152 | "def get_gleu(true, pred):\n", 153 | " return(sentence_gleu([true], pred))\n", 154 | "def get_bleu(true, pred):\n", 155 | " return(sentence_bleu([true], pred, smoothing_function=SmoothingFunction().method4))" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "# BLEU\n", 165 | "for name, his in histories.items():\n", 166 | " if \"drop0.5_lr0.001_batch160_epochs50_layer1\" in name:\n", 167 | " trues = histories[name][\"trues\"]\n", 168 | " preds = histories[name][\"preds\"]\n", 169 | " gleu = np.mean([get_bleu(x, y) for x, y in zip(trues, preds)])\n", 170 | " print(name, gleu)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "# GLEU\n", 180 | "for name, his in histories.items():\n", 181 | " if \"drop0.5_lr0.001_batch160_epochs50_layer1\" in name:\n", 182 | " trues = histories[name][\"trues\"]\n", 183 | " preds = histories[name][\"preds\"]\n", 184 | " gleu = np.mean([get_gleu(x, y) for x, y in zip(trues, preds)])\n", 185 | " print(name, gleu)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "codes = [codes[int(i)] for i in histories[names[0]][\"numbers\"]]" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "df_data = [[\" \".join(x) for x in histories[name][\"preds\"]] for name in names] + [[\" \".join(x) for x in histories[names[0]][\"trues\"]]]\n", 204 | "df_data += [histories[name][\"bleus\"] for name in names] + [codes]\n", 205 | "df_index = [\"PREDICTION \" + name for name in names] + [\"GROUND TRUTH\"]\n", 206 | "df_index += [\"BLEU-4 \" + name for name in names] + [\"SOURCE CODE\"]\n", 207 | "df = pd.DataFrame(data=df_data, index=df_index).T" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "df.to_csv(\"for_hitachi_lab.csv\")" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "df.get([\"SOURCE CODE\", \"GROUND TRUTH\"]).head()" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "shido = histories[\"multiway_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\"][\"bleus\"]\n", 235 | "np.mean(shido)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "child = histories[\"childsum_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\"][\"bleus\"]\n", 245 | "np.mean(child)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "dif = np.array(shido) - np.array(child)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "index = np.argsort(dif)[::-1]" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "trues = histories[\"childsumlstm_1layer\"][\"trues\"]" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "childsum = histories[\"childsum_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\"][\"preds\"]" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "ours = histories[\"multiway_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\"][\"preds\"]" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "number = histories[\"multiway_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\"][\"numbers\"]" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "for i in index[:500]:\n", 309 | " if \" \".join(trues[i]) != \" \".join(ours[i]):\n", 310 | " print(\"GT: \", \" \".join(trues[i]))\n", 311 | " print(\"CSum: \", \" \".join(childsum[i]))\n", 312 | " print(\"Ours: \", \" \".join(ours[i]))\n", 313 | " print(\"Codes:\\n\" + codes[i])\n", 314 | "# print(\"Num: \", number[i])\n", 315 | " print(\"-\" * 100)" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "trn_data = read_pickle(\"../dataset/nl/train.pkl\")\n", 325 | "vld_data = read_pickle(\"../dataset/nl/valid.pkl\")\n", 326 | "tst_data = read_pickle(\"../dataset/nl/test.pkl\")\n", 327 | "code_i2w = read_pickle(\"../dataset/code_i2w.pkl\")\n", 328 | "code_w2i = read_pickle(\"../dataset/code_w2i.pkl\")\n", 329 | "nl_i2w = read_pickle(\"../dataset/nl_i2w.pkl\")\n", 330 | "nl_w2i = read_pickle(\"../dataset/nl_w2i.pkl\")" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "trn_data" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "xx = [traverse_label(read_pickle(t)) for t in tst_x]" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "x = [\" \".join([str(code_w2i[w]) for w in t]) + \"\\n\" for t in xx]" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "y = [\" \".join([str(w) for w in t[1:-1]]) + \"\\n\" for t in tst_y]" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "xy = [xw + \"\\t\" + yw for xw, yw in zip(x, y)]" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "y[0]" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "open(\"x.tst\", \"w\").writelines(x)" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "len(x) / 32" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "y[0]" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "a = open(\"x.tst\", \"r\").read()" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": null, 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "a.split(\"\\n\")[-2]" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [] 445 | } 446 | ], 447 | "metadata": { 448 | "kernelspec": { 449 | "display_name": "Python 3", 450 | "language": "python", 451 | "name": "python3" 452 | }, 453 | "language_info": { 454 | "codemirror_mode": { 455 | "name": "ipython", 456 | "version": 3 457 | }, 458 | "file_extension": ".py", 459 | "mimetype": "text/x-python", 460 | "name": "python", 461 | "nbconvert_exporter": "python", 462 | "pygments_lexer": "ipython3", 463 | "version": "3.6.1" 464 | } 465 | }, 466 | "nbformat": 4, 467 | "nbformat_minor": 2 468 | } 469 | -------------------------------------------------------------------------------- /notebooks/example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "import sys\n", 11 | "sys.path.append(\"../\")\n", 12 | "import pickle\n", 13 | "import numpy as np\n", 14 | "from tqdm import tqdm_notebook\n", 15 | "from prefetch_generator import BackgroundGenerator\n", 16 | "from matplotlib import pylab as plt\n", 17 | "from IPython.display import clear_output\n", 18 | "import os\n", 19 | "from joblib import Parallel, delayed\n", 20 | "from tqdm import tqdm\n", 21 | "import nltk\n", 22 | "from glob import glob\n", 23 | "from joblib import Parallel, delayed\n", 24 | "from collections import Counter\n", 25 | "from layers import *\n", 26 | "from utils import *\n", 27 | "from models import *\n", 28 | "import json\n", 29 | "import tensorflow as tf\n", 30 | "tfe = tf.contrib.eager \n", 31 | "config = tf.ConfigProto(\n", 32 | " gpu_options=tf.GPUOptions(\n", 33 | " visible_device_list=\"0\"))\n", 34 | "config.gpu_options.allow_growth = True\n", 35 | "session = tf.Session(config=config)\n", 36 | "tf.enable_eager_execution(config=config)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "checkpoint_dir = \"../models/path_to_dir\"" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "trn_data = read_pickle(\"../dataset/nl/train.pkl\")\n", 55 | "vld_data = read_pickle(\"../dataset/nl/valid.pkl\")\n", 56 | "tst_data = read_pickle(\"../dataset/nl/test.pkl\")\n", 57 | "code_i2w = read_pickle(\"../dataset/code_i2w.pkl\")\n", 58 | "code_w2i = read_pickle(\"../dataset/code_w2i.pkl\")\n", 59 | "nl_i2w = read_pickle(\"../dataset/nl_i2w.pkl\")\n", 60 | "nl_w2i = read_pickle(\"../dataset/nl_w2i.pkl\")" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "trn_x, trn_y_raw = zip(*trn_data.items())\n", 70 | "vld_x, vld_y_raw = zip(*vld_data.items())\n", 71 | "tst_x, tst_y_raw = zip(*tst_data.items())" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "trn_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in trn_y_raw]\n", 81 | "vld_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in vld_y_raw]\n", 82 | "tst_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in tst_y_raw]" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "# model defining\n", 92 | "class Model(BaseModel):\n", 93 | " def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.5, lr=1e-4):\n", 94 | " super(Model, self).__init__(dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer, dropout, lr)\n", 95 | " self.E = TreeEmbeddingLayer(dim_E, in_vocab)\n", 96 | " self.encoder = ChildSumLSTMLayer(dim_E, dim_rep)\n", 97 | " \n", 98 | " def encode(self, trees):\n", 99 | " trees = self.E(trees)\n", 100 | " trees = self.encoder(trees)\n", 101 | " \n", 102 | " hx = tf.stack([tree.h for tree in trees])\n", 103 | " cx = tf.stack([tree.c for tree in trees])\n", 104 | " ys = [tf.stack([node.h for node in traverse(tree)]) for tree in trees]\n", 105 | " \n", 106 | " return ys, [hx, cx]" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# setting model\n", 116 | "model = Model(512, 512, 512, len(code_w2i), len(nl_w2i), dropout=0.5, lr=1e-4)\n", 117 | "epochs = 15\n", 118 | "batch_size = 64\n", 119 | "os.makedirs(checkpoint_dir, exist_ok=True)\n", 120 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", 121 | "root = tfe.Checkpoint(model=model)\n", 122 | "history = {\"loss\":[], \"loss_val\":[]}" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "# Setting Data Generator\n", 132 | "trn_gen = Datagen_tree(trn_x, trn_y, batch_size, code_w2i, nl_i2w, train=True)\n", 133 | "vld_gen = Datagen_tree(vld_x, vld_y, batch_size, code_w2i, nl_i2w, train=False)\n", 134 | "tst_gen = Datagen_tree(tst_x, tst_y, batch_size, code_w2i, nl_i2w, train=False)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "# training\n", 144 | "for epoch in range(epochs):\n", 145 | " \n", 146 | " # train\n", 147 | " loss_tmp = []\n", 148 | " t = tqdm(trn_gen(epoch))\n", 149 | " for x, y, _, _ in t:\n", 150 | " loss_tmp.append(model.train_on_batch(x, y))\n", 151 | " t.set_description(\"epoch:{:03d}, loss = {}\".format(epoch + 1, np.mean(loss_tmp)))\n", 152 | " history[\"loss\"].append(np.sum(loss_tmp) / len(t))\n", 153 | " \n", 154 | " loss_tmp = []\n", 155 | " t = tqdm(vld_gen(epoch))\n", 156 | " for x, y, _, _ in t:\n", 157 | " loss_tmp.append(model.evaluate_on_batch(x, y))\n", 158 | " t.set_description(\"epoch:{:03d}, loss_val = {}\".format(epoch + 1, np.mean(loss_tmp)))\n", 159 | " history[\"loss_val\"].append(np.sum(loss_tmp) / len(t))\n", 160 | " \n", 161 | " # checkpoint\n", 162 | " if history[\"loss_val\"][-1] == min(history[\"loss_val\"]):\n", 163 | " checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", 164 | " root.save(file_prefix=checkpoint_prefix)\n", 165 | " \n", 166 | " # print\n", 167 | " clear_output()\n", 168 | " for key, val in history.items():\n", 169 | " if \"loss\" in key:\n", 170 | " plt.plot(val, label=key)\n", 171 | " plt.legend()\n", 172 | " plt.show()" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "root.restore(tf.train.latest_checkpoint(checkpoint_dir))" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "preds = []\n", 191 | "trues = []\n", 192 | "for x, y, _, y_raw in tqdm(tst_gen(0)):\n", 193 | " res = model.translate(x, nl_i2w, nl_w2i)\n", 194 | " preds += res\n", 195 | " trues += [s[1:-1] for s in y_raw]" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "bleus = Parallel(n_jobs=-1)(delayed(bleu4)(t, p) for t, p in tqdm(list(zip(trues, preds))))" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "history[\"bleus\"] = bleus\n", 214 | "history[\"preds\"] = preds\n", 215 | "history[\"trues\"] = trues\n", 216 | "history[\"numbers\"] = [int(x.split(\"/\")[-1]) for x in tst_x]" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "with open(os.path.join(checkpoint_dir, \"history.json\"), \"w\") as f:\n", 226 | " json.dump(history, f)" 227 | ] 228 | } 229 | ], 230 | "metadata": { 231 | "anaconda-cloud": {}, 232 | "kernelspec": { 233 | "display_name": "Python 3", 234 | "language": "python", 235 | "name": "python3" 236 | }, 237 | "language_info": { 238 | "codemirror_mode": { 239 | "name": "ipython", 240 | "version": 3 241 | }, 242 | "file_extension": ".py", 243 | "mimetype": "text/x-python", 244 | "name": "python", 245 | "nbconvert_exporter": "python", 246 | "pygments_lexer": "ipython3", 247 | "version": "3.6.1" 248 | } 249 | }, 250 | "nbformat": 4, 251 | "nbformat_minor": 2 252 | } 253 | -------------------------------------------------------------------------------- /parser/README.md: -------------------------------------------------------------------------------- 1 | # parser 2 | 3 | Run `java -jar parser.jar -f [filename] -d [dirname]`. 4 | 5 | # example 6 | 7 | `java -jar parser.jar -f valid.json -d valid` 8 | 9 | # requirement 10 | Java 1.8 11 | -------------------------------------------------------------------------------- /parser/parser.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sh1doy/summarization_tf/2f14f2c28c63140288acc6515db236e486ab7152/parser/parser.jar -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.5.0 2 | alabaster==0.7.10 3 | anaconda-client==1.6.3 4 | anaconda-navigator==1.6.2 5 | anaconda-project==0.6.0 6 | asn1crypto==0.22.0 7 | astor==0.7.1 8 | astroid==1.4.9 9 | astropy==1.3.2 10 | Babel==2.4.0 11 | backports.shutil-get-terminal-size==1.0.0 12 | beautifulsoup4==4.6.0 13 | bitarray==0.8.1 14 | blaze==0.10.1 15 | bleach==1.5.0 16 | bokeh==0.12.5 17 | boto==2.46.1 18 | Bottleneck==1.2.1 19 | cffi==1.10.0 20 | chardet==3.0.3 21 | click==6.7 22 | cloudpickle==0.2.2 23 | clyent==1.2.2 24 | colorama==0.3.9 25 | conda==4.3.21 26 | contextlib2==0.5.5 27 | cryptography==1.8.1 28 | cycler==0.10.0 29 | Cython==0.25.2 30 | cytoolz==0.8.2 31 | dask==0.14.3 32 | datashape==0.5.4 33 | decorator==4.0.11 34 | distributed==1.16.3 35 | docutils==0.13.1 36 | entrypoints==0.2.2 37 | et-xmlfile==1.0.1 38 | fastcache==1.0.2 39 | Flask==0.12.2 40 | Flask-Cors==3.0.2 41 | gast==0.2.0 42 | gevent==1.2.1 43 | greenlet==0.4.12 44 | grpcio==1.15.0 45 | h5py==2.8.0 46 | HeapDict==1.0.0 47 | html5lib==0.999 48 | idna==2.5 49 | imagesize==0.7.1 50 | ipykernel==4.6.1 51 | ipython==5.3.0 52 | ipython-genutils==0.2.0 53 | ipywidgets==6.0.0 54 | isort==4.2.5 55 | itsdangerous==0.24 56 | jdcal==1.3 57 | jedi==0.10.2 58 | Jinja2==2.9.6 59 | joblib==0.12.5 60 | jsonschema==2.6.0 61 | jupyter==1.0.0 62 | jupyter-client==5.0.1 63 | jupyter-console==5.1.0 64 | jupyter-core==4.3.0 65 | jupyterthemes==0.17.0 66 | lazy-object-proxy==1.2.2 67 | lesscpy==0.13.0 68 | llvmlite==0.18.0 69 | locket==0.2.0 70 | lxml==3.7.3 71 | Markdown==2.6.11 72 | MarkupSafe==0.23 73 | matplotlib==2.0.2 74 | mistune==0.7.4 75 | mpmath==0.19 76 | msgpack-python==0.4.8 77 | multipledispatch==0.4.9 78 | navigator-updater==0.1.0 79 | nbconvert==5.1.1 80 | nbformat==4.3.0 81 | networkx==1.11 82 | nltk==3.2.3 83 | nose==1.3.7 84 | notebook==5.0.0 85 | numba==0.33.0 86 | numexpr==2.6.2 87 | numpy==1.14.5 88 | numpydoc==0.6.0 89 | odo==0.5.0 90 | olefile==0.44 91 | openpyxl==2.4.7 92 | packaging==16.8 93 | pandas==0.20.1 94 | pandocfilters==1.4.1 95 | partd==0.3.8 96 | pathlib2==2.2.1 97 | patsy==0.4.1 98 | pep8==1.7.0 99 | pexpect==4.2.1 100 | pickleshare==0.7.4 101 | Pillow==4.1.1 102 | ply==3.10 103 | prefetch-generator==1.0.0 104 | prometheus-client==0.3.1 105 | prompt-toolkit==1.0.14 106 | protobuf==3.6.1 107 | psutil==5.2.2 108 | ptyprocess==0.5.1 109 | py==1.4.33 110 | pycosat==0.6.2 111 | pycparser==2.17 112 | pycrypto==2.6.1 113 | pycurl==7.43.0 114 | pyflakes==1.5.0 115 | Pygments==2.2.0 116 | pylint==1.6.4 117 | pyodbc==4.0.16 118 | pyOpenSSL==17.0.0 119 | pyparsing==2.1.4 120 | pytest==3.0.7 121 | python-dateutil==2.6.0 122 | pytz==2017.2 123 | PyWavelets==0.5.2 124 | PyYAML==3.12 125 | pyzmq==16.0.2 126 | QtAwesome==0.4.4 127 | qtconsole==4.3.0 128 | QtPy==1.2.1 129 | requests==2.14.2 130 | rope-py3k==0.9.4.post1 131 | scikit-image==0.13.0 132 | scikit-learn==0.18.1 133 | scipy==0.19.0 134 | seaborn==0.7.1 135 | simplegeneric==0.8.1 136 | singledispatch==3.4.0.3 137 | six==1.10.0 138 | snowballstemmer==1.2.1 139 | sortedcollections==0.5.3 140 | sortedcontainers==1.5.7 141 | Sphinx==1.5.6 142 | spyder==3.1.4 143 | SQLAlchemy==1.1.9 144 | statsmodels==0.8.0 145 | sympy==1.0 146 | tables==3.3.0 147 | tblib==1.3.2 148 | tensorboard==1.10.0 149 | tensorflow-gpu==1.10.1 150 | termcolor==1.1.0 151 | terminado==0.6 152 | testpath==0.3 153 | toolz==0.8.2 154 | tornado==4.5.1 155 | tqdm==4.26.0 156 | traitlets==4.3.2 157 | unicodecsv==0.14.1 158 | wcwidth==0.1.7 159 | Werkzeug==0.12.2 160 | widgetsnbextension==2.0.0 161 | wrapt==1.10.10 162 | xlrd==1.0.0 163 | XlsxWriter==0.9.6 164 | xlwt==1.2.0 165 | zict==0.1.2 166 | -------------------------------------------------------------------------------- /retrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import read_pickle, Datagen_set, Datagen_deepcom, Datagen_tree, Datagen_binary, bleu4 3 | from models import Seq2seqModel, CodennModel, ChildsumModel, MultiwayModel, NaryModel 4 | import numpy as np 5 | import os 6 | import tensorflow as tf 7 | from tqdm import tqdm 8 | from joblib import delayed, Parallel 9 | import json 10 | 11 | 12 | # parse argments 13 | 14 | parser = argparse.ArgumentParser(description='Source Code Generation') 15 | 16 | parser.add_argument('-m', "--method", type=str, nargs="?", required=True, 17 | choices=['seq2seq', 'deepcom', 'codenn', 'childsum', 'multiway', "nary"], 18 | help='Encoder method') 19 | parser.add_argument('-d', "--dim", type=int, nargs="?", required=False, default=512, 20 | help='Representation dimension') 21 | parser.add_argument("--embed", type=int, nargs="?", required=False, default=256, 22 | help='Representation dimension') 23 | parser.add_argument("--drop", type=float, nargs="?", required=False, default=.5, 24 | help="Dropout rate") 25 | parser.add_argument('-r', "--lr", type=float, nargs="?", required=True, 26 | help='Learning rate') 27 | parser.add_argument('-b', "--batch", type=int, nargs="?", required=True, 28 | help='Mini batch size') 29 | parser.add_argument('-e', "--epochs", type=int, nargs="?", required=True, 30 | help='Epoch number') 31 | parser.add_argument('-g', "--gpu", type=str, nargs="?", required=True, 32 | help='What GPU to use') 33 | parser.add_argument('-l', "--layer", type=int, nargs="?", required=False, default=1, 34 | help='Number of layers') 35 | parser.add_argument("--val", type=str, nargs="?", required=False, default="BLEU", 36 | help='Validation method') 37 | 38 | args = parser.parse_args() 39 | 40 | name = args.method + "_dim" + str(args.dim) + "_embed" + str(args.embed) 41 | name = name + "_drop" + str(args.drop) 42 | name = name + "_lr" + str(args.lr) + "_batch" + str(args.batch) 43 | name = name + "_epochs" + str(args.epochs) + "_layer" + str(args.layer) 44 | 45 | checkpoint_dir = "./models/" + name 46 | 47 | 48 | # set tf eager 49 | 50 | tfe = tf.contrib.eager 51 | config = tf.ConfigProto( 52 | gpu_options=tf.GPUOptions( 53 | visible_device_list=args.gpu)) 54 | # config.gpu_options.allow_growth = True 55 | session = tf.Session(config=config) 56 | tf.enable_eager_execution(config=config) 57 | os.makedirs("./logs/" + name, exist_ok=True) 58 | writer = tf.contrib.summary.create_file_writer("./logs/" + name, flush_millis=10000) 59 | 60 | 61 | # load data 62 | 63 | trn_data = read_pickle("dataset/nl/train.pkl") 64 | vld_data = read_pickle("dataset/nl/valid.pkl") 65 | tst_data = read_pickle("dataset/nl/test.pkl") 66 | code_i2w = read_pickle("dataset/code_i2w.pkl") 67 | code_w2i = read_pickle("dataset/code_w2i.pkl") 68 | nl_i2w = read_pickle("dataset/nl_i2w.pkl") 69 | nl_w2i = read_pickle("dataset/nl_w2i.pkl") 70 | 71 | trn_x, trn_y_raw = zip(*trn_data.items()) 72 | vld_x, vld_y_raw = zip(*vld_data.items()) 73 | tst_x, tst_y_raw = zip(*tst_data.items()) 74 | 75 | trn_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in trn_y_raw] 76 | vld_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in vld_y_raw] 77 | tst_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in tst_y_raw] 78 | 79 | 80 | # setting model 81 | 82 | if args.method in ['seq2seq', 'deepcom']: 83 | Model = Seq2seqModel 84 | elif args.method in ['codenn']: 85 | Model = CodennModel 86 | elif args.method in ['childsum']: 87 | Model = ChildsumModel 88 | elif args.method in ['multiway']: 89 | Model = MultiwayModel 90 | elif args.method in ['nary']: 91 | Model = NaryModel 92 | 93 | 94 | model = Model(args.dim, args.dim, args.dim, len(code_w2i), len(nl_w2i), 95 | dropout=args.drop, lr=args.lr, layer=args.layer) 96 | epochs = args.epochs 97 | batch_size = args.batch 98 | os.makedirs(checkpoint_dir, exist_ok=True) 99 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") 100 | root = tfe.Checkpoint(model=model) 101 | history = {"loss": [], "loss_val": [], "bleu_val": []} 102 | 103 | root.restore(tf.train.latest_checkpoint(checkpoint_dir)) 104 | 105 | # Setting Data Generator 106 | 107 | if args.method in ['deepcom']: 108 | Datagen = Datagen_deepcom 109 | elif args.method in ['codenn']: 110 | Datagen = Datagen_set 111 | elif args.method in ['childsum', 'multiway']: 112 | Datagen = Datagen_tree 113 | elif args.method in ['nary']: 114 | Datagen = Datagen_binary 115 | 116 | 117 | trn_gen = Datagen(trn_x, trn_y, batch_size, code_w2i, nl_i2w, train=True) 118 | vld_gen = Datagen(vld_x, vld_y, batch_size, code_w2i, nl_i2w, train=False) 119 | tst_gen = Datagen(tst_x, tst_y, batch_size, code_w2i, nl_i2w, train=False) 120 | 121 | 122 | # training 123 | with writer.as_default(), tf.contrib.summary.always_record_summaries(): 124 | 125 | for epoch in range(1, epochs + 1): 126 | 127 | # train 128 | loss_tmp = [] 129 | t = tqdm(trn_gen(0)) 130 | for x, y, _, _ in t: 131 | loss_tmp.append(model.train_on_batch(x, y)) 132 | t.set_description("epoch:{:03d}, loss = {}".format(epoch, np.mean(loss_tmp))) 133 | history["loss"].append(np.sum(loss_tmp) / len(t)) 134 | tf.contrib.summary.scalar("loss", np.sum(loss_tmp) / len(t), step=epoch) 135 | 136 | # validate loss 137 | loss_tmp = [] 138 | t = tqdm(vld_gen(0)) 139 | for x, y, _, _ in t: 140 | loss_tmp.append(model.evaluate_on_batch(x, y)) 141 | t.set_description("epoch:{:03d}, loss_val = {}".format(epoch, np.mean(loss_tmp))) 142 | history["loss_val"].append(np.sum(loss_tmp) / len(t)) 143 | tf.contrib.summary.scalar("loss_val", np.sum(loss_tmp) / len(t), step=epoch) 144 | 145 | # validate bleu 146 | preds = [] 147 | trues = [] 148 | bleus = [] 149 | t = tqdm(vld_gen(0)) 150 | for x, y, _, y_raw in t: 151 | res = model.translate(x, nl_i2w, nl_w2i) 152 | preds += res 153 | trues += [s[1:-1] for s in y_raw] 154 | bleus += [bleu4(tt, p) for tt, p in zip(trues, preds)] 155 | t.set_description("epoch:{:03d}, bleu_val = {}".format(epoch, np.mean(bleus))) 156 | history["bleu_val"].append(np.mean(bleus)) 157 | tf.contrib.summary.scalar("bleu_val", np.mean(bleus), step=epoch) 158 | 159 | # checkpoint 160 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") 161 | hoge = root.save(file_prefix=checkpoint_prefix) 162 | if history["bleu_val"][-1] == max(history["bleu_val"]): 163 | best_model = hoge 164 | print("Now best model is {}".format(best_model)) 165 | 166 | 167 | # load final weight 168 | 169 | print("Restore {}".format(best_model)) 170 | root.restore(best_model) 171 | 172 | # evaluation 173 | 174 | preds = [] 175 | trues = [] 176 | for x, y, _, y_raw in tqdm(tst_gen(0), "Testing"): 177 | res = model.translate(x, nl_i2w, nl_w2i) 178 | preds += res 179 | trues += [s[1:-1] for s in y_raw] 180 | 181 | bleus = Parallel(n_jobs=-1)(delayed(bleu4)(t, p) for t, p in (list(zip(trues, preds)))) 182 | 183 | history["bleus"] = bleus 184 | history["preds"] = preds 185 | history["trues"] = trues 186 | history["numbers"] = [int(x.split("/")[-1]) for x in tst_x] 187 | 188 | with open(os.path.join(checkpoint_dir, "history.json"), "w") as f: 189 | json.dump(history, f) 190 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import read_pickle, Datagen_set, Datagen_deepcom, Datagen_tree, Datagen_binary, bleu4 3 | from models import Seq2seqModel, CodennModel, ChildsumModel, MultiwayModel, NaryModel 4 | import numpy as np 5 | import os 6 | import tensorflow as tf 7 | from tqdm import tqdm 8 | from joblib import delayed, Parallel 9 | import json 10 | 11 | 12 | # parse argments 13 | 14 | parser = argparse.ArgumentParser(description='Source Code Generation') 15 | 16 | parser.add_argument('-m', "--method", type=str, nargs="?", required=True, 17 | choices=['seq2seq', 'deepcom', 'codenn', 'childsum', 'multiway', "nary"], 18 | help='Encoder method') 19 | parser.add_argument('-d', "--dim", type=int, nargs="?", required=False, default=512, 20 | help='Representation dimension') 21 | parser.add_argument("--embed", type=int, nargs="?", required=False, default=256, 22 | help='Representation dimension') 23 | parser.add_argument("--drop", type=float, nargs="?", required=False, default=.5, 24 | help="Dropout rate") 25 | parser.add_argument('-r', "--lr", type=float, nargs="?", required=True, 26 | help='Learning rate') 27 | parser.add_argument('-b', "--batch", type=int, nargs="?", required=True, 28 | help='Mini batch size') 29 | parser.add_argument('-e', "--epochs", type=int, nargs="?", required=True, 30 | help='Epoch number') 31 | parser.add_argument('-g', "--gpu", type=str, nargs="?", required=True, 32 | help='What GPU to use') 33 | parser.add_argument('-l', "--layer", type=int, nargs="?", required=False, default=1, 34 | help='Number of layers') 35 | parser.add_argument("--val", type=str, nargs="?", required=False, default="BLEU", 36 | help='Validation method') 37 | 38 | args = parser.parse_args() 39 | 40 | name = args.method + "_dim" + str(args.dim) + "_embed" + str(args.embed) 41 | name = name + "_drop" + str(args.drop) 42 | name = name + "_lr" + str(args.lr) + "_batch" + str(args.batch) 43 | name = name + "_epochs" + str(args.epochs) + "_layer" + str(args.layer) + "NEW_skip_size100" 44 | 45 | checkpoint_dir = "./models/" + name 46 | 47 | 48 | # set tf eager 49 | 50 | tfe = tf.contrib.eager 51 | config = tf.ConfigProto( 52 | gpu_options=tf.GPUOptions( 53 | visible_device_list=args.gpu)) 54 | # config.gpu_options.allow_growth = True 55 | session = tf.Session(config=config) 56 | tf.enable_eager_execution(config=config) 57 | os.makedirs("./logs/" + name, exist_ok=True) 58 | writer = tf.contrib.summary.create_file_writer("./logs/" + name, flush_millis=10000) 59 | 60 | 61 | # load data 62 | 63 | trn_data = read_pickle("dataset/nl/train.pkl") 64 | vld_data = read_pickle("dataset/nl/valid.pkl") 65 | tst_data = read_pickle("dataset/nl/test.pkl") 66 | code_i2w = read_pickle("dataset/code_i2w.pkl") 67 | code_w2i = read_pickle("dataset/code_w2i.pkl") 68 | nl_i2w = read_pickle("dataset/nl_i2w.pkl") 69 | nl_w2i = read_pickle("dataset/nl_w2i.pkl") 70 | 71 | trn_x, trn_y_raw = zip(*sorted(trn_data.items())) 72 | vld_x, vld_y_raw = zip(*sorted(vld_data.items())) 73 | tst_x, tst_y_raw = zip(*sorted(tst_data.items())) 74 | 75 | trn_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in trn_y_raw] 76 | vld_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in vld_y_raw] 77 | tst_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in tst_y_raw] 78 | 79 | 80 | # setting model 81 | 82 | if args.method in ['seq2seq', 'deepcom']: 83 | Model = Seq2seqModel 84 | elif args.method in ['codenn']: 85 | Model = CodennModel 86 | elif args.method in ['childsum']: 87 | Model = ChildsumModel 88 | elif args.method in ['multiway']: 89 | Model = MultiwayModel 90 | elif args.method in ['nary']: 91 | Model = NaryModel 92 | 93 | 94 | model = Model(args.dim, args.dim, args.dim, len(code_w2i), len(nl_w2i), 95 | dropout=args.drop, lr=args.lr, layer=args.layer) 96 | epochs = args.epochs 97 | batch_size = args.batch 98 | os.makedirs(checkpoint_dir, exist_ok=True) 99 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") 100 | root = tfe.Checkpoint(model=model) 101 | history = {"loss": [], "loss_val": [], "bleu_val": []} 102 | 103 | 104 | # Setting Data Generator 105 | 106 | if args.method in ['deepcom']: 107 | Datagen = Datagen_deepcom 108 | elif args.method in ['codenn']: 109 | Datagen = Datagen_set 110 | elif args.method in ['childsum', 'multiway']: 111 | Datagen = Datagen_tree 112 | elif args.method in ['nary']: 113 | Datagen = Datagen_binary 114 | 115 | 116 | trn_gen = Datagen(trn_x, trn_y, batch_size, code_w2i, nl_i2w, train=True) 117 | vld_gen = Datagen(vld_x, vld_y, batch_size, code_w2i, nl_i2w, train=False) 118 | tst_gen = Datagen(tst_x, tst_y, batch_size, code_w2i, nl_i2w, train=False) 119 | 120 | 121 | # training 122 | with writer.as_default(), tf.contrib.summary.always_record_summaries(): 123 | 124 | for epoch in range(1, epochs + 1): 125 | 126 | # train 127 | loss_tmp = [] 128 | t = tqdm(trn_gen(0)) 129 | for x, y, _, _ in t: 130 | loss_tmp.append(model.train_on_batch(x, y)) 131 | t.set_description("epoch:{:03d}, loss = {}".format(epoch, np.mean(loss_tmp))) 132 | history["loss"].append(np.sum(loss_tmp) / len(t)) 133 | tf.contrib.summary.scalar("loss", np.sum(loss_tmp) / len(t), step=epoch) 134 | 135 | # validate loss 136 | loss_tmp = [] 137 | t = tqdm(vld_gen(0)) 138 | for x, y, _, _ in t: 139 | loss_tmp.append(model.evaluate_on_batch(x, y)) 140 | t.set_description("epoch:{:03d}, loss_val = {}".format(epoch, np.mean(loss_tmp))) 141 | history["loss_val"].append(np.sum(loss_tmp) / len(t)) 142 | tf.contrib.summary.scalar("loss_val", np.sum(loss_tmp) / len(t), step=epoch) 143 | 144 | # validate bleu 145 | preds = [] 146 | trues = [] 147 | bleus = [] 148 | t = tqdm(vld_gen(0)) 149 | for x, y, _, y_raw in t: 150 | res = model.translate(x, nl_i2w, nl_w2i) 151 | preds += res 152 | trues += [s[1:-1] for s in y_raw] 153 | bleus += [bleu4(tt, p) for tt, p in zip(trues, preds)] 154 | t.set_description("epoch:{:03d}, bleu_val = {}".format(epoch, np.mean(bleus))) 155 | history["bleu_val"].append(np.mean(bleus)) 156 | tf.contrib.summary.scalar("bleu_val", np.mean(bleus), step=epoch) 157 | 158 | # checkpoint 159 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") 160 | hoge = root.save(file_prefix=checkpoint_prefix) 161 | if history["bleu_val"][-1] == max(history["bleu_val"]): 162 | best_model = hoge 163 | print("Now best model is {}".format(best_model)) 164 | 165 | 166 | # load final weight 167 | 168 | print("Restore {}".format(best_model)) 169 | root.restore(best_model) 170 | 171 | # evaluation 172 | 173 | preds = [] 174 | trues = [] 175 | for x, y, _, y_raw in tqdm(tst_gen(0), "Testing"): 176 | res = model.translate(x, nl_i2w, nl_w2i) 177 | preds += res 178 | trues += [s[1:-1] for s in y_raw] 179 | 180 | bleus = Parallel(n_jobs=-1)(delayed(bleu4)(t, p) for t, p in (list(zip(trues, preds)))) 181 | 182 | history["bleus"] = bleus 183 | history["preds"] = preds 184 | history["trues"] = trues 185 | history["numbers"] = [int(x.split("/")[-1]) for x in tst_x] 186 | 187 | with open(os.path.join(checkpoint_dir, "history.json"), "w") as f: 188 | json.dump(history, f) 189 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Utilities""" 2 | import tensorflow as tf 3 | import numpy as np 4 | import math 5 | from collections import defaultdict 6 | import pickle 7 | from prefetch_generator import BackgroundGenerator 8 | 9 | 10 | def get_nums(roots): 11 | '''convert roots to indices''' 12 | res = [[x.num for x in n.children] if n.children != [] else [0] for n in roots] 13 | max_len = max([len(x) for x in res]) 14 | res = tf.keras.preprocessing.sequence.pad_sequences( 15 | res, max_len, padding="post", value=-1.) 16 | return tf.constant(res, tf.int32) 17 | 18 | 19 | def tree2binary(trees): 20 | def helper(root): 21 | if len(root.children) > 2: 22 | tmp = root.children[0] 23 | for child in root.children[1:]: 24 | tmp.children += [child] 25 | tmp = child 26 | root.children = root.children[0:1] 27 | for child in root.children: 28 | helper(child) 29 | return root 30 | return [helper(x) for x in trees] 31 | 32 | 33 | def tree2tensor(trees): 34 | ''' 35 | indice: 36 | this has structure data. 37 | 0 represent init state, 38 | 1 r else np.exp(1 - r / (c + 1e-10)) 268 | score = 0 269 | for i in range(1, 5): 270 | true_ngram = set(ngram(true, i)) 271 | pred_ngram = ngram(pred, i) 272 | length = float(len(pred_ngram)) + 1e-10 273 | count = sum([1. if t in true_ngram else 0. for t in pred_ngram]) 274 | score += math.log(1e-10 + (count / length)) 275 | score = math.exp(score * .25) 276 | bleu = bp * score 277 | return bleu 278 | 279 | 280 | class Datagen_tree: 281 | def __init__(self, X, Y, batch_size, code_dic, nl_dic, train=True, binary=False): 282 | self.X = X 283 | self.Y = Y 284 | self.batch_size = batch_size 285 | self.code_dic = code_dic 286 | self.nl_dic = nl_dic 287 | self.train = train 288 | self.binary = binary 289 | 290 | def __len__(self): 291 | return len(range(0, len(self.X), self.batch_size)) 292 | 293 | def __call__(self, epoch=0): 294 | return GeneratorLen(BackgroundGenerator(self.gen(epoch), 1), len(self)) 295 | 296 | def gen(self, epoch): 297 | if self.train: 298 | np.random.seed(epoch) 299 | newindex = list(np.random.permutation(len(self.X))) 300 | X = [self.X[i] for i in newindex] 301 | Y = [self.Y[i] for i in newindex] 302 | else: 303 | X = [x for x in self.X] 304 | Y = [y for y in self.Y] 305 | for i in range(0, len(self.X), self.batch_size): 306 | x = X[i:i + self.batch_size] 307 | y = Y[i:i + self.batch_size] 308 | x_raw = [read_pickle(n) for n in x] 309 | if self.binary: 310 | x_raw = tree2binary(x_raw) 311 | y_raw = [[self.nl_dic[t] for t in s] for s in y] 312 | x = [consult_tree(n, self.code_dic) for n in x_raw] 313 | x_raw = [traverse_label(n) for n in x_raw] 314 | y = tf.keras.preprocessing.sequence.pad_sequences( 315 | y, 316 | min(max([len(s) for s in y]), 100), 317 | padding="post", truncating="post", value=-1.) 318 | yield tree2tensor(x), y, x_raw, y_raw 319 | 320 | 321 | class Datagen_binary(Datagen_tree): 322 | def __init__(self, X, Y, batch_size, code_dic, nl_dic, train=True, binary=True): 323 | super(Datagen_binary, self).__init__(X, Y, batch_size, code_dic, 324 | nl_dic, train=True, binary=True) 325 | 326 | 327 | class Datagen_set: 328 | def __init__(self, X, Y, batch_size, code_dic, nl_dic, train=True): 329 | self.X = X 330 | self.Y = Y 331 | self.batch_size = batch_size 332 | self.code_dic = code_dic 333 | self.nl_dic = nl_dic 334 | self.train = train 335 | 336 | def __len__(self): 337 | return len(range(0, len(self.X), self.batch_size)) 338 | 339 | def __call__(self, epoch=0): 340 | return GeneratorLen(BackgroundGenerator(self.gen(epoch), 1), len(self)) 341 | 342 | def gen(self, epoch): 343 | if self.train: 344 | np.random.seed(epoch) 345 | newindex = list(np.random.permutation(len(self.X))) 346 | X = [self.X[i] for i in newindex] 347 | Y = [self.Y[i] for i in newindex] 348 | else: 349 | X = [x for x in self.X] 350 | Y = [y for y in self.Y] 351 | for i in range(0, len(self.X), self.batch_size): 352 | x = X[i:i + self.batch_size] 353 | y = Y[i:i + self.batch_size] 354 | x_raw = [read_pickle(n) for n in x] 355 | y_raw = [[self.nl_dic[t] for t in s] for s in y] 356 | x = [traverse_label(n) for n in x_raw] 357 | x = [np.array([self.code_dic[t] for t in xx], "int32") for xx in x] 358 | x_raw = [traverse_label(n) for n in x_raw] 359 | y = tf.constant( 360 | tf.keras.preprocessing.sequence.pad_sequences( 361 | y, 362 | min(max([len(s) for s in y]), 100), 363 | padding="post", truncating="post", value=-1.)) 364 | yield x, y, x_raw, y_raw 365 | 366 | 367 | def sequencing(root): 368 | li = ["(", root.label] 369 | for child in root.children: 370 | li += sequencing(child) 371 | li += [")", root.label] 372 | return(li) 373 | 374 | 375 | class Datagen_deepcom: 376 | def __init__(self, X, Y, batch_size, code_dic, nl_dic, train=True): 377 | self.X = X 378 | self.Y = Y 379 | self.batch_size = batch_size 380 | self.code_dic = code_dic 381 | self.nl_dic = nl_dic 382 | self.train = train 383 | 384 | def __len__(self): 385 | return len(range(0, len(self.X), self.batch_size)) 386 | 387 | def __call__(self, epoch=0): 388 | return GeneratorLen(BackgroundGenerator(self.gen(epoch), 1), len(self)) 389 | 390 | def gen(self, epoch): 391 | if self.train: 392 | np.random.seed(epoch) 393 | newindex = list(np.random.permutation(len(self.X))) 394 | X = [self.X[i] for i in newindex] 395 | Y = [self.Y[i] for i in newindex] 396 | else: 397 | X = [x for x in self.X] 398 | Y = [y for y in self.Y] 399 | for i in range(0, len(self.X), self.batch_size): 400 | x = X[i:i + self.batch_size] 401 | y = Y[i:i + self.batch_size] 402 | x_raw = [read_pickle(n) for n in x] 403 | y_raw = [[self.nl_dic[t] for t in s] for s in y] 404 | x = [sequencing(n) for n in x_raw] 405 | x = [np.array([self.code_dic[t] for t in xx], "int32") for xx in x] 406 | x = tf.constant( 407 | tf.keras.preprocessing.sequence.pad_sequences( 408 | x, 409 | min(max([len(s) for s in x]), 400), 410 | padding="post", truncating="post", value=-1.)) 411 | x_raw = [traverse_label(n) for n in x_raw] 412 | y = tf.constant( 413 | tf.keras.preprocessing.sequence.pad_sequences( 414 | y, 415 | min(max([len(s) for s in y]), 100), 416 | padding="post", truncating="post", value=-1.)) 417 | yield x, y, x_raw, y_raw 418 | 419 | 420 | def get_length(tensor, pad_value=-1.): 421 | '''tensor: [batch, max_len]''' 422 | mask = tf.not_equal(tensor, pad_value) 423 | return tf.reduce_sum(tf.cast(mask, tf.int32), 1) 424 | --------------------------------------------------------------------------------