├── README.md
├── dataset.py
├── layers.py
├── models.py
├── notebooks
├── check_result.ipynb
└── example.ipynb
├── parser
├── README.md
└── parser.jar
├── requirements.txt
├── retrain.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Attention-based Tree-to-Sequence Code Summarization Model
2 |
3 | The TensorFlow Eager Execution implementation of [Source Code Summarization with Extended Tree-LSTM](https://arxiv.org/abs/1906.08094) (Shido+, 2019)
4 |
5 | including:
6 |
7 | - **Multi-way Tree-LSTM model (Ours)**
8 | - Child-sum Tree-LSTM model
9 | - N-ary Tree-LSTM model
10 | - DeepCom (Hu et al.)
11 | - CODE-NN (Iyer et al.)
12 |
13 | ## Dataset
14 |
15 | 1. Download raw dataset from [https://github.com/xing-hu/DeepCom]
16 | 2. Parse them with parser.jar
17 |
18 | ## Usage
19 |
20 | 1. Prepare tree-structured data with `dataset.py`
21 | - Run `$ python dataset.py [dir]`
22 | 2. Train and evaluate model with `train.py`
23 | - See `$ python train.py -h`
24 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from tqdm import tqdm
3 | from glob import glob
4 | from utils import Node, traverse_label, traverse
5 | import pickle
6 | import os
7 | from joblib import Parallel, delayed
8 | from collections import Counter
9 | import re
10 | from os.path import abspath
11 | import nltk
12 |
13 |
14 | def parse(path):
15 | with open(path, "r") as f:
16 | num_objects = f.readline()
17 | nodes = [Node(num=i, children=[]) for i in range(int(num_objects))]
18 | for i in range(int(num_objects)):
19 | label = " ".join(f.readline().split(" ")[1:])[:-1]
20 | nodes[i].label = label
21 | while 1:
22 | line = f.readline()
23 | if line == "\n":
24 | break
25 | p, c = map(int, line.split(" "))
26 | nodes[p].children.append(nodes[c])
27 | nodes[c].parent = nodes[p]
28 | nl = f.readline()[:-1]
29 | return nodes[0], nl
30 |
31 |
32 | def is_invalid_com(s):
33 | return s[:2] == "/*" and len(s) > 1
34 |
35 |
36 | def is_invalid_seq(s):
37 | return len(s) < 4
38 |
39 |
40 | def get_method_name(root):
41 | for c in root.children:
42 | if c.label == "name (SimpleName)":
43 | return c.children[0].label[12:-1]
44 |
45 |
46 | def is_invalid_tree(root):
47 | labels = traverse_label(root)
48 | if root.label == 'root (ConstructorDeclaration)':
49 | return True
50 | if len(labels) >= 100:
51 | return True
52 | method_name = get_method_name(root)
53 | for word in ["test", "Test", "set", "Set", "get", "Get"]:
54 | if method_name[:len(word)] == word:
55 | return True
56 | return False
57 |
58 |
59 | def clean_nl(s):
60 | if s[-1] == ".":
61 | s = s[:-1]
62 | s = s.split(". ")[0]
63 | s = re.sub("[<].+?[>]", "", s)
64 | s = re.sub("[\[\]\%]", "", s)
65 | s = s[0:1].lower() + s[1:]
66 | return s
67 |
68 |
69 | def tokenize(s):
70 | return [""] + nltk.word_tokenize(s) + [""]
71 |
72 |
73 | def parse_dir(path_to_dir):
74 | files = sorted(glob(path_to_dir + "/*"))
75 | set_name = path_to_dir.split("/")[-1]
76 |
77 | nls = {}
78 | skip = 0
79 |
80 | for file in tqdm(files, "parsing {}".format(path_to_dir)):
81 | tree, nl = parse(file)
82 | nl = clean_nl(nl)
83 | if is_invalid_com(nl):
84 | skip += 1
85 | continue
86 | if is_invalid_tree(tree):
87 | skip += 1
88 | continue
89 | number = int(file.split("/")[-1])
90 | seq = tokenize(nl)
91 | if is_invalid_seq(seq):
92 | skip += 1
93 | continue
94 | nls[abspath("./dataset/tree/" + set_name + "/" + str(number))] = seq
95 | with open("./dataset/tree_raw/" + set_name + "/" + str(number), "wb", 1) as f:
96 | pickle.dump(tree, f)
97 |
98 | print("{} files skipped".format(skip))
99 |
100 | if set_name == "train":
101 | vocab = Counter([x for l in nls.values() for x in l])
102 | nl_i2w = {i: w for i, w in enumerate(
103 | ["", ""] + sorted([x[0] for x in vocab.most_common(30000)]))}
104 | nl_w2i = {w: i for i, w in enumerate(
105 | ["", ""] + sorted([x[0] for x in vocab.most_common(30000)]))}
106 | pickle.dump(nl_i2w, open("./dataset/nl_i2w.pkl", "wb"))
107 | pickle.dump(nl_w2i, open("./dataset/nl_w2i.pkl", "wb"))
108 |
109 | return nls
110 |
111 |
112 | def pickling():
113 | args = sys.argv
114 |
115 | if len(args) <= 1:
116 | raise Exception("(usage) $ python dataset.py [dir]")
117 |
118 | data_dir = args[1]
119 |
120 | dirs = [
121 | "dataset",
122 | "dataset/tree_raw",
123 | "dataset/tree_raw/train",
124 | "dataset/tree_raw/valid",
125 | "dataset/tree_raw/test",
126 | "dataset/nl"
127 | ]
128 | for d in dirs:
129 | if not os.path.exists(d):
130 | os.mkdir(d)
131 |
132 | for path in [data_dir + "/" + s for s in ["train", "valid", "test"]]:
133 | set_name = path.split("/")[-1]
134 | nl = parse_dir(path)
135 | with open("./dataset/nl/" + set_name + ".pkl", "wb", 1) as f:
136 | pickle.dump(nl, f)
137 |
138 |
139 | def isnum(s):
140 | try:
141 | float(s)
142 | except ValueError:
143 | return False
144 | else:
145 | return True
146 |
147 |
148 | def get_labels(path):
149 | tree = pickle.load(open(path, "rb"))
150 | return traverse_label(tree)
151 |
152 |
153 | def get_bracket(s):
154 | if "value=" == s[:6] or "identifier=" in s[:11]:
155 | return None
156 | p = "\(.+?\)"
157 | res = re.findall(p, s)
158 | if len(res) == 1:
159 | return res[0]
160 | return s
161 |
162 |
163 | def get_identifier(s):
164 | if "identifier=" == s[:11]:
165 | return "SimpleName_" + s[11:]
166 | else:
167 | return None
168 |
169 |
170 | def is_SimpleName(s):
171 | return "SimpleName_" == s[:11]
172 |
173 |
174 | def get_values(s):
175 | if "value=" == s[:6]:
176 | return "Value_" + s[6:]
177 | else:
178 | return None
179 |
180 |
181 | def is_value(s):
182 | return "Value_" == s[:6]
183 |
184 |
185 | def make_dict():
186 | labels = Parallel(n_jobs=-1)(delayed(get_labels)(p) for p in tqdm(
187 | glob("./dataset/tree_raw/train/*"), "reading all labels"))
188 | labels = [l for s in labels for l in s]
189 |
190 | non_terminals = set(
191 | [get_bracket(x) for x in tqdm(
192 | list(set(labels)), "collect non-tarminals")]) - set([None, "(SimpleName)"])
193 | non_terminals = sorted(list(non_terminals))
194 |
195 | ids = Counter(
196 | [y for y in [get_identifier(x) for x in tqdm(
197 | labels, "collect identifiers")] if y is not None])
198 | ids_list = [x[0] for x in ids.most_common(30000)]
199 |
200 | values = Counter(
201 | [y for y in [get_values(x) for x in tqdm(
202 | labels, "collect values")] if y is not None])
203 | values_list = [x[0] for x in values.most_common(1000)]
204 |
205 | vocab = ["", "SimpleName_", "Value_", "Value_"]
206 | vocab += non_terminals + ids_list + values_list + ["(", ")"]
207 |
208 | code_i2w = {i: w for i, w in enumerate(vocab)}
209 | code_w2i = {w: i for i, w in enumerate(vocab)}
210 |
211 | pickle.dump(code_i2w, open("./dataset/code_i2w.pkl", "wb"))
212 | pickle.dump(code_w2i, open("./dataset/code_w2i.pkl", "wb"))
213 |
214 |
215 | def remove_SimpleName(root):
216 | for node in traverse(root):
217 | if "=" not in node.label and "(SimpleName)" in node.label:
218 | if node.children[0].label[:11] != "identifier=":
219 | raise Exception("ERROR!")
220 | node.label = "SimpleName_" + node.children[0].label[11:]
221 | node.children = []
222 | elif node.label[:11] == "identifier=":
223 | node.label = "SimpleName_" + node.label[11:]
224 | elif node.label[:6] == "value=":
225 | node.label = "Value_" + node.label[6:]
226 |
227 | return root
228 |
229 |
230 | def modifier(root, dic):
231 | for node in traverse(root):
232 | if is_SimpleName(node.label):
233 | if node.label not in dic:
234 | node.label = "SimpleName_"
235 | elif is_value(node.label):
236 | if node.label not in dic:
237 | if isnum(node.label):
238 | node.label = "Value_"
239 | else:
240 | node.label = "Value_"
241 | else:
242 | node.label = get_bracket(node.label)
243 | if node.label not in dic:
244 | raise Exception("Unknown word", node.label)
245 |
246 | return root
247 |
248 |
249 | def rebuild_tree(path, dst, dic):
250 | root = pickle.load(open(path, "rb"))
251 | root = remove_SimpleName(root)
252 | root = modifier(root, dic)
253 | pickle.dump(root, open(dst, "wb"), 1)
254 |
255 |
256 | def preprocess_trees():
257 |
258 | dirs = [
259 | "./dataset",
260 | "./dataset/tree",
261 | "./dataset/tree/train",
262 | "./dataset/tree/valid",
263 | "./dataset/tree/test",
264 | "./dataset/nl"
265 | ]
266 | for d in dirs:
267 | if not os.path.exists(d):
268 | os.mkdir(d)
269 |
270 | sets_name = [
271 | "./dataset/tree_raw/train/*",
272 | "./dataset/tree_raw/valid/*",
273 | "./dataset/tree_raw/test/*"
274 | ]
275 |
276 | dic = set(pickle.load(open("./dataset/code_i2w.pkl", "rb")).values())
277 |
278 | for sets in sets_name:
279 | files = sorted(list(glob(sets)))
280 | dst = [x.replace("tree_raw", "tree") for x in files]
281 | Parallel(n_jobs=-1)(
282 | delayed(rebuild_tree)(p, d, dic) for p, d in tqdm(
283 | list(zip(files, dst)), "preprocessing {}".format(sets)))
284 |
285 |
286 | if __name__ == "__main__":
287 | nltk.download('punkt')
288 | sys.setrecursionlimit(10000)
289 | pickling()
290 | make_dict()
291 | preprocess_trees()
292 |
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | """layers"""
2 |
3 | import tensorflow as tf
4 | from utils import *
5 | tfe = tf.contrib.eager
6 |
7 |
8 | class TreeEmbeddingLayer(tf.keras.Model):
9 | def __init__(self, dim_E, in_vocab):
10 | super(TreeEmbeddingLayer, self).__init__()
11 | self.E = tf.get_variable("E", [in_vocab, dim_E], tf.float32,
12 | initializer=tf.keras.initializers.RandomUniform())
13 |
14 | def call(self, x):
15 | '''x: list of [1,]'''
16 | x_len = [xx.shape[0] for xx in x]
17 | ex = tf.nn.embedding_lookup(self.E, tf.concat(x, axis=0))
18 | exs = tf.split(ex, x_len, 0)
19 | return exs
20 |
21 |
22 | class TreeEmbeddingLayerTreeBase(tf.keras.Model):
23 | def __init__(self, dim_E, in_vocab):
24 | super(TreeEmbeddingLayerTreeBase, self).__init__()
25 | self.E = tf.get_variable("E", [in_vocab, dim_E], tf.float32,
26 | initializer=tf.keras.initializers.RandomUniform())
27 |
28 | def call(self, roots):
29 | return [self.apply_single(root) for root in roots]
30 |
31 | def apply_single(self, root):
32 | labels = traverse_label(root)
33 | embedded = tf.nn.embedding_lookup(self.E, labels)
34 | new_nodes = self.Node2TreeLSTMNode(root, parent=None)
35 | for rep, node in zip(embedded, traverse(new_nodes)):
36 | node.h = rep
37 | return new_nodes
38 |
39 | def Node2TreeLSTMNode(self, node, parent):
40 | children = [self.Node2TreeLSTMNode(c, node) for c in node.children]
41 | return TreeLSTMNode(node.label, parent=parent, children=children, num=node.num)
42 |
43 |
44 | class ChildSumLSTMLayerWithEmbedding(tf.keras.Model):
45 | def __init__(self, in_vocab, dim_in, dim_out):
46 | super(ChildSumLSTMLayerWithEmbedding, self).__init__()
47 | self.dim_in = dim_in
48 | self.dim_out = dim_out
49 | self.E = tf.get_variable("E", [in_vocab, dim_in], tf.float32,
50 | initializer=tf.keras.initializers.RandomUniform())
51 | self.U_f = tf.keras.layers.Dense(dim_out, use_bias=False)
52 | self.U_iuo = tf.keras.layers.Dense(dim_out * 3, use_bias=False)
53 | self.W = tf.keras.layers.Dense(dim_out * 4)
54 | # self.h_init = tfe.Variable(
55 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal()))
56 | # self.c_init = tfe.Variable(
57 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal()))
58 | self.h_init = tf.zeros([1, dim_out], tf.float32)
59 | self.c_init = tf.zeros([1, dim_out], tf.float32)
60 |
61 | @staticmethod
62 | def get_nums(roots):
63 | res = [[x.num for x in n.children] if n.children != [] else [0] for n in roots]
64 | max_len = max([len(x) for x in res])
65 | res = tf.keras.preprocessing.sequence.pad_sequences(
66 | res, max_len, padding="post", value=-1.)
67 | return tf.constant(res, tf.int32)
68 |
69 | def call(self, roots):
70 | depthes = [x[1] for x in sorted(depth_split_batch2(
71 | roots).items(), key=lambda x:-x[0])] # list of list of Nodes
72 | indices = [self.get_nums(nodes) for nodes in depthes]
73 |
74 | h_tensor = self.h_init
75 | c_tensor = self.c_init
76 | for indice, nodes in zip(indices, depthes):
77 | x = tf.nn.embedding_lookup(self.E, [node.label for node in nodes]) # [nodes, dim_in]
78 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice, nodes)
79 | h_tensor = tf.concat([self.h_init, h_tensor], 0)
80 | c_tensor = tf.concat([self.c_init, c_tensor], 0)
81 | return depthes[-1]
82 |
83 | def apply(self, x, h_tensor, c_tensor, indice, nodes):
84 |
85 | mask_bool = tf.not_equal(indice, -1.)
86 | mask = tf.cast(mask_bool, tf.float32) # [batch, child]
87 |
88 | h = tf.gather(h_tensor, tf.where(mask_bool,
89 | indice, tf.zeros_like(indice))) # [nodes, child, dim]
90 | c = tf.gather(c_tensor, tf.where(mask_bool,
91 | indice, tf.zeros_like(indice)))
92 | h_sum = tf.reduce_sum(h * tf.expand_dims(mask, -1), 1) # [nodes, dim_out]
93 |
94 | W_x = self.W(x) # [nodes, dim_out * 4]
95 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out]
96 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2]
97 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3]
98 | W_o_x = W_x[:, self.dim_out * 3:]
99 |
100 | branch_f_k = tf.reshape(self.U_f(tf.reshape(h, [-1, h.shape[-1]])), h.shape)
101 | branch_f_k = tf.sigmoid(tf.expand_dims(W_f_x, 1) + branch_f_k)
102 | branch_f = tf.reduce_sum(branch_f_k * c * tf.expand_dims(mask, -1), 1) # [node, dim_out]
103 |
104 | branch_iuo = self.U_iuo(h_sum) # [nodes, dim_out * 3]
105 | branch_i = tf.sigmoid(branch_iuo[:, :self.dim_out * 1] + W_i_x) # [nodes, dim_out]
106 | branch_u = tf.tanh(branch_iuo[:, self.dim_out * 1:self.dim_out * 2] + W_u_x)
107 | branch_o = tf.sigmoid(branch_iuo[:, self.dim_out * 2:] + W_o_x)
108 |
109 | new_c = branch_i * branch_u + branch_f # [node, dim_out]
110 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out]
111 |
112 | for n, c, h in zip(nodes, new_c, new_h):
113 | n.c = c
114 | n.h = h
115 |
116 | return new_h, new_c
117 |
118 |
119 | class ChildSumLSTMLayer(tf.keras.Model):
120 | def __init__(self, dim_in, dim_out):
121 | super(ChildSumLSTMLayer, self).__init__()
122 | self.dim_in = dim_in
123 | self.dim_out = dim_out
124 | self.U_f = tf.keras.layers.Dense(dim_out, use_bias=False)
125 | self.U_iuo = tf.keras.layers.Dense(dim_out * 3, use_bias=False)
126 | self.W = tf.keras.layers.Dense(dim_out * 4)
127 | # self.h_init = tfe.Variable(
128 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal()))
129 | # self.c_init = tfe.Variable(
130 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal()))
131 | self.h_init = tf.zeros([1, dim_out], tf.float32)
132 | self.c_init = tf.zeros([1, dim_out], tf.float32)
133 |
134 | def call(self, tensor, indices):
135 | h_tensor = self.h_init
136 | c_tensor = self.c_init
137 | res_h, res_c = [], []
138 | for indice, x in zip(indices, tensor):
139 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice)
140 | h_tensor = tf.concat([self.h_init, h_tensor], 0)
141 | c_tensor = tf.concat([self.c_init, c_tensor], 0)
142 | res_h.append(h_tensor[1:, :])
143 | res_c.append(c_tensor[1:, :])
144 | return res_h, res_c
145 |
146 | def apply(self, x, h_tensor, c_tensor, indice):
147 |
148 | mask_bool = tf.not_equal(indice, -1.)
149 | mask = tf.cast(mask_bool, tf.float32) # [batch, child]
150 |
151 | h = tf.gather(h_tensor, tf.where(mask_bool,
152 | indice, tf.zeros_like(indice))) # [nodes, child, dim]
153 | c = tf.gather(c_tensor, tf.where(mask_bool,
154 | indice, tf.zeros_like(indice)))
155 | h_sum = tf.reduce_sum(h * tf.expand_dims(mask, -1), 1) # [nodes, dim_out]
156 |
157 | W_x = self.W(x) # [nodes, dim_out * 4]
158 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out]
159 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2]
160 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3]
161 | W_o_x = W_x[:, self.dim_out * 3:]
162 |
163 | branch_f_k = tf.reshape(self.U_f(tf.reshape(h, [-1, h.shape[-1]])), h.shape)
164 | branch_f_k = tf.sigmoid(tf.expand_dims(W_f_x, 1) + branch_f_k)
165 | branch_f = tf.reduce_sum(branch_f_k * c * tf.expand_dims(mask, -1), 1) # [node, dim_out]
166 |
167 | branch_iuo = self.U_iuo(h_sum) # [nodes, dim_out * 3]
168 | branch_i = tf.sigmoid(branch_iuo[:, :self.dim_out * 1] + W_i_x) # [nodes, dim_out]
169 | branch_u = tf.tanh(branch_iuo[:, self.dim_out * 1:self.dim_out * 2] + W_u_x)
170 | branch_o = tf.sigmoid(branch_iuo[:, self.dim_out * 2:] + W_o_x)
171 |
172 | new_c = branch_i * branch_u + branch_f # [node, dim_out]
173 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out]
174 |
175 | return new_h, new_c
176 |
177 |
178 | class ChildSumLSTMLayerTreeBase(tf.keras.Model):
179 | def __init__(self, dim_in, dim_out):
180 | super(ChildSumLSTMLayerTreeBase, self).__init__()
181 | self.dim_in = dim_in
182 | self.dim_out = dim_out
183 | self.U_f = tf.keras.layers.Dense(dim_out, use_bias=False)
184 | self.U_iuo = tf.keras.layers.Dense(dim_out * 3, use_bias=False)
185 | self.W = tf.keras.layers.Dense(dim_out * 4)
186 | # self.h_init = tfe.Variable(
187 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal()))
188 | # self.c_init = tfe.Variable(
189 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal()))
190 | self.h_init = tf.zeros([1, dim_out], tf.float32)
191 | self.c_init = tf.zeros([1, dim_out], tf.float32)
192 |
193 | @staticmethod
194 | def get_nums(roots):
195 | res = [[x.num for x in n.children] if n.children != [] else [0] for n in roots]
196 | max_len = max([len(x) for x in res])
197 | res = tf.keras.preprocessing.sequence.pad_sequences(
198 | res, max_len, padding="post", value=-1.)
199 | return tf.constant(res, tf.int32)
200 |
201 | def call(self, roots):
202 | depthes = [x[1] for x in sorted(depth_split_batch2(
203 | roots).items(), key=lambda x:-x[0])] # list of list of Nodes
204 | indices = [self.get_nums(nodes) for nodes in depthes]
205 |
206 | h_tensor = self.h_init
207 | c_tensor = self.c_init
208 | for indice, nodes in zip(indices, depthes):
209 | x = tf.stack([node.h for node in nodes]) # [nodes, dim_in]
210 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice, nodes)
211 | h_tensor = tf.concat([self.h_init, h_tensor], 0)
212 | c_tensor = tf.concat([self.c_init, c_tensor], 0)
213 | return depthes[-1]
214 |
215 | def apply(self, x, h_tensor, c_tensor, indice, nodes):
216 |
217 | mask_bool = tf.not_equal(indice, -1.)
218 | mask = tf.cast(mask_bool, tf.float32) # [batch, child]
219 |
220 | h = tf.gather(h_tensor, tf.where(mask_bool,
221 | indice, tf.zeros_like(indice))) # [nodes, child, dim]
222 | c = tf.gather(c_tensor, tf.where(mask_bool,
223 | indice, tf.zeros_like(indice)))
224 | h_sum = tf.reduce_sum(h * tf.expand_dims(mask, -1), 1) # [nodes, dim_out]
225 |
226 | W_x = self.W(x) # [nodes, dim_out * 4]
227 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out]
228 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2]
229 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3]
230 | W_o_x = W_x[:, self.dim_out * 3:]
231 |
232 | branch_f_k = tf.reshape(self.U_f(tf.reshape(h, [-1, h.shape[-1]])), h.shape)
233 | branch_f_k = tf.sigmoid(tf.expand_dims(W_f_x, 1) + branch_f_k)
234 | branch_f = tf.reduce_sum(branch_f_k * c * tf.expand_dims(mask, -1), 1) # [node, dim_out]
235 |
236 | branch_iuo = self.U_iuo(h_sum) # [nodes, dim_out * 3]
237 | branch_i = tf.sigmoid(branch_iuo[:, :self.dim_out * 1] + W_i_x) # [nodes, dim_out]
238 | branch_u = tf.tanh(branch_iuo[:, self.dim_out * 1:self.dim_out * 2] + W_u_x)
239 | branch_o = tf.sigmoid(branch_iuo[:, self.dim_out * 2:] + W_o_x)
240 |
241 | new_c = branch_i * branch_u + branch_f # [node, dim_out]
242 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out]
243 |
244 | for n, c, h in zip(nodes, new_c, new_h):
245 | n.c = c
246 | n.h = h
247 |
248 | return new_h, new_c
249 |
250 |
251 | class NaryLSTMLayer(tf.keras.Model):
252 | def __init__(self, dim_in, dim_out):
253 | super(NaryLSTMLayer, self).__init__()
254 | self.dim_in = dim_in
255 | self.dim_out = dim_out
256 | self.U_f1 = tf.keras.layers.Dense(dim_out, use_bias=False)
257 | self.U_f2 = tf.keras.layers.Dense(dim_out, use_bias=False)
258 | self.U_iuo = tf.keras.layers.Dense(dim_out * 3, use_bias=False)
259 | self.W = tf.keras.layers.Dense(dim_out * 4)
260 | # self.h_init = tfe.Variable(
261 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal()))
262 | # self.c_init = tfe.Variable(
263 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal()))
264 | self.h_init = tf.zeros([1, dim_out], tf.float32)
265 | self.c_init = tf.zeros([1, dim_out], tf.float32)
266 |
267 | def call(self, tensor, indices):
268 | h_tensor = self.h_init
269 | c_tensor = self.c_init
270 | res_h, res_c = [], []
271 | for indice, x in zip(indices, tensor):
272 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice)
273 | h_tensor = tf.concat([self.h_init, h_tensor], 0)
274 | c_tensor = tf.concat([self.c_init, c_tensor], 0)
275 | res_h.append(h_tensor[1:, :])
276 | res_c.append(c_tensor[1:, :])
277 | return res_h, res_c
278 |
279 | def apply(self, x, h_tensor, c_tensor, indice):
280 |
281 | mask_bool = tf.not_equal(indice, -1.)
282 |
283 | h = tf.gather(h_tensor, tf.where(mask_bool,
284 | indice, tf.zeros_like(indice))) # [nodes, child, dim]
285 | c = tf.gather(c_tensor, tf.where(mask_bool,
286 | indice, tf.zeros_like(indice)))
287 |
288 | W_x = self.W(x) # [nodes, dim_out * 4]
289 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out]
290 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2]
291 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3]
292 | W_o_x = W_x[:, self.dim_out * 3:]
293 |
294 | if h.shape[1] <= 1:
295 | h = tf.concat([h, tf.zeros_like(h)], 1) # [nodes, 2, dim]
296 | c = tf.concat([c, tf.zeros_like(c)], 1)
297 |
298 | h_concat = tf.reshape(h, [h.shape[0], -1])
299 |
300 | branch_f1 = self.U_f1(h_concat)
301 | branch_f1 = tf.sigmoid(W_f_x + branch_f1)
302 | branch_f2 = self.U_f2(h_concat)
303 | branch_f2 = tf.sigmoid(W_f_x + branch_f2)
304 | branch_f = branch_f1 * c[:, 0] + branch_f2 * c[:, 1]
305 |
306 | branch_iuo = self.U_iuo(h_concat) # [nodes, dim_out * 3]
307 | branch_i = tf.sigmoid(branch_iuo[:, :self.dim_out * 1] + W_i_x) # [nodes, dim_out]
308 | branch_u = tf.tanh(branch_iuo[:, self.dim_out * 1:self.dim_out * 2] + W_u_x)
309 | branch_o = tf.sigmoid(branch_iuo[:, self.dim_out * 2:] + W_o_x)
310 |
311 | new_c = branch_i * branch_u + branch_f # [node, dim_out]
312 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out]
313 |
314 | return new_h, new_c
315 |
316 |
317 | class BiLSTM_(tf.keras.Model):
318 | def __init__(self, dim, return_seq=False):
319 | super(BiLSTM_, self).__init__()
320 | self.dim = dim
321 | # self.c_init_f = tfe.Variable(tf.get_variable("c_init_f", [1, dim], tf.float32,
322 | # initializer=he_normal()))
323 | # self.h_init_f = tfe.Variable(tf.get_variable("h_initf", [1, dim], tf.float32,
324 | # initializer=he_normal()))
325 | # self.c_init_b = tfe.Variable(tf.get_variable("c_init_b", [1, dim], tf.float32,
326 | # initializer=he_normal()))
327 | # self.h_init_b = tfe.Variable(tf.get_variable("h_init_b", [1, dim], tf.float32,
328 | # initializer=he_normal()))
329 | self.c_init_f = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32))
330 | self.h_init_f = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32))
331 | self.c_init_b = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32))
332 | self.h_init_b = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32))
333 | self.Cell_f = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(dim)
334 | self.Cell_b = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(dim)
335 | self.fc = tf.keras.layers.Dense(dim, use_bias=False)
336 | self.return_seq = return_seq
337 |
338 | def call(self, x, length):
339 | '''x: [batch, length, dim]'''
340 | batch = x.shape[0]
341 | ys, states = tf.nn.bidirectional_dynamic_rnn(self.Cell_f, self.Cell_b, x,
342 | length,
343 | tf.nn.rnn_cell.LSTMStateTuple(
344 | tf.tile(self.c_init_f, [batch, 1]),
345 | tf.tile(self.h_init_f, [batch, 1])),
346 | tf.nn.rnn_cell.LSTMStateTuple(
347 | tf.tile(self.c_init_b, [batch, 1]),
348 | tf.tile(self.h_init_b, [batch, 1])))
349 | if self.return_seq:
350 | return self.fc(tf.concat(ys, -1))
351 | else:
352 | state_f, state_b = states
353 | state_concat = tf.concat([state_f.h, state_b.h], -1)
354 | return self.fc(state_concat)
355 |
356 |
357 | class BiLSTM(tf.keras.Model):
358 | def __init__(self, dim, return_seq=False):
359 | super(BiLSTM, self).__init__()
360 | self.dim = dim
361 | self.c_init_f = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32))
362 | self.h_init_f = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32))
363 | self.c_init_b = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32))
364 | self.h_init_b = tfe.Variable(tf.random_normal([1, dim], stddev=0.01, dtype=tf.float32))
365 | self.lay_f = tf.keras.layers.CuDNNLSTM(dim, return_sequences=True, return_state=True)
366 | self.lay_b = tf.keras.layers.CuDNNLSTM(dim, return_sequences=True, return_state=True)
367 | self.fc = tf.keras.layers.Dense(dim, use_bias=False)
368 | self.return_seq = return_seq
369 |
370 | def call(self, x, length):
371 | '''x: [batch, length, dim]'''
372 | batch = x.shape[0]
373 | x_back = tf.reverse_sequence(x, length, 1)
374 |
375 | init_state_f = (tf.tile(self.h_init_f, [batch, 1]), tf.tile(self.c_init_f, [batch, 1]))
376 | init_state_b = (tf.tile(self.h_init_b, [batch, 1]), tf.tile(self.c_init_b, [batch, 1]))
377 |
378 | y_f, h_f, c_f = self.lay_f(x, init_state_f)
379 | y_b, h_b, c_b = self.lay_b(x_back, init_state_b)
380 |
381 | y = tf.concat([y_f, y_b], -1)
382 |
383 | if self.return_seq:
384 | return self.fc(y)
385 | else:
386 | y_last = tf.gather_nd(y, tf.stack([tf.range(batch), length - 1], 1))
387 | return self.fc(y_last)
388 |
389 |
390 | class ShidoTreeLSTMLayer(tf.keras.Model):
391 | def __init__(self, dim_in, dim_out):
392 | super(ShidoTreeLSTMLayer, self).__init__()
393 | self.dim_in = dim_in
394 | self.dim_out = dim_out
395 | self.U_f = BiLSTM(dim_out, return_seq=True)
396 | self.U_i = BiLSTM(dim_out)
397 | self.U_u = BiLSTM(dim_out)
398 | self.U_o = BiLSTM(dim_out)
399 | self.W = tf.keras.layers.Dense(dim_out * 4)
400 | # self.h_init = tfe.Variable(
401 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal()))
402 | # self.c_init = tfe.Variable(
403 | # tf.get_variable("c_init", [1, dim_out], tf.float32, initializer=he_normal()))
404 | self.h_init = tf.zeros([1, dim_out], tf.float32)
405 | self.c_init = tf.zeros([1, dim_out], tf.float32)
406 |
407 | def call(self, tensor, indices):
408 | h_tensor = self.h_init
409 | c_tensor = self.c_init
410 | res_h, res_c = [], []
411 | for indice, x in zip(indices, tensor):
412 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice)
413 | res_h.append(h_tensor[:, :])
414 | res_c.append(c_tensor[:, :])
415 | h_tensor = tf.concat([self.h_init, h_tensor], 0)
416 | c_tensor = tf.concat([self.c_init, c_tensor], 0)
417 | return res_h, res_c
418 |
419 | def apply(self, x, h_tensor, c_tensor, indice):
420 |
421 | mask_bool = tf.not_equal(indice, -1.)
422 | mask = tf.cast(mask_bool, tf.float32) # [nodes, child]
423 | length = tf.cast(tf.reduce_sum(mask, 1), tf.int32)
424 |
425 | h = tf.gather(h_tensor, tf.where(mask_bool,
426 | indice, tf.zeros_like(indice))) # [nodes, child, dim]
427 | c = tf.gather(c_tensor, tf.where(mask_bool,
428 | indice, tf.zeros_like(indice)))
429 |
430 | W_x = self.W(x) # [nodes, dim_out * 4]
431 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out]
432 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2]
433 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3]
434 | W_o_x = W_x[:, self.dim_out * 3:]
435 |
436 | branch_f_k = self.U_f(h, length)
437 | branch_f_k = tf.sigmoid(tf.expand_dims(W_f_x, 1) + branch_f_k)
438 | branch_f = tf.reduce_sum(branch_f_k * c * tf.expand_dims(mask, -1), 1) # [node, dim_out]
439 |
440 | branch_i = self.U_i(h, length) # [nodes, dim_out]
441 | branch_i = tf.sigmoid(branch_i + W_i_x) # [nodes, dim_out]
442 | branch_u = self.U_u(h, length) # [nodes, dim_out]
443 | branch_u = tf.tanh(branch_u + W_u_x)
444 | branch_o = self.U_o(h, length) # [nodes, dim_out]
445 | branch_o = tf.sigmoid(branch_o + W_o_x)
446 |
447 | new_c = branch_i * branch_u + branch_f # [node, dim_out]
448 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out]
449 |
450 | return new_h, new_c
451 |
452 |
453 | class ShidoTreeLSTMLayerTreeBase(tf.keras.Model):
454 | def __init__(self, dim_in, dim_out):
455 | super(ShidoTreeLSTMLayerTreeBase, self).__init__()
456 | self.dim_in = dim_in
457 | self.dim_out = dim_out
458 | self.U_f = BiLSTM(dim_out, return_seq=True)
459 | self.U_i = BiLSTM(dim_out)
460 | self.U_u = BiLSTM(dim_out)
461 | self.U_o = BiLSTM(dim_out)
462 | self.W = tf.keras.layers.Dense(dim_out * 4)
463 | # self.h_init = tfe.Variable(
464 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal()))
465 | # self.c_init = tfe.Variable(
466 | # tf.get_variable("c_init", [1, dim_out], tf.float32, initializer=he_normal()))
467 | self.h_init = tf.zeros([1, dim_out], tf.float32)
468 | self.c_init = tf.zeros([1, dim_out], tf.float32)
469 |
470 | @staticmethod
471 | def get_nums(roots):
472 | res = [[x.num for x in n.children] if n.children != [] else [0] for n in roots]
473 | max_len = max([len(x) for x in res])
474 | res = tf.keras.preprocessing.sequence.pad_sequences(
475 | res, max_len, padding="post", value=-1.)
476 | return tf.constant(res, tf.int32)
477 |
478 | def call(self, roots):
479 | depthes = [x[1] for x in sorted(depth_split_batch2(
480 | roots).items(), key=lambda x:-x[0])] # list of list of Nodes
481 | indices = [self.get_nums(nodes) for nodes in depthes]
482 |
483 | h_tensor = self.h_init
484 | c_tensor = self.c_init
485 | for indice, nodes in zip(indices, depthes):
486 | x = tf.stack([node.h for node in nodes]) # [nodes, dim_in]
487 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice, nodes)
488 | h_tensor = tf.concat([self.h_init, h_tensor], 0)
489 | c_tensor = tf.concat([self.c_init, c_tensor], 0)
490 | return depthes[-1]
491 |
492 | def apply(self, x, h_tensor, c_tensor, indice, nodes):
493 |
494 | mask_bool = tf.not_equal(indice, -1.)
495 | mask = tf.cast(mask_bool, tf.float32) # [nodes, child]
496 | length = tf.cast(tf.reduce_sum(mask, 1), tf.int32)
497 |
498 | h = tf.gather(h_tensor, tf.where(mask_bool,
499 | indice, tf.zeros_like(indice))) # [nodes, child, dim]
500 | c = tf.gather(c_tensor, tf.where(mask_bool,
501 | indice, tf.zeros_like(indice)))
502 |
503 | W_x = self.W(x) # [nodes, dim_out * 4]
504 | W_f_x = W_x[:, :self.dim_out * 1] # [nodes, dim_out]
505 | W_i_x = W_x[:, self.dim_out * 1:self.dim_out * 2]
506 | W_u_x = W_x[:, self.dim_out * 2:self.dim_out * 3]
507 | W_o_x = W_x[:, self.dim_out * 3:]
508 |
509 | branch_f_k = self.U_f(h, length)
510 | branch_f_k = tf.sigmoid(tf.expand_dims(W_f_x, 1) + branch_f_k)
511 | branch_f = tf.reduce_sum(branch_f_k * c * tf.expand_dims(mask, -1), 1) # [node, dim_out]
512 |
513 | branch_i = self.U_i(h, length) # [nodes, dim_out]
514 | branch_i = tf.sigmoid(branch_i + W_i_x) # [nodes, dim_out]
515 | branch_u = self.U_u(h, length) # [nodes, dim_out]
516 | branch_u = tf.tanh(branch_u + W_u_x)
517 | branch_o = self.U_o(h, length) # [nodes, dim_out]
518 | branch_o = tf.sigmoid(branch_o + W_o_x)
519 |
520 | new_c = branch_i * branch_u + branch_f # [node, dim_out]
521 | new_h = branch_o * tf.tanh(new_c) # [node, dim_out]
522 |
523 | for n, c, h in zip(nodes, new_c, new_h):
524 | n.c = c
525 | n.h = h
526 |
527 | return new_h, new_c
528 |
529 |
530 | class ShidoTreeLSTMWithEmbedding(ShidoTreeLSTMLayer):
531 | def __init__(self, in_vocab, dim_in, dim_out):
532 | super(ShidoTreeLSTMWithEmbedding, self).__init__(dim_in, dim_out)
533 | self.E = tf.get_variable("E", [in_vocab, dim_in], tf.float32,
534 | initializer=tf.keras.initializers.RandomUniform())
535 | self.dim_in = dim_in
536 | self.dim_out = dim_out
537 | self.U_f = BiLSTM(dim_out, return_seq=True)
538 | self.U_i = BiLSTM(dim_out)
539 | self.U_u = BiLSTM(dim_out)
540 | self.U_o = BiLSTM(dim_out)
541 | self.W = tf.keras.layers.Dense(dim_out * 4)
542 | # self.h_init = tfe.Variable(
543 | # tf.get_variable("h_init", [1, dim_out], tf.float32, initializer=he_normal()))
544 | # self.c_init = tfe.Variable(
545 | # tf.get_variable("c_init", [1, dim_out], tf.float32, initializer=he_normal()))
546 | self.h_init = tf.zeros([1, dim_out], tf.float32)
547 | self.c_init = tf.zeros([1, dim_out], tf.float32)
548 |
549 | def call(self, roots):
550 | depthes = [x[1] for x in sorted(depth_split_batch2(
551 | roots).items(), key=lambda x:-x[0])] # list of list of Nodes
552 | indices = [self.get_nums(nodes) for nodes in depthes]
553 |
554 | h_tensor = self.h_init
555 | c_tensor = self.c_init
556 | for indice, nodes in zip(indices, depthes):
557 | x = tf.nn.embedding_lookup(self.E, [node.label for node in nodes]) # [nodes, dim_in]
558 | h_tensor, c_tensor = self.apply(x, h_tensor, c_tensor, indice, nodes)
559 | h_tensor = tf.concat([self.h_init, h_tensor], 0)
560 | c_tensor = tf.concat([self.c_init, c_tensor], 0)
561 | return depthes[-1]
562 |
563 |
564 | class TreeDropout(tf.keras.Model):
565 | def __init__(self, rate):
566 | super(TreeDropout, self).__init__()
567 | self.dropout_layer = tf.keras.layers.Dropout(rate)
568 |
569 | def call(self, roots):
570 | nodes = [node for root in roots for node in traverse(root)]
571 | ys = [node.h for node in nodes]
572 | tensor = tf.stack(ys)
573 | dropped = self.dropout_layer(tensor)
574 | for e, v in enumerate(tf.split(dropped, len(ys))):
575 | nodes[e].h = tf.squeeze(v)
576 | return roots
577 |
578 |
579 | class SetEmbeddingLayer(tf.keras.Model):
580 | def __init__(self, dim_E, in_vocab):
581 | super(SetEmbeddingLayer, self).__init__()
582 | self.E = tf.get_variable("E", [in_vocab, dim_E], tf.float32,
583 | initializer=tf.keras.initializers.RandomUniform())
584 |
585 | def call(self, sets):
586 | length = [len(s) for s in sets]
587 | concatenated = tf.concat(sets, 0)
588 | embedded = tf.nn.embedding_lookup(self.E, concatenated)
589 | y = tf.split(embedded, length)
590 | return y
591 |
592 |
593 | class LSTMEncoder(tf.keras.Model):
594 | def __init__(self, dim, layer=1):
595 | super(LSTMEncoder, self).__init__()
596 | self.dim = dim
597 | # self.c_init_f = tfe.Variable(tf.get_variable("c_init_f", [1, dim], tf.float32,
598 | # initializer=he_normal()))
599 | # self.h_init_f = tfe.Variable(tf.get_variable("h_initf", [1, dim], tf.float32,
600 | # initializer=he_normal()))
601 | self.Cell_f = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(dim)
602 | self.h_init_f = tf.zeros([1, dim], tf.float32)
603 | self.c_init_f = tf.zeros([1, dim], tf.float32)
604 |
605 | def call(self, x, length):
606 | '''x: [batch, length, dim]'''
607 | batch = x.shape[0]
608 | ys, states = tf.nn.dynamic_rnn(self.Cell_f, x,
609 | length,
610 | tf.nn.rnn_cell.LSTMStateTuple(
611 | tf.tile(self.c_init_f, [batch, 1]),
612 | tf.tile(self.h_init_f, [batch, 1])))
613 | return ys, states
614 |
615 |
616 | class SequenceEmbeddingLayer(tf.keras.Model):
617 | def __init__(self, dim_E, in_vocab):
618 | super(SequenceEmbeddingLayer, self).__init__()
619 | self.E = tf.keras.layers.Embedding(in_vocab, dim_E)
620 |
621 | def call(self, y):
622 | y = self.E(y)
623 | return y
624 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from utils import pad_tensor
3 | from layers import *
4 | import numpy as np
5 |
6 |
7 | class AttentionDecoder(tf.keras.Model):
8 | def __init__(self, dim_F, dim_rep, vocab_size, layer=1):
9 | super(AttentionDecoder, self).__init__()
10 | self.layer = layer
11 | self.dim_rep = dim_rep
12 | self.F = tf.keras.layers.Embedding(vocab_size, dim_F)
13 | for i in range(layer):
14 | self.__setattr__("layer{}".format(i),
15 | tf.keras.layers.CuDNNLSTM(dim_rep,
16 | return_sequences=True,
17 | return_state=True,
18 | recurrent_initializer='glorot_uniform'))
19 | self.fc = tf.keras.layers.Dense(vocab_size)
20 |
21 | # used for attention
22 | self.W1 = tf.keras.layers.Dense(self.dim_rep)
23 | self.W2 = tf.keras.layers.Dense(self.dim_rep)
24 | self.V = tf.keras.layers.Dense(1)
25 | print("I am Decoder, dim is {} and {} layered".format(str(self.dim_rep), str(self.layer)))
26 |
27 | @staticmethod
28 | def loss_function(real, pred):
29 | loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred)
30 | return tf.reduce_sum(loss_)
31 |
32 | def get_loss(self, enc_y, states, target, dropout=0.0):
33 | '''
34 | enc_y: batch_size([seq_len, dim])
35 | states: ([batch, dim], [batch, dim])
36 | target: [batch, max_len] (padded with -1.)
37 | '''
38 | mask = tf.not_equal(target, -1.)
39 | h, c = states
40 | enc_y, _ = pad_tensor(enc_y)
41 | enc_y = tf.nn.dropout(enc_y, 1. - dropout)
42 | dec_hidden = tf.nn.dropout(h, 1. - dropout)
43 | dec_cell = tf.nn.dropout(c, 1. - dropout)
44 |
45 | l_states = [(dec_hidden, dec_cell) for _ in range(self.layer)]
46 | target = tf.nn.relu(target)
47 | dec_input = target[:, 0]
48 | loss = 0
49 | for t in range(1, target.shape[1]):
50 | # passing enc_output to the decoder
51 | predictions, l_states, att = self.call(
52 | dec_input, l_states, enc_y)
53 | real = tf.boolean_mask(target[:, t], mask[:, t])
54 | pred = tf.boolean_mask(predictions, mask[:, t])
55 | loss += self.loss_function(real, pred)
56 | # using teacher forcing
57 | dec_input = target[:, t]
58 |
59 | return loss / tf.reduce_sum(tf.cast(mask, tf.float32))
60 |
61 | def translate(self, y_enc, states, max_length, start_token, end_token):
62 | '''
63 | enc_y: [seq_len, dim]
64 | states: ([dim,], [dim,])
65 | '''
66 | attention_plot = np.zeros((max_length, y_enc.shape[0]))
67 |
68 | h, c = states
69 | y_enc = tf.expand_dims(y_enc, 0)
70 | dec_hidden = tf.expand_dims(h, 0)
71 | dec_cell = tf.expand_dims(c, 0)
72 | dec_input = tf.constant(start_token, tf.int32, [1])
73 | result = []
74 |
75 | l_states = [(dec_hidden, dec_cell) for _ in range(self.layer)]
76 |
77 | for t in range(max_length):
78 | predictions, l_states, attention_weights = self.call(
79 | dec_input, l_states, y_enc)
80 |
81 | attention_weights = tf.reshape(attention_weights, (-1,))
82 | attention_plot[t] = attention_weights.numpy()
83 |
84 | predicted_id = tf.argmax(predictions[0]).numpy()
85 | result.append(predicted_id)
86 |
87 | if predicted_id == end_token:
88 | return result[:-1], attention_plot[:t]
89 |
90 | # the predicted ID is fed back into the model
91 | dec_input = tf.expand_dims(predicted_id, 0)
92 |
93 | return result, attention_plot
94 |
95 | def call(self, x, l_states, enc_y):
96 | # enc_y shape == (batch_size, max_length, hidden_size)
97 |
98 | # hidden shape == (batch_size, hidden size)
99 | # hidden_with_time_axis shape == (batch_size, 1, hidden size)
100 | # we are doing this to perform addition to calculate the score
101 | hidden_with_time_axis = tf.expand_dims(l_states[-1][0], 1)
102 |
103 | # score shape == (batch_size, max_length, hidden_size)
104 | score = tf.nn.tanh(self.W1(enc_y) + self.W2(hidden_with_time_axis))
105 |
106 | # attention_weights shape == (batch_size, max_length, 1)
107 | # we get 1 at the last axis because we are applying score to self.V
108 | attention_weights = tf.nn.softmax(self.V(score), axis=1)
109 |
110 | # context_vector shape after sum == (batch_size, hidden_size)
111 | context_vector = attention_weights * enc_y
112 | context_vector = tf.reduce_sum(context_vector, axis=1)
113 |
114 | # x shape after passing through embedding == (batch_size, 1, embedding_dim)
115 | x = tf.expand_dims(x, 1)
116 | x = self.F(x)
117 |
118 | # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
119 | # x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
120 |
121 | # passing the concatenated vector to the GRU
122 | new_l_states = []
123 | for i, states in zip(range(self.layer), l_states):
124 | if i < self.layer - 1:
125 | skip = x
126 | x, h, c = getattr(self, "layer{}".format(i))(x, states)
127 | x += skip
128 | else:
129 | x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
130 | x, h, c = getattr(self, "layer{}".format(i))(x, states)
131 | n_states = (h, c)
132 | new_l_states.append(n_states)
133 |
134 | # output shape == (batch_size * 1, hidden_size)
135 | x = tf.reshape(x, (-1, x.shape[2]))
136 |
137 | # output shape == (batch_size * 1, vocab)
138 | x = self.fc(x)
139 |
140 | return x, new_l_states, attention_weights
141 |
142 |
143 | class BaseModel(tf.keras.Model):
144 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0., lr=1e-3):
145 | super(BaseModel, self).__init__()
146 | self.dim_E = dim_E
147 | self.dim_F = dim_F
148 | self.dim_rep = dim_rep
149 | self.in_vocab = in_vocab
150 | self.out_vocab = out_vocab
151 | self.dropout = dropout
152 | self.decoder = AttentionDecoder(dim_F, dim_rep, out_vocab, layer)
153 | self.optimizer = tf.train.AdamOptimizer(lr)
154 |
155 | def encode(self, trees):
156 | '''
157 | ys: list of [seq_len, dim]
158 | hx, cx: [batch, dim]
159 | return: ys, [hx, cx]
160 | '''
161 |
162 | def train_on_batch(self, x, y):
163 | with tf.GradientTape() as tape:
164 | y_enc, (c, h) = self.encode(x)
165 | loss = self.decoder.get_loss(y_enc, (c, h), y, dropout=self.dropout)
166 | variables = self.variables
167 | gradients = tape.gradient(loss, variables)
168 | self.optimizer.apply_gradients(zip(gradients, variables))
169 | return loss.numpy()
170 |
171 | def translate(self, x, nl_i2w, nl_w2i, max_length=100):
172 | res = []
173 | y_enc, (c, h) = self.encode(x)
174 | batch_size = len(y_enc)
175 | for i in range(batch_size):
176 | nl, _ = self.decoder.translate(
177 | y_enc[i], (c[i], h[i]), max_length, nl_w2i[""], nl_w2i[""])
178 | res.append([nl_i2w[n] for n in nl])
179 | return res
180 |
181 | def evaluate_on_batch(self, x, y):
182 | y_enc, (c, h) = self.encode(x)
183 | loss = self.decoder.get_loss(y_enc, (c, h), y)
184 | return loss.numpy()
185 |
186 |
187 | class CodennModel(BaseModel):
188 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.5, lr=1e-3):
189 | super(CodennModel, self).__init__(dim_E, dim_F, dim_rep, in_vocab,
190 | out_vocab, layer, dropout, lr)
191 | self.dropout = dropout
192 | self.E = SetEmbeddingLayer(dim_E, in_vocab)
193 | print("I am CodeNNModel, dim is {} and {} layered".format(
194 | str(self.dim_rep), "0"))
195 |
196 | def encode(self, sets):
197 | sets = self.E(sets)
198 | # sets = [tf.nn.dropout(t, 1. - self.dropout) for t in sets]
199 |
200 | hx = tf.zeros([len(sets), self.dim_rep])
201 | cx = tf.zeros([len(sets), self.dim_rep])
202 | ys = sets
203 |
204 | return ys, [hx, cx]
205 |
206 |
207 | class Seq2seqModel(BaseModel):
208 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.5, lr=1e-3):
209 | super(Seq2seqModel, self).__init__(dim_E, dim_F,
210 | dim_rep, in_vocab, out_vocab, layer, dropout, lr)
211 | self.layer = layer
212 | self.dropout = dropout
213 | self.E = tf.keras.layers.Embedding(in_vocab + 1, dim_E, mask_zero=True)
214 | for i in range(layer):
215 | self.__setattr__("layer{}".format(i),
216 | tf.keras.layers.CuDNNLSTM(dim_rep,
217 | return_sequences=True,
218 | return_state=True))
219 | print("I am seq2seq model, dim is {} and {} layered".format(
220 | str(self.dim_rep), str(self.layer)))
221 |
222 | def encode(self, seq):
223 | length = get_length(seq)
224 | tensor = self.E(seq + 1)
225 | # tensor = tf.nn.dropout(tensor, 1. - self.dropout)
226 | for i in range(self.layer):
227 | skip = tensor
228 | tensor, h, c = getattr(self, "layer{}".format(i))(tensor)
229 | tensor += skip
230 |
231 | cx = c
232 | hx = h
233 | ys = [y[:i] for y, i in zip(tf.unstack(tensor, axis=0), length.numpy())]
234 |
235 | return ys, [hx, cx]
236 |
237 |
238 | class ChildsumModel(BaseModel):
239 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.5, lr=1e-4):
240 | super(ChildsumModel, self).__init__(dim_E, dim_F,
241 | dim_rep, in_vocab, out_vocab, layer, dropout, lr)
242 | self.layer = layer
243 | self.dropout = dropout
244 | self.E = TreeEmbeddingLayer(dim_E, in_vocab)
245 | for i in range(layer):
246 | self.__setattr__("layer{}".format(i), ChildSumLSTMLayer(dim_E, dim_rep))
247 | print("I am Child-sum model, dim is {} and {} layered".format(
248 | str(self.dim_rep), str(self.layer)))
249 |
250 | def encode(self, x):
251 | tensor, indice, tree_num = x
252 | tensor = self.E(tensor)
253 | # tensor = [tf.nn.dropout(t, 1. - self.dropout) for t in tensor]
254 | for i in range(self.layer):
255 | skip = tensor
256 | tensor, c = getattr(self, "layer{}".format(i))(tensor, indice)
257 | tensor = [t + s for t, s in zip(tensor, skip)]
258 |
259 | hx = tensor[-1]
260 | cx = c[-1]
261 | ys = []
262 | batch_size = tensor[-1].shape[0]
263 | tensor = tf.concat(tensor, 0)
264 | tree_num = tf.concat(tree_num, 0)
265 | for batch in range(batch_size):
266 | ys.append(tf.boolean_mask(tensor, tf.equal(tree_num, batch)))
267 | return ys, [hx, cx]
268 |
269 |
270 | class NaryModel(BaseModel):
271 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.5, lr=1e-4):
272 | super(NaryModel, self).__init__(dim_E, dim_F,
273 | dim_rep, in_vocab, out_vocab, layer, dropout, lr)
274 | self.layer = layer
275 | self.dropout = dropout
276 | self.E = TreeEmbeddingLayer(dim_E, in_vocab)
277 | for i in range(layer):
278 | self.__setattr__("layer{}".format(i), NaryLSTMLayer(dim_E, dim_rep))
279 | print("I am N-ary model, dim is {} and {} layered".format(
280 | str(self.dim_rep), str(self.layer)))
281 |
282 | def encode(self, x):
283 | tensor, indice, tree_num = x
284 | tensor = self.E(tensor)
285 | # tensor = [tf.nn.dropout(t, 1. - self.dropout) for t in tensor]
286 | for i in range(self.layer):
287 | skip = tensor
288 | tensor, c = getattr(self, "layer{}".format(i))(tensor, indice)
289 | tensor = [t + s for t, s in zip(tensor, skip)]
290 |
291 | hx = tensor[-1]
292 | cx = c[-1]
293 | ys = []
294 | batch_size = tensor[-1].shape[0]
295 | tensor = tf.concat(tensor, 0)
296 | tree_num = tf.concat(tree_num, 0)
297 | for batch in range(batch_size):
298 | ys.append(tf.boolean_mask(tensor, tf.equal(tree_num, batch)))
299 | return ys, [hx, cx]
300 |
301 |
302 | class MultiwayModel(BaseModel):
303 | def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.0, lr=1e-4):
304 | super(MultiwayModel, self).__init__(dim_E, dim_F,
305 | dim_rep, in_vocab, out_vocab, layer, dropout, lr)
306 | self.layer = layer
307 | self.dropout = dropout
308 | self.E = TreeEmbeddingLayer(dim_E, in_vocab)
309 | for i in range(layer):
310 | self.__setattr__("layer{}".format(i), ShidoTreeLSTMLayer(dim_E, dim_rep))
311 | print("I am Multi-way model, dim is {} and {} layered".format(
312 | str(self.dim_rep), str(self.layer)))
313 |
314 | def encode(self, x):
315 | tensor, indice, tree_num = x
316 | tensor = self.E(tensor)
317 | # tensor = [tf.nn.dropout(t, 1. - self.dropout) for t in tensor]
318 | for i in range(self.layer):
319 | skip = tensor
320 | tensor, c = getattr(self, "layer{}".format(i))(tensor, indice)
321 | tensor = [t + s for t, s in zip(tensor, skip)]
322 |
323 | hx = tensor[-1]
324 | cx = c[-1]
325 | ys = []
326 | batch_size = tensor[-1].shape[0]
327 | tensor = tf.concat(tensor, 0)
328 | tree_num = tf.concat(tree_num, 0)
329 | for batch in range(batch_size):
330 | ys.append(tf.boolean_mask(tensor, tf.equal(tree_num, batch)))
331 | return ys, [hx, cx]
332 |
--------------------------------------------------------------------------------
/notebooks/check_result.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%matplotlib inline\n",
10 | "import sys\n",
11 | "sys.path.append(\"../\")\n",
12 | "from matplotlib import pylab as plt\n",
13 | "import numpy as np\n",
14 | "from glob import glob\n",
15 | "import json\n",
16 | "from utils import *\n",
17 | "from tqdm import tqdm\n",
18 | "import pandas as pd"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "!CUDA_VISIBLE_DEVICE="
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "codes = [json.loads(s)['code'] for s in open(\"/home/shido/summarization_java/test.json\").readlines()]"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "trn_data = read_pickle(\"../dataset/nl/train.pkl\")\n",
46 | "vld_data = read_pickle(\"../dataset/nl/valid.pkl\")\n",
47 | "tst_data = read_pickle(\"../dataset/nl/test.pkl\")\n",
48 | "code_i2w = read_pickle(\"../dataset/code_i2w.pkl\")\n",
49 | "code_w2i = read_pickle(\"../dataset/code_w2i.pkl\")\n",
50 | "nl_i2w = read_pickle(\"../dataset/nl_i2w.pkl\")\n",
51 | "nl_w2i = read_pickle(\"../dataset/nl_w2i.pkl\")\n",
52 | "\n",
53 | "trn_x, trn_y_raw = zip(*trn_data.items())\n",
54 | "vld_x, vld_y_raw = zip(*vld_data.items())\n",
55 | "tst_x, tst_y_raw = zip(*tst_data.items())\n",
56 | "\n",
57 | "trn_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in trn_y_raw]\n",
58 | "vld_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in vld_y_raw]\n",
59 | "tst_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in tst_y_raw]"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "sorted([len(x) for x in trn_y])[::-1]"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "print(len(trn_y))\n",
78 | "print(len(vld_y))\n",
79 | "print(len(tst_y))"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "lengthes = [len(traverse_label(read_pickle(x))) for x in tst_x]"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "files = sorted(glob(\"../models/*/history.json\"))\n",
98 | "dirs = [x.split(\"/\")[-2] for x in files]\n",
99 | "histories = {name: json.load(open(x)) for name, x in zip(dirs, files)}\n",
100 | "dirs"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "names = [\n",
110 | " \"multiway_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\",\n",
111 | " \"childsum_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\",\n",
112 | " \"deepcom_dim256_embed256_drop0.5_lr0.001_batch64_epochs30_layer1\",\n",
113 | " \"codenn_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\",\n",
114 | "]"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "plt.figure(figsize=(15, 10))\n",
124 | "# for name, his in [(view, histories[name]) for view, name in zip([\"Ours\", \"Child-Sum\", \"[Hu+, 18]\", \"[Iyer+, 16]\"], names)]:\n",
125 | "for name, his in histories.items():\n",
126 | " if \"drop0.5_lr0.001_batch160_epochs50_layer1\" in name:\n",
127 | " plt.plot(his[\"bleu_val\"], \"-\", label=name)\n",
128 | "# plt.plot(his[\"loss_val\"], \"-x\", label=name + \"_valid\")\n",
129 | "plt.grid()\n",
130 | "plt.legend()"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "for name, his in histories.items():\n",
140 | " if \"drop0.5_lr0.001_batch160_epochs50_layer1\" in name:\n",
141 | " print(name, \":\", np.mean(his[\"bleus\"]))"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "from nltk.translate.gleu_score import sentence_gleu\n",
151 | "from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction\n",
152 | "def get_gleu(true, pred):\n",
153 | " return(sentence_gleu([true], pred))\n",
154 | "def get_bleu(true, pred):\n",
155 | " return(sentence_bleu([true], pred, smoothing_function=SmoothingFunction().method4))"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "# BLEU\n",
165 | "for name, his in histories.items():\n",
166 | " if \"drop0.5_lr0.001_batch160_epochs50_layer1\" in name:\n",
167 | " trues = histories[name][\"trues\"]\n",
168 | " preds = histories[name][\"preds\"]\n",
169 | " gleu = np.mean([get_bleu(x, y) for x, y in zip(trues, preds)])\n",
170 | " print(name, gleu)"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": null,
176 | "metadata": {},
177 | "outputs": [],
178 | "source": [
179 | "# GLEU\n",
180 | "for name, his in histories.items():\n",
181 | " if \"drop0.5_lr0.001_batch160_epochs50_layer1\" in name:\n",
182 | " trues = histories[name][\"trues\"]\n",
183 | " preds = histories[name][\"preds\"]\n",
184 | " gleu = np.mean([get_gleu(x, y) for x, y in zip(trues, preds)])\n",
185 | " print(name, gleu)"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": null,
191 | "metadata": {},
192 | "outputs": [],
193 | "source": [
194 | "codes = [codes[int(i)] for i in histories[names[0]][\"numbers\"]]"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "df_data = [[\" \".join(x) for x in histories[name][\"preds\"]] for name in names] + [[\" \".join(x) for x in histories[names[0]][\"trues\"]]]\n",
204 | "df_data += [histories[name][\"bleus\"] for name in names] + [codes]\n",
205 | "df_index = [\"PREDICTION \" + name for name in names] + [\"GROUND TRUTH\"]\n",
206 | "df_index += [\"BLEU-4 \" + name for name in names] + [\"SOURCE CODE\"]\n",
207 | "df = pd.DataFrame(data=df_data, index=df_index).T"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": null,
213 | "metadata": {},
214 | "outputs": [],
215 | "source": [
216 | "df.to_csv(\"for_hitachi_lab.csv\")"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": null,
222 | "metadata": {},
223 | "outputs": [],
224 | "source": [
225 | "df.get([\"SOURCE CODE\", \"GROUND TRUTH\"]).head()"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": null,
231 | "metadata": {},
232 | "outputs": [],
233 | "source": [
234 | "shido = histories[\"multiway_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\"][\"bleus\"]\n",
235 | "np.mean(shido)"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": null,
241 | "metadata": {},
242 | "outputs": [],
243 | "source": [
244 | "child = histories[\"childsum_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\"][\"bleus\"]\n",
245 | "np.mean(child)"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": null,
251 | "metadata": {},
252 | "outputs": [],
253 | "source": [
254 | "dif = np.array(shido) - np.array(child)"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "index = np.argsort(dif)[::-1]"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": null,
269 | "metadata": {},
270 | "outputs": [],
271 | "source": [
272 | "trues = histories[\"childsumlstm_1layer\"][\"trues\"]"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": null,
278 | "metadata": {},
279 | "outputs": [],
280 | "source": [
281 | "childsum = histories[\"childsum_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\"][\"preds\"]"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": null,
287 | "metadata": {},
288 | "outputs": [],
289 | "source": [
290 | "ours = histories[\"multiway_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\"][\"preds\"]"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": null,
296 | "metadata": {},
297 | "outputs": [],
298 | "source": [
299 | "number = histories[\"multiway_dim256_embed256_drop0.5_lr0.001_batch128_epochs30_layer1\"][\"numbers\"]"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": null,
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "for i in index[:500]:\n",
309 | " if \" \".join(trues[i]) != \" \".join(ours[i]):\n",
310 | " print(\"GT: \", \" \".join(trues[i]))\n",
311 | " print(\"CSum: \", \" \".join(childsum[i]))\n",
312 | " print(\"Ours: \", \" \".join(ours[i]))\n",
313 | " print(\"Codes:\\n\" + codes[i])\n",
314 | "# print(\"Num: \", number[i])\n",
315 | " print(\"-\" * 100)"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": null,
321 | "metadata": {},
322 | "outputs": [],
323 | "source": [
324 | "trn_data = read_pickle(\"../dataset/nl/train.pkl\")\n",
325 | "vld_data = read_pickle(\"../dataset/nl/valid.pkl\")\n",
326 | "tst_data = read_pickle(\"../dataset/nl/test.pkl\")\n",
327 | "code_i2w = read_pickle(\"../dataset/code_i2w.pkl\")\n",
328 | "code_w2i = read_pickle(\"../dataset/code_w2i.pkl\")\n",
329 | "nl_i2w = read_pickle(\"../dataset/nl_i2w.pkl\")\n",
330 | "nl_w2i = read_pickle(\"../dataset/nl_w2i.pkl\")"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": null,
336 | "metadata": {},
337 | "outputs": [],
338 | "source": [
339 | "trn_data"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": null,
345 | "metadata": {},
346 | "outputs": [],
347 | "source": []
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": null,
352 | "metadata": {},
353 | "outputs": [],
354 | "source": [
355 | "xx = [traverse_label(read_pickle(t)) for t in tst_x]"
356 | ]
357 | },
358 | {
359 | "cell_type": "code",
360 | "execution_count": null,
361 | "metadata": {},
362 | "outputs": [],
363 | "source": [
364 | "x = [\" \".join([str(code_w2i[w]) for w in t]) + \"\\n\" for t in xx]"
365 | ]
366 | },
367 | {
368 | "cell_type": "code",
369 | "execution_count": null,
370 | "metadata": {},
371 | "outputs": [],
372 | "source": [
373 | "y = [\" \".join([str(w) for w in t[1:-1]]) + \"\\n\" for t in tst_y]"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": null,
379 | "metadata": {},
380 | "outputs": [],
381 | "source": [
382 | "xy = [xw + \"\\t\" + yw for xw, yw in zip(x, y)]"
383 | ]
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": null,
388 | "metadata": {},
389 | "outputs": [],
390 | "source": [
391 | "y[0]"
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "execution_count": null,
397 | "metadata": {},
398 | "outputs": [],
399 | "source": [
400 | "open(\"x.tst\", \"w\").writelines(x)"
401 | ]
402 | },
403 | {
404 | "cell_type": "code",
405 | "execution_count": null,
406 | "metadata": {},
407 | "outputs": [],
408 | "source": [
409 | "len(x) / 32"
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "execution_count": null,
415 | "metadata": {},
416 | "outputs": [],
417 | "source": [
418 | "y[0]"
419 | ]
420 | },
421 | {
422 | "cell_type": "code",
423 | "execution_count": null,
424 | "metadata": {},
425 | "outputs": [],
426 | "source": [
427 | "a = open(\"x.tst\", \"r\").read()"
428 | ]
429 | },
430 | {
431 | "cell_type": "code",
432 | "execution_count": null,
433 | "metadata": {},
434 | "outputs": [],
435 | "source": [
436 | "a.split(\"\\n\")[-2]"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": null,
442 | "metadata": {},
443 | "outputs": [],
444 | "source": []
445 | }
446 | ],
447 | "metadata": {
448 | "kernelspec": {
449 | "display_name": "Python 3",
450 | "language": "python",
451 | "name": "python3"
452 | },
453 | "language_info": {
454 | "codemirror_mode": {
455 | "name": "ipython",
456 | "version": 3
457 | },
458 | "file_extension": ".py",
459 | "mimetype": "text/x-python",
460 | "name": "python",
461 | "nbconvert_exporter": "python",
462 | "pygments_lexer": "ipython3",
463 | "version": "3.6.1"
464 | }
465 | },
466 | "nbformat": 4,
467 | "nbformat_minor": 2
468 | }
469 |
--------------------------------------------------------------------------------
/notebooks/example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%matplotlib inline\n",
10 | "import sys\n",
11 | "sys.path.append(\"../\")\n",
12 | "import pickle\n",
13 | "import numpy as np\n",
14 | "from tqdm import tqdm_notebook\n",
15 | "from prefetch_generator import BackgroundGenerator\n",
16 | "from matplotlib import pylab as plt\n",
17 | "from IPython.display import clear_output\n",
18 | "import os\n",
19 | "from joblib import Parallel, delayed\n",
20 | "from tqdm import tqdm\n",
21 | "import nltk\n",
22 | "from glob import glob\n",
23 | "from joblib import Parallel, delayed\n",
24 | "from collections import Counter\n",
25 | "from layers import *\n",
26 | "from utils import *\n",
27 | "from models import *\n",
28 | "import json\n",
29 | "import tensorflow as tf\n",
30 | "tfe = tf.contrib.eager \n",
31 | "config = tf.ConfigProto(\n",
32 | " gpu_options=tf.GPUOptions(\n",
33 | " visible_device_list=\"0\"))\n",
34 | "config.gpu_options.allow_growth = True\n",
35 | "session = tf.Session(config=config)\n",
36 | "tf.enable_eager_execution(config=config)"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "checkpoint_dir = \"../models/path_to_dir\""
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "trn_data = read_pickle(\"../dataset/nl/train.pkl\")\n",
55 | "vld_data = read_pickle(\"../dataset/nl/valid.pkl\")\n",
56 | "tst_data = read_pickle(\"../dataset/nl/test.pkl\")\n",
57 | "code_i2w = read_pickle(\"../dataset/code_i2w.pkl\")\n",
58 | "code_w2i = read_pickle(\"../dataset/code_w2i.pkl\")\n",
59 | "nl_i2w = read_pickle(\"../dataset/nl_i2w.pkl\")\n",
60 | "nl_w2i = read_pickle(\"../dataset/nl_w2i.pkl\")"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "trn_x, trn_y_raw = zip(*trn_data.items())\n",
70 | "vld_x, vld_y_raw = zip(*vld_data.items())\n",
71 | "tst_x, tst_y_raw = zip(*tst_data.items())"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "trn_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in trn_y_raw]\n",
81 | "vld_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in vld_y_raw]\n",
82 | "tst_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[\"\"] for t in l] for l in tst_y_raw]"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "metadata": {},
89 | "outputs": [],
90 | "source": [
91 | "# model defining\n",
92 | "class Model(BaseModel):\n",
93 | " def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.5, lr=1e-4):\n",
94 | " super(Model, self).__init__(dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer, dropout, lr)\n",
95 | " self.E = TreeEmbeddingLayer(dim_E, in_vocab)\n",
96 | " self.encoder = ChildSumLSTMLayer(dim_E, dim_rep)\n",
97 | " \n",
98 | " def encode(self, trees):\n",
99 | " trees = self.E(trees)\n",
100 | " trees = self.encoder(trees)\n",
101 | " \n",
102 | " hx = tf.stack([tree.h for tree in trees])\n",
103 | " cx = tf.stack([tree.c for tree in trees])\n",
104 | " ys = [tf.stack([node.h for node in traverse(tree)]) for tree in trees]\n",
105 | " \n",
106 | " return ys, [hx, cx]"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "# setting model\n",
116 | "model = Model(512, 512, 512, len(code_w2i), len(nl_w2i), dropout=0.5, lr=1e-4)\n",
117 | "epochs = 15\n",
118 | "batch_size = 64\n",
119 | "os.makedirs(checkpoint_dir, exist_ok=True)\n",
120 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
121 | "root = tfe.Checkpoint(model=model)\n",
122 | "history = {\"loss\":[], \"loss_val\":[]}"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "metadata": {},
129 | "outputs": [],
130 | "source": [
131 | "# Setting Data Generator\n",
132 | "trn_gen = Datagen_tree(trn_x, trn_y, batch_size, code_w2i, nl_i2w, train=True)\n",
133 | "vld_gen = Datagen_tree(vld_x, vld_y, batch_size, code_w2i, nl_i2w, train=False)\n",
134 | "tst_gen = Datagen_tree(tst_x, tst_y, batch_size, code_w2i, nl_i2w, train=False)"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "metadata": {},
141 | "outputs": [],
142 | "source": [
143 | "# training\n",
144 | "for epoch in range(epochs):\n",
145 | " \n",
146 | " # train\n",
147 | " loss_tmp = []\n",
148 | " t = tqdm(trn_gen(epoch))\n",
149 | " for x, y, _, _ in t:\n",
150 | " loss_tmp.append(model.train_on_batch(x, y))\n",
151 | " t.set_description(\"epoch:{:03d}, loss = {}\".format(epoch + 1, np.mean(loss_tmp)))\n",
152 | " history[\"loss\"].append(np.sum(loss_tmp) / len(t))\n",
153 | " \n",
154 | " loss_tmp = []\n",
155 | " t = tqdm(vld_gen(epoch))\n",
156 | " for x, y, _, _ in t:\n",
157 | " loss_tmp.append(model.evaluate_on_batch(x, y))\n",
158 | " t.set_description(\"epoch:{:03d}, loss_val = {}\".format(epoch + 1, np.mean(loss_tmp)))\n",
159 | " history[\"loss_val\"].append(np.sum(loss_tmp) / len(t))\n",
160 | " \n",
161 | " # checkpoint\n",
162 | " if history[\"loss_val\"][-1] == min(history[\"loss_val\"]):\n",
163 | " checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
164 | " root.save(file_prefix=checkpoint_prefix)\n",
165 | " \n",
166 | " # print\n",
167 | " clear_output()\n",
168 | " for key, val in history.items():\n",
169 | " if \"loss\" in key:\n",
170 | " plt.plot(val, label=key)\n",
171 | " plt.legend()\n",
172 | " plt.show()"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "metadata": {},
179 | "outputs": [],
180 | "source": [
181 | "root.restore(tf.train.latest_checkpoint(checkpoint_dir))"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "preds = []\n",
191 | "trues = []\n",
192 | "for x, y, _, y_raw in tqdm(tst_gen(0)):\n",
193 | " res = model.translate(x, nl_i2w, nl_w2i)\n",
194 | " preds += res\n",
195 | " trues += [s[1:-1] for s in y_raw]"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "metadata": {},
202 | "outputs": [],
203 | "source": [
204 | "bleus = Parallel(n_jobs=-1)(delayed(bleu4)(t, p) for t, p in tqdm(list(zip(trues, preds))))"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "history[\"bleus\"] = bleus\n",
214 | "history[\"preds\"] = preds\n",
215 | "history[\"trues\"] = trues\n",
216 | "history[\"numbers\"] = [int(x.split(\"/\")[-1]) for x in tst_x]"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": null,
222 | "metadata": {},
223 | "outputs": [],
224 | "source": [
225 | "with open(os.path.join(checkpoint_dir, \"history.json\"), \"w\") as f:\n",
226 | " json.dump(history, f)"
227 | ]
228 | }
229 | ],
230 | "metadata": {
231 | "anaconda-cloud": {},
232 | "kernelspec": {
233 | "display_name": "Python 3",
234 | "language": "python",
235 | "name": "python3"
236 | },
237 | "language_info": {
238 | "codemirror_mode": {
239 | "name": "ipython",
240 | "version": 3
241 | },
242 | "file_extension": ".py",
243 | "mimetype": "text/x-python",
244 | "name": "python",
245 | "nbconvert_exporter": "python",
246 | "pygments_lexer": "ipython3",
247 | "version": "3.6.1"
248 | }
249 | },
250 | "nbformat": 4,
251 | "nbformat_minor": 2
252 | }
253 |
--------------------------------------------------------------------------------
/parser/README.md:
--------------------------------------------------------------------------------
1 | # parser
2 |
3 | Run `java -jar parser.jar -f [filename] -d [dirname]`.
4 |
5 | # example
6 |
7 | `java -jar parser.jar -f valid.json -d valid`
8 |
9 | # requirement
10 | Java 1.8
11 |
--------------------------------------------------------------------------------
/parser/parser.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sh1doy/summarization_tf/2f14f2c28c63140288acc6515db236e486ab7152/parser/parser.jar
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.5.0
2 | alabaster==0.7.10
3 | anaconda-client==1.6.3
4 | anaconda-navigator==1.6.2
5 | anaconda-project==0.6.0
6 | asn1crypto==0.22.0
7 | astor==0.7.1
8 | astroid==1.4.9
9 | astropy==1.3.2
10 | Babel==2.4.0
11 | backports.shutil-get-terminal-size==1.0.0
12 | beautifulsoup4==4.6.0
13 | bitarray==0.8.1
14 | blaze==0.10.1
15 | bleach==1.5.0
16 | bokeh==0.12.5
17 | boto==2.46.1
18 | Bottleneck==1.2.1
19 | cffi==1.10.0
20 | chardet==3.0.3
21 | click==6.7
22 | cloudpickle==0.2.2
23 | clyent==1.2.2
24 | colorama==0.3.9
25 | conda==4.3.21
26 | contextlib2==0.5.5
27 | cryptography==1.8.1
28 | cycler==0.10.0
29 | Cython==0.25.2
30 | cytoolz==0.8.2
31 | dask==0.14.3
32 | datashape==0.5.4
33 | decorator==4.0.11
34 | distributed==1.16.3
35 | docutils==0.13.1
36 | entrypoints==0.2.2
37 | et-xmlfile==1.0.1
38 | fastcache==1.0.2
39 | Flask==0.12.2
40 | Flask-Cors==3.0.2
41 | gast==0.2.0
42 | gevent==1.2.1
43 | greenlet==0.4.12
44 | grpcio==1.15.0
45 | h5py==2.8.0
46 | HeapDict==1.0.0
47 | html5lib==0.999
48 | idna==2.5
49 | imagesize==0.7.1
50 | ipykernel==4.6.1
51 | ipython==5.3.0
52 | ipython-genutils==0.2.0
53 | ipywidgets==6.0.0
54 | isort==4.2.5
55 | itsdangerous==0.24
56 | jdcal==1.3
57 | jedi==0.10.2
58 | Jinja2==2.9.6
59 | joblib==0.12.5
60 | jsonschema==2.6.0
61 | jupyter==1.0.0
62 | jupyter-client==5.0.1
63 | jupyter-console==5.1.0
64 | jupyter-core==4.3.0
65 | jupyterthemes==0.17.0
66 | lazy-object-proxy==1.2.2
67 | lesscpy==0.13.0
68 | llvmlite==0.18.0
69 | locket==0.2.0
70 | lxml==3.7.3
71 | Markdown==2.6.11
72 | MarkupSafe==0.23
73 | matplotlib==2.0.2
74 | mistune==0.7.4
75 | mpmath==0.19
76 | msgpack-python==0.4.8
77 | multipledispatch==0.4.9
78 | navigator-updater==0.1.0
79 | nbconvert==5.1.1
80 | nbformat==4.3.0
81 | networkx==1.11
82 | nltk==3.2.3
83 | nose==1.3.7
84 | notebook==5.0.0
85 | numba==0.33.0
86 | numexpr==2.6.2
87 | numpy==1.14.5
88 | numpydoc==0.6.0
89 | odo==0.5.0
90 | olefile==0.44
91 | openpyxl==2.4.7
92 | packaging==16.8
93 | pandas==0.20.1
94 | pandocfilters==1.4.1
95 | partd==0.3.8
96 | pathlib2==2.2.1
97 | patsy==0.4.1
98 | pep8==1.7.0
99 | pexpect==4.2.1
100 | pickleshare==0.7.4
101 | Pillow==4.1.1
102 | ply==3.10
103 | prefetch-generator==1.0.0
104 | prometheus-client==0.3.1
105 | prompt-toolkit==1.0.14
106 | protobuf==3.6.1
107 | psutil==5.2.2
108 | ptyprocess==0.5.1
109 | py==1.4.33
110 | pycosat==0.6.2
111 | pycparser==2.17
112 | pycrypto==2.6.1
113 | pycurl==7.43.0
114 | pyflakes==1.5.0
115 | Pygments==2.2.0
116 | pylint==1.6.4
117 | pyodbc==4.0.16
118 | pyOpenSSL==17.0.0
119 | pyparsing==2.1.4
120 | pytest==3.0.7
121 | python-dateutil==2.6.0
122 | pytz==2017.2
123 | PyWavelets==0.5.2
124 | PyYAML==3.12
125 | pyzmq==16.0.2
126 | QtAwesome==0.4.4
127 | qtconsole==4.3.0
128 | QtPy==1.2.1
129 | requests==2.14.2
130 | rope-py3k==0.9.4.post1
131 | scikit-image==0.13.0
132 | scikit-learn==0.18.1
133 | scipy==0.19.0
134 | seaborn==0.7.1
135 | simplegeneric==0.8.1
136 | singledispatch==3.4.0.3
137 | six==1.10.0
138 | snowballstemmer==1.2.1
139 | sortedcollections==0.5.3
140 | sortedcontainers==1.5.7
141 | Sphinx==1.5.6
142 | spyder==3.1.4
143 | SQLAlchemy==1.1.9
144 | statsmodels==0.8.0
145 | sympy==1.0
146 | tables==3.3.0
147 | tblib==1.3.2
148 | tensorboard==1.10.0
149 | tensorflow-gpu==1.10.1
150 | termcolor==1.1.0
151 | terminado==0.6
152 | testpath==0.3
153 | toolz==0.8.2
154 | tornado==4.5.1
155 | tqdm==4.26.0
156 | traitlets==4.3.2
157 | unicodecsv==0.14.1
158 | wcwidth==0.1.7
159 | Werkzeug==0.12.2
160 | widgetsnbextension==2.0.0
161 | wrapt==1.10.10
162 | xlrd==1.0.0
163 | XlsxWriter==0.9.6
164 | xlwt==1.2.0
165 | zict==0.1.2
166 |
--------------------------------------------------------------------------------
/retrain.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from utils import read_pickle, Datagen_set, Datagen_deepcom, Datagen_tree, Datagen_binary, bleu4
3 | from models import Seq2seqModel, CodennModel, ChildsumModel, MultiwayModel, NaryModel
4 | import numpy as np
5 | import os
6 | import tensorflow as tf
7 | from tqdm import tqdm
8 | from joblib import delayed, Parallel
9 | import json
10 |
11 |
12 | # parse argments
13 |
14 | parser = argparse.ArgumentParser(description='Source Code Generation')
15 |
16 | parser.add_argument('-m', "--method", type=str, nargs="?", required=True,
17 | choices=['seq2seq', 'deepcom', 'codenn', 'childsum', 'multiway', "nary"],
18 | help='Encoder method')
19 | parser.add_argument('-d', "--dim", type=int, nargs="?", required=False, default=512,
20 | help='Representation dimension')
21 | parser.add_argument("--embed", type=int, nargs="?", required=False, default=256,
22 | help='Representation dimension')
23 | parser.add_argument("--drop", type=float, nargs="?", required=False, default=.5,
24 | help="Dropout rate")
25 | parser.add_argument('-r', "--lr", type=float, nargs="?", required=True,
26 | help='Learning rate')
27 | parser.add_argument('-b', "--batch", type=int, nargs="?", required=True,
28 | help='Mini batch size')
29 | parser.add_argument('-e', "--epochs", type=int, nargs="?", required=True,
30 | help='Epoch number')
31 | parser.add_argument('-g', "--gpu", type=str, nargs="?", required=True,
32 | help='What GPU to use')
33 | parser.add_argument('-l', "--layer", type=int, nargs="?", required=False, default=1,
34 | help='Number of layers')
35 | parser.add_argument("--val", type=str, nargs="?", required=False, default="BLEU",
36 | help='Validation method')
37 |
38 | args = parser.parse_args()
39 |
40 | name = args.method + "_dim" + str(args.dim) + "_embed" + str(args.embed)
41 | name = name + "_drop" + str(args.drop)
42 | name = name + "_lr" + str(args.lr) + "_batch" + str(args.batch)
43 | name = name + "_epochs" + str(args.epochs) + "_layer" + str(args.layer)
44 |
45 | checkpoint_dir = "./models/" + name
46 |
47 |
48 | # set tf eager
49 |
50 | tfe = tf.contrib.eager
51 | config = tf.ConfigProto(
52 | gpu_options=tf.GPUOptions(
53 | visible_device_list=args.gpu))
54 | # config.gpu_options.allow_growth = True
55 | session = tf.Session(config=config)
56 | tf.enable_eager_execution(config=config)
57 | os.makedirs("./logs/" + name, exist_ok=True)
58 | writer = tf.contrib.summary.create_file_writer("./logs/" + name, flush_millis=10000)
59 |
60 |
61 | # load data
62 |
63 | trn_data = read_pickle("dataset/nl/train.pkl")
64 | vld_data = read_pickle("dataset/nl/valid.pkl")
65 | tst_data = read_pickle("dataset/nl/test.pkl")
66 | code_i2w = read_pickle("dataset/code_i2w.pkl")
67 | code_w2i = read_pickle("dataset/code_w2i.pkl")
68 | nl_i2w = read_pickle("dataset/nl_i2w.pkl")
69 | nl_w2i = read_pickle("dataset/nl_w2i.pkl")
70 |
71 | trn_x, trn_y_raw = zip(*trn_data.items())
72 | vld_x, vld_y_raw = zip(*vld_data.items())
73 | tst_x, tst_y_raw = zip(*tst_data.items())
74 |
75 | trn_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in trn_y_raw]
76 | vld_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in vld_y_raw]
77 | tst_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in tst_y_raw]
78 |
79 |
80 | # setting model
81 |
82 | if args.method in ['seq2seq', 'deepcom']:
83 | Model = Seq2seqModel
84 | elif args.method in ['codenn']:
85 | Model = CodennModel
86 | elif args.method in ['childsum']:
87 | Model = ChildsumModel
88 | elif args.method in ['multiway']:
89 | Model = MultiwayModel
90 | elif args.method in ['nary']:
91 | Model = NaryModel
92 |
93 |
94 | model = Model(args.dim, args.dim, args.dim, len(code_w2i), len(nl_w2i),
95 | dropout=args.drop, lr=args.lr, layer=args.layer)
96 | epochs = args.epochs
97 | batch_size = args.batch
98 | os.makedirs(checkpoint_dir, exist_ok=True)
99 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
100 | root = tfe.Checkpoint(model=model)
101 | history = {"loss": [], "loss_val": [], "bleu_val": []}
102 |
103 | root.restore(tf.train.latest_checkpoint(checkpoint_dir))
104 |
105 | # Setting Data Generator
106 |
107 | if args.method in ['deepcom']:
108 | Datagen = Datagen_deepcom
109 | elif args.method in ['codenn']:
110 | Datagen = Datagen_set
111 | elif args.method in ['childsum', 'multiway']:
112 | Datagen = Datagen_tree
113 | elif args.method in ['nary']:
114 | Datagen = Datagen_binary
115 |
116 |
117 | trn_gen = Datagen(trn_x, trn_y, batch_size, code_w2i, nl_i2w, train=True)
118 | vld_gen = Datagen(vld_x, vld_y, batch_size, code_w2i, nl_i2w, train=False)
119 | tst_gen = Datagen(tst_x, tst_y, batch_size, code_w2i, nl_i2w, train=False)
120 |
121 |
122 | # training
123 | with writer.as_default(), tf.contrib.summary.always_record_summaries():
124 |
125 | for epoch in range(1, epochs + 1):
126 |
127 | # train
128 | loss_tmp = []
129 | t = tqdm(trn_gen(0))
130 | for x, y, _, _ in t:
131 | loss_tmp.append(model.train_on_batch(x, y))
132 | t.set_description("epoch:{:03d}, loss = {}".format(epoch, np.mean(loss_tmp)))
133 | history["loss"].append(np.sum(loss_tmp) / len(t))
134 | tf.contrib.summary.scalar("loss", np.sum(loss_tmp) / len(t), step=epoch)
135 |
136 | # validate loss
137 | loss_tmp = []
138 | t = tqdm(vld_gen(0))
139 | for x, y, _, _ in t:
140 | loss_tmp.append(model.evaluate_on_batch(x, y))
141 | t.set_description("epoch:{:03d}, loss_val = {}".format(epoch, np.mean(loss_tmp)))
142 | history["loss_val"].append(np.sum(loss_tmp) / len(t))
143 | tf.contrib.summary.scalar("loss_val", np.sum(loss_tmp) / len(t), step=epoch)
144 |
145 | # validate bleu
146 | preds = []
147 | trues = []
148 | bleus = []
149 | t = tqdm(vld_gen(0))
150 | for x, y, _, y_raw in t:
151 | res = model.translate(x, nl_i2w, nl_w2i)
152 | preds += res
153 | trues += [s[1:-1] for s in y_raw]
154 | bleus += [bleu4(tt, p) for tt, p in zip(trues, preds)]
155 | t.set_description("epoch:{:03d}, bleu_val = {}".format(epoch, np.mean(bleus)))
156 | history["bleu_val"].append(np.mean(bleus))
157 | tf.contrib.summary.scalar("bleu_val", np.mean(bleus), step=epoch)
158 |
159 | # checkpoint
160 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
161 | hoge = root.save(file_prefix=checkpoint_prefix)
162 | if history["bleu_val"][-1] == max(history["bleu_val"]):
163 | best_model = hoge
164 | print("Now best model is {}".format(best_model))
165 |
166 |
167 | # load final weight
168 |
169 | print("Restore {}".format(best_model))
170 | root.restore(best_model)
171 |
172 | # evaluation
173 |
174 | preds = []
175 | trues = []
176 | for x, y, _, y_raw in tqdm(tst_gen(0), "Testing"):
177 | res = model.translate(x, nl_i2w, nl_w2i)
178 | preds += res
179 | trues += [s[1:-1] for s in y_raw]
180 |
181 | bleus = Parallel(n_jobs=-1)(delayed(bleu4)(t, p) for t, p in (list(zip(trues, preds))))
182 |
183 | history["bleus"] = bleus
184 | history["preds"] = preds
185 | history["trues"] = trues
186 | history["numbers"] = [int(x.split("/")[-1]) for x in tst_x]
187 |
188 | with open(os.path.join(checkpoint_dir, "history.json"), "w") as f:
189 | json.dump(history, f)
190 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from utils import read_pickle, Datagen_set, Datagen_deepcom, Datagen_tree, Datagen_binary, bleu4
3 | from models import Seq2seqModel, CodennModel, ChildsumModel, MultiwayModel, NaryModel
4 | import numpy as np
5 | import os
6 | import tensorflow as tf
7 | from tqdm import tqdm
8 | from joblib import delayed, Parallel
9 | import json
10 |
11 |
12 | # parse argments
13 |
14 | parser = argparse.ArgumentParser(description='Source Code Generation')
15 |
16 | parser.add_argument('-m', "--method", type=str, nargs="?", required=True,
17 | choices=['seq2seq', 'deepcom', 'codenn', 'childsum', 'multiway', "nary"],
18 | help='Encoder method')
19 | parser.add_argument('-d', "--dim", type=int, nargs="?", required=False, default=512,
20 | help='Representation dimension')
21 | parser.add_argument("--embed", type=int, nargs="?", required=False, default=256,
22 | help='Representation dimension')
23 | parser.add_argument("--drop", type=float, nargs="?", required=False, default=.5,
24 | help="Dropout rate")
25 | parser.add_argument('-r', "--lr", type=float, nargs="?", required=True,
26 | help='Learning rate')
27 | parser.add_argument('-b', "--batch", type=int, nargs="?", required=True,
28 | help='Mini batch size')
29 | parser.add_argument('-e', "--epochs", type=int, nargs="?", required=True,
30 | help='Epoch number')
31 | parser.add_argument('-g', "--gpu", type=str, nargs="?", required=True,
32 | help='What GPU to use')
33 | parser.add_argument('-l', "--layer", type=int, nargs="?", required=False, default=1,
34 | help='Number of layers')
35 | parser.add_argument("--val", type=str, nargs="?", required=False, default="BLEU",
36 | help='Validation method')
37 |
38 | args = parser.parse_args()
39 |
40 | name = args.method + "_dim" + str(args.dim) + "_embed" + str(args.embed)
41 | name = name + "_drop" + str(args.drop)
42 | name = name + "_lr" + str(args.lr) + "_batch" + str(args.batch)
43 | name = name + "_epochs" + str(args.epochs) + "_layer" + str(args.layer) + "NEW_skip_size100"
44 |
45 | checkpoint_dir = "./models/" + name
46 |
47 |
48 | # set tf eager
49 |
50 | tfe = tf.contrib.eager
51 | config = tf.ConfigProto(
52 | gpu_options=tf.GPUOptions(
53 | visible_device_list=args.gpu))
54 | # config.gpu_options.allow_growth = True
55 | session = tf.Session(config=config)
56 | tf.enable_eager_execution(config=config)
57 | os.makedirs("./logs/" + name, exist_ok=True)
58 | writer = tf.contrib.summary.create_file_writer("./logs/" + name, flush_millis=10000)
59 |
60 |
61 | # load data
62 |
63 | trn_data = read_pickle("dataset/nl/train.pkl")
64 | vld_data = read_pickle("dataset/nl/valid.pkl")
65 | tst_data = read_pickle("dataset/nl/test.pkl")
66 | code_i2w = read_pickle("dataset/code_i2w.pkl")
67 | code_w2i = read_pickle("dataset/code_w2i.pkl")
68 | nl_i2w = read_pickle("dataset/nl_i2w.pkl")
69 | nl_w2i = read_pickle("dataset/nl_w2i.pkl")
70 |
71 | trn_x, trn_y_raw = zip(*sorted(trn_data.items()))
72 | vld_x, vld_y_raw = zip(*sorted(vld_data.items()))
73 | tst_x, tst_y_raw = zip(*sorted(tst_data.items()))
74 |
75 | trn_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in trn_y_raw]
76 | vld_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in vld_y_raw]
77 | tst_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i[""] for t in l] for l in tst_y_raw]
78 |
79 |
80 | # setting model
81 |
82 | if args.method in ['seq2seq', 'deepcom']:
83 | Model = Seq2seqModel
84 | elif args.method in ['codenn']:
85 | Model = CodennModel
86 | elif args.method in ['childsum']:
87 | Model = ChildsumModel
88 | elif args.method in ['multiway']:
89 | Model = MultiwayModel
90 | elif args.method in ['nary']:
91 | Model = NaryModel
92 |
93 |
94 | model = Model(args.dim, args.dim, args.dim, len(code_w2i), len(nl_w2i),
95 | dropout=args.drop, lr=args.lr, layer=args.layer)
96 | epochs = args.epochs
97 | batch_size = args.batch
98 | os.makedirs(checkpoint_dir, exist_ok=True)
99 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
100 | root = tfe.Checkpoint(model=model)
101 | history = {"loss": [], "loss_val": [], "bleu_val": []}
102 |
103 |
104 | # Setting Data Generator
105 |
106 | if args.method in ['deepcom']:
107 | Datagen = Datagen_deepcom
108 | elif args.method in ['codenn']:
109 | Datagen = Datagen_set
110 | elif args.method in ['childsum', 'multiway']:
111 | Datagen = Datagen_tree
112 | elif args.method in ['nary']:
113 | Datagen = Datagen_binary
114 |
115 |
116 | trn_gen = Datagen(trn_x, trn_y, batch_size, code_w2i, nl_i2w, train=True)
117 | vld_gen = Datagen(vld_x, vld_y, batch_size, code_w2i, nl_i2w, train=False)
118 | tst_gen = Datagen(tst_x, tst_y, batch_size, code_w2i, nl_i2w, train=False)
119 |
120 |
121 | # training
122 | with writer.as_default(), tf.contrib.summary.always_record_summaries():
123 |
124 | for epoch in range(1, epochs + 1):
125 |
126 | # train
127 | loss_tmp = []
128 | t = tqdm(trn_gen(0))
129 | for x, y, _, _ in t:
130 | loss_tmp.append(model.train_on_batch(x, y))
131 | t.set_description("epoch:{:03d}, loss = {}".format(epoch, np.mean(loss_tmp)))
132 | history["loss"].append(np.sum(loss_tmp) / len(t))
133 | tf.contrib.summary.scalar("loss", np.sum(loss_tmp) / len(t), step=epoch)
134 |
135 | # validate loss
136 | loss_tmp = []
137 | t = tqdm(vld_gen(0))
138 | for x, y, _, _ in t:
139 | loss_tmp.append(model.evaluate_on_batch(x, y))
140 | t.set_description("epoch:{:03d}, loss_val = {}".format(epoch, np.mean(loss_tmp)))
141 | history["loss_val"].append(np.sum(loss_tmp) / len(t))
142 | tf.contrib.summary.scalar("loss_val", np.sum(loss_tmp) / len(t), step=epoch)
143 |
144 | # validate bleu
145 | preds = []
146 | trues = []
147 | bleus = []
148 | t = tqdm(vld_gen(0))
149 | for x, y, _, y_raw in t:
150 | res = model.translate(x, nl_i2w, nl_w2i)
151 | preds += res
152 | trues += [s[1:-1] for s in y_raw]
153 | bleus += [bleu4(tt, p) for tt, p in zip(trues, preds)]
154 | t.set_description("epoch:{:03d}, bleu_val = {}".format(epoch, np.mean(bleus)))
155 | history["bleu_val"].append(np.mean(bleus))
156 | tf.contrib.summary.scalar("bleu_val", np.mean(bleus), step=epoch)
157 |
158 | # checkpoint
159 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
160 | hoge = root.save(file_prefix=checkpoint_prefix)
161 | if history["bleu_val"][-1] == max(history["bleu_val"]):
162 | best_model = hoge
163 | print("Now best model is {}".format(best_model))
164 |
165 |
166 | # load final weight
167 |
168 | print("Restore {}".format(best_model))
169 | root.restore(best_model)
170 |
171 | # evaluation
172 |
173 | preds = []
174 | trues = []
175 | for x, y, _, y_raw in tqdm(tst_gen(0), "Testing"):
176 | res = model.translate(x, nl_i2w, nl_w2i)
177 | preds += res
178 | trues += [s[1:-1] for s in y_raw]
179 |
180 | bleus = Parallel(n_jobs=-1)(delayed(bleu4)(t, p) for t, p in (list(zip(trues, preds))))
181 |
182 | history["bleus"] = bleus
183 | history["preds"] = preds
184 | history["trues"] = trues
185 | history["numbers"] = [int(x.split("/")[-1]) for x in tst_x]
186 |
187 | with open(os.path.join(checkpoint_dir, "history.json"), "w") as f:
188 | json.dump(history, f)
189 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """Utilities"""
2 | import tensorflow as tf
3 | import numpy as np
4 | import math
5 | from collections import defaultdict
6 | import pickle
7 | from prefetch_generator import BackgroundGenerator
8 |
9 |
10 | def get_nums(roots):
11 | '''convert roots to indices'''
12 | res = [[x.num for x in n.children] if n.children != [] else [0] for n in roots]
13 | max_len = max([len(x) for x in res])
14 | res = tf.keras.preprocessing.sequence.pad_sequences(
15 | res, max_len, padding="post", value=-1.)
16 | return tf.constant(res, tf.int32)
17 |
18 |
19 | def tree2binary(trees):
20 | def helper(root):
21 | if len(root.children) > 2:
22 | tmp = root.children[0]
23 | for child in root.children[1:]:
24 | tmp.children += [child]
25 | tmp = child
26 | root.children = root.children[0:1]
27 | for child in root.children:
28 | helper(child)
29 | return root
30 | return [helper(x) for x in trees]
31 |
32 |
33 | def tree2tensor(trees):
34 | '''
35 | indice:
36 | this has structure data.
37 | 0 represent init state,
38 | 1 r else np.exp(1 - r / (c + 1e-10))
268 | score = 0
269 | for i in range(1, 5):
270 | true_ngram = set(ngram(true, i))
271 | pred_ngram = ngram(pred, i)
272 | length = float(len(pred_ngram)) + 1e-10
273 | count = sum([1. if t in true_ngram else 0. for t in pred_ngram])
274 | score += math.log(1e-10 + (count / length))
275 | score = math.exp(score * .25)
276 | bleu = bp * score
277 | return bleu
278 |
279 |
280 | class Datagen_tree:
281 | def __init__(self, X, Y, batch_size, code_dic, nl_dic, train=True, binary=False):
282 | self.X = X
283 | self.Y = Y
284 | self.batch_size = batch_size
285 | self.code_dic = code_dic
286 | self.nl_dic = nl_dic
287 | self.train = train
288 | self.binary = binary
289 |
290 | def __len__(self):
291 | return len(range(0, len(self.X), self.batch_size))
292 |
293 | def __call__(self, epoch=0):
294 | return GeneratorLen(BackgroundGenerator(self.gen(epoch), 1), len(self))
295 |
296 | def gen(self, epoch):
297 | if self.train:
298 | np.random.seed(epoch)
299 | newindex = list(np.random.permutation(len(self.X)))
300 | X = [self.X[i] for i in newindex]
301 | Y = [self.Y[i] for i in newindex]
302 | else:
303 | X = [x for x in self.X]
304 | Y = [y for y in self.Y]
305 | for i in range(0, len(self.X), self.batch_size):
306 | x = X[i:i + self.batch_size]
307 | y = Y[i:i + self.batch_size]
308 | x_raw = [read_pickle(n) for n in x]
309 | if self.binary:
310 | x_raw = tree2binary(x_raw)
311 | y_raw = [[self.nl_dic[t] for t in s] for s in y]
312 | x = [consult_tree(n, self.code_dic) for n in x_raw]
313 | x_raw = [traverse_label(n) for n in x_raw]
314 | y = tf.keras.preprocessing.sequence.pad_sequences(
315 | y,
316 | min(max([len(s) for s in y]), 100),
317 | padding="post", truncating="post", value=-1.)
318 | yield tree2tensor(x), y, x_raw, y_raw
319 |
320 |
321 | class Datagen_binary(Datagen_tree):
322 | def __init__(self, X, Y, batch_size, code_dic, nl_dic, train=True, binary=True):
323 | super(Datagen_binary, self).__init__(X, Y, batch_size, code_dic,
324 | nl_dic, train=True, binary=True)
325 |
326 |
327 | class Datagen_set:
328 | def __init__(self, X, Y, batch_size, code_dic, nl_dic, train=True):
329 | self.X = X
330 | self.Y = Y
331 | self.batch_size = batch_size
332 | self.code_dic = code_dic
333 | self.nl_dic = nl_dic
334 | self.train = train
335 |
336 | def __len__(self):
337 | return len(range(0, len(self.X), self.batch_size))
338 |
339 | def __call__(self, epoch=0):
340 | return GeneratorLen(BackgroundGenerator(self.gen(epoch), 1), len(self))
341 |
342 | def gen(self, epoch):
343 | if self.train:
344 | np.random.seed(epoch)
345 | newindex = list(np.random.permutation(len(self.X)))
346 | X = [self.X[i] for i in newindex]
347 | Y = [self.Y[i] for i in newindex]
348 | else:
349 | X = [x for x in self.X]
350 | Y = [y for y in self.Y]
351 | for i in range(0, len(self.X), self.batch_size):
352 | x = X[i:i + self.batch_size]
353 | y = Y[i:i + self.batch_size]
354 | x_raw = [read_pickle(n) for n in x]
355 | y_raw = [[self.nl_dic[t] for t in s] for s in y]
356 | x = [traverse_label(n) for n in x_raw]
357 | x = [np.array([self.code_dic[t] for t in xx], "int32") for xx in x]
358 | x_raw = [traverse_label(n) for n in x_raw]
359 | y = tf.constant(
360 | tf.keras.preprocessing.sequence.pad_sequences(
361 | y,
362 | min(max([len(s) for s in y]), 100),
363 | padding="post", truncating="post", value=-1.))
364 | yield x, y, x_raw, y_raw
365 |
366 |
367 | def sequencing(root):
368 | li = ["(", root.label]
369 | for child in root.children:
370 | li += sequencing(child)
371 | li += [")", root.label]
372 | return(li)
373 |
374 |
375 | class Datagen_deepcom:
376 | def __init__(self, X, Y, batch_size, code_dic, nl_dic, train=True):
377 | self.X = X
378 | self.Y = Y
379 | self.batch_size = batch_size
380 | self.code_dic = code_dic
381 | self.nl_dic = nl_dic
382 | self.train = train
383 |
384 | def __len__(self):
385 | return len(range(0, len(self.X), self.batch_size))
386 |
387 | def __call__(self, epoch=0):
388 | return GeneratorLen(BackgroundGenerator(self.gen(epoch), 1), len(self))
389 |
390 | def gen(self, epoch):
391 | if self.train:
392 | np.random.seed(epoch)
393 | newindex = list(np.random.permutation(len(self.X)))
394 | X = [self.X[i] for i in newindex]
395 | Y = [self.Y[i] for i in newindex]
396 | else:
397 | X = [x for x in self.X]
398 | Y = [y for y in self.Y]
399 | for i in range(0, len(self.X), self.batch_size):
400 | x = X[i:i + self.batch_size]
401 | y = Y[i:i + self.batch_size]
402 | x_raw = [read_pickle(n) for n in x]
403 | y_raw = [[self.nl_dic[t] for t in s] for s in y]
404 | x = [sequencing(n) for n in x_raw]
405 | x = [np.array([self.code_dic[t] for t in xx], "int32") for xx in x]
406 | x = tf.constant(
407 | tf.keras.preprocessing.sequence.pad_sequences(
408 | x,
409 | min(max([len(s) for s in x]), 400),
410 | padding="post", truncating="post", value=-1.))
411 | x_raw = [traverse_label(n) for n in x_raw]
412 | y = tf.constant(
413 | tf.keras.preprocessing.sequence.pad_sequences(
414 | y,
415 | min(max([len(s) for s in y]), 100),
416 | padding="post", truncating="post", value=-1.))
417 | yield x, y, x_raw, y_raw
418 |
419 |
420 | def get_length(tensor, pad_value=-1.):
421 | '''tensor: [batch, max_len]'''
422 | mask = tf.not_equal(tensor, pad_value)
423 | return tf.reduce_sum(tf.cast(mask, tf.int32), 1)
424 |
--------------------------------------------------------------------------------