├── LICENSE
├── README.md
├── __init__.py
├── ast_graph_encoder.py
├── baselines
└── codebert_bow.py
├── comment_update
├── SARI.py
├── comment_generation.py
├── decoder.py
├── embedding_store.py
├── external_cache.py
├── generation_decoder.py
├── tensor_utils.py
├── update_decoder.py
└── update_evaluation_utils.py
├── constants.py
├── data_loader.py
├── data_processing
├── ast_diffing
│ ├── code_samples
│ │ ├── new.java
│ │ └── old.java
│ └── python
│ │ └── xml_diff_parser.py
├── build_example.py
├── data_formatting_utils.py
├── high_level_feature_extractor.py
└── tokenization_feature_extractor.py
├── data_utils.py
├── detection_evaluation_utils.py
├── detection_module.py
├── diff_utils.py
├── display_scores.py
├── encoder.py
├── gleu
├── README.md
├── data
│ ├── all_judgments.csv
│ └── all_judgments.xml
├── gleu_update_2016.pdf
└── scripts
│ ├── compute_gleu
│ ├── gleu.py
│ └── original_gleu
│ ├── compute_gleu
│ └── gleu.py
├── gnn.py
├── module_manager.py
├── run_comment_model.py
└── update_module.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 panthap2
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Just-In-Time Inconsistency Detection Between Comments and Source Code
2 |
3 | **Code and datasets for our AAAI-2021 paper "Deep Just-In-Time Inconsistency Detection Between Comments and Source Code"**
4 | which can be found [here](https://arxiv.org/pdf/2010.01625.pdf).
5 |
6 | If you find this work useful, please consider citing our paper:
7 |
8 | ```
9 | @inproceedings{PanthaplackelETAL21DeepJITInconsistency,
10 | author = {Panthaplackel, Sheena and Li, Junyi Jessy and Gligoric, Milos and Mooney, Raymond J.},
11 | title = {Deep Just-In-Time Inconsistency Detection Between Comments and Source Code},
12 | booktitle = {AAAI},
13 | pages = {427--435},
14 | year = {2021},
15 | }
16 | ```
17 | The code base shares components with our prior work called [Learning to Update Natural Language Comments Based on Code Changes](https://github.com/panthap2/LearningToUpdateNLComments).
18 |
19 | Download data from [here](https://drive.google.com/drive/folders/1heqEQGZHgO6gZzCjuQD1EyYertN4SAYZ?usp=sharing). Download additional model resources from [here](https://drive.google.com/drive/folders/1cutxr4rMDkT1g2BbmCAR2wqKTxeFH11K?usp=sharing). Edit configurations in `constants.py` to specify data, resource, and output locations.
20 |
21 | **Inconsistency Detection:**
22 |
23 | *SEQ(C, Medit) + features*
24 | ```
25 | python3 run_comment_model.py --task=detect --attend_code_sequence_states --features --model_path=detect_attend_code_sequence_states_features.pkl.gz --model_name=detect_attend_code_sequence_states_features
26 | ```
27 |
28 | *GRAPH(C, Tedit) + features*
29 | (The GGNN used for this approach is derived from [here](https://github.com/pcyin/pytorch-gated-graph-neural-network/blob/master/gnn.py).)
30 | ```
31 | python3 run_comment_model.py --task=detect --attend_code_graph_states --features --model_path=detect_attend_code_graph_states_features.pkl.gz --model_name=detect_attend_code_graph_states_features
32 | ```
33 |
34 | *HYBRID(C, Medit, Tedit) + features*
35 | ```
36 | python3 run_comment_model.py --task=detect --attend_code_sequence_states --attend_code_graph_states --features --model_path=detect_attend_code_sequence_states_attend_code_graph_states_features.pkl.gz --model_name=detect_attend_code_sequence_states_attend_code_graph_states_features
37 | ```
38 |
39 | To run inference on a detection model, add `--test_mode` to the command used to train the model.
40 |
41 | **Combined Detection + Update:**
42 |
43 | *Update w/ implicit detection*
44 | ```
45 | python3 run_comment_model.py --task=update --features --model_path=update_features.pkl.gz --model_name=update_features
46 | ```
47 |
48 | To run inference, add `--test_mode --rerank` to the command used to train the model.
49 |
50 | *Pretrained update + detection*
51 | ```
52 | python3 run_comment_model.py --task=update --features --positive_only --model_path=update_features_positive_only.pkl.gz --model_name=update_features_positive_only
53 | ```
54 |
55 | One of the detection models should also be trained, following instructions provided in the "Inconsistency Detection" section above. To run inference on the update model, add `--test_mode --rerank` to the command used to train the model. Inference on the detection model should also be done as instructed in the "Inconsistency Detection" section.
56 |
57 | *Jointly trained update + detection*
58 |
59 | To train, simply replace `--task=detect` with `--task=dual` in the configurations given for "Inconsistency Detection." For inference, additionally include `--test_mode --rerank`.
60 |
61 | **Displaying metrics:**
62 |
63 | To display metrics for the full test set as well as the cleaned test sample, run:
64 |
65 | ```
66 | python3 display_scores.py --detection_output_file=[PATH TO DETECTION PREDICTIONS] --update_output_file=[PATH TO UPDATE PREDICTIONS]
67 | ```
68 |
69 | For evaluating in the pretrained update + detection setting, both filepaths are required. For all other settings, only one should be specified.
70 |
71 | **AST Diffing:**
72 |
73 | The AST diffs were built using Java files provided by [Pengyu Nie](https://github.com/pengyunie). First, download `ast-diffing-1.6-jar-with-dependencies.jar` from [here](https://drive.google.com/file/d/1JVfIfJoDDSFBaFOhK18UsBOmC39z03am/view?usp=sharing). Then, go to `data_processing/ast_diffing/python` and run:
74 |
75 | ```
76 | python3 xml_diff_parser.py --old_sample_path=[PATH TO OLD VERSION OF CODE] --new_sample_path=[PATH TO NEW VERSION OF CODE] --jar_path=[PATH TO DOWNLOADED JAR FILE]
77 | ```
78 |
79 | You can see an example by running:
80 |
81 | ```
82 | python3 xml_diff_parser.py --old_sample_path=../code_samples/old.java --new_sample_path=../code_samples/new.java --jar_path=[PATH TO DOWNLOADED JAR FILE]
83 | ```
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/panthap2/deep-jit-inconsistency-detection/dacf8513c155f35157eedc2bf630212bf815544c/__init__.py
--------------------------------------------------------------------------------
/ast_graph_encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | from constants import *
6 | from gnn import GatedGraphNeuralNetwork, AdjacencyList
7 |
8 | class ASTGraphEncoder(nn.Module):
9 | """Encoder which learns a representation of a method's AST. The underlying network is a Gated Graph Neural Network."""
10 | def __init__(self, hidden_size, num_edge_types):
11 | super(ASTGraphEncoder, self).__init__()
12 | self.hidden_size = hidden_size
13 | self.num_edge_types = num_edge_types
14 | self.gnn = GatedGraphNeuralNetwork(self.hidden_size, self.num_edge_types,
15 | [GNN_LAYER_TIMESTEPS], {}, GNN_DROPOUT_RATE, GNN_DROPOUT_RATE)
16 |
17 | def forward(self, initial_node_representation, graph_batch, device):
18 | adjacency_lists = []
19 | for edge_type in range(self.num_edge_types):
20 | adjacency_lists.append(AdjacencyList(node_num=graph_batch.num_nodes,
21 | adj_list=graph_batch.edges[edge_type], device=device))
22 | node_representations = self.gnn.compute_node_representations(
23 | initial_node_representation=initial_node_representation, adjacency_lists=adjacency_lists)
24 | hidden_states = node_representations[graph_batch.node_positions]
25 | return hidden_states
--------------------------------------------------------------------------------
/baselines/codebert_bow.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 | import numpy as np
4 | import os
5 | import random
6 | import torch
7 | from torch import nn
8 | from transformers import *
9 | import sys
10 | import json
11 |
12 | sys.path.append('../')
13 | sys.path.append('../comment_update')
14 | from constants import *
15 | from data_loader import get_data_splits
16 | from detection_evaluation_utils import compute_score
17 |
18 | BERT_HIDDEN_SIZE = 768
19 | DROPOUT_RATE = 0.6
20 | BATCH_SIZE = 100
21 | CLASSIFICATION_HIDDEN_SIZE = 256
22 | # TRANSFORMERS_CACHE='' # TODO: Fill in
23 |
24 | class BERTBatch():
25 | def __init__(self, old_comment_ids, old_comment_lengths,
26 | new_code_ids, new_code_lengths, diff_code_ids, diff_code_lengths, labels):
27 | self.old_comment_ids = old_comment_ids
28 | self.old_comment_lengths = old_comment_lengths
29 | self.new_code_ids = new_code_ids
30 | self.new_code_lengths = new_code_lengths
31 | self.diff_code_ids = diff_code_ids
32 | self.diff_code_lengths = diff_code_lengths
33 | self.labels = labels
34 |
35 | class BERTClassifier(nn.Module):
36 | def __init__(self, model_path, new_code, diff_code):
37 | super(BERTClassifier, self).__init__()
38 | self.model_path = model_path
39 | self.new_code = new_code
40 | self.diff_code = diff_code
41 |
42 | self.code_tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base", cache_dir=TRANSFORMERS_CACHE)
43 | self.code_model = RobertaModel.from_pretrained("microsoft/codebert-base", cache_dir=TRANSFORMERS_CACHE)
44 | self.comment_tokenizer = self.code_tokenizer
45 | self.comment_model = self.code_model
46 |
47 | self.torch_device_name = 'cpu'
48 | self.max_nl_length = 0
49 | self.max_code_length = 0
50 |
51 | print('Model path: {}'.format(self.model_path))
52 | print('New code: {}'.format(self.new_code))
53 | print('Diff code: {}'.format(self.diff_code))
54 | sys.stdout.flush()
55 |
56 | def initialize(self, train_examples):
57 | self.max_nl_length = 200
58 | self.max_code_length = 200
59 |
60 | output_size = BERT_HIDDEN_SIZE
61 |
62 | if self.new_code:
63 | output_size += BERT_HIDDEN_SIZE
64 | if self.diff_code:
65 | output_size += BERT_HIDDEN_SIZE
66 |
67 | self.classification_dropout_layer = nn.Dropout(p=DROPOUT_RATE)
68 | self.fc1 = nn.Linear(output_size, CLASSIFICATION_HIDDEN_SIZE)
69 | self.fc2 = nn.Linear(CLASSIFICATION_HIDDEN_SIZE, CLASSIFICATION_HIDDEN_SIZE)
70 | self.output_layer = nn.Linear(CLASSIFICATION_HIDDEN_SIZE, NUM_CLASSES)
71 |
72 | self.optimizer = torch.optim.Adam(self.parameters(), lr=LR)
73 |
74 | def get_code_inputs(self, input_text, max_length):
75 | tokens = self.code_tokenizer.tokenize(input_text)
76 | length = min(len(tokens), max_length)
77 | tokens = tokens[:length]
78 | token_ids = self.code_tokenizer.convert_tokens_to_ids(tokens)
79 |
80 | padding_length = max_length - len(tokens)
81 | token_ids += [self.code_tokenizer.pad_token_id]*padding_length
82 | return token_ids, length
83 |
84 | def get_comment_inputs(self, input_text, max_length):
85 | tokens = self.comment_tokenizer.tokenize(input_text)
86 | length = min(len(tokens), max_length)
87 | tokens = tokens[:length]
88 | token_ids = self.comment_tokenizer.convert_tokens_to_ids(tokens)
89 |
90 | padding_length = max_length - len(tokens)
91 | token_ids += [self.comment_tokenizer.pad_token_id]*padding_length
92 | return token_ids, length
93 |
94 | def get_batches(self, dataset, shuffle=False):
95 | batches = []
96 | if shuffle:
97 | random.shuffle(dataset)
98 |
99 | curr_idx = 0
100 | while curr_idx < len(dataset):
101 | batch_idx = 0
102 |
103 | start_idx = curr_idx
104 | end_idx = min(start_idx + BATCH_SIZE, len(dataset))
105 | labels = []
106 | old_comment_ids = []
107 | old_comment_lengths = []
108 | new_code_ids = []
109 | new_code_lengths = []
110 | diff_code_ids = []
111 | diff_code_lengths = []
112 |
113 | for i in range(start_idx, end_idx):
114 | comment_ids, comment_length = self.get_comment_inputs(dataset[i].old_comment_raw, self.max_nl_length)
115 | old_comment_ids.append(comment_ids)
116 | old_comment_lengths.append(comment_length)
117 |
118 | if self.new_code:
119 | code_ids, code_length = self.get_code_inputs(dataset[i].new_code_raw, self.max_code_length)
120 | new_code_ids.append(code_ids)
121 | new_code_lengths.append(code_length)
122 |
123 | if self.diff_code:
124 | code_ids, code_length = self.get_code_inputs(' '.join(dataset[i].span_diff_code_tokens), self.max_code_length)
125 | diff_code_ids.append(code_ids)
126 | diff_code_lengths.append(code_length)
127 |
128 | labels.append(dataset[i].label)
129 |
130 | curr_idx = end_idx
131 | batches.append(BERTBatch(
132 | torch.tensor(old_comment_ids, dtype=torch.int64, device=self.get_device()),
133 | torch.tensor(old_comment_lengths, dtype=torch.int64, device=self.get_device()),
134 | torch.tensor(new_code_ids, dtype=torch.int64, device=self.get_device()),
135 | torch.tensor(new_code_lengths, dtype=torch.int64, device=self.get_device()),
136 | torch.tensor(diff_code_ids, dtype=torch.int64, device=self.get_device()),
137 | torch.tensor(diff_code_lengths, dtype=torch.int64, device=self.get_device()),
138 | torch.tensor(labels, dtype=torch.int64, device=self.get_device())
139 | ))
140 |
141 | return batches
142 |
143 | def get_code_representation(self, input_ids, masks):
144 | embeddings = self.code_model.embeddings(input_ids)
145 | if self.torch_device_name == 'cpu':
146 | factor = masks.type(torch.FloatTensor).unsqueeze(-1)
147 | else:
148 | factor = masks.type(torch.FloatTensor).cuda(self.get_device()).unsqueeze(-1)
149 | embeddings = embeddings * factor
150 | vector = torch.sum(embeddings, dim=1)/torch.sum(factor, dim=1)
151 | return embeddings, vector
152 |
153 | def get_comment_representation(self, input_ids, masks):
154 | embeddings = self.comment_model.embeddings(input_ids)
155 | if self.torch_device_name == 'cpu':
156 | factor = masks.type(torch.FloatTensor).unsqueeze(-1)
157 | else:
158 | factor = masks.type(torch.FloatTensor).cuda(self.get_device()).unsqueeze(-1)
159 | embeddings = embeddings * factor
160 | vector = torch.sum(embeddings, dim=1)/torch.sum(factor, dim=1)
161 | return embeddings, vector
162 |
163 | def get_input_features(self, batch_data):
164 | old_comment_masks = (torch.arange(
165 | batch_data.old_comment_ids.shape[1], device=self.get_device()).view(1, -1) < batch_data.old_comment_lengths.view(-1, 1))
166 | old_comment_hidden_states, old_comment_final_state = self.get_comment_representation(batch_data.old_comment_ids, old_comment_masks)
167 | final_state = old_comment_final_state
168 |
169 | if self.new_code:
170 | new_code_masks = (torch.arange(
171 | batch_data.new_code_ids.shape[1], device=self.get_device()).view(1, -1) < batch_data.new_code_lengths.view(-1, 1))
172 | new_code_hidden_states, new_code_final_state = self.get_code_representation(batch_data.new_code_ids, new_code_masks)
173 | final_state = torch.cat([final_state, new_code_final_state], dim=-1)
174 |
175 | if self.diff_code:
176 | diff_code_masks = (torch.arange(
177 | batch_data.diff_code_ids.shape[1], device=self.get_device()).view(1, -1) < batch_data.diff_code_lengths.view(-1, 1))
178 | diff_code_hidden_states, diff_code_final_state = self.get_code_representation(batch_data.diff_code_ids, diff_code_masks)
179 | final_state = torch.cat([final_state, diff_code_final_state], dim=-1)
180 |
181 | return final_state
182 |
183 | def get_logits(self, batch_data):
184 | all_features = self.get_input_features(batch_data)
185 | all_features = self.classification_dropout_layer(torch.nn.functional.relu(self.fc1(all_features)))
186 | all_features = self.classification_dropout_layer(torch.nn.functional.relu(self.fc2(all_features)))
187 |
188 | return self.output_layer(all_features)
189 |
190 | def get_logprobs(self, batch_data):
191 | logits = self.get_logits(batch_data)
192 | return torch.nn.functional.log_softmax(logits, dim=-1)
193 |
194 | def forward(self, batch_data, is_training=True):
195 | logprobs = self.get_logprobs(batch_data)
196 | loss = torch.nn.functional.nll_loss(logprobs, batch_data.labels)
197 | return loss, logprobs
198 |
199 | def run_train(self, train_examples, valid_examples):
200 | best_loss = float('inf')
201 | best_f1 = 0.0
202 | patience_tally = 0
203 | valid_batches = self.get_batches(valid_examples)
204 |
205 | for epoch in range(MAX_EPOCHS):
206 | if patience_tally > PATIENCE:
207 | print('Terminating')
208 | break
209 |
210 | self.train()
211 | train_batches = self.get_batches(train_examples, shuffle=True)
212 |
213 | train_loss = 0
214 | for batch_data in train_batches:
215 | train_loss += self.run_gradient_step(batch_data)
216 |
217 | self.eval()
218 | validation_loss = 0
219 | validation_predicted_labels = []
220 | validation_gold_labels = []
221 | with torch.no_grad():
222 | for batch_data in valid_batches:
223 | b_loss, b_logprobs = self.forward(batch_data)
224 | validation_loss += float(b_loss.cpu())
225 | validation_predicted_labels.extend(b_logprobs.argmax(-1).tolist())
226 | validation_gold_labels.extend(batch_data.labels.tolist())
227 |
228 | validation_loss = validation_loss/len(valid_batches)
229 | validation_precision, validation_recall, validation_f1 = compute_score(
230 | validation_predicted_labels, validation_gold_labels, verbose=False)
231 |
232 | if validation_f1 >= best_f1:
233 | best_f1 = validation_f1
234 | torch.save(self, self.model_path)
235 | saved = True
236 | patience_tally = 0
237 | else:
238 | saved = False
239 | patience_tally += 1
240 |
241 | print('Epoch: {}'.format(epoch))
242 | print('Training loss: {:.3f}'.format(train_loss/len(train_batches)))
243 | print('Validation loss: {:.3f}'.format(validation_loss))
244 | print('Validation precision: {:.3f}'.format(validation_precision))
245 | print('Validation recall: {:.3f}'.format(validation_recall))
246 | print('Validation f1: {:.3f}'.format(validation_f1))
247 | if saved:
248 | print('Saved')
249 | print('-----------------------------------')
250 | sys.stdout.flush()
251 |
252 | def get_device(self):
253 | """Returns the proper device."""
254 | if self.torch_device_name == 'gpu':
255 | return torch.device('cuda')
256 | else:
257 | return torch.device('cpu')
258 |
259 | def run_gradient_step(self, batch_data):
260 | """Performs gradient step."""
261 | self.optimizer.zero_grad()
262 | loss, _ = self.forward(batch_data)
263 | loss.backward()
264 | self.optimizer.step()
265 | return float(loss.cpu())
266 |
267 | def run_evaluation(self, test_examples, write_file):
268 | self.eval()
269 |
270 | test_batches = self.get_batches(test_examples)
271 | test_predictions = []
272 |
273 | with torch.no_grad():
274 | for b, batch in enumerate(test_batches):
275 | print('Testing batch {}/{}'.format(b, len(test_batches)))
276 | sys.stdout.flush()
277 | batch_logprobs = self.get_logprobs(batch)
278 | test_predictions.extend(batch_logprobs.argmax(dim=-1).tolist())
279 |
280 | self.compute_metrics(test_predictions, test_examples, write_file)
281 |
282 | def compute_metrics(self, predicted_labels, test_examples, write_file):
283 | gold_labels = []
284 | correct = 0
285 |
286 | print('Writing to: {}'.format(write_file))
287 | with open(write_file, 'w+') as f:
288 | for e, ex in enumerate(test_examples):
289 | f.write('{} {}\n'.format(ex.id, predicted_labels[e]))
290 | gold_label = ex.label
291 | if gold_label == predicted_labels[e]:
292 | correct += 1
293 | gold_labels.append(gold_label)
294 |
295 | accuracy = float(correct)/len(test_examples)
296 | precision, recall, f1 = compute_score(predicted_labels, gold_labels, False)
297 |
298 | print('Precision: {}'.format(precision))
299 | print('Recall: {}'.format(recall))
300 | print('F1: {}'.format(f1))
301 | print('Accuracy: {}'.format(accuracy))
302 |
303 | if __name__ == "__main__":
304 | parser = argparse.ArgumentParser()
305 | parser.add_argument('--new_code', action='store_true')
306 | parser.add_argument('--diff_code', action='store_true')
307 | parser.add_argument('--comment_type')
308 | parser.add_argument('--trial')
309 | parser.add_argument('--test_mode', action='store_true')
310 | args = parser.parse_args()
311 |
312 | print('Starting')
313 | sys.stdout.flush()
314 |
315 | train_examples, valid_examples, test_examples, high_level_details = get_data_splits()
316 |
317 | print('Train: {}'.format(len(train_examples)))
318 | print('Valid: {}'.format(len(valid_examples)))
319 | print('Test: {}'.format(len(test_examples)))
320 | sys.stdout.flush()
321 |
322 | model_name = 'bert'
323 |
324 | if args.new_code:
325 | model_name += '-new_code'
326 | if args.diff_code:
327 | model_name += '-diff_code'
328 |
329 | if args.comment_type:
330 | model_name += '-{}'.format(args.comment_type)
331 | if args.trial:
332 | model_name += '-{}'.format(args.trial)
333 |
334 | # Assumes that saved_bert_models directory exists
335 | model_path = 'saved_bert_models/{}.pkl.gz'.format(model_name)
336 | sys.stdout.flush()
337 |
338 | if args.test_mode:
339 | print('Loading model from: {}'.format(model_path))
340 | print('Starting evaluation: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
341 | sys.stdout.flush()
342 | model = torch.load(model_path)
343 | if torch.cuda.is_available():
344 | model.torch_device_name = 'gpu'
345 | model.cuda()
346 | for c in model.children():
347 | c.cuda()
348 | else:
349 | model.torch_device_name = 'cpu'
350 | model.cpu()
351 | for c in model.children():
352 | c.cpu()
353 |
354 | # Assumes that bert_predictions directory exists
355 | write_file = os.path.join('bert_predictions', '{}.txt'.format(model_name))
356 | model.run_evaluation(test_examples, write_file)
357 | print('Terminating evaluation: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
358 | else:
359 | print('Starting training: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
360 | sys.stdout.flush()
361 | model = BERTClassifier(model_path, args.new_code, args.diff_code)
362 | model.initialize(train_examples)
363 |
364 | if torch.cuda.is_available():
365 | model.torch_device_name = 'gpu'
366 | model.cuda()
367 | for c in model.children():
368 | c.cuda()
369 | else:
370 | model.torch_device_name = 'cpu'
371 | model.cpu()
372 | for c in model.children():
373 | c.cpu()
374 |
375 | model.run_train(train_examples, valid_examples)
376 | print('Terminating training: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
377 |
378 |
379 |
--------------------------------------------------------------------------------
/comment_update/SARI.py:
--------------------------------------------------------------------------------
1 | # =======================================================
2 | # SARI -- Text Simplification Tunable Evaluation Metric
3 | # =======================================================
4 | #
5 | # Author: Wei Xu (UPenn xwe@cis.upenn.edu)
6 | #
7 | # A Python implementation of the SARI metric for text simplification
8 | # evaluation in the following paper
9 | #
10 | # "Optimizing Statistical Machine Translation for Text Simplification"
11 | # Wei Xu, Courtney Napoles, Ellie Pavlick, Quanze Chen and Chris Callison-Burch
12 | # In Transactions of the Association for Computational Linguistics (TACL) 2015
13 | #
14 | # There is also a Java implementation of the SARI metric
15 | # that is integrated into the Joshua MT Decoder. It can
16 | # be used for tuning Joshua models for a real end-to-end
17 | # text simplification model.
18 | #
19 |
20 | from __future__ import division
21 | from collections import Counter
22 | import sys
23 |
24 |
25 |
26 | def ReadInFile (filename):
27 |
28 | with open(filename) as f:
29 | lines = f.readlines()
30 | lines = [x.strip() for x in lines]
31 | return lines
32 |
33 |
34 | def SARIngram(sgrams, cgrams, rgramslist, numref):
35 | rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams]
36 | rgramcounter = Counter(rgramsall)
37 |
38 | sgramcounter = Counter(sgrams)
39 | sgramcounter_rep = Counter()
40 | for sgram, scount in sgramcounter.items():
41 | sgramcounter_rep[sgram] = scount * numref
42 |
43 | cgramcounter = Counter(cgrams)
44 | cgramcounter_rep = Counter()
45 | for cgram, ccount in cgramcounter.items():
46 | cgramcounter_rep[cgram] = ccount * numref
47 |
48 |
49 | # KEEP
50 | keepgramcounter_rep = sgramcounter_rep & cgramcounter_rep
51 | keepgramcountergood_rep = keepgramcounter_rep & rgramcounter
52 | keepgramcounterall_rep = sgramcounter_rep & rgramcounter
53 |
54 | keeptmpscore1 = 0
55 | keeptmpscore2 = 0
56 | for keepgram in keepgramcountergood_rep:
57 | keeptmpscore1 += keepgramcountergood_rep[keepgram] / keepgramcounter_rep[keepgram]
58 | keeptmpscore2 += keepgramcountergood_rep[keepgram] / keepgramcounterall_rep[keepgram]
59 | #print "KEEP", keepgram, keepscore, cgramcounter[keepgram], sgramcounter[keepgram], rgramcounter[keepgram]
60 | keepscore_precision = 0
61 | if len(keepgramcounter_rep) > 0:
62 | keepscore_precision = keeptmpscore1 / len(keepgramcounter_rep)
63 | keepscore_recall = 0
64 | if len(keepgramcounterall_rep) > 0:
65 | keepscore_recall = keeptmpscore2 / len(keepgramcounterall_rep)
66 | keepscore = 0
67 | if keepscore_precision > 0 or keepscore_recall > 0:
68 | keepscore = 2 * keepscore_precision * keepscore_recall / (keepscore_precision + keepscore_recall)
69 |
70 |
71 | # DELETION
72 | delgramcounter_rep = sgramcounter_rep - cgramcounter_rep
73 | delgramcountergood_rep = delgramcounter_rep - rgramcounter
74 | delgramcounterall_rep = sgramcounter_rep - rgramcounter
75 | deltmpscore1 = 0
76 | deltmpscore2 = 0
77 | for delgram in delgramcountergood_rep:
78 | deltmpscore1 += delgramcountergood_rep[delgram] / delgramcounter_rep[delgram]
79 | deltmpscore2 += delgramcountergood_rep[delgram] / delgramcounterall_rep[delgram]
80 | delscore_precision = 0
81 | if len(delgramcounter_rep) > 0:
82 | delscore_precision = deltmpscore1 / len(delgramcounter_rep)
83 | delscore_recall = 0
84 | if len(delgramcounterall_rep) > 0:
85 | delscore_recall = deltmpscore1 / len(delgramcounterall_rep)
86 | delscore = 0
87 | if delscore_precision > 0 or delscore_recall > 0:
88 | delscore = 2 * delscore_precision * delscore_recall / (delscore_precision + delscore_recall)
89 |
90 |
91 | # ADDITION
92 | addgramcounter = set(cgramcounter) - set(sgramcounter)
93 | addgramcountergood = set(addgramcounter) & set(rgramcounter)
94 | addgramcounterall = set(rgramcounter) - set(sgramcounter)
95 |
96 | addtmpscore = 0
97 | for addgram in addgramcountergood:
98 | addtmpscore += 1
99 |
100 | addscore_precision = 0
101 | addscore_recall = 0
102 | if len(addgramcounter) > 0:
103 | addscore_precision = addtmpscore / len(addgramcounter)
104 | if len(addgramcounterall) > 0:
105 | addscore_recall = addtmpscore / len(addgramcounterall)
106 | addscore = 0
107 | if addscore_precision > 0 or addscore_recall > 0:
108 | addscore = 2 * addscore_precision * addscore_recall / (addscore_precision + addscore_recall)
109 |
110 | return (keepscore, delscore_precision, addscore)
111 |
112 |
113 | def SARIsent (ssent, csent, rsents) :
114 | numref = len(rsents)
115 |
116 | s1grams = ssent.lower().split(" ")
117 | c1grams = csent.lower().split(" ")
118 | s2grams = []
119 | c2grams = []
120 | s3grams = []
121 | c3grams = []
122 | s4grams = []
123 | c4grams = []
124 |
125 | r1gramslist = []
126 | r2gramslist = []
127 | r3gramslist = []
128 | r4gramslist = []
129 | for rsent in rsents:
130 | r1grams = rsent.lower().split(" ")
131 | r2grams = []
132 | r3grams = []
133 | r4grams = []
134 | r1gramslist.append(r1grams)
135 | for i in range(0, len(r1grams)-1) :
136 | if i < len(r1grams) - 1:
137 | r2gram = r1grams[i] + " " + r1grams[i+1]
138 | r2grams.append(r2gram)
139 | if i < len(r1grams)-2:
140 | r3gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2]
141 | r3grams.append(r3gram)
142 | if i < len(r1grams)-3:
143 | r4gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2] + " " + r1grams[i+3]
144 | r4grams.append(r4gram)
145 | r2gramslist.append(r2grams)
146 | r3gramslist.append(r3grams)
147 | r4gramslist.append(r4grams)
148 |
149 | for i in range(0, len(s1grams)-1) :
150 | if i < len(s1grams) - 1:
151 | s2gram = s1grams[i] + " " + s1grams[i+1]
152 | s2grams.append(s2gram)
153 | if i < len(s1grams)-2:
154 | s3gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2]
155 | s3grams.append(s3gram)
156 | if i < len(s1grams)-3:
157 | s4gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2] + " " + s1grams[i+3]
158 | s4grams.append(s4gram)
159 |
160 | for i in range(0, len(c1grams)-1) :
161 | if i < len(c1grams) - 1:
162 | c2gram = c1grams[i] + " " + c1grams[i+1]
163 | c2grams.append(c2gram)
164 | if i < len(c1grams)-2:
165 | c3gram = c1grams[i] + " " + c1grams[i+1] + " " + c1grams[i+2]
166 | c3grams.append(c3gram)
167 | if i < len(c1grams)-3:
168 | c4gram = c1grams[i] + " " + c1grams[i+1] + " " + c1grams[i+2] + " " + c1grams[i+3]
169 | c4grams.append(c4gram)
170 |
171 |
172 | (keep1score, del1score, add1score) = SARIngram(s1grams, c1grams, r1gramslist, numref)
173 | (keep2score, del2score, add2score) = SARIngram(s2grams, c2grams, r2gramslist, numref)
174 | (keep3score, del3score, add3score) = SARIngram(s3grams, c3grams, r3gramslist, numref)
175 | (keep4score, del4score, add4score) = SARIngram(s4grams, c4grams, r4gramslist, numref)
176 | avgkeepscore = sum([keep1score,keep2score,keep3score,keep4score])/4
177 | avgdelscore = sum([del1score,del2score,del3score,del4score])/4
178 | avgaddscore = sum([add1score,add2score,add3score,add4score])/4
179 | finalscore = (avgkeepscore + avgdelscore + avgaddscore ) / 3
180 |
181 | return finalscore
182 |
183 |
184 | def main():
185 |
186 | fnamenorm = "./turkcorpus/test.8turkers.tok.norm"
187 | fnamesimp = "./turkcorpus/test.8turkers.tok.simp"
188 | fnameturk = "./turkcorpus/test.8turkers.tok.turk."
189 |
190 |
191 | ssent = "About 95 species are currently accepted ."
192 | csent1 = "About 95 you now get in ."
193 | csent2 = "About 95 species are now agreed ."
194 | csent3 = "About 95 species are currently agreed ."
195 | rsents = ["About 95 species are currently known .", "About 95 species are now accepted .", "95 species are now accepted ."]
196 |
197 | print(SARIsent(ssent, csent1, rsents))
198 | print(SARIsent(ssent, csent2, rsents))
199 | print(SARIsent(ssent, csent3, rsents))
200 |
201 |
202 | if __name__ == '__main__':
203 | main()
--------------------------------------------------------------------------------
/comment_update/decoder.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class Decoder(nn.Module):
7 | def __init__(self, input_size, hidden_size, attention_state_size, embedding_store,
8 | embedding_size, dropout_rate):
9 | super(Decoder, self).__init__()
10 | self.input_size = input_size # Dimension of input into decoder cell
11 | self.hidden_size = hidden_size # Dimension of output from decoder cell
12 | self.attention_state_size = attention_state_size # Dimension of the encoder hidden states to attend to
13 | self.embedding_store = embedding_store
14 | self.gen_vocabulary_size = len(self.embedding_store.nl_vocabulary)
15 | self.embedding_size = embedding_size
16 | self.dropout_rate = dropout_rate
17 |
18 | self.gru = nn.GRU(
19 | input_size=self.input_size,
20 | hidden_size=self.hidden_size,
21 | batch_first=True
22 | )
23 |
24 | # Parameters for attention
25 | self.attention_encoder_hidden_transform_matrix = nn.Parameter(
26 | torch.randn(self.attention_state_size, self.hidden_size,
27 | dtype=torch.float, requires_grad=True)
28 | )
29 | self.attention_output_layer = nn.Linear(self.attention_state_size + self.hidden_size,
30 | self.hidden_size, bias=False)
31 |
32 | # Parameters for generating/copying
33 | self.generation_output_matrix = nn.Parameter(
34 | torch.randn(self.hidden_size, self.gen_vocabulary_size,
35 | dtype=torch.float, requires_grad=True)
36 | )
37 |
38 | self.copy_encoder_hidden_transform_matrix = nn.Parameter(
39 | torch.randn(self.attention_state_size, self.hidden_size,
40 | dtype=torch.float, requires_grad=True)
41 | )
42 |
43 | @abstractmethod
44 | def decode(self):
45 | return NotImplemented
46 |
47 | @abstractmethod
48 | def forward(self, initial_state, decoder_input_embeddings, encoder_hidden_states, masks):
49 | return NotImplemented
--------------------------------------------------------------------------------
/comment_update/embedding_store.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import Counter, defaultdict
3 | from dpu_utils.mlutils import Vocabulary
4 | import heapq
5 | import json
6 | import logging
7 | import numpy as np
8 | import os
9 | import random
10 | import sys
11 | import torch
12 | from torch import nn
13 |
14 | from constants import START, END, NL_EMBEDDING_PATH, CODE_EMBEDDING_PATH, MAX_VOCAB_SIZE,\
15 | NL_EMBEDDING_SIZE, CODE_EMBEDDING_SIZE
16 | from diff_utils import get_edit_keywords
17 |
18 | class EmbeddingStore(nn.Module):
19 | def __init__(self, nl_threshold, nl_embedding_size, nl_token_counter,
20 | code_threshold, code_embedding_size, code_token_counter,
21 | dropout_rate, num_src_embeddings, src_embedding_size, node_embedding_size,
22 | load_pretrained_embeddings=False):
23 | """Keeps track of the NL and code vocabularies and embeddings."""
24 | super(EmbeddingStore, self).__init__()
25 | edit_keywords = get_edit_keywords()
26 | self.__nl_vocabulary = Vocabulary.create_vocabulary(tokens=edit_keywords,
27 | max_size=MAX_VOCAB_SIZE,
28 | count_threshold=1,
29 | add_pad=True)
30 | self.__nl_vocabulary.update(nl_token_counter, MAX_VOCAB_SIZE, nl_threshold)
31 | self.__nl_embedding_layer = nn.Embedding(num_embeddings=len(self.__nl_vocabulary),
32 | embedding_dim=nl_embedding_size,
33 | padding_idx=self.__nl_vocabulary.get_id_or_unk(
34 | Vocabulary.get_pad()))
35 | self.nl_embedding_dropout_layer = nn.Dropout(p=dropout_rate)
36 |
37 |
38 | self.__code_vocabulary = Vocabulary.create_vocabulary(tokens=edit_keywords,
39 | max_size=MAX_VOCAB_SIZE,
40 | count_threshold=1,
41 | add_pad=True)
42 | self.__code_vocabulary.update(code_token_counter, MAX_VOCAB_SIZE, code_threshold)
43 | self.__code_embedding_layer = nn.Embedding(num_embeddings=len(self.__code_vocabulary),
44 | embedding_dim=code_embedding_size,
45 | padding_idx=self.__code_vocabulary.get_id_or_unk(
46 | Vocabulary.get_pad()))
47 | self.code_embedding_dropout_layer = nn.Dropout(p=dropout_rate)
48 |
49 | self.src_embedding_layer = nn.Embedding(num_embeddings=num_src_embeddings, embedding_dim=src_embedding_size)
50 | self.src_embedding_dropout_layer = nn.Dropout(p=dropout_rate)
51 | self.node_synthesis_layer = nn.Linear(code_embedding_size+src_embedding_size, node_embedding_size, bias=False)
52 |
53 | print('NL vocabulary size: {}'.format(len(self.__nl_vocabulary)))
54 | print('Code vocabulary size: {}'.format(len(self.__code_vocabulary)))
55 |
56 | if load_pretrained_embeddings:
57 | self.initialize_embeddings()
58 |
59 | def initialize_embeddings(self):
60 | with open(NL_EMBEDDING_PATH) as f:
61 | nl_embeddings = json.load(f)
62 |
63 | nl_weights_matrix = np.zeros((len(self.__nl_vocabulary), NL_EMBEDDING_SIZE), dtype=np.float64)
64 | nl_word_count = 0
65 | for i, word in enumerate(self.__nl_vocabulary.id_to_token):
66 | try:
67 | nl_weights_matrix[i] = nl_embeddings[word]
68 | nl_word_count += 1
69 | except KeyError:
70 | nl_weights_matrix[i] = np.random.normal(scale=0.6, size=(NL_EMBEDDING_SIZE, ))
71 |
72 | self.__nl_embedding_layer.weight = torch.nn.Parameter(torch.FloatTensor(nl_weights_matrix),
73 | requires_grad=True)
74 |
75 | with open(CODE_EMBEDDING_PATH) as f:
76 | code_embeddings = json.load(f)
77 |
78 | code_weights_matrix = np.zeros((len(self.__code_vocabulary), CODE_EMBEDDING_SIZE))
79 | code_word_count = 0
80 | for i, word in enumerate(self.__code_vocabulary.id_to_token):
81 | try:
82 | code_weights_matrix[i] = code_embeddings[word]
83 | code_word_count += 1
84 | except KeyError:
85 | code_weights_matrix[i] = np.random.normal(scale=0.6, size=(CODE_EMBEDDING_SIZE, ))
86 |
87 | self.__code_embedding_layer.weight = torch.nn.Parameter(torch.FloatTensor(code_weights_matrix),
88 | requires_grad=True)
89 |
90 | print('Using {} pre-trained NL embeddings'.format(nl_word_count))
91 | print('Using {} pre-trained code embeddings'.format(code_word_count))
92 |
93 | def get_nl_embeddings(self, token_ids):
94 | return self.nl_embedding_dropout_layer(self.__nl_embedding_layer(token_ids))
95 |
96 | def get_code_embeddings(self, token_ids):
97 | return self.code_embedding_dropout_layer(self.__code_embedding_layer(token_ids))
98 |
99 | def get_src_embeddings(self, src_ids):
100 | return self.src_embedding_dropout_layer(self.src_embedding_layer(src_ids))
101 |
102 | def get_node_embeddings(self, lookup_ids, src_ids):
103 | lookup_embeddings = self.get_code_embeddings(lookup_ids)
104 | src_embeddings = self.get_src_embeddings(src_ids)
105 |
106 | embeddings = torch.cat([lookup_embeddings, src_embeddings], dim=-1)
107 | node_embeddings = self.node_synthesis_layer(embeddings)
108 | return node_embeddings
109 |
110 | @property
111 | def nl_vocabulary(self):
112 | return self.__nl_vocabulary
113 |
114 | @property
115 | def code_vocabulary(self):
116 | return self.__code_vocabulary
117 |
118 | @property
119 | def nl_embedding_layer(self):
120 | return self.__nl_embedding_layer
121 |
122 | @property
123 | def code_embedding_layer(self):
124 | return self.__code_embedding_layer
125 |
126 | def get_padded_code_ids(self, code_sequence, pad_length):
127 | return self.__code_vocabulary.get_id_or_unk_multiple(code_sequence,
128 | pad_to_size=pad_length,
129 | padding_element=self.__code_vocabulary.get_id_or_unk(
130 | Vocabulary.get_pad()),
131 | )
132 |
133 | def get_padded_nl_ids(self, nl_sequence, pad_length):
134 | return self.__nl_vocabulary.get_id_or_unk_multiple(nl_sequence,
135 | pad_to_size=pad_length,
136 | padding_element=self.__nl_vocabulary.get_id_or_unk(
137 | Vocabulary.get_pad()),
138 | )
139 |
140 | def get_extended_padded_nl_ids(self, nl_sequence, pad_length, inp_ids, inp_tokens):
141 | # Derived from: https://github.com/microsoft/dpu-utils/blob/master/python/dpu_utils/mlutils/vocabulary.py
142 | nl_ids = []
143 | for token in nl_sequence:
144 | nl_id = self.get_nl_id(token)
145 | if self.is_nl_unk(nl_id) and token in inp_tokens:
146 | copy_idx = inp_tokens.index(token)
147 | nl_id = inp_ids[copy_idx]
148 | nl_ids.append(nl_id)
149 |
150 | if len(nl_ids) > pad_length:
151 | return nl_ids[:pad_length]
152 | else:
153 | padding = [self.__nl_vocabulary.get_id_or_unk(Vocabulary.get_pad())] * (pad_length - len(nl_ids))
154 | return nl_ids + padding
155 |
156 | def pad_length(self, sequence, target_length):
157 | if len(sequence) >= target_length:
158 | return sequence[:target_length]
159 | else:
160 | return sequence + [self.__nl_vocabulary.get_id_or_unk(Vocabulary.get_pad()) for _ in range(target_length-len(sequence))]
161 |
162 | def get_code_id(self, token):
163 | return self.__code_vocabulary.get_id_or_unk(token)
164 |
165 | def is_code_unk(self, id):
166 | return id == self.__code_vocabulary.get_id_or_unk(Vocabulary.get_unk())
167 |
168 | def get_code_token(self, token_id):
169 | return self.__code_vocabulary.get_name_for_id(token_id)
170 |
171 | def get_nl_id(self, token):
172 | return self.__nl_vocabulary.get_id_or_unk(token)
173 |
174 | def is_nl_unk(self, id):
175 | return id == self.__nl_vocabulary.get_id_or_unk(Vocabulary.get_unk())
176 |
177 | def get_nl_token(self, token_id):
178 | return self.__nl_vocabulary.get_name_for_id(token_id)
179 |
180 | def get_vocab_extended_nl_token(self, token_id, inp_ids, inp_tokens):
181 | if token_id < len(self.__nl_vocabulary):
182 | return self.get_nl_token(token_id)
183 | elif token_id in inp_ids:
184 | copy_idx = inp_ids.index(token_id)
185 | return inp_tokens[copy_idx]
186 | else:
187 | return Vocabulary.get_unk()
188 |
189 | def get_nl_tokens(self, token_ids, inp_ids, inp_tokens):
190 | tokens = [self.get_vocab_extended_nl_token(t, inp_ids, inp_tokens) for t in token_ids]
191 | if END in tokens:
192 | return tokens[:tokens.index(END)]
193 | return tokens
194 |
195 | def get_end_id(self):
196 | return self.get_nl_id(END)
197 |
198 | def get_nl_pad_id(self):
199 | return self.__nl_vocabulary.get_id_or_unk(Vocabulary.get_pad())
200 |
201 | def get_code_pad_id(self):
202 | return self.__code_vocabulary.get_id_or_unk(Vocabulary.get_pad())
--------------------------------------------------------------------------------
/comment_update/external_cache.py:
--------------------------------------------------------------------------------
1 | import json
2 | from nltk.corpus import stopwords
3 | from nltk.tokenize import word_tokenize
4 | from nltk import pos_tag
5 | import numpy as np
6 | import os
7 | import re
8 |
9 | from constants import *
10 | from diff_utils import *
11 |
12 | method_details = dict()
13 | tokenization_features = dict()
14 | for d in os.listdir(RESOURCES_PATH):
15 | try:
16 | with open(os.path.join(RESOURCES_PATH, d, 'high_level_details.json')) as f:
17 | method_details.update(json.load(f))
18 | with open(os.path.join(RESOURCES_PATH, d, 'tokenization_features.json')) as f:
19 | tokenization_features.update(json.load(f))
20 | except:
21 | print('Failed parsing: {}'.format(d))
22 |
23 | stop_words = set(stopwords.words('english'))
24 | java_keywords = set(['abstract', 'assert', 'boolean', 'break', 'byte', 'case', 'catch', 'char', 'class',
25 | 'continue', 'default', 'do', 'double', 'else', 'enum', 'extends', 'final', 'finally',
26 | 'float', 'for', 'if', 'implements', 'import', 'instanceof', 'int', 'interface', 'long',
27 | 'native', 'new', 'null', 'package', 'private', 'protected', 'public', 'return', 'short',
28 | 'static', 'strictfp', 'super', 'switch', 'synchronized', 'this', 'throw', 'throws', 'transient',
29 | 'try', 'void', 'volatile', 'while'])
30 |
31 | tags = ['CC','CD','DT','EX','FW','IN','JJ','JJR','JJS','LS','MD','NN','NNS','NNP','NNPS','PDT',
32 | 'POS','PRP','PRP$','RB','RBR','RBS','RP','TO','UH','VB','VBD','VBG','VBN','VBP','VBZ','WDT','WP','WP$','WRB',
33 | 'OTHER']
34 |
35 | NUM_CODE_FEATURES = 19
36 | NUM_NL_FEATURES = 17 + len(tags)
37 |
38 | def get_num_code_features():
39 | return NUM_CODE_FEATURES
40 |
41 | def get_num_nl_features():
42 | return NUM_NL_FEATURES
43 |
44 | def is_java_keyword(token):
45 | return token in java_keywords
46 |
47 | def is_operator(token):
48 | for s in token:
49 | if s.isalnum():
50 | return False
51 | return True
52 |
53 | def get_return_type_subtokens(example):
54 | return method_details[example.id]['new']['subtoken']['return_type']
55 |
56 | def get_old_return_type_subtokens(example):
57 | return method_details[example.id]['old']['subtoken']['return_type']
58 |
59 | def get_method_name_subtokens(example):
60 | return method_details[example.id]['new']['subtoken']['method_name']
61 |
62 | def get_new_return_sequence(example):
63 | return method_details[example.id]['new']['subtoken']['return_statement']
64 |
65 | def get_old_return_sequence(example):
66 | return method_details[example.id]['old']['subtoken']['return_statement']
67 |
68 | def get_old_argument_type_subtokens(example):
69 | return method_details[example.id]['old']['subtoken']['argument_type']
70 |
71 | def get_new_argument_type_subtokens(example):
72 | return method_details[example.id]['new']['subtoken']['argument_type']
73 |
74 | def get_old_argument_name_subtokens(example):
75 | return method_details[example.id]['old']['subtoken']['argument_name']
76 |
77 | def get_new_argument_name_subtokens(example):
78 | return method_details[example.id]['new']['subtoken']['argument_name']
79 |
80 | def get_old_code(example):
81 | return example.old_code_raw
82 |
83 | def get_new_code(example):
84 | return example.new_code_raw
85 |
86 | def get_edit_span_subtoken_tokenization_labels(example):
87 | return tokenization_features[example.id]['edit_span_subtoken_labels']
88 |
89 | def get_edit_span_subtoken_tokenization_indices(example):
90 | return tokenization_features[example.id]['edit_span_subtoken_indices']
91 |
92 | def get_nl_subtoken_tokenization_labels(example):
93 | return tokenization_features[example.id]['old_nl_subtoken_labels']
94 |
95 | def get_nl_subtoken_tokenization_indices(example):
96 | return tokenization_features[example.id]['old_nl_subtoken_indices']
97 |
98 | def get_node_features(nodes, example, max_ast_length):
99 | old_return_type_subtokens = get_old_return_type_subtokens(example)
100 | new_return_type_subtokens = get_return_type_subtokens(example)
101 | method_name_subtokens = get_method_name_subtokens(example)
102 |
103 | old_return_sequence = get_old_return_sequence(example)
104 | new_return_sequence = get_new_return_sequence(example)
105 |
106 | old_return_line_terms = set([t for t in old_return_sequence if not is_java_keyword(t) and not is_operator(t)])
107 | new_return_line_terms = set([t for t in new_return_sequence if not is_java_keyword(t) and not is_operator(t)])
108 | return_line_intersection = old_return_line_terms.intersection(new_return_line_terms)
109 |
110 | old_set = set(old_return_type_subtokens)
111 | new_set = set(new_return_type_subtokens)
112 |
113 | intersection = old_set.intersection(new_set)
114 |
115 | features = np.zeros((len(nodes), get_num_code_features()), dtype=np.int64)
116 |
117 | old_nl_tokens = set(example.old_comment_subtokens)
118 | last_command = None
119 |
120 | for i, node in enumerate(nodes):
121 | if not node.is_leaf:
122 | continue
123 |
124 | token = node.value
125 |
126 | if token in intersection:
127 | features[i][0] = True
128 | elif token in old_set:
129 | features[i][1] = True
130 | elif token in new_set:
131 | features[i][2] = True
132 | else:
133 | features[i][3] = True
134 |
135 | if token in return_line_intersection:
136 | features[i][4] = True
137 | elif token in old_return_line_terms:
138 | features[i][5] = True
139 | elif token in new_return_line_terms:
140 | features[i][6] = True
141 | else:
142 | features[i][7] = True
143 |
144 | if is_edit_keyword(token):
145 | features[i][8] = True
146 | if is_java_keyword(token):
147 | features[i][9] = True
148 | if is_operator(token):
149 | features[i][10] = True
150 | if token in old_nl_tokens:
151 | features[i][11] = True
152 |
153 | if not is_edit_keyword(token):
154 | if last_command == KEEP:
155 | features[i][12] = 1
156 | elif last_command == INSERT:
157 | features[i][13] = 1
158 | elif last_command == DELETE:
159 | features[i][14] = 1
160 | elif last_command == REPLACE_NEW:
161 | features[i][15] = 1
162 | else:
163 | features[i][16] = 1
164 | else:
165 | last_command = token
166 |
167 | if len(node.subtoken_children) > 0 or len(node.subtoken_parents) > 0:
168 | features[i][17] = True
169 |
170 | if len(node.subtoken_parents) == 1:
171 | features[i][18] = node.subtoken_parents[0].subtoken_children.index(node)
172 |
173 | return features.astype(np.float32)
174 |
175 | def get_code_features(code_sequence, example, max_code_length):
176 | old_return_type_subtokens = get_old_return_type_subtokens(example)
177 | new_return_type_subtokens = get_return_type_subtokens(example)
178 | method_name_subtokens = get_method_name_subtokens(example)
179 |
180 | old_return_sequence = get_old_return_sequence(example)
181 | new_return_sequence = get_new_return_sequence(example)
182 |
183 | old_return_line_terms = set([t for t in old_return_sequence if not is_java_keyword(t) and not is_operator(t)])
184 | new_return_line_terms = set([t for t in new_return_sequence if not is_java_keyword(t) and not is_operator(t)])
185 | return_line_intersection = old_return_line_terms.intersection(new_return_line_terms)
186 |
187 | old_set = set(old_return_type_subtokens)
188 | new_set = set(new_return_type_subtokens)
189 |
190 | intersection = old_set.intersection(new_set)
191 |
192 | features = np.zeros((max_code_length, get_num_code_features()), dtype=np.int64)
193 |
194 | old_nl_tokens = set(example.old_comment_subtokens)
195 | last_command = None
196 |
197 | subtoken_labels = get_edit_span_subtoken_tokenization_labels(example)
198 | subtoken_indices = get_edit_span_subtoken_tokenization_indices(example)
199 |
200 | for i, token in enumerate(code_sequence):
201 | if i >= max_code_length:
202 | break
203 | if token in intersection:
204 | features[i][0] = True
205 | elif token in old_set:
206 | features[i][1] = True
207 | elif token in new_set:
208 | features[i][2] = True
209 | else:
210 | features[i][3] = True
211 |
212 | if token in return_line_intersection:
213 | features[i][4] = True
214 | elif token in old_return_line_terms:
215 | features[i][5] = True
216 | elif token in new_return_line_terms:
217 | features[i][6] = True
218 | else:
219 | features[i][7] = True
220 |
221 | if is_edit_keyword(token):
222 | features[i][8] = True
223 | if is_java_keyword(token):
224 | features[i][9] = True
225 | if is_operator(token):
226 | features[i][10] = True
227 | if token in old_nl_tokens:
228 | features[i][11] = True
229 |
230 | if not is_edit_keyword(token):
231 | if last_command == KEEP:
232 | features[i][12] = 1
233 | elif last_command == INSERT:
234 | features[i][13] = 1
235 | elif last_command == DELETE:
236 | features[i][14] = 1
237 | elif last_command == REPLACE_NEW:
238 | features[i][15] = 1
239 | else:
240 | features[i][16] = 1
241 | else:
242 | last_command = token
243 |
244 | features[i][17] = subtoken_labels[i]
245 | features[i][18] = subtoken_indices[i]
246 |
247 | return features.astype(np.float32)
248 |
249 | def get_nl_features(old_nl_sequence, example, max_nl_length):
250 | insert_code_tokens = set()
251 | keep_code_tokens = set()
252 | delete_code_tokens = set()
253 | replace_old_code_tokens = set()
254 | replace_new_code_tokens = set()
255 |
256 | frequency_map = dict()
257 | for tok in old_nl_sequence:
258 | if tok not in frequency_map:
259 | frequency_map[tok] = 0
260 | frequency_map[tok] += 1
261 |
262 | pos_tags = pos_tag(word_tokenize(' '.join(old_nl_sequence)))
263 | pos_tag_indices = []
264 | for _, t in pos_tags:
265 | if t in tags:
266 | pos_tag_indices.append(tags.index(t))
267 | else:
268 | pos_tag_indices.append(tags.index('OTHER'))
269 |
270 | i = 0
271 | code_tokens = example.token_diff_code_subtokens
272 |
273 | while i < len(code_tokens):
274 | if code_tokens[i] == INSERT:
275 | insert_code_tokens.add(code_tokens[i+1].lower())
276 | i += 2
277 | elif code_tokens[i] == KEEP:
278 | keep_code_tokens.add(code_tokens[i+1].lower())
279 | i += 2
280 | elif code_tokens[i] == DELETE:
281 | delete_code_tokens.add(code_tokens[i+1].lower())
282 | i += 2
283 | elif code_tokens[i] == REPLACE_OLD:
284 | replace_old_code_tokens.add(code_tokens[i+1].lower())
285 | i += 2
286 | elif code_tokens[i] == REPLACE_NEW:
287 | replace_new_code_tokens.add(code_tokens[i+1].lower())
288 | i += 2
289 |
290 | old_return_type_subtokens = get_old_return_type_subtokens(example)
291 | new_return_type_subtokens = get_return_type_subtokens(example)
292 |
293 | old_return_sequence = get_old_return_sequence(example)
294 | new_return_sequence = get_new_return_sequence(example)
295 |
296 | old_return_line_terms = set([t for t in old_return_sequence if not is_java_keyword(t) and not is_operator(t)])
297 | new_return_line_terms = set([t for t in new_return_sequence if not is_java_keyword(t) and not is_operator(t)])
298 | return_line_intersection = old_return_line_terms.intersection(new_return_line_terms)
299 |
300 | old_set = set(old_return_type_subtokens)
301 | new_set = set(new_return_type_subtokens)
302 |
303 | intersection = old_set.intersection(new_set)
304 |
305 | method_name_subtokens = method_name_subtokens = get_method_name_subtokens(example)
306 |
307 | nl_subtoken_labels = get_nl_subtoken_tokenization_labels(example)
308 | nl_subtoken_indices = get_nl_subtoken_tokenization_indices(example)
309 |
310 | features = np.zeros((max_nl_length, get_num_nl_features()), dtype=np.int64)
311 | for i in range(len(old_nl_sequence)):
312 | if i >= max_nl_length:
313 | break
314 | token = old_nl_sequence[i].lower()
315 | if token in intersection:
316 | features[i][0] = True
317 | elif token in old_set:
318 | features[i][1] = True
319 | elif token in new_set:
320 | features[i][2] = True
321 | else:
322 | features[i][3] = True
323 |
324 | if token in return_line_intersection:
325 | features[i][4] = True
326 | elif token in old_return_line_terms:
327 | features[i][5] = True
328 | elif token in new_return_line_terms:
329 | features[i][6] = True
330 | else:
331 | features[i][7] = True
332 |
333 | features[i][8] = token in insert_code_tokens
334 | features[i][9] = token in keep_code_tokens
335 | features[i][10] = token in delete_code_tokens
336 | features[i][11] = token in replace_old_code_tokens
337 | features[i][12] = token in replace_new_code_tokens
338 | features[i][13] = token in stop_words
339 | features[i][14] = frequency_map[token] > 1
340 |
341 | features[i][15] = nl_subtoken_labels[i]
342 | features[i][16] = nl_subtoken_indices[i]
343 | features[i][17 + pos_tag_indices[i]] = 1
344 |
345 | return features.astype(np.float32)
346 |
--------------------------------------------------------------------------------
/comment_update/generation_decoder.py:
--------------------------------------------------------------------------------
1 | from dpu_utils.mlutils import Vocabulary
2 | import logging
3 | import numpy as np
4 | import os
5 | import random
6 | import sys
7 | import torch
8 | from torch import nn
9 |
10 | from constants import START, BEAM_SIZE
11 | from decoder import Decoder
12 |
13 | class GenerationDecoder(Decoder):
14 | def __init__(self, input_size, hidden_size, attention_state_size, embedding_store,
15 | embedding_size, dropout_rate):
16 | """Decoder for the generation model which generates a comment based on a
17 | learned representation of a method."""
18 | super(GenerationDecoder, self).__init__(input_size, hidden_size, attention_state_size,
19 | embedding_store, embedding_size, dropout_rate)
20 |
21 | def decode(self, initial_state, decoder_input_embeddings, encoder_hidden_states, masks):
22 | """Decoding with attention and copy."""
23 | decoder_states, decoder_final_state = self.gru.forward(decoder_input_embeddings,
24 | initial_state.unsqueeze(0))
25 |
26 | # https://stackoverflow.com/questions/50571991/implementing-luong-attention-in-pytorch
27 | attn_alignment = torch.einsum('ijk,km,inm->inj', encoder_hidden_states,
28 | self.attention_encoder_hidden_transform_matrix, decoder_states)
29 | attn_alignment.masked_fill_(masks, float('-inf'))
30 | attention_scores = nn.functional.softmax(attn_alignment, dim=-1)
31 | contexts = torch.einsum('ijk,ikm->ijm', attention_scores, encoder_hidden_states)
32 | decoder_states = torch.tanh(self.attention_output_layer(torch.cat([contexts, decoder_states], dim=-1)))
33 |
34 | generation_scores = torch.einsum('ijk,km->ijm', decoder_states, self.generation_output_matrix)
35 | copy_scores = torch.einsum('ijk,km,inm->inj', encoder_hidden_states,
36 | self.copy_encoder_hidden_transform_matrix, decoder_states)
37 | copy_scores.masked_fill_(masks, float('-inf'))
38 |
39 | combined_logprobs = nn.functional.log_softmax(torch.cat([generation_scores, copy_scores], dim=-1), dim=-1)
40 | generation_logprobs = combined_logprobs[:,:,:len(self.embedding_store.nl_vocabulary)]
41 | copy_logprobs = combined_logprobs[:, :,len(self.embedding_store.nl_vocabulary):]
42 |
43 | return decoder_states, decoder_final_state, generation_logprobs, copy_logprobs
44 |
45 | def forward(self, initial_state, decoder_input_embeddings, encoder_hidden_states, masks):
46 | """Runs decoding."""
47 | return self.decode(initial_state, decoder_input_embeddings, encoder_hidden_states, masks)
48 |
49 | def greedy_decode(self, initial_state, encoder_hidden_states, masks, max_out_len, batch_data, device):
50 | """Greedily generates the output sequence."""
51 | # Derived from https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/9f6b66f43d2e05175dabcc024f79e1d37a667070/decode_beam.py#L163
52 | batch_size = initial_state.shape[0]
53 | decoder_state = initial_state
54 | decoder_input = torch.tensor(
55 | [[self.embedding_store.get_nl_id(START)]] * batch_size,
56 | device=device
57 | )
58 |
59 | decoded_batch = np.zeros([batch_size, max_out_len], dtype=np.int64)
60 | decoded_batch_scores = np.zeros([batch_size, max_out_len])
61 |
62 | for i in range(max_out_len):
63 | decoder_input_embeddings = self.embedding_store.get_nl_embeddings(decoder_input)
64 | decoder_attention_states, decoder_state, generation_logprobs, copy_logprobs = self.decode(decoder_state,
65 | decoder_input_embeddings, encoder_hidden_states, masks)
66 |
67 | generation_logprobs = generation_logprobs.squeeze(1)
68 | copy_logprobs = copy_logprobs.squeeze(1)
69 |
70 | prob_scores = torch.zeros([generation_logprobs.shape[0],
71 | generation_logprobs.shape[-1] + copy_logprobs.shape[-1]], dtype=torch.float32, device=device)
72 | prob_scores[:, :generation_logprobs.shape[-1]] = torch.exp(generation_logprobs)
73 | for b in range(generation_logprobs.shape[0]):
74 | for c, inp_id in enumerate(batch_data.input_ids[b]):
75 | prob_scores[b, inp_id] = prob_scores[b, inp_id] + torch.exp(copy_logprobs[b,c])
76 |
77 | predicted_ids = torch.argmax(prob_scores, dim=-1)
78 | decoded_batch_scores[:, i] = prob_scores[torch.arange(prob_scores.shape[0]), predicted_ids].cpu()
79 | decoded_batch[:, i] = predicted_ids.cpu()
80 |
81 | unks = torch.ones(
82 | predicted_ids.shape[0], dtype=torch.int64, device=device) * self.embedding_store.get_nl_id(Vocabulary.get_unk())
83 | decoder_input = torch.where(predicted_ids < len(self.embedding_store.nl_vocabulary), predicted_ids, unks).unsqueeze(1)
84 | decoder_state = decoder_state.squeeze(0)
85 |
86 | return decoded_batch, decoded_batch_scores
87 |
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/comment_update/tensor_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def merge_encoder_outputs(a_states, a_lengths, b_states, b_lengths, device):
5 | a_max_len = a_states.size(1)
6 | b_max_len = b_states.size(1)
7 | combined_len = a_max_len + b_max_len
8 | padded_b_states = torch.zeros([b_states.size(0), combined_len, b_states.size(-1)], device=device)
9 | padded_b_states[:, :b_max_len, :] = b_states
10 | full_matrix = torch.cat([a_states, padded_b_states], dim=1)
11 | a_idxs = torch.arange(combined_len, dtype=torch.long, device=device).view(-1, 1)
12 | b_idxs = torch.arange(combined_len, dtype=torch.long,
13 | device=device).view(-1,1) - a_lengths.view(1, -1) + a_max_len
14 | idxs = torch.where(b_idxs < a_max_len, a_idxs, b_idxs).permute(1, 0)
15 | offset = torch.arange(0, full_matrix.size(0) * full_matrix.size(1), full_matrix.size(1), device=device)
16 | idxs = idxs + offset.unsqueeze(1)
17 | combined_states = full_matrix.reshape(-1, full_matrix.shape[-1])[idxs]
18 | combined_lengths = a_lengths + b_lengths
19 |
20 | return combined_states, combined_lengths
21 |
22 | def get_invalid_copy_locations(input_sequence, max_input_length, output_sequence, max_output_length):
23 | input_length = min(len(input_sequence), max_input_length)
24 | output_length = min(len(output_sequence), max_output_length)
25 |
26 | invalid_copy_locations = np.ones([max_output_length, max_input_length], dtype=np.bool)
27 | for o in range(output_length):
28 | for i in range(input_length):
29 | invalid_copy_locations[o,i] = output_sequence[o] != input_sequence[i]
30 |
31 | return invalid_copy_locations
32 |
33 | def compute_attention_states(key_states, masks, query_states, transformation_matrix=None, multihead_attention=None):
34 | if multihead_attention is not None:
35 | if transformation_matrix is not None:
36 | key = torch.einsum('bsh,hd->sbd', key_states, transformation_matrix) # S x B x D
37 | else:
38 | key = key_states.permute(1,0,2) # S x B x D
39 |
40 | query = query_states.permute(1,0,2) # T x B x D
41 | value = key
42 | attn_output, attn_output_weights = multihead_attention(query, key, value, key_padding_mask=masks.squeeze(1))
43 | return attn_output.permute(1,0,2)
44 | else:
45 | if transformation_matrix is not None:
46 | alignment = torch.einsum('bsh,hd,btd->bts', key_states, transformation_matrix, query_states)
47 | else:
48 | alignment = torch.einsum('bsh,bth->bts', key_states, query_states)
49 | alignment.masked_fill_(masks, float('-inf'))
50 | attention_scores = torch.nn.functional.softmax(alignment, dim=-1)
51 | return torch.einsum('ijk,ikm->ijm', attention_scores, key_states)
--------------------------------------------------------------------------------
/comment_update/update_decoder.py:
--------------------------------------------------------------------------------
1 | from dpu_utils.mlutils import Vocabulary
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch_scatter import scatter_add
6 |
7 | from constants import START, BEAM_SIZE
8 | from decoder import Decoder
9 | from tensor_utils import compute_attention_states
10 |
11 | class UpdateDecoder(Decoder):
12 | def __init__(self, input_size, hidden_size, attention_state_size, embedding_store,
13 | embedding_size, dropout_rate, attn_input_size):
14 | """Decoder for the edit model which generates a sequence of NL edits based on learned representations of
15 | the old comment and code edits."""
16 | super(UpdateDecoder, self).__init__(input_size, hidden_size, attention_state_size,
17 | embedding_store, embedding_size, dropout_rate)
18 |
19 | self.sequence_attention_code_transform_matrix = nn.Parameter(
20 | torch.randn(self.attention_state_size, self.hidden_size,
21 | dtype=torch.float, requires_grad=True)
22 | )
23 | self.attention_old_nl_hidden_transform_matrix = nn.Parameter(
24 | torch.randn(self.attention_state_size, self.hidden_size,
25 | dtype=torch.float, requires_grad=True)
26 | )
27 |
28 | self.attention_output_layer = nn.Linear(attn_input_size + self.hidden_size,
29 | self.hidden_size, bias=False)
30 |
31 | def decode(self, initial_state, decoder_input_embeddings, encoder_hidden_states,
32 | code_hidden_states, old_nl_hidden_states, masks, code_masks, old_nl_masks):
33 | """Decoding with attention and copy. Attention is computed separately for each set of encoder hidden states."""
34 | decoder_states, decoder_final_state = self.gru.forward(decoder_input_embeddings, initial_state.unsqueeze(0))
35 |
36 | attention_context_states = compute_attention_states(old_nl_hidden_states, old_nl_masks,
37 | decoder_states, self.attention_old_nl_hidden_transform_matrix, None)
38 |
39 | code_contexts = compute_attention_states(code_hidden_states, code_masks,
40 | decoder_states, self.sequence_attention_code_transform_matrix, None)
41 | attention_context_states = torch.cat([attention_context_states, code_contexts], dim=-1)
42 |
43 | decoder_states = torch.tanh(self.attention_output_layer(
44 | torch.cat([attention_context_states, decoder_states], dim=-1)))
45 |
46 | generation_scores = torch.einsum('ijk,km->ijm', decoder_states, self.generation_output_matrix)
47 | copy_scores = torch.einsum('ijk,km,inm->inj', encoder_hidden_states,
48 | self.copy_encoder_hidden_transform_matrix, decoder_states)
49 | copy_scores.masked_fill_(masks, float('-inf'))
50 |
51 | combined_logprobs = nn.functional.log_softmax(torch.cat([generation_scores, copy_scores], dim=-1), dim=-1)
52 | generation_logprobs = combined_logprobs[:,:,:len(self.embedding_store.nl_vocabulary)]
53 | copy_logprobs = combined_logprobs[:, :,len(self.embedding_store.nl_vocabulary):]
54 |
55 | return decoder_states, decoder_final_state, generation_logprobs, copy_logprobs
56 |
57 | def forward(self, initial_state, decoder_input_embeddings, encoder_hidden_states,
58 | code_hidden_states, old_nl_hidden_states, masks, code_masks, old_nl_masks):
59 | """Runs decoding."""
60 | return self.decode(initial_state, decoder_input_embeddings, encoder_hidden_states,
61 | code_hidden_states, old_nl_hidden_states, masks, code_masks, old_nl_masks)
62 |
63 | def beam_decode(self, initial_state, encoder_hidden_states, code_hidden_states, old_nl_hidden_states,
64 | masks, max_out_len, batch_data, code_masks, old_nl_masks, device):
65 | """Beam search. Generates the top K candidate predictions."""
66 | batch_size = initial_state.shape[0]
67 | decoded_batch = [list() for _ in range(batch_size)]
68 | decoded_batch_scores = np.zeros([batch_size, BEAM_SIZE])
69 |
70 | decoder_input = torch.tensor(
71 | [[self.embedding_store.get_nl_id(START)]] * batch_size, device=device)
72 | decoder_input = decoder_input.unsqueeze(1)
73 | decoder_state = initial_state.unsqueeze(1).expand(
74 | -1, decoder_input.shape[1], -1).reshape(-1, initial_state.shape[-1])
75 |
76 | beam_scores = torch.ones([batch_size, 1], dtype=torch.float32, device=device)
77 | beam_status = torch.zeros([batch_size, 1], dtype=torch.uint8, device=device)
78 | beam_predicted_ids = torch.full([batch_size, 1, max_out_len], self.embedding_store.get_end_id(),
79 | dtype=torch.int64, device=device)
80 |
81 | for i in range(max_out_len):
82 | beam_size = decoder_input.shape[1]
83 | if beam_status[:,0].sum() == batch_size:
84 | break
85 |
86 | tiled_encoder_states = encoder_hidden_states.unsqueeze(1).expand(-1, beam_size, -1, -1)
87 | tiled_masks = masks.unsqueeze(1).expand(-1, beam_size, -1, -1)
88 | tiled_code_hidden_states = code_hidden_states.unsqueeze(1).expand(-1, beam_size, -1, -1)
89 | tiled_code_masks = code_masks.unsqueeze(1).expand(-1, beam_size, -1, -1)
90 | tiled_old_nl_hidden_states = old_nl_hidden_states.unsqueeze(1).expand(-1, beam_size, -1, -1)
91 | tiled_old_nl_masks = old_nl_masks.unsqueeze(1).expand(-1, beam_size, -1, -1)
92 |
93 | flat_decoder_input = decoder_input.reshape(-1, decoder_input.shape[-1])
94 | flat_encoder_states = tiled_encoder_states.reshape(-1, tiled_encoder_states.shape[-2], tiled_encoder_states.shape[-1])
95 | flat_masks = tiled_masks.reshape(-1, tiled_masks.shape[-2], tiled_masks.shape[-1])
96 | flat_code_hidden_states = tiled_code_hidden_states.reshape(-1, tiled_code_hidden_states.shape[-2], tiled_code_hidden_states.shape[-1])
97 | flat_code_masks = tiled_code_masks.reshape(-1, tiled_code_masks.shape[-2], tiled_code_masks.shape[-1])
98 | flat_old_nl_hidden_states = tiled_old_nl_hidden_states.reshape(-1, tiled_old_nl_hidden_states.shape[-2], tiled_old_nl_hidden_states.shape[-1])
99 | flat_old_nl_masks = tiled_old_nl_masks.reshape(-1, tiled_old_nl_masks.shape[-2], tiled_old_nl_masks.shape[-1])
100 |
101 | decoder_input_embeddings = self.embedding_store.get_nl_embeddings(flat_decoder_input)
102 | decoder_attention_states, flat_decoder_state, generation_logprobs, copy_logprobs = self.decode(
103 | decoder_state, decoder_input_embeddings, flat_encoder_states, flat_code_hidden_states,
104 | flat_old_nl_hidden_states, flat_masks, flat_code_masks, flat_old_nl_masks)
105 |
106 | generation_logprobs = generation_logprobs.squeeze(1)
107 | copy_logprobs = copy_logprobs.squeeze(1)
108 |
109 | generation_logprobs = generation_logprobs.reshape(batch_size, beam_size, generation_logprobs.shape[-1])
110 | copy_logprobs = copy_logprobs.reshape(batch_size, beam_size, copy_logprobs.shape[-1])
111 |
112 | prob_scores = torch.zeros([batch_size, beam_size,
113 | generation_logprobs.shape[-1] + copy_logprobs.shape[-1]], dtype=torch.float32, device=device)
114 | prob_scores[:, :, :generation_logprobs.shape[-1]] = torch.exp(generation_logprobs)
115 |
116 | # Factoring in the copy scores
117 | expanded_token_ids = batch_data.input_ids.unsqueeze(1).expand(-1, beam_size, -1)
118 | prob_scores += scatter_add(src=torch.exp(copy_logprobs), index=expanded_token_ids, out=torch.zeros_like(prob_scores))
119 |
120 | top_scores_per_beam, top_indices_per_beam = torch.topk(prob_scores, k=BEAM_SIZE, dim=-1)
121 |
122 | updated_scores = torch.einsum('eb,ebm->ebm', beam_scores, top_scores_per_beam)
123 | retained_scores = beam_scores.unsqueeze(-1).expand(-1, -1, top_scores_per_beam.shape[-1])
124 |
125 | # Trying to keep at most one ray corresponding to completed beams
126 | end_mask = (torch.arange(beam_size) == 0).type(torch.float32).to(device)
127 | end_scores = torch.einsum('b,ebm->ebm', end_mask, retained_scores)
128 |
129 | possible_next_scores = torch.where(beam_status.unsqueeze(-1) == 1, end_scores, updated_scores)
130 | possible_next_status = torch.where(top_indices_per_beam == self.embedding_store.get_end_id(),
131 | torch.ones([batch_size, beam_size, top_scores_per_beam.shape[-1]], dtype=torch.uint8, device=device),
132 | beam_status.unsqueeze(-1).expand(-1,-1,top_scores_per_beam.shape[-1]))
133 |
134 | possible_beam_predicted_ids = beam_predicted_ids.unsqueeze(2).expand(-1, -1, top_scores_per_beam.shape[-1], -1)
135 | pool_next_scores = possible_next_scores.reshape(batch_size, -1)
136 | pool_next_status = possible_next_status.reshape(batch_size, -1)
137 | pool_next_ids = top_indices_per_beam.reshape(batch_size, -1)
138 | pool_predicted_ids = possible_beam_predicted_ids.reshape(batch_size, -1, beam_predicted_ids.shape[-1])
139 |
140 | possible_decoder_state = flat_decoder_state.reshape(batch_size, beam_size, flat_decoder_state.shape[-1])
141 | possible_decoder_state = possible_decoder_state.unsqueeze(2).expand(-1, -1, top_scores_per_beam.shape[-1], -1)
142 | pool_decoder_state = possible_decoder_state.reshape(batch_size, -1, possible_decoder_state.shape[-1])
143 |
144 | top_scores, top_indices = torch.topk(pool_next_scores, k=BEAM_SIZE, dim=-1)
145 | next_step_ids = torch.gather(pool_next_ids, -1, top_indices)
146 |
147 | decoder_state = torch.gather(pool_decoder_state, 1, top_indices.unsqueeze(-1).expand(-1,-1, pool_decoder_state.shape[-1]))
148 | decoder_state = decoder_state.reshape(-1, decoder_state.shape[-1])
149 | beam_status = torch.gather(pool_next_status, -1, top_indices)
150 | beam_scores = torch.gather(pool_next_scores, -1, top_indices)
151 |
152 | end_tags = torch.full_like(next_step_ids, self.embedding_store.get_end_id())
153 | next_step_ids = torch.where(beam_status == 1, end_tags, next_step_ids)
154 |
155 | beam_predicted_ids = torch.gather(pool_predicted_ids, 1, top_indices.unsqueeze(-1).expand(-1, -1, pool_predicted_ids.shape[-1]))
156 | beam_predicted_ids[:,:,i] = next_step_ids
157 |
158 | unks = torch.full_like(next_step_ids, self.embedding_store.get_nl_id(Vocabulary.get_unk()))
159 | decoder_input = torch.where(next_step_ids < len(self.embedding_store.nl_vocabulary), next_step_ids, unks).unsqueeze(-1)
160 |
161 | return beam_predicted_ids, beam_scores
--------------------------------------------------------------------------------
/comment_update/update_evaluation_utils.py:
--------------------------------------------------------------------------------
1 | import difflib
2 | import logging
3 | import os
4 | import numpy as np
5 | from typing import List, NamedTuple
6 | import subprocess
7 |
8 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
9 | from pycocoevalcap.meteor.meteor import Meteor
10 | from SARI import SARIsent
11 |
12 | from data_utils import get_processed_comment_str
13 |
14 | def compute_accuracy(reference_strings, predicted_strings):
15 | assert(len(reference_strings) == len(predicted_strings))
16 | correct = 0.0
17 | for i in range(len(reference_strings)):
18 | if reference_strings[i] == predicted_strings[i]:
19 | correct += 1
20 | return 100 * correct/float(len(reference_strings))
21 |
22 | def compute_bleu(references, hypotheses):
23 | bleu_4_sentence_scores = []
24 | for ref, hyp in zip(references, hypotheses):
25 | bleu_4_sentence_scores.append(sentence_bleu(ref, hyp,
26 | smoothing_function=SmoothingFunction().method2))
27 | return 100*sum(bleu_4_sentence_scores)/float(len(bleu_4_sentence_scores))
28 |
29 | def compute_sentence_bleu(ref, hyp):
30 | return sentence_bleu(ref, hyp, smoothing_function=SmoothingFunction().method2)
31 |
32 | def compute_sentence_meteor(reference_list, sentences):
33 | preds = dict()
34 | refs = dict()
35 |
36 | for i in range(len(sentences)):
37 | preds[i] = [' '.join([s for s in sentences[i]])]
38 | refs[i] = [' '.join(l) for l in reference_list[i]]
39 |
40 | final_scores = dict()
41 |
42 | scorers = [
43 | (Meteor(),"METEOR")
44 | ]
45 |
46 | for scorer, method in scorers:
47 | score, scores = scorer.compute_score(refs, preds)
48 | if type(method) == list:
49 | for sc, scs, m in zip(score, scores, method):
50 | final_scores[m] = scs
51 | else:
52 | final_scores[method] = scores
53 |
54 | meteor_scores = final_scores["METEOR"]
55 | return meteor_scores
56 |
57 | def compute_meteor(reference_list, sentences):
58 | meteor_scores = compute_sentence_meteor(reference_list, sentences)
59 | return 100 * sum(meteor_scores)/len(meteor_scores)
60 |
61 | def compute_unchanged(test_data, predictions):
62 | source_sentences = [get_processed_comment_str(ex.old_comment_subtokens) for ex in test_data]
63 | predicted_sentences = [' '.join(p) for p in predictions]
64 | unchanged = 0
65 |
66 | for source, predicted in zip(source_sentences, predicted_sentences):
67 | if source == predicted:
68 | unchanged += 1
69 |
70 | return 100*(unchanged)/len(test_data)
71 |
72 | def compute_sari(test_data, predictions):
73 | source_sentences = [get_processed_comment_str(ex.old_comment_subtokens) for ex in test_data]
74 | target_sentences = [[get_processed_comment_str(ex.new_comment_subtokens)] for ex in test_data]
75 | predicted_sentences = [' '.join(p) for p in predictions]
76 |
77 | inp = zip(source_sentences, target_sentences, predicted_sentences)
78 | scores = []
79 |
80 | for source, target, predicted in inp:
81 | scores.append(SARIsent(source, predicted, target))
82 |
83 | return 100*sum(scores)/float(len(scores))
84 |
85 | def compute_gleu(test_data, orig_file, ref_file, pred_file):
86 | command = 'python2.7 gleu/scripts/compute_gleu -s {} -r {} -o {} -d'.format(orig_file, ref_file, pred_file)
87 | output = subprocess.check_output(command.split())
88 |
89 | output_lines = [l.strip() for l in output.decode("utf-8").split('\n') if len(l.strip()) > 0]
90 | l = 0
91 | while l < len(output_lines):
92 | if output_lines[l][0] == '0':
93 | break
94 | l += 1
95 |
96 | scores = np.zeros(len(test_data), dtype=np.float32)
97 | while l < len(test_data):
98 | terms = output_lines[l].split()
99 | idx = int(terms[0])
100 | val = float(terms[1])
101 | scores[idx] = val
102 | l += 1
103 | scores = np.ndarray.tolist(scores)
104 | return 100*sum(scores)/float(len(scores))
105 |
106 | def write_predictions(predicted_strings, write_file):
107 | os.makedirs(os.path.dirname(write_file), exist_ok=True)
108 | with open(write_file, 'w+') as f:
109 | for p in predicted_strings:
110 | f.write('{}\n'.format(p))
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | START = ''
4 | END = ''
5 | NL_EMBEDDING_SIZE = 64
6 | CODE_EMBEDDING_SIZE = 64
7 | HIDDEN_SIZE = 64
8 | DROPOUT_RATE = 0.6
9 | NUM_LAYERS = 2
10 | LR = 0.001
11 | BATCH_SIZE = 100
12 | MAX_EPOCHS = 100
13 | PATIENCE = 10
14 | VOCAB_CUTOFF_PCT = 5
15 | LENGTH_CUTOFF_PCT = 95
16 | MAX_VOCAB_EXTENSION = 50
17 | BEAM_SIZE = 20
18 | MAX_VOCAB_SIZE = 10000
19 | FEATURE_DIMENSION = 128
20 | NUM_CLASSES = 2
21 |
22 | GNN_HIDDEN_SIZE = 64
23 | GNN_LAYER_TIMESTEPS = 8
24 | GNN_DROPOUT_RATE = 0.0
25 | SRC_EMBEDDING_SIZE = 8
26 | NODE_EMBEDDING_SIZE = 64
27 |
28 | MODEL_LAMBDA = 0.5
29 | LIKELIHOOD_LAMBDA = 0.3
30 | OLD_METEOR_LAMBDA = 0.2
31 | GEN_MODEL_LAMBDA = 0.5
32 | GEN_OLD_BLEU_LAMBDA = 0.5
33 | DECODER_HIDDEN_SIZE = 128
34 | MULTI_HEADS = 4
35 | NUM_TRANSFORMER_LAYERS = 2
36 |
37 | # Download data from here: https://drive.google.com/drive/folders/1heqEQGZHgO6gZzCjuQD1EyYertN4SAYZ?usp=sharing
38 | # DATA_PATH should point to the location in which the above data is saved locally
39 | DATA_PATH = '[PATH TO DOWNLOADED DATA]' # TODO
40 | RESOURCES_PATH = os.path.join(DATA_PATH, 'resources')
41 |
42 | # Download model resources from here: https://drive.google.com/drive/folders/1cutxr4rMDkT1g2BbmCAR2wqKTxeFH11K?usp=sharing
43 | # MODEL_RESOURCES_PATH should point to the location in which the above resources are saved locally.
44 | MODEL_RESOURCES_PATH = '[PATH TO DOWNLOADED MODEL RESOURCES]' # TODO
45 | NL_EMBEDDING_PATH = os.path.join(MODEL_RESOURCES_PATH, 'nl_embeddings.json')
46 | CODE_EMBEDDING_PATH = os.path.join(MODEL_RESOURCES_PATH, 'code_embeddings.json')
47 | FULL_GENERATION_MODEL_PATH = os.path.join(MODEL_RESOURCES_PATH, 'generation-model.pkl.gz')
48 |
49 | # Should point to where the output is to be saved
50 | PREDICTION_DIR = '[ROOT DIR TO STORE PREDICTED OUTPUT FOR UPDATE AND DUAL MODELS]' # TODO
51 | DETECTION_DIR = '[ROOT DIR TO STORE PREDICTED OUTPUT FOR DETECTION MODELS]' # TODO
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from constants import DATA_PATH
5 | from data_utils import DiffAST, DiffExample, DiffASTExample, CommentCategory
6 |
7 | PARTITIONS = ['train', 'valid', 'test']
8 |
9 | def get_data_splits(comment_type_str=None, ignore_ast=False):
10 | """Retrieves train/validation/test sets for the given comment_type_str.
11 | comment_type_str -- Return, Param, Summary, or None (if None, uses all comment types)
12 | ignore_ast -- Skip loading ASTs (they take a long time)"""
13 | dataset, high_level_details = load_processed_data(comment_type_str, ignore_ast)
14 | train_examples = dataset['train']
15 | valid_examples = dataset['valid']
16 | test_examples = dataset['test']
17 | return train_examples, valid_examples, test_examples, high_level_details
18 |
19 | def load_cleaned_test_set(comment_type_str=None):
20 | """Retrieves the ids corresponding to clean examples, for the given comment_type_str.
21 | comment_type_str -- Return, Param, Summary, or None (if None, uses all comment types)"""
22 | if not comment_type_str:
23 | comment_types = [CommentCategory(category).name for category in CommentCategory]
24 | else:
25 | comment_types = [comment_type_str]
26 |
27 | test_ids = []
28 | for comment_type in comment_types:
29 | resources_path = os.path.join(DATA_PATH, 'resources', comment_type, 'clean_test_ids.json')
30 | with open(resources_path) as f:
31 | test_ids.extend(json.load(f))
32 | return test_ids
33 |
34 | def load_processed_data(comment_type_str, ignore_ast):
35 | """Processes saved data for the given comment_type_str.
36 | comment_type_str -- Return, Param, Summary, or None (if None, uses all comment types)
37 | ignore_ast -- Skip loading ASTs (they take a long time)"""
38 | if not comment_type_str:
39 | comment_types = [CommentCategory(category).name for category in CommentCategory]
40 | else:
41 | comment_types = [comment_type_str]
42 |
43 | print('Loading data from: {}'.format(comment_types))
44 |
45 | dataset = dict()
46 | high_level_details = dict()
47 | for comment_type in comment_types:
48 | path = os.path.join(DATA_PATH, comment_type)
49 | loaded = load_raw_data_from_path(path)
50 | category_high_level_details_path = os.path.join(DATA_PATH, 'resources', comment_type, 'high_level_details.json')
51 |
52 | with open(category_high_level_details_path) as f:
53 | category_high_level_details = json.load(f)
54 | high_level_details.update(category_high_level_details)
55 |
56 | if not ignore_ast:
57 | ast_path = os.path.join(DATA_PATH, 'resources', comment_type, 'ast_objs.json')
58 | with open(ast_path) as f:
59 | ast_details = json.load(f)
60 |
61 | for partition, examples in loaded.items():
62 | if partition not in dataset:
63 | dataset[partition] = []
64 |
65 | if ignore_ast:
66 | dataset[partition].extend(examples)
67 | else:
68 | for ex in examples:
69 | ex_ast_info = ast_details[ex.id]
70 | old_ast = DiffAST.from_json(ex_ast_info['old_ast'])
71 | new_ast = DiffAST.from_json(ex_ast_info['new_ast'])
72 | diff_ast = DiffAST.from_json(ex_ast_info['diff_ast'])
73 |
74 | ast_ex = DiffASTExample(ex.id, ex.label, ex.comment_type, ex.old_comment_raw,
75 | ex.old_comment_subtokens, ex.new_comment_raw, ex.new_comment_subtokens, ex.span_minimal_diff_comment_subtokens,
76 | ex.old_code_raw, ex.old_code_subtokens, ex.new_code_raw, ex.new_code_subtokens,
77 | ex.span_diff_code_subtokens, ex.token_diff_code_subtokens, old_ast, new_ast, diff_ast)
78 |
79 | dataset[partition].append(ast_ex)
80 |
81 | return dataset, high_level_details
82 |
83 | def load_raw_data_from_path(path):
84 | """Reads saved partition-level data from a directory path"""
85 | dataset = dict()
86 |
87 | for partition in PARTITIONS:
88 | dataset[partition] = []
89 | dataset[partition].extend(read_diff_examples_from_file(os.path.join(path, '{}.json'.format(partition))))
90 |
91 | return dataset
92 |
93 | def read_diff_examples_from_file(filename):
94 | """Reads saved data from filename"""
95 | with open(filename) as f:
96 | data = json.load(f)
97 | return [DiffExample(**d) for d in data]
--------------------------------------------------------------------------------
/data_processing/ast_diffing/code_samples/new.java:
--------------------------------------------------------------------------------
1 | /**Computes the highest value from the list of scores.*/
2 | public double getBestScore() {
3 | return Collections.max(scores);
4 | }
--------------------------------------------------------------------------------
/data_processing/ast_diffing/code_samples/old.java:
--------------------------------------------------------------------------------
1 | /**Computes the lowest value from the list of scores.*/
2 | public int getBestScore() {
3 | return Collections.min(scores);
4 | }
--------------------------------------------------------------------------------
/data_processing/ast_diffing/python/xml_diff_parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | import subprocess
6 | import sys
7 |
8 | import xml.etree.ElementTree as ET
9 |
10 | sys.path.append('../../../')
11 | sys.path.append('../../../comment_update')
12 | from data_utils import DiffTreeNode, DiffAST
13 |
14 |
15 | class Indexer:
16 | def __init__ (self):
17 | self.count = 0
18 |
19 | def generate(self):
20 | new_id = self.count
21 | self.count += 1
22 | return new_id
23 |
24 | class XMLNode:
25 | def __init__(self, value, node_id, parent, attribute,
26 | alignment_id, location_id, src, is_leaf=True):
27 | self.value = value
28 | self.node_id = node_id
29 | self.parent = parent
30 | self.attribute = attribute
31 | self.alignment_id = alignment_id
32 | self.location_id = location_id
33 | self.src = src
34 | self.is_leaf = is_leaf
35 | self.children = []
36 | self.pseudo_children = []
37 | self.prev_sibling = None
38 | self.next_sibling = None
39 |
40 | def print_node(self):
41 | parent_value = None
42 | if self.parent:
43 | parent_value = self.parent.value
44 |
45 | print('{}: {} ({}, {})'.format(self.node_id, self.value, parent_value, len(self.children)))
46 | for c in self.children:
47 | c.print_node()
48 |
49 | class AST:
50 | def __init__(self, ast_root):
51 | self.root = ast_root
52 | self.nodes = []
53 | self.traverse(ast_root)
54 |
55 | def traverse(self, curr_node):
56 | self.nodes.append(curr_node)
57 | for c, child_node in enumerate(curr_node.children):
58 | if c > 0:
59 | child_node.prev_sibling = curr_node.children[c-1]
60 | if c < len(curr_node.children) - 1:
61 | child_node.next_sibling = curr_node.children[c+1]
62 | self.traverse(child_node)
63 |
64 | @property
65 | def leaves(self):
66 | return [n for n in self.nodes if n.is_leaf]
67 |
68 | def parse_xml_obj(xml_obj, indexer, parent, src):
69 | fields = xml_obj.attrib
70 | attribute = fields['typeLabel']
71 | is_leaf = False
72 |
73 | if 'label' in fields:
74 | is_leaf = True
75 | value = fields['label']
76 | else:
77 | value = attribute
78 |
79 | alignment_id = None
80 | location_id = '{}-{}-{}-{}'.format(fields['type'], value, fields['pos'], fields['length'])
81 |
82 | if 'other_pos' in fields:
83 | if src == 'old':
84 | alignment_id = '{}-{}-{}-{}'.format(fields['pos'], fields['length'], fields['other_pos'], fields['other_length'])
85 | else:
86 | alignment_id = '{}-{}-{}-{}'.format(fields['other_pos'], fields['other_length'], fields['pos'], fields['length'])
87 |
88 | node = XMLNode(value, indexer.generate(), parent,
89 | attribute, alignment_id, location_id, src, is_leaf)
90 |
91 | for child_obj in xml_obj:
92 | node.children.append(parse_xml_obj(child_obj, indexer, node, src))
93 | return node
94 |
95 | def set_id(diff_node, indexer):
96 | diff_node.node_id = indexer.generate()
97 | for node in diff_node.children:
98 | set_id(node, indexer)
99 |
100 | def print_diff_node(diff_node):
101 | print('{} ({}-{}): {}, {}'.format(diff_node.value, diff_node.src, diff_node.node_id,
102 | [c.value for c in diff_node.children], [p.node_id for p in diff_node.parents]))
103 | for child in diff_node.children:
104 | print_diff_node(child)
105 |
106 | def get_individual_ast_objs(old_sample_path, new_sample_path, actions_json, jar_path):
107 | old_xml_path = os.path.join(XML_DIR, 'old.xml')
108 | new_xml_path = os.path.join(XML_DIR, 'new.xml')
109 |
110 | output = subprocess.check_output(['java', '-jar', jar_path, old_sample_path,
111 | new_sample_path, old_xml_path, new_xml_path, actions_json])
112 |
113 | xml_obj = ET.parse(old_xml_path)
114 | old_root = parse_xml_obj(xml_obj.getroot()[1], Indexer(), None, 'old')
115 | old_ast = AST(old_root)
116 |
117 | xml_obj = ET.parse(new_xml_path)
118 | new_root = parse_xml_obj(xml_obj.getroot()[1], Indexer(), None, 'new')
119 | new_ast = AST(new_root)
120 |
121 | old_nodes = old_ast.nodes
122 | old_diff_nodes = [DiffTreeNode(n.value, n.attribute, n.src, n.is_leaf) for n in old_nodes]
123 |
124 | old_diff_nodes_by_alignment = dict()
125 | for n, old_node in enumerate(old_nodes):
126 | old_diff_node = old_diff_nodes[n]
127 | if old_node.parent:
128 | old_diff_node.parents.append(old_diff_nodes[old_node.parent.node_id])
129 |
130 | for c in old_node.children:
131 | old_diff_node.children.append(old_diff_nodes[c.node_id])
132 |
133 | if old_node.prev_sibling:
134 | old_diff_node.prev_siblings.append(old_diff_nodes[old_node.prev_sibling.node_id])
135 |
136 | if old_node.next_sibling:
137 | old_diff_node.next_siblings.append(old_diff_nodes[old_node.next_sibling.node_id])
138 |
139 | if old_node.alignment_id:
140 | old_diff_nodes_by_alignment[old_node.alignment_id] = old_diff_node
141 |
142 | new_nodes = new_ast.nodes
143 | new_diff_nodes = [DiffTreeNode(n.value, n.attribute, n.src, n.is_leaf) for n in new_nodes]
144 |
145 | for n, new_node in enumerate(new_nodes):
146 | new_diff_node = new_diff_nodes[n]
147 | if new_node.parent:
148 | new_diff_node.parents.append(new_diff_nodes[new_node.parent.node_id])
149 |
150 | for c in new_node.children:
151 | new_diff_node.children.append(new_diff_nodes[c.node_id])
152 |
153 | if new_node.prev_sibling:
154 | new_diff_node.prev_siblings.append(new_diff_nodes[new_node.prev_sibling.node_id])
155 |
156 | if new_node.next_sibling:
157 | new_diff_node.next_siblings.append(new_diff_nodes[new_node.next_sibling.node_id])
158 |
159 | old_diff_ast = DiffAST(old_diff_nodes[0])
160 | new_diff_ast = DiffAST(new_diff_nodes[0])
161 |
162 | return old_diff_ast, new_diff_ast
163 |
164 | def get_diff_ast(old_sample_path, new_sample_path, actions_json, jar_path):
165 | old_xml_path = os.path.join(XML_DIR, 'old.xml')
166 | new_xml_path = os.path.join(XML_DIR, 'new.xml')
167 | output = subprocess.check_output(['java', '-jar', jar_path, old_sample_path,
168 | new_sample_path, old_xml_path, new_xml_path, actions_json])
169 |
170 | xml_obj = ET.parse(old_xml_path)
171 | old_root = parse_xml_obj(xml_obj.getroot()[1], Indexer(), None, 'old')
172 | old_ast = AST(old_root)
173 |
174 | xml_obj = ET.parse(new_xml_path)
175 | new_root = parse_xml_obj(xml_obj.getroot()[1], Indexer(), None, 'new')
176 | new_ast = AST(new_root)
177 |
178 | with open(actions_json) as f:
179 | actions = json.load(f)
180 |
181 | old_actions = dict()
182 | new_actions = dict()
183 |
184 | for action in actions:
185 | location_id = '{}-{}-{}-{}'.format(action['type'], action['label'], action['position'], action['length'])
186 | if action['action'] == 'Insert':
187 | new_actions[location_id] = action['action']
188 | else:
189 | old_actions[location_id] = action['action']
190 |
191 | old_nodes = old_ast.nodes
192 | old_diff_nodes = []
193 | for n in old_nodes:
194 | old_diff_node = DiffTreeNode(n.value, n.attribute, n.src, n.is_leaf)
195 | if n.location_id in old_actions:
196 | old_diff_node.action_type = old_actions[n.location_id]
197 | old_diff_nodes.append(old_diff_node)
198 |
199 | old_diff_nodes_by_alignment = dict()
200 | for n, old_node in enumerate(old_nodes):
201 | old_diff_node = old_diff_nodes[n]
202 | if old_node.parent:
203 | old_diff_node.parents.append(old_diff_nodes[old_node.parent.node_id])
204 |
205 | for c in old_node.children:
206 | old_diff_node.children.append(old_diff_nodes[c.node_id])
207 |
208 | if old_node.prev_sibling:
209 | old_diff_node.prev_siblings.append(old_diff_nodes[old_node.prev_sibling.node_id])
210 |
211 | if old_node.next_sibling:
212 | old_diff_node.next_siblings.append(old_diff_nodes[old_node.next_sibling.node_id])
213 |
214 | if old_node.alignment_id:
215 | if old_node.alignment_id not in old_diff_nodes_by_alignment:
216 | old_diff_nodes_by_alignment[old_node.alignment_id] = []
217 | old_diff_nodes_by_alignment[old_node.alignment_id].append(old_diff_node)
218 |
219 | new_nodes = new_ast.nodes
220 | new_diff_nodes = []
221 |
222 | for n, new_node in enumerate(new_nodes):
223 | if new_node.alignment_id in old_diff_nodes_by_alignment and len(old_diff_nodes_by_alignment[new_node.alignment_id]) > 0:
224 | old_diff_node = old_diff_nodes_by_alignment[new_node.alignment_id].pop(0)
225 | if new_node.value == old_diff_node.value:
226 | new_diff_node = old_diff_node
227 | new_diff_node.src = 'both'
228 | new_diff_nodes.append(new_diff_node)
229 | else:
230 | new_diff_node = DiffTreeNode(new_node.value, new_node.attribute, new_node.src, new_node.is_leaf)
231 | new_diff_node.aligned_neighbors.append(old_diff_node)
232 | old_diff_node.aligned_neighbors.append(new_diff_node)
233 | new_diff_node.action_type = old_diff_node.action_type
234 |
235 | if new_node.location_id in new_actions:
236 | new_diff_node.action_type = new_actions[new_node.location_id]
237 |
238 | new_diff_nodes.append(new_diff_node)
239 | else:
240 | new_diff_node = DiffTreeNode(new_node.value, new_node.attribute, new_node.src, new_node.is_leaf)
241 | if new_node.location_id in new_actions:
242 | new_diff_node.action_type = new_actions[new_node.location_id]
243 | new_diff_nodes.append(new_diff_node)
244 |
245 | for n, new_node in enumerate(new_nodes):
246 | new_diff_node = new_diff_nodes[n]
247 | if new_node.parent and new_diff_nodes[new_node.parent.node_id] not in new_diff_node.parents:
248 | new_diff_node.parents.append(new_diff_nodes[new_node.parent.node_id])
249 |
250 | for c in new_node.children:
251 | if new_diff_nodes[c.node_id] not in new_diff_node.children:
252 | new_diff_node.children.append(new_diff_nodes[c.node_id])
253 |
254 | if new_node.prev_sibling and new_diff_nodes[new_node.prev_sibling.node_id] not in new_diff_node.prev_siblings:
255 | new_diff_node.prev_siblings.append(new_diff_nodes[new_node.prev_sibling.node_id])
256 |
257 | if new_node.next_sibling and new_diff_nodes[new_node.next_sibling.node_id] not in new_diff_node.next_siblings:
258 | new_diff_node.next_siblings.append(new_diff_nodes[new_node.next_sibling.node_id])
259 |
260 | super_root = DiffTreeNode('SuperRoot', 'SuperRoot', 'both', False)
261 | super_root.children.append(old_diff_nodes[0])
262 | old_diff_nodes[0].parents.append(super_root)
263 |
264 | if old_diff_nodes[0] != new_diff_nodes[0]:
265 | super_root.children.append(new_diff_nodes[0])
266 | new_diff_nodes[0].parents.append(super_root)
267 |
268 | diff_ast = DiffAST(super_root)
269 | return diff_ast
270 |
271 | if __name__ == "__main__":
272 | parser = argparse.ArgumentParser()
273 | parser.add_argument('--old_sample_path', help='path to java file containing old version of method')
274 | parser.add_argument('--new_sample_path', help='path to java file containing new version of method')
275 | parser.add_argument('--jar_path', help='path to downloaded jar file')
276 | args = parser.parse_args()
277 |
278 | logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s')
279 | logging.basicConfig(level=logging.ERROR, format='%(asctime)-15s %(message)s')
280 |
281 | XML_DIR = 'xml_files/'
282 | os.makedirs(XML_DIR, exist_ok=True)
283 |
284 | old_ast, new_ast = get_individual_ast_objs(args.old_sample_path, args.new_sample_path, 'old_new_ast_actions.json', args.jar_path)
285 | diff_ast = get_diff_ast(args.old_sample_path, args.new_sample_path, 'diff_ast_actions.json', args.jar_path)
286 |
287 | print(diff_ast.to_json())
--------------------------------------------------------------------------------
/data_processing/build_example.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from data_formatting_utils import subtokenize_code, subtokenize_comment
4 |
5 | sys.path.append('../')
6 | sys.path.append('../comment_update')
7 | from data_utils import DiffASTExample
8 | from diff_utils import compute_minimal_comment_diffs, compute_code_diffs
9 |
10 | # NOTE: Javalang will need to installed for this
11 | def build_test_example():
12 | example_id = 'test-id-0'
13 | label = 1
14 | comment_type = 'Return'
15 | old_comment_raw = '@return the highest score'
16 | old_comment_subtokens = subtokenize_comment(old_comment_raw).split()
17 | new_comment_raw = '@return the lowest score'
18 | new_comment_subtokens = subtokenize_comment(new_comment_raw).split()
19 | span_minimal_diff_comment_subtokens, _, _ = compute_minimal_comment_diffs(
20 | old_comment_subtokens, new_comment_subtokens)
21 | old_code_raw = 'public int getBestScore()\n{\n\treturn Collections.max(scores);\n}'
22 | old_code_subtokens = subtokenize_code(old_code_raw).split()
23 | new_code_raw = 'public int getBestScore()\n{\n\treturn Collections.min(scores);\n}'
24 | new_code_subtokens = subtokenize_code(new_code_raw).split()
25 | span_diff_code_subtokens, token_diff_code_subtokens, _ = compute_code_diffs(old_code_subtokens, new_code_subtokens)
26 |
27 | # TODO: Add code for parsing ASTs
28 | old_ast = None
29 | new_ast = None
30 | diff_ast = None
31 |
32 | return DiffASTExample(example_id, label, comment_type, old_comment_raw, old_comment_subtokens, new_comment_raw,
33 | new_comment_subtokens, span_minimal_diff_comment_subtokens, old_code_raw, old_code_subtokens, new_code_raw,
34 | new_code_subtokens, span_diff_code_subtokens, token_diff_code_subtokens, old_ast, new_ast, diff_ast)
--------------------------------------------------------------------------------
/data_processing/data_formatting_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import javalang
3 | import json
4 | import numpy as np
5 | import os
6 | import random
7 | import re
8 | import string
9 |
10 | SPECIAL_TAGS = ['{', '}', '@code', '@docRoot', '@inheritDoc', '@link', '@linkplain', '@value']
11 |
12 | def remove_html_tag(line):
13 | clean = re.compile('<.*?>')
14 | line = re.sub(clean, '', line)
15 |
16 | for tag in SPECIAL_TAGS:
17 | line = line.replace(tag, '')
18 |
19 | return line
20 |
21 | def remove_tag_string(line):
22 | search_strings = ['@return', '@ return', '@param', '@ param', '@throws', '@ throws']
23 | for s in search_strings:
24 | line = line.replace(s, '').strip()
25 | return line
26 |
27 | def tokenize_comment(comment_line, remove_tag=True):
28 | if remove_tag:
29 | comment_line = remove_tag_string(comment_line)
30 | comment_line = remove_html_tag(comment_line)
31 | comment_line = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", comment_line.strip())
32 | comment_line = ' '.join(comment_line)
33 | comment_line = comment_line.replace('\n', ' ').strip()
34 |
35 | return comment_line
36 |
37 | def subtokenize_comment(comment_line, remove_tag=True):
38 | if remove_tag:
39 | comment_line = remove_tag_string(comment_line)
40 | comment_line = remove_html_tag(comment_line.replace('/**', '').replace('**/', '').replace('/*', '').replace('*/', '').replace('*', '').strip())
41 | comment_line = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", comment_line.strip())
42 | comment_line = ' '.join(comment_line)
43 | comment_line = comment_line.replace('\n', ' ').strip()
44 |
45 | tokens = comment_line.split(' ')
46 | subtokens = []
47 | for token in tokens:
48 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split()
49 | try:
50 | new_curr = []
51 | for c in curr:
52 | by_symbol = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", c.strip())
53 | new_curr = new_curr + by_symbol
54 |
55 | curr = new_curr
56 | except:
57 | curr = []
58 | subtokens = subtokens + [c.lower() for c in curr]
59 |
60 | comment_line = ' '.join(subtokens)
61 | return comment_line.lower()
62 |
63 | def subtokenize_code(line):
64 | try:
65 | tokens = get_clean_code(list(javalang.tokenizer.tokenize(line)))
66 | except:
67 | tokens = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", line.strip())
68 | subtokens = []
69 | for token in tokens:
70 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split()
71 | subtokens = subtokens + [c.lower() for c in curr]
72 |
73 | return ' '.join(subtokens)
74 |
75 | def tokenize_code(line):
76 | try:
77 | tokens = [t.value for t in list(javalang.tokenizer.tokenize(line))]
78 | return ' '.join(tokens)
79 | except:
80 | return tokenize_clean_code(line)
81 |
82 | def tokenize_clean_code(line):
83 | try:
84 | return ' '.join(get_clean_code(list(javalang.tokenizer.tokenize(line))))
85 | except:
86 | return ' '.join(re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", line.strip()))
87 |
88 | def get_clean_code(tokenized_code):
89 | token_vals = [t.value for t in tokenized_code]
90 | new_token_vals = []
91 | for t in token_vals:
92 | n = [c for c in re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", t.encode('ascii', errors='ignore').decode().strip()) if len(c) > 0]
93 | new_token_vals = new_token_vals + n
94 |
95 | token_vals = new_token_vals
96 | cleaned_code_tokens = []
97 |
98 | for c in token_vals:
99 | try:
100 | cleaned_code_tokens.append(str(c))
101 | except:
102 | pass
103 |
104 | return cleaned_code_tokens
--------------------------------------------------------------------------------
/data_processing/high_level_feature_extractor.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | import sys
5 |
6 | from build_example import build_test_example
7 | from data_formatting_utils import tokenize_clean_code, subtokenize_code
8 |
9 | sys.path.append('../')
10 | from diff_utils import is_edit_keyword, KEEP, DELETE, INSERT, REPLACE_OLD, REPLACE_NEW
11 |
12 | EDIT_INDICES = [KEEP, DELETE, INSERT, REPLACE_OLD, REPLACE_NEW]
13 |
14 | def extract_arguments(code_block):
15 | i = 0
16 | while i < len(code_block):
17 | line = code_block[i].strip()
18 | if len(line.strip()) == 0:
19 | i += 1
20 | continue
21 | if line[0] == '@' and ' ' not in line:
22 | i += 1
23 | continue
24 | if '//' in line or '*' in line:
25 | i += 1
26 | continue
27 | else:
28 | break
29 |
30 | argument_string = line[line.index('(')+1:]
31 |
32 | if argument_string.count('(') + 1 == argument_string.count(')'):
33 | argument_string = argument_string[:argument_string.rfind(')')]
34 | else:
35 | curr_open_count = argument_string.count('(') + 1
36 | curr_close_count = argument_string.count(')')
37 | i += 1
38 | extension = ''
39 | while i < len(code_block):
40 | for w in code_block[i].strip():
41 | extension += w
42 | if w == '(':
43 | curr_open_count += 1
44 | elif w == ')':
45 | curr_close_count += 1
46 | if curr_open_count == curr_close_count:
47 | break
48 | if curr_open_count == curr_close_count:
49 | break
50 | i += 1
51 |
52 | if curr_open_count != curr_close_count:
53 | raise ValueError('Invalid arguments')
54 |
55 | argument_string = argument_string + extension[:-1]
56 |
57 | argument_types = []
58 | argument_names = []
59 |
60 | argument_string = ' '.join([a for a in argument_string.split() if '@' not in a])
61 | terms = []
62 | a = 0
63 | curr_term = []
64 |
65 | open_count = 0
66 | close_count = 0
67 |
68 | while a < len(argument_string):
69 | t = argument_string[a]
70 | if t == ' ' and open_count == close_count:
71 | terms.append(''.join(curr_term).strip())
72 | curr_term = []
73 | a += 1
74 | continue
75 | if t == ',' and open_count == close_count:
76 | curr_term.append(t)
77 | terms.append(''.join(curr_term).strip())
78 | curr_term = []
79 | a += 1
80 | continue
81 |
82 | if t == ',' and open_count != close_count:
83 | a += 1
84 | continue
85 |
86 | if t == '<':
87 | open_count += 1
88 |
89 | if t == '>':
90 | close_count += 1
91 |
92 | curr_term.append(t)
93 | a += 1
94 |
95 | if len(curr_term) > 0:
96 | terms.append(''.join(curr_term).strip())
97 |
98 | terms = [t for t in terms if t not in ['private', 'protected', 'public', 'final', 'static']]
99 | arguments = ' '.join(terms).split(',')
100 | arguments = [a.strip() for a in arguments if len(a.strip()) > 0]
101 | for argument in arguments:
102 | argument_tokens = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", argument.strip())
103 | argument_types.append(argument_tokens[0])
104 | argument_names.append(argument_tokens[-1])
105 |
106 | return argument_names, argument_types
107 |
108 | def strip_comment(s):
109 | """Checks whether a single line follows the structure of a comment."""
110 | new_s = re.sub(r'\"(.+?)\"', '', s)
111 | matched_obj = re.findall("(?:/\\*(?:[^*]|(?:\\*+[^*/]))*\\*+/)|(?://.*)", new_s)
112 | url_match = re.findall('https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+', new_s)
113 | file_match = re.findall('^(.*/)?(?:$|(.+?)(?:(\.[^.]*$)|$))', new_s)
114 |
115 | if matched_obj and not url_match:
116 | for m in matched_obj:
117 | s = s.replace(m, ' ')
118 | return s.strip()
119 |
120 | def extract_return_statements(code_block):
121 | cleaned_lines = []
122 | for l in code_block:
123 | cleaned_l = strip_comment(l)
124 | if len(cleaned_l) > 0:
125 | cleaned_lines.append(cleaned_l)
126 |
127 | combined_block = ' '.join(cleaned_lines)
128 | if 'return' not in combined_block:
129 | return []
130 | indices = [m.start() for m in re.finditer('return ', combined_block)]
131 | return_statements = []
132 | for idx in indices:
133 | s_idx = idx + len('return ')
134 | e_idx = s_idx + combined_block[s_idx:].index(';')
135 | statement = combined_block[s_idx:e_idx].strip()
136 | if len(statement) > 0:
137 | return_statements.append(statement)
138 |
139 | return return_statements
140 |
141 |
142 | def is_operator(token):
143 | for s in token:
144 | if s.isalnum():
145 | return False
146 | return True
147 |
148 | def extract_method_name(code_block):
149 | i = 0
150 | while i < len(code_block):
151 | line = code_block[i].strip()
152 | if len(line.strip()) == 0:
153 | i += 1
154 | continue
155 | if line[0] == '@' and ' ' not in line:
156 | i += 1
157 | continue
158 | if '//' in line or '*' in line:
159 | i += 1
160 | continue
161 | else:
162 | break
163 |
164 | try:
165 | method_components = line.strip().split('(')[0].split(' ')
166 | method_components = [m for m in method_components if len(m) > 0]
167 | method_name = method_components[-1].strip()
168 | except:
169 | method_name = ''
170 |
171 | return method_name
172 |
173 | def extract_return_type(code_block):
174 | i = 0
175 | while i < len(code_block):
176 | line = code_block[i].strip()
177 | if len(line.strip()) == 0:
178 | i += 1
179 | continue
180 | if line[0] == '@':
181 | i += 1
182 | continue
183 | if '//' in line or '*' in line:
184 | i += 1
185 | continue
186 | else:
187 | break
188 |
189 | before_method_name_tokens = line.split('(')[0].split(' ')[:-1]
190 | return_type_tokens = []
191 | for tok in before_method_name_tokens:
192 | if tok not in ['private', 'protected', 'public', 'final', 'static']:
193 | return_type_tokens.append(tok)
194 | return ' '.join(return_type_tokens)
195 |
196 | def get_change_labels(tokens):
197 | cache = dict()
198 | for label in EDIT_INDICES:
199 | cache[label] = set()
200 |
201 | label = None
202 | for t in tokens:
203 | if is_edit_keyword(t):
204 | label = t
205 | elif is_operator(t):
206 | continue
207 | else:
208 | cache[label].add(t)
209 |
210 | for label, label_set in cache.items():
211 | cache[label] = list(label_set)
212 | return cache
213 |
214 | def extract_throwable_exceptions(code_block):
215 | i = 0
216 | while i < len(code_block):
217 | line = code_block[i].strip()
218 | if 'throws' in line:
219 | break
220 | i += 1
221 |
222 | if 'throws' not in line:
223 | return []
224 |
225 | throws_string = line[line.index('throws') + len('throws'):]
226 | if '{' in throws_string:
227 | throws_string = throws_string[:throws_string.index('{')]
228 | else:
229 | extension = ''
230 | i += 1
231 | while i < len(code_block):
232 | line = code_block[i].strip()
233 | if len(line) == 0:
234 | i += 1
235 | continue
236 | for w in line:
237 | if w == '{':
238 | break
239 | else:
240 | extension += w
241 | if w == '{':
242 | break
243 | else:
244 | i += 1
245 |
246 | throws_string += extension
247 |
248 | exception_tokens = [t for t in tokenize_clean_code(throws_string).split() if not is_operator(t)]
249 | return exception_tokens
250 |
251 | def extract_throw_statements(code_block):
252 | cleaned_lines = []
253 | for l in code_block:
254 | cleaned_l = strip_comment(l)
255 | if len(cleaned_l) > 0:
256 | cleaned_lines.append(cleaned_l)
257 |
258 | combined_block = ' '.join(cleaned_lines)
259 | if 'throw' not in combined_block:
260 | return []
261 | indices = [m.start() for m in re.finditer('throw ', combined_block)]
262 | throw_statements = []
263 | for idx in indices:
264 | s_idx = idx + len('throw ')
265 | e_idx = s_idx + combined_block[s_idx:].index(';')
266 | statement = combined_block[s_idx:e_idx].strip()
267 | if len(statement) > 0:
268 | throw_statements.append(statement)
269 |
270 | return throw_statements
271 |
272 | def get_method_elements(code_block):
273 | argument_names, argument_types = extract_arguments(code_block)
274 | return_statements = extract_return_statements(code_block)
275 | return_type = extract_return_type(code_block)
276 |
277 | throwable_exception_tokens = extract_throwable_exceptions(code_block)
278 | throwable_exception_subtokens = []
279 | for throwable_exception in throwable_exception_tokens:
280 | throwable_exception_subtokens.extend(subtokenize_code(throwable_exception).split())
281 |
282 | throw_statements = extract_throw_statements(code_block)
283 | throw_statement_tokens = []
284 | throw_statement_subtokens = []
285 | for throw_statement in throw_statements:
286 | throw_statement_tokens.extend([t for t in tokenize_clean_code(throw_statement).split() if not is_operator(t)])
287 | throw_statement_subtokens.extend([t for t in subtokenize_code(throw_statement).split() if not is_operator(t)])
288 |
289 | argument_name_tokens = []
290 | argument_name_subtokens = []
291 | argument_type_tokens = []
292 | argument_type_subtokens = []
293 |
294 | for argument_name in argument_names:
295 | argument_name_tokens.extend([t for t in tokenize_clean_code(argument_name).split() if not is_operator(t)])
296 | argument_name_subtokens.extend([t for t in subtokenize_code(argument_name).split() if not is_operator(t)])
297 |
298 | for argument_type in argument_types:
299 | argument_type_tokens.extend([t for t in tokenize_clean_code(argument_type).split() if not is_operator(t)])
300 | argument_type_subtokens.extend([t for t in subtokenize_code(argument_type).split() if not is_operator(t)])
301 |
302 | return_statement_tokens = []
303 | return_statement_subtokens = []
304 | for return_statement in return_statements:
305 | return_statement_tokens.extend([t for t in tokenize_clean_code(return_statement).split() if not is_operator(t)])
306 | return_statement_subtokens.extend([t for t in subtokenize_code(return_statement).split() if not is_operator(t)])
307 |
308 | return_type_tokens = [t for t in tokenize_clean_code(return_type).split() if not is_operator(t)]
309 | return_type_subtokens = [t for t in subtokenize_code(return_type).split() if not is_operator(t)]
310 |
311 | method_name = extract_method_name(code_block)
312 | method_name_tokens = [method_name]
313 | method_name_subtokens = subtokenize_code(method_name).split()
314 |
315 | token_elements = {
316 | 'argument_name': argument_name_tokens,
317 | 'argument_type': argument_type_tokens,
318 | 'return_type': return_type_tokens,
319 | 'return_statement': return_statement_tokens,
320 | 'throwable_exception': throwable_exception_tokens,
321 | 'throw_statement': throw_statement_tokens,
322 | 'method_name': method_name_tokens
323 | }
324 |
325 | subtoken_elements = {
326 | 'argument_name': argument_name_subtokens,
327 | 'argument_type': argument_type_subtokens,
328 | 'return_type': return_type_subtokens,
329 | 'return_statement': return_statement_subtokens,
330 | 'throwable_exception': throwable_exception_subtokens,
331 | 'throw_statement': throw_statement_subtokens,
332 | 'method_name': method_name_subtokens
333 | }
334 |
335 | return {
336 | 'token': token_elements,
337 | 'subtoken': subtoken_elements
338 | }
339 |
340 | if __name__ == "__main__":
341 | # Demo for extracting high level features for one example
342 | # Corresponds to what is written in high_level_features.json files
343 |
344 | ex = build_test_example()
345 | cache = dict()
346 | cache[ex.id] = {
347 | 'old': get_method_elements(ex.old_code_raw.split('\n')),
348 | 'new': get_method_elements(ex.new_code_raw.split('\n')),
349 | 'code_change_labels': {'subtoken': get_change_labels(ex.token_diff_code_subtokens)}
350 | }
--------------------------------------------------------------------------------
/data_processing/tokenization_feature_extractor.py:
--------------------------------------------------------------------------------
1 | import difflib
2 | import javalang
3 | import json
4 | import os
5 | import re
6 | import sys
7 |
8 | from build_example import build_test_example
9 | from data_formatting_utils import subtokenize_code, tokenize_clean_code, get_clean_code,\
10 | subtokenize_comment, tokenize_comment
11 |
12 | sys.path.append('../')
13 | from diff_utils import is_edit_keyword, KEEP, KEEP_END, REPLACE_OLD, REPLACE_NEW,\
14 | REPLACE_END, INSERT, INSERT_END, DELETE, DELETE_END, compute_code_diffs
15 |
16 | def subtokenize_token(token, parse_comment=False):
17 | if parse_comment and token in ['@return', '@param', '@throws']:
18 | return [token]
19 | if is_edit_keyword(token):
20 | return [token]
21 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split()
22 |
23 | try:
24 | new_curr = []
25 | for t in curr:
26 | new_curr.extend([c for c in re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", t.encode('ascii', errors='ignore').decode().strip()) if len(c) > 0])
27 | curr = new_curr
28 | except:
29 | pass
30 | try:
31 | new_curr = []
32 | for c in curr:
33 | by_symbol = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", c.strip())
34 | new_curr = new_curr + by_symbol
35 |
36 | curr = new_curr
37 | except:
38 | curr = []
39 | subtokens = [c.lower() for c in curr]
40 |
41 | return subtokens
42 |
43 | def get_subtoken_labels(gold_subtokens, tokens, parse_comment=False):
44 | labels = []
45 | indices = []
46 | all_subtokens = []
47 |
48 | token_map = []
49 | subtoken_map = []
50 |
51 | gold_idx = 0
52 |
53 | for token in tokens:
54 | subtokens = subtokenize_token(token, parse_comment)
55 | all_subtokens.extend(subtokens)
56 | token_map.append(subtokens)
57 | if len(subtokens) == 1:
58 | label = 0
59 | labels.append(label)
60 | indices.append(0)
61 | subtoken_map.append([token])
62 | else:
63 | label = 1
64 | for s, subtoken in enumerate(subtokens):
65 | labels.append(label)
66 | indices.append(s)
67 | subtoken_map.append([token])
68 | try:
69 | assert len(labels) == len(gold_subtokens)
70 | assert len(indices) == len(gold_subtokens)
71 | assert len(token_map) == len(tokens)
72 | assert len(subtoken_map) == len(gold_subtokens)
73 | except:
74 | print(tokens)
75 | print('\n')
76 | print(gold_subtokens)
77 | print('\n')
78 | for s, subtoken in enumerate(all_subtokens):
79 | print('Parsed: {}'.format(subtoken))
80 | print('True: {}'.format(gold_subtokens[s]))
81 | print('---------------------------------')
82 | if subtoken != gold_subtokens[s]:
83 | break
84 | print(len(labels))
85 | print(len(gold_subtokens))
86 | raise ValueError('stop')
87 | return labels, indices, token_map, subtoken_map
88 |
89 | def get_code_subtoken_labels(gold_subtokens, tokens, raw_code):
90 | labels = []
91 | indices = []
92 | all_subtokens = []
93 |
94 | token_map = []
95 | subtoken_map = []
96 |
97 | for token in tokens:
98 | if is_edit_keyword(token):
99 | token_map.append([token])
100 | else:
101 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split()
102 | new_curr = []
103 | for c in curr:
104 | by_symbol = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", c.strip())
105 | new_curr = new_curr + by_symbol
106 | token_map.append([s.lower() for s in new_curr])
107 |
108 | try:
109 | parsed_tokens = get_clean_code(list(javalang.tokenizer.tokenize(raw_code)))
110 | except:
111 | parsed_tokens = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", raw_code.strip())
112 |
113 | subtokens = []
114 | for t, token in enumerate(parsed_tokens):
115 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split()
116 | subtokens = [c.lower() for c in curr]
117 | all_subtokens.extend(subtokens)
118 | if len(subtokens) == 1:
119 | label = 0
120 | labels.append(label)
121 | indices.append(0)
122 | subtoken_map.append([token])
123 | else:
124 | label = 1
125 | for s, subtoken in enumerate(subtokens):
126 | labels.append(label)
127 | indices.append(s)
128 | subtoken_map.append([token])
129 | try:
130 | assert len(labels) == len(gold_subtokens)
131 | assert len(indices) == len(gold_subtokens)
132 | assert len(token_map) == len(tokens)
133 | assert len(subtoken_map) == len(gold_subtokens)
134 | except:
135 | print(tokens)
136 | print('\n')
137 | print(gold_subtokens)
138 | print('\n')
139 | for s, subtoken in enumerate(all_subtokens):
140 | print('Parsed: {}'.format(subtoken))
141 | print('True: {}'.format(gold_subtokens[s]))
142 | print('---------------------------------')
143 | if subtoken != gold_subtokens[s]:
144 | break
145 | print(len(labels))
146 | print(len(gold_subtokens))
147 | raise ValueError('stop')
148 | return labels, indices, token_map, subtoken_map
149 |
150 | def get_diff_subtoken_labels(diff_subtokens, old_subtokens, old_tokens, new_subtokens, new_tokens, diff_tokens, old_code_raw, new_code_raw):
151 | old_labels, old_indices, old_token_map, old_subtoken_map = get_code_subtoken_labels(old_subtokens, old_tokens, old_code_raw)
152 | new_labels, new_indices, new_token_map, new_subtoken_map = get_code_subtoken_labels(new_subtokens, new_tokens, new_code_raw)
153 |
154 | diff_labels = []
155 | diff_indices = []
156 |
157 | diff_token_map = []
158 | diff_subtoken_map = []
159 |
160 | for token in diff_tokens:
161 | if is_edit_keyword(token):
162 | diff_token_map.append([token])
163 | else:
164 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split()
165 | new_curr = []
166 | for c in curr:
167 | by_symbol = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", c.strip())
168 | new_curr = new_curr + by_symbol
169 | diff_token_map.append([s.lower() for s in new_curr])
170 |
171 | for edit_type, o_start, o_end, n_start, n_end in difflib.SequenceMatcher(None, old_subtokens, new_subtokens).get_opcodes():
172 | if edit_type == 'equal':
173 | diff_labels.extend([0] + old_labels[o_start:o_end] + [0])
174 | diff_indices.extend([0] + old_indices[o_start:o_end] + [0])
175 | diff_subtoken_map.append([KEEP])
176 | diff_subtoken_map.extend(old_subtoken_map[o_start:o_end])
177 | diff_subtoken_map.append([KEEP_END])
178 | elif edit_type == 'replace':
179 | diff_labels.extend([0] + old_labels[o_start:o_end] + [0] + new_labels[n_start:n_end] + [0])
180 | diff_indices.extend([0] + old_indices[o_start:o_end] + [0] + new_indices[n_start:n_end] + [0])
181 | diff_subtoken_map.append([REPLACE_OLD])
182 | diff_subtoken_map.extend(old_subtoken_map[o_start:o_end])
183 | diff_subtoken_map.append([REPLACE_NEW])
184 | diff_subtoken_map.extend(new_subtoken_map[n_start:n_end])
185 | diff_subtoken_map.append([REPLACE_END])
186 | elif edit_type == 'insert':
187 | diff_labels.extend([0] + new_labels[n_start:n_end] + [0])
188 | diff_indices.extend([0] + new_indices[n_start:n_end] + [0])
189 | diff_subtoken_map.append([INSERT])
190 | diff_subtoken_map.extend(new_subtoken_map[n_start:n_end])
191 | diff_subtoken_map.append([INSERT_END])
192 | else:
193 | diff_labels.extend([0] + old_labels[o_start:o_end] + [0])
194 | diff_indices.extend([0] + old_indices[o_start:o_end] + [0])
195 | diff_subtoken_map.append([DELETE])
196 | diff_subtoken_map.extend(old_subtoken_map[o_start:o_end])
197 | diff_subtoken_map.append([DELETE_END])
198 |
199 | assert len(diff_labels) == len(diff_subtokens)
200 | assert len(diff_indices) == len(diff_subtokens)
201 | assert len(diff_subtoken_map) == len(diff_subtokens)
202 | assert len(diff_token_map) == len(diff_tokens)
203 | return diff_labels, diff_indices, diff_token_map, diff_subtoken_map
204 |
205 | if __name__ == "__main__":
206 | # Demo for extracting tokenization features for one example
207 | # Corresponds to what is written in tokenization_features.json files
208 | ex = build_test_example()
209 |
210 | old_code_tokens = tokenize_clean_code(ex.old_code_raw).split()
211 | new_code_tokens = tokenize_clean_code(ex.new_code_raw).split()
212 | span_diff_code_tokens, _, _ = compute_code_diffs(old_code_tokens, new_code_tokens)
213 |
214 | edit_span_subtoken_labels, edit_span_subtoken_indices, edit_span_token_map, edit_span_subtoken_map = get_diff_subtoken_labels(
215 | ex.span_diff_code_subtokens, ex.old_code_subtokens, old_code_tokens, ex.new_code_subtokens, new_code_tokens,
216 | span_diff_code_tokens, ex.old_code_raw, ex.new_code_raw)
217 |
218 | old_comment_tokens = tokenize_comment(ex.old_comment_raw).split()
219 |
220 | prefix = []
221 | if ex.comment_type == 'Return':
222 | prefix = ['@return']
223 | elif ex.comment_type == 'Param':
224 | prefix = ['@param']
225 |
226 | old_nl_subtoken_labels, old_nl_subtoken_indices, old_nl_token_map, old_nl_subtoken_map = get_subtoken_labels(
227 | prefix + ex.old_comment_subtokens, prefix + old_comment_tokens, parse_comment=True)
228 |
229 | cache = dict()
230 | cache[ex.id] = {
231 | 'old_nl_subtoken_labels': old_nl_subtoken_labels,
232 | 'old_nl_subtoken_indices': old_nl_subtoken_indices,
233 | 'edit_span_subtoken_labels': edit_span_subtoken_labels,
234 | 'edit_span_subtoken_indices': edit_span_subtoken_indices,
235 | 'old_nl_token_map': old_nl_token_map,
236 | 'old_nl_subtoken_map': old_nl_subtoken_map,
237 | 'edit_span_token_map': edit_span_token_map,
238 | 'edit_span_subtoken_map': edit_span_subtoken_map
239 | }
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import enum
2 | from enum import Enum
3 | import json
4 | import numpy as np
5 | import re
6 | import torch
7 | from typing import List, NamedTuple
8 |
9 | from external_cache import get_node_features
10 |
11 | @enum.unique
12 | class CommentCategory(Enum):
13 | Return = 0
14 | Param = 1
15 | Summary = 2
16 |
17 | @enum.unique
18 | class DiffEdgeType(Enum):
19 | PARENT = 0
20 | CHILD = 1
21 | SUBTOKEN_CHILD = 2
22 | SUBTOKEN_PARENT = 3
23 | PREV_SUBTOKEN = 4
24 | NEXT_SUBTOKEN = 5
25 | ALIGNED_NEIGHBOR = 6
26 |
27 | @enum.unique
28 | class SrcType(Enum):
29 | KEEP = 0
30 | INSERT = 1
31 | DELETE = 2
32 | REPLACE_OLD = 3
33 | REPLACE_NEW = 4
34 | MOVE = 5
35 |
36 | class DiffTreeNode:
37 | def __init__(self, value, attribute, src, is_leaf):
38 | self.value = value
39 | self.node_id = -1
40 | self.parents = []
41 | self.attribute = attribute
42 | self.src = src
43 | self.is_leaf = is_leaf
44 | self.children = []
45 | self.prev_siblings = []
46 | self.next_siblings = []
47 | self.aligned_neighbors = []
48 | self.action_type = None
49 | self.prev_tokens = []
50 | self.next_tokens = []
51 | self.subtokens = []
52 |
53 | self.subtoken_children = []
54 | self.subtoken_parents = []
55 | self.prev_subtokens = []
56 | self.next_subtokens = []
57 |
58 | def to_json(self):
59 | return {
60 | 'value': self.value,
61 | 'node_id': self.node_id,
62 | 'parent_ids': [p.node_id for p in self.parents],
63 | 'attribute': self.attribute,
64 | 'src': self.src,
65 | 'is_leaf': self.is_leaf,
66 | 'children_ids': [c.node_id for c in self.children],
67 | 'prev_sibling_ids': [p.node_id for p in self.prev_siblings],
68 | 'next_sibling_ids': [n.node_id for n in self.next_siblings],
69 | 'aligned_neighbor_ids': [n.node_id for n in self.aligned_neighbors],
70 | 'action_type': self.action_type,
71 | }
72 |
73 | @property
74 | def is_identifier(self):
75 | return self.is_leaf and self.attribute == 'SimpleName'
76 |
77 | class DiffAST:
78 | def __init__(self, ast_root):
79 | self.node_cache = set()
80 | self.root = ast_root
81 | self.nodes = []
82 | self.traverse(self.root)
83 |
84 | def traverse(self, curr_node):
85 | if curr_node not in self.node_cache:
86 | self.node_cache.add(curr_node)
87 | curr_node.node_id = len(self.nodes)
88 | self.nodes.append(curr_node)
89 | for child in curr_node.subtoken_children:
90 | self.traverse(child)
91 | for child in curr_node.children:
92 | self.traverse(child)
93 |
94 | def to_json(self):
95 | return [n.to_json() for n in self.nodes]
96 |
97 | @property
98 | def leaves(self):
99 | return [n for n in self.nodes if n.is_leaf]
100 |
101 | @classmethod
102 | def from_json(cls, obj):
103 | nodes = []
104 | for node_obj in obj:
105 | node = DiffTreeNode(node_obj['value'], node_obj['attribute'], node_obj['src'], False)
106 | if 'action_type' in node_obj:
107 | node.action_type = node_obj['action_type']
108 | nodes.append(node)
109 |
110 | new_nodes = []
111 |
112 | for n, node_obj in enumerate(obj):
113 | nodes[n].parents = [nodes[i] for i in node_obj['parent_ids']]
114 | nodes[n].children = [nodes[i] for i in node_obj['children_ids']]
115 | nodes[n].prev_siblings = [nodes[i] for i in node_obj['prev_sibling_ids']]
116 | nodes[n].next_siblings = [nodes[i] for i in node_obj['next_sibling_ids']]
117 | nodes[n].aligned_neighbors = [nodes[i] for i in node_obj['aligned_neighbor_ids']]
118 | new_nodes.append(nodes[n])
119 |
120 | if len(nodes[n].children) == 0:
121 | nodes[n].is_leaf = True
122 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', nodes[n].value).split()
123 | new_curr = []
124 | for c in curr:
125 | by_symbol = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", c.strip())
126 | new_curr = new_curr + by_symbol
127 | nodes[n].subtokens = [s.lower() for s in new_curr]
128 |
129 | if len(nodes[n].subtokens) > 1:
130 | for s in nodes[n].subtokens:
131 | sub_node = DiffTreeNode(s, '', nodes[n].src, True)
132 | sub_node.action_type = nodes[n].action_type
133 | sub_node.subtoken_parents.append(nodes[n])
134 |
135 | if len(nodes[n].subtoken_children) > 0:
136 | nodes[n].subtoken_children[-1].next_subtokens.append(sub_node)
137 | sub_node.prev_subtokens.append(nodes[n].subtoken_children[-1])
138 |
139 | nodes[n].subtoken_children.append(sub_node)
140 | new_nodes.append(sub_node)
141 |
142 | nodes[n].value = nodes[n].value.lower()
143 |
144 | return cls(new_nodes[0])
145 |
146 | def insert_graph(batch, ex, ast, vocabulary, use_features, max_ast_length):
147 | batch.root_ids.append(batch.num_nodes)
148 | graph_node_positions = []
149 | for n, node in enumerate(ast.nodes):
150 | batch.graph_ids.append(batch.num_graphs)
151 | batch.is_internal.append(not node.is_leaf)
152 | batch.value_lookup_ids.append(vocabulary.get_id_or_unk(node.value))
153 |
154 | if node.action_type == 'Insert':
155 | src_type = SrcType.INSERT
156 | elif node.action_type == 'Delete':
157 | src_type = SrcType.DELETE
158 | elif node.action_type == 'Move':
159 | src_type = SrcType.MOVE
160 | elif node.src == 'old' and node.action_type == 'Update':
161 | src_type = SrcType.REPLACE_OLD
162 | elif node.src == 'new' and node.action_type == 'Update':
163 | src_type = SrcType.REPLACE_NEW
164 | else:
165 | src_type = SrcType.KEEP
166 |
167 | batch.src_type_ids.append(src_type.value)
168 | graph_node_positions.append(batch.num_nodes + node.node_id)
169 |
170 | for parent in node.parents:
171 | if parent.node_id < len(ast.nodes):
172 | batch.edges[DiffEdgeType.PARENT.value].append(
173 | (batch.num_nodes + node.node_id, batch.num_nodes + parent.node_id))
174 |
175 | for child in node.children:
176 | if child.node_id < len(ast.nodes):
177 | batch.edges[DiffEdgeType.CHILD.value].append(
178 | (batch.num_nodes + node.node_id, batch.num_nodes + child.node_id))
179 |
180 | for subtoken_parent in node.subtoken_parents:
181 | if subtoken_parent.node_id < len(ast.nodes):
182 | batch.edges[DiffEdgeType.SUBTOKEN_PARENT.value].append(
183 | (batch.num_nodes + node.node_id, batch.num_nodes + subtoken_parent.node_id))
184 |
185 | for subtoken_child in node.subtoken_children:
186 | if subtoken_child.node_id < len(ast.nodes):
187 | batch.edges[DiffEdgeType.SUBTOKEN_CHILD.value].append(
188 | (batch.num_nodes + node.node_id, batch.num_nodes + subtoken_child.node_id))
189 |
190 | for next_subtoken in node.next_subtokens:
191 | if next_subtoken.node_id < len(ast.nodes):
192 | batch.edges[DiffEdgeType.NEXT_SUBTOKEN.value].append(
193 | (batch.num_nodes + node.node_id, batch.num_nodes + next_subtoken.node_id))
194 |
195 | for prev_subtoken in node.prev_subtokens:
196 | if prev_subtoken.node_id < len(ast.nodes):
197 | batch.edges[DiffEdgeType.PREV_SUBTOKEN.value].append(
198 | (batch.num_nodes + node.node_id, batch.num_nodes + prev_subtoken.node_id))
199 |
200 | if len(batch.edges) == len(DiffEdgeType):
201 | for aligned_neighbor in node.aligned_neighbors:
202 | if aligned_neighbor.node_id < len(ast.nodes):
203 | batch.edges[DiffEdgeType.ALIGNED_NEIGHBOR.value].append(
204 | (batch.num_nodes + node.node_id, batch.num_nodes + aligned_neighbor.node_id))
205 |
206 | if use_features:
207 | node_features = get_node_features(ast.nodes, ex, max_ast_length)
208 | batch.node_features.extend(node_features)
209 |
210 | batch.node_positions.append(graph_node_positions)
211 | batch.num_nodes_per_graph.append(len(ast.nodes))
212 | batch.num_nodes += len(ast.nodes)
213 | batch.num_graphs += 1
214 | return batch
215 |
216 |
217 | class GraphMethodBatch:
218 | def __init__(self, graph_ids, value_lookup_ids, src_type_ids, root_ids, is_internal,
219 | edges, num_graphs, num_nodes, node_features, node_positions, num_nodes_per_graph):
220 | self.graph_ids = graph_ids
221 | self.value_lookup_ids = value_lookup_ids
222 | self.src_type_ids = src_type_ids
223 | self.root_ids = root_ids
224 | self.is_internal = is_internal
225 | self.edges = edges
226 | self.num_graphs = num_graphs
227 | self.num_nodes = num_nodes
228 | self.node_features = node_features
229 | self.node_positions = node_positions
230 | self.num_nodes_per_graph = num_nodes_per_graph
231 |
232 | def initialize_graph_method_batch(num_edges):
233 | return GraphMethodBatch(
234 | graph_ids = [],
235 | value_lookup_ids = [],
236 | src_type_ids = [],
237 | root_ids = [],
238 | is_internal = [],
239 | edges = [[] for _ in range(num_edges)],
240 | num_graphs = 0,
241 | num_nodes = 0,
242 | node_features = [],
243 | node_positions = [],
244 | num_nodes_per_graph = []
245 | )
246 |
247 | def tensorize_graph_method_batch(batch, device, max_num_nodes_per_graph):
248 | node_positions = np.zeros([batch.num_graphs, max_num_nodes_per_graph], dtype=np.int64)
249 | for g in range(batch.num_graphs):
250 | graph_node_positions = batch.node_positions[g]
251 | node_positions[g,:len(graph_node_positions)] = graph_node_positions
252 | node_positions[g,len(graph_node_positions):] = batch.root_ids[g]
253 |
254 | return GraphMethodBatch(
255 | torch.tensor(batch.graph_ids, dtype=torch.int64, device=device),
256 | torch.tensor(batch.value_lookup_ids, dtype=torch.int64, device=device),
257 | torch.tensor(batch.src_type_ids, dtype=torch.int64, device=device),
258 | torch.tensor(batch.root_ids, dtype=torch.int64, device=device),
259 | torch.tensor(batch.is_internal, dtype=torch.uint8, device=device),
260 | batch.edges, batch.num_graphs, batch.num_nodes,
261 | torch.tensor(batch.node_features, dtype=torch.float32, device=device),
262 | torch.tensor(node_positions, dtype=torch.int64, device=device),
263 | torch.tensor(batch.num_nodes_per_graph, dtype=torch.int64, device=device))
264 |
265 | class GenerationBatchData(NamedTuple):
266 | """Stores tensorized batch used in generation model."""
267 | code_ids: torch.Tensor
268 | code_lengths: torch.Tensor
269 | trg_nl_ids: torch.Tensor
270 | trg_extended_nl_ids: torch.Tensor
271 | trg_nl_lengths: torch.Tensor
272 | invalid_copy_positions: torch.Tensor
273 | input_str_reps: List[List[str]]
274 | input_ids: List[List[str]]
275 |
276 | class UpdateBatchData(NamedTuple):
277 | """Stores tensorized batch used in edit model."""
278 | code_ids: torch.Tensor
279 | code_lengths: torch.Tensor
280 | old_nl_ids: torch.Tensor
281 | old_nl_lengths: torch.Tensor
282 | trg_nl_ids: torch.Tensor
283 | trg_extended_nl_ids: torch.Tensor
284 | trg_nl_lengths: torch.Tensor
285 | invalid_copy_positions: torch.Tensor
286 | input_str_reps: List[List[str]]
287 | input_ids: List[List[str]]
288 | code_features: torch.Tensor
289 | nl_features: torch.Tensor
290 | labels: torch.Tensor
291 | graph_batch: GraphMethodBatch
292 |
293 | class EncoderOutputs(NamedTuple):
294 | """Stores tensorized batch used in edit model."""
295 | encoder_hidden_states: torch.Tensor
296 | masks: torch.Tensor
297 | encoder_final_state: torch.Tensor
298 | code_hidden_states: torch.Tensor
299 | code_masks: torch.Tensor
300 | old_nl_hidden_states: torch.Tensor
301 | old_nl_masks: torch.Tensor
302 | old_nl_final_state: torch.Tensor
303 | attended_old_nl_final_state: torch.Tensor
304 |
305 | class Example(NamedTuple):
306 | """Data format for examples used in generation model."""
307 | id: str
308 | old_comment: str
309 | old_comment_tokens: List[str]
310 | new_comment: str
311 | new_comment_tokens: List[str]
312 | old_code: str
313 | old_code_tokens: List[str]
314 | new_code: str
315 | new_code_tokens: List[str]
316 |
317 | class DiffExample(NamedTuple):
318 | id: str
319 | label: int
320 | comment_type: str
321 | old_comment_raw: str
322 | old_comment_subtokens: List[str]
323 | new_comment_raw: str
324 | new_comment_subtokens: List[str]
325 | span_minimal_diff_comment_subtokens: List[str]
326 | old_code_raw: str
327 | old_code_subtokens: List[str]
328 | new_code_raw: str
329 | new_code_subtokens: List[str]
330 | span_diff_code_subtokens: List[str]
331 | token_diff_code_subtokens: List[str]
332 |
333 | class DiffASTExample(NamedTuple):
334 | id: str
335 | label: int
336 | comment_type: str
337 | old_comment_raw: str
338 | old_comment_subtokens: List[str]
339 | new_comment_raw: str
340 | new_comment_subtokens: List[str]
341 | span_minimal_diff_comment_subtokens: List[str]
342 | old_code_raw: str
343 | old_code_subtokens: List[str]
344 | new_code_raw: str
345 | new_code_subtokens: List[str]
346 | span_diff_code_subtokens: List[str]
347 | token_diff_code_subtokens: List[str]
348 | old_ast: DiffAST
349 | new_ast: DiffAST
350 | diff_ast: DiffAST
351 |
352 | def get_processed_comment_sequence(comment_subtokens):
353 | """Returns sequence without tag string. Tag strings are excluded for evaluation purposes."""
354 | if len(comment_subtokens) > 0 and comment_subtokens[0] in ['@param', '@return']:
355 | return comment_subtokens[1:]
356 |
357 | return comment_subtokens
358 |
359 | def get_processed_comment_str(comment_subtokens):
360 | """Returns string without tag string. Tag strings are excluded for evaluation purposes."""
361 | return ' '.join(get_processed_comment_sequence(comment_subtokens))
362 |
363 | def read_full_examples_from_file(filename):
364 | """Reads in data in the format used for generation model."""
365 | with open(filename) as f:
366 | data = json.load(f)
367 | return [Example(**d) for d in data]
--------------------------------------------------------------------------------
/detection_evaluation_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from sklearn.metrics import precision_recall_fscore_support
4 |
5 | def compute_average(values):
6 | return sum(values)/float(len(values))
7 |
8 | def compute_score(predicted_labels, gold_labels, verbose=True):
9 | true_positives = 0.0
10 | true_negatives = 0.0
11 | false_positives = 0.0
12 | false_negatives = 0.0
13 |
14 | assert(len(predicted_labels) == len(gold_labels))
15 |
16 | for i in range(len(gold_labels)):
17 | if gold_labels[i]:
18 | if predicted_labels[i]:
19 | true_positives += 1
20 | else:
21 | false_negatives += 1
22 | else:
23 | if predicted_labels[i]:
24 | false_positives += 1
25 | else:
26 | true_negatives += 1
27 |
28 | if verbose:
29 | print('True positives: {}'.format(true_positives))
30 | print('False positives: {}'.format(false_positives))
31 | print('True negatives: {}'.format(true_negatives))
32 | print('False negatives: {}'.format(false_negatives))
33 |
34 | try:
35 | precision = true_positives/(true_positives + false_positives)
36 | except:
37 | precision = 0.0
38 | try:
39 | recall = true_positives/(true_positives + false_negatives)
40 | except:
41 | recall = 0.0
42 | try:
43 | f1 = 2*((precision * recall)/(precision + recall))
44 | except:
45 | f1 = 0.0
46 | return precision, recall, f1
--------------------------------------------------------------------------------
/detection_module.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import Counter
3 | import numpy as np
4 | import os
5 | import random
6 | import sys
7 | import torch
8 | from torch import nn
9 |
10 | from constants import *
11 | from detection_evaluation_utils import compute_score
12 |
13 |
14 | class DetectionModule(nn.Module):
15 | """Binary classification model for detecting inconsistent comments."""
16 | def __init__(self, model_path, manager):
17 | super(DetectionModule, self).__init__()
18 |
19 | self.model_path = model_path
20 | self.manager = manager
21 | feature_input_dimension = self.manager.out_dim
22 |
23 | self.output_layer = nn.Linear(feature_input_dimension, NUM_CLASSES)
24 | self.optimizer = torch.optim.Adam(self.parameters(), lr=LR)
25 |
26 | def get_logprobs(self, encoder_outputs):
27 | """Computes the class-level log probabilities corresponding to the examples in the batch."""
28 | logits = self.output_layer(encoder_outputs.attended_old_nl_final_state)
29 | return torch.nn.functional.log_softmax(logits, dim=-1)
30 |
31 | def compute_detection_loss(self, encoder_outputs, batch_data):
32 | """Computes the negative log likelihood loss against the gold labels corresponding to the examples in the batch."""
33 | logprobs = self.get_logprobs(encoder_outputs)
34 | return torch.nn.functional.nll_loss(logprobs, batch_data.labels), logprobs
35 |
36 | def forward(self, batch_data):
37 | """Computes prediction loss for given batch."""
38 | encoder_outputs = self.manager.get_encoder_output(batch_data, self.get_device())
39 | loss, logprobs = self.compute_detection_loss(encoder_outputs, batch_data)
40 | return loss, logprobs
41 |
42 | def run_train(self, train_examples, valid_examples):
43 | """Runs training over the entire training set across several epochs. Following each epoch,
44 | F1 on the validation data is computed. If the validation F1 has improved, save the model.
45 | Early-stopping is employed to stop training if validation hasn't improved for a certain number
46 | of epochs."""
47 | valid_batches = self.manager.get_batches(valid_examples, self.get_device())
48 | best_loss = float('inf')
49 | best_f1 = 0.0
50 | patience_tally = 0
51 |
52 | for epoch in range(MAX_EPOCHS):
53 | if patience_tally > PATIENCE:
54 | print('Terminating: {}'.format(epoch))
55 | break
56 |
57 | self.train()
58 | train_batches = self.manager.get_batches(train_examples, self.get_device(), shuffle=True)
59 |
60 | train_loss = 0
61 | for batch_data in train_batches:
62 | train_loss += self.run_gradient_step(batch_data)
63 |
64 | self.eval()
65 | validation_loss = 0
66 | validation_predicted_labels = []
67 | validation_gold_labels = []
68 | with torch.no_grad():
69 | for batch_data in valid_batches:
70 | b_loss, b_logprobs = self.forward(batch_data)
71 | validation_loss += float(b_loss.cpu())
72 | validation_predicted_labels.extend(b_logprobs.argmax(-1).tolist())
73 | validation_gold_labels.extend(batch_data.labels.tolist())
74 |
75 | validation_loss = validation_loss/len(valid_batches)
76 | validation_precision, validation_recall, validation_f1 = compute_score(
77 | validation_predicted_labels, validation_gold_labels, verbose=False)
78 |
79 | if validation_f1 >= best_f1:
80 | best_f1 = validation_f1
81 | torch.save(self, self.model_path)
82 | saved = True
83 | patience_tally = 0
84 | else:
85 | saved = False
86 | patience_tally += 1
87 |
88 | print('Epoch: {}'.format(epoch))
89 | print('Training loss: {:.3f}'.format(train_loss/len(train_batches)))
90 | print('Validation loss: {:.3f}'.format(validation_loss))
91 | print('Validation precision: {:.3f}'.format(validation_precision))
92 | print('Validation recall: {:.3f}'.format(validation_recall))
93 | print('Validation f1: {:.3f}'.format(validation_f1))
94 | if saved:
95 | print('Saved')
96 | print('-----------------------------------')
97 | sys.stdout.flush()
98 |
99 | def get_device(self):
100 | """Returns the proper device."""
101 | if self.torch_device_name == 'gpu':
102 | return torch.device('cuda')
103 | else:
104 | return torch.device('cpu')
105 |
106 | def run_gradient_step(self, batch_data):
107 | """Performs gradient step."""
108 | self.optimizer.zero_grad()
109 | loss, _ = self.forward(batch_data)
110 | loss.backward()
111 | self.optimizer.step()
112 | return float(loss.cpu())
113 |
114 | def run_evaluation(self, test_examples, model_name):
115 | """Predicts labels for all comments in the test set and computes evaluation metrics."""
116 | self.eval()
117 |
118 | test_batches = self.manager.get_batches(test_examples, self.get_device())
119 | test_predictions = []
120 |
121 | with torch.no_grad():
122 | for b, batch in enumerate(test_batches):
123 | print('Testing batch {}/{}'.format(b, len(test_batches)))
124 | sys.stdout.flush()
125 | encoder_outputs = self.manager.get_encoder_output(batch, self.get_device())
126 | batch_logprobs = self.get_logprobs(encoder_outputs)
127 | test_predictions.extend(batch_logprobs.argmax(dim=-1).tolist())
128 |
129 | self.compute_metrics(test_predictions, test_examples, model_name)
130 |
131 | def compute_metrics(self, predicted_labels, test_examples, model_name):
132 | """Computes evaluation metrics."""
133 | gold_labels = []
134 | correct = 0
135 | for e, ex in enumerate(test_examples):
136 | if ex.label == predicted_labels[e]:
137 | correct += 1
138 | gold_labels.append(ex.label)
139 |
140 | accuracy = float(correct)/len(test_examples)
141 | precision, recall, f1 = compute_score(predicted_labels, gold_labels)
142 |
143 | print('Precision: {}'.format(precision))
144 | print('Recall: {}'.format(recall))
145 | print('F1: {}'.format(f1))
146 | print('Accuracy: {}\n'.format(accuracy))
147 |
148 | write_file = os.path.join(DETECTION_DIR, '{}_detection.txt'.format(model_name))
149 | with open(write_file, 'w+') as f:
150 | for e, ex in enumerate(test_examples):
151 | f.write('{} {}\n'.format(ex.id, predicted_labels[e]))
152 |
--------------------------------------------------------------------------------
/display_scores.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | sys.path.append('comment_update')
6 | from data_loader import get_data_splits, load_cleaned_test_set
7 | from data_utils import get_processed_comment_str
8 | from detection_evaluation_utils import compute_score
9 | from update_evaluation_utils import write_predictions, compute_accuracy, compute_bleu,\
10 | compute_meteor, compute_sari, compute_gleu
11 |
12 | """Script for printing update or detection metrics for output, on full and clean test sets."""
13 |
14 | def load_predicted_detection_labels(filepath, selected_positions):
15 | with open(filepath) as f:
16 | lines = f.readlines()
17 |
18 | selected_labels = []
19 | for s in selected_positions:
20 | selected_labels.append(int(lines[s].strip().split()[-1]))
21 | return selected_labels
22 |
23 | def load_predicted_generation_sequences(filepath, selected_positions):
24 | with open(filepath) as f:
25 | lines = f.readlines()
26 |
27 | selected_sequences = []
28 | for s in selected_positions:
29 | selected_sequences.append(lines[s].strip())
30 | return selected_sequences
31 |
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('--detection_output_file', help='path to detection output file')
35 | parser.add_argument('--update_output_file', help='path to update output file')
36 | args = parser.parse_args()
37 |
38 | # NOTE: To evaluate the pretrained approach, detection_output_file and
39 | # update_output_file must be both specified. For all other approaches,
40 | # only one should be specified.
41 |
42 | _, _, test_examples, _ = get_data_splits(ignore_ast=True)
43 | positions = list(range(len(test_examples)))
44 |
45 | clean_ids = load_cleaned_test_set()
46 | clean_positions = []
47 | for e, example in enumerate(test_examples):
48 | if example.id in clean_ids:
49 | clean_positions.append(e)
50 | clean_test_examples = [test_examples[pos] for pos in clean_positions]
51 |
52 | eval_tuples = [(test_examples, positions, 'full'), (clean_test_examples, clean_positions, 'clean')]
53 |
54 | for (examples, indices, test_type) in eval_tuples:
55 | if args.detection_output_file:
56 | predicted_labels = load_predicted_detection_labels(args.detection_output_file, indices)
57 | gold_labels = [ex.label for ex in examples]
58 |
59 | precision, recall, f1 = compute_score(predicted_labels, gold_labels, verbose=False)
60 |
61 | num_correct = 0
62 | for p, p_label in enumerate(predicted_labels):
63 | if p_label == gold_labels[p]:
64 | num_correct += 1
65 |
66 | print('Detection Precision: {}'.format(precision))
67 | print('Detection Recall: {}'.format(recall))
68 | print('Detection F1: {}'.format(f1))
69 | print('Detection Accuracy: {}\n'.format(float(num_correct)/len(predicted_labels)))
70 |
71 | if args.update_output_file:
72 | update_strs = load_predicted_generation_sequences(args.update_output_file, indices)
73 |
74 | references = []
75 | pred_instances = []
76 | src_strs = []
77 | gold_strs = []
78 | pred_strs = []
79 |
80 | for i in range(len(examples)):
81 | src_str = get_processed_comment_str(examples[i].old_comment_subtokens)
82 | src_strs.append(src_str)
83 |
84 | gold_str = get_processed_comment_str(examples[i].new_comment_subtokens)
85 | gold_strs.append(gold_str)
86 | references.append([gold_str.split()])
87 |
88 | if args.detection_output_file and predicted_labels[i] == 0:
89 | pred_instances.append(src_str.split())
90 | pred_strs.append(src_str)
91 | else:
92 | pred_instances.append(update_strs[i].split())
93 | pred_strs.append(update_strs[i])
94 |
95 | prediction_file = os.path.join(os.getcwd(), 'pred.txt')
96 | src_file = os.path.join(os.getcwd(), 'src.txt')
97 | ref_file = os.path.join(os.getcwd(), 'ref.txt')
98 |
99 | write_predictions(pred_strs, prediction_file)
100 | write_predictions(src_strs, src_file)
101 | write_predictions(gold_strs, ref_file)
102 |
103 | predicted_accuracy = compute_accuracy(gold_strs, pred_strs)
104 | predicted_bleu = compute_bleu(references, pred_instances)
105 | predicted_meteor = compute_meteor(references, pred_instances)
106 | predicted_sari = compute_sari(examples, pred_instances)
107 | predicted_gleu = compute_gleu(examples, src_file, ref_file, prediction_file)
108 |
109 | print('Update Accuracy: {}'.format(predicted_accuracy))
110 | print('Update BLEU: {}'.format(predicted_bleu))
111 | print('Update Meteor: {}'.format(predicted_meteor))
112 | print('Update SARI: {}'.format(predicted_sari))
113 | print('Update GLEU: {}\n'.format(predicted_gleu))
114 |
115 | print('Test type: {}'.format(test_type))
116 | print('Detection file: {}'.format(args.detection_output_file))
117 | print('Update file: {}'.format(args.update_output_file))
118 | print('Total: {}'.format(len(examples)))
119 | print('--------------------------------------')
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class Encoder(nn.Module):
5 | def __init__(self, embedding_size, hidden_size, num_layers, dropout, bidirectional=True):
6 | super(Encoder, self).__init__()
7 | self.__rnn = nn.GRU(input_size=embedding_size,
8 | hidden_size=hidden_size,
9 | dropout=dropout,
10 | num_layers=num_layers,
11 | batch_first=True,
12 | bidirectional=bidirectional)
13 |
14 | def forward(self, src_embedded_tokens, src_lengths, device):
15 | encoder_hidden_states, _ = self.__rnn.forward(src_embedded_tokens)
16 | encoder_final_state = encoder_hidden_states[torch.arange(
17 | src_embedded_tokens.size()[0], dtype=torch.int64, device=device), src_lengths-1]
18 | # encoder_final_state, _ = torch.max(encoder_hidden_states, dim=1)
19 | return encoder_hidden_states, encoder_final_state
--------------------------------------------------------------------------------
/gleu/README.md:
--------------------------------------------------------------------------------
1 | # Ground Truth for Grammatical Error Correction Metrics
2 |
3 |
4 | This repository contains a python implementation of the GLEU metric
5 | (**G**eneral **L**anguage **E**valuation **U**nderstanding), which
6 | can be used for any monolingual "translation" task. It also contains
7 | human rankings of the CoNLL-14 Shared Task system output as well as
8 | scripts to evaluate the rankings to extract an absolute system
9 | ranking.
10 |
11 | These results were described in the ACL 2015 paper:
12 |
13 | > [*Ground Truth for Grammatical Error Correction Metrics*](http://www.aclweb.org/anthology/P/P15/P15-2097.pdf)
14 | by Courtney Napoles, Keisuke Sakaguchi, Joel Tetreault, and Matt Post
15 |
16 | Please cite this work when using this data or the GLEU metric.
17 |
18 | @InProceedings{napoles-EtAl:2015:ACL-IJCNLP,
19 | author = {Napoles, Courtney and Sakaguchi, Keisuke and Post, Matt and Tetreault, Joel},
20 | title = {Ground Truth for Grammatical Error Correction Metrics},
21 | booktitle = {Proceedings of the 53rd Annual Meeting of the Association for Computational Linguistics and the 7th International Joint Conference on Natural Language Processing (Volume 2: Short Papers)},
22 | month = {July},
23 | year = {2015},
24 | address = {Beijing, China},
25 | publisher = {Association for Computational Linguistics},
26 | pages = {588--593},
27 | url = {http://www.aclweb.org/anthology/P15-2097}
28 | }
29 |
30 | ---
31 |
32 | # GLEU Update
33 |
34 | As of May 2, 2016, we have identified a problem with the GLEU metric as the number of references increases.
35 | To resolve this issue, we made a minor adjustment to the metric so that it no longer has a tunable weight and is reliable using any number of reference sets.
36 | This update to GLEU is reflected in `scripts/compute_gleu` and `scripts/gleu.py`.
37 | The original GLEU scripts can be found in `scripts/original_gleu/`.
38 | We do not recommend using the original GLEU code. The new GLEU should be used instead.
39 |
40 | The changes to GLEU and updated results to our ACL 2015 paper are described in the eprint, [*GLEU Without Tuning*](http://arxiv.org/abs/1605.02592).
41 | The citation for the updated metric is
42 |
43 | @Article{napoles2016gleu,
44 | author = {Napoles, Courtney and Sakaguchi, Keisuke and Post, Matt and Tetreault, Joel},
45 | title = {{GLEU} Without Tuning},
46 | journal = {eprint arXiv:1605.02592 [cs.CL]},
47 | year = {2016},
48 | url = {http://arxiv.org/abs/1605.02592}
49 | }
50 |
51 | ---
52 |
53 | ## Instructions
54 |
55 | ### 1. Obtain the raw system output
56 |
57 | The rankings found in the gec-ranking-data correspond to the 12 system outputs
58 | from the CoNLL-14 Shared Task on Grammatical Error Correction, which can be
59 | downloaded from .
60 |
61 | Human judgments are located in `gec-ranking/data`.
62 |
63 | ### 2. Run TrueSkill
64 |
65 | To get the human rankings, run TrueSkill (which can be downloaded from
66 | ) on `all_judgments.csv`, following
67 | the instructions in the TrueSkill readme.
68 |
69 | ### 3. Calculate metric scores
70 |
71 | GLEU is included in `gec-ranking/scripts`. To obtain the GLEU scores for
72 | system output, run the following command:
73 |
74 | ```
75 | ./compute_gleu -s source_sentences -r reference [reference ...] \
76 | -o system_output [system_output ...] -n 4 -l 0.0
77 | ```
78 |
79 | where each file contains one sentence per line. GLEU can be run with multiple
80 | references. To get the GLEU scores of multiple outputs, include the path to
81 | each system output file. GLEU was developed using Python 2.7.
82 |
83 | I-measure scores were taken from Felice and Briscoe's 2015 NAACL paper,
84 | *Towards a standard evaluation method for grammatical error detection and
85 | correction*. The I-measure scorer can be downloaded from
86 | .
87 |
88 | M2 scores were calculated using the official scorer (3.2) of the CoNLL-2014 Shared Task ().
89 |
90 | ---
91 |
92 | ## Errata
93 |
94 | There was an error in the calculation of the GLEU denominator, which was corrected in the 10 March 2016 commit.
95 |
96 | ---
97 |
98 | Please contact Courtney Napoles (courtneyn[at]jhu[dot]edu) or Keisuke Sakaguchi (keisuke[at]cs[dot]jhu[dot]edu) with any questions.
99 |
100 | Last updated 10 May 2016
101 |
--------------------------------------------------------------------------------
/gleu/gleu_update_2016.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/panthap2/deep-jit-inconsistency-detection/dacf8513c155f35157eedc2bf630212bf815544c/gleu/gleu_update_2016.pdf
--------------------------------------------------------------------------------
/gleu/scripts/compute_gleu:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Courtney Napoles
4 | #
5 | # 21 June 2015
6 | # ##
7 | # compute_gleu
8 | #
9 | # This script calls gleu.py to calculate the GLEU score of a sentence, as
10 | # described in our ACL 2015 paper, Ground Truth for Grammatical Error
11 | # Correction Metrics by Courtney Napoles, Keisuke Sakaguchi, Matt Post,
12 | # and Joel Tetreault.
13 | #
14 | # For instructions on how to get the GLEU score, call "compute_gleu -h"
15 | #
16 | # Updated 2 May 2016: This is an updated version of GLEU that has been
17 | # modified to handle multiple references more fairly.
18 | #
19 | # This script was adapted from compute-bleu by Adam Lopez.
20 | #
21 |
22 | import argparse
23 | import sys
24 | import os
25 | from gleu import GLEU
26 | import scipy.stats
27 | import numpy as np
28 | import random
29 |
30 | def get_gleu_stats(scores) :
31 | mean = np.mean(scores)
32 | std = np.std(scores)
33 | ci = scipy.stats.norm.interval(0.95,loc=mean,scale=std)
34 | return ['%f'%mean,
35 | '%f'%std,
36 | '(%.3f,%.3f)'%(ci[0],ci[1])]
37 |
38 | if __name__ == '__main__' :
39 |
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument("-r", "--reference",
42 | help="Target language reference sentences. Multiple "
43 | "files for multiple references.",
44 | nargs="*",
45 | dest="reference",
46 | required=True)
47 | parser.add_argument("-s", "--source",
48 | help="Source language source sentences",
49 | dest="source",
50 | required=True)
51 | parser.add_argument("-o", "--hypothesis",
52 | help="Target language hypothesis sentences to evaluate "
53 | "(can be more than one file--the GLEU score of each "
54 | "file will be output separately). Use '-o -' to read "
55 | "hypotheses from stdin.",
56 | nargs="*",
57 | dest="hypothesis",
58 | required=True)
59 | parser.add_argument("-n",
60 | help="Maximum order of ngrams",
61 | type=int,
62 | default=4)
63 | parser.add_argument("-d","--debug",
64 | help="Debug; print sentence-level scores",
65 | default=False,
66 | action="store_true")
67 | parser.add_argument('--iter',
68 | type=int,
69 | default=500,
70 | help='the number of iterations to run')
71 |
72 | args = parser.parse_args()
73 |
74 | num_iterations = args.iter
75 |
76 | # if there is only one reference, just do one iteration
77 | if len(args.reference) == 1 :
78 | num_iterations = 1
79 |
80 | gleu_calculator = GLEU(args.n)
81 |
82 | gleu_calculator.load_sources(args.source)
83 | gleu_calculator.load_references(args.reference)
84 |
85 | for hpath in args.hypothesis :
86 | instream = sys.stdin if hpath == '-' else open(hpath)
87 | hyp = [line.split() for line in instream]
88 |
89 | if not args.debug :
90 | print os.path.basename(hpath),
91 |
92 | # first generate a random list of indices, using a different seed
93 | # for each iteration
94 | indices = []
95 | for j in range(num_iterations) :
96 | random.seed(j*101)
97 | indices.append([random.randint(0,len(args.reference)-1)
98 | for i in range(len(hyp))])
99 |
100 | if args.debug :
101 | print
102 | print '===== Sentence-level scores ====='
103 | print 'SID Mean Stdev 95%CI GLEU'
104 |
105 | iter_stats = [ [0 for i in xrange(2*args.n+2)]
106 | for j in range(num_iterations) ]
107 |
108 | for i,h in enumerate(hyp) :
109 |
110 | gleu_calculator.load_hypothesis_sentence(h)
111 | # we are going to store the score of this sentence for each ref
112 | # so we don't have to recalculate them 500 times
113 |
114 | stats_by_ref = [ None for r in range(len(args.reference)) ]
115 |
116 | for j in range(num_iterations) :
117 | ref = indices[j][i]
118 | this_stats = stats_by_ref[ref]
119 |
120 | if this_stats is None :
121 | this_stats = [ s for s in gleu_calculator.gleu_stats(
122 | i,r_ind=ref) ]
123 | stats_by_ref[ref] = this_stats
124 |
125 | iter_stats[j] = [ sum(scores)
126 | for scores in zip(iter_stats[j], this_stats)]
127 |
128 | if args.debug :
129 | # sentence-level GLEU is the mean GLEU of the hypothesis
130 | # compared to each reference
131 | for r in range(len(args.reference)) :
132 | if stats_by_ref[r] is None :
133 | stats_by_ref[r] = [s for s in gleu_calculator.gleu_stats(
134 | i,r_ind=r) ]
135 |
136 | print i,
137 | print ' '.join(get_gleu_stats([gleu_calculator.gleu(stats,smooth=True)
138 | for stats in stats_by_ref]))
139 |
140 | if args.debug :
141 | print '\n==== Overall score ====='
142 | print 'Mean Stdev 95%CI GLEU'
143 | print ' '.join(get_gleu_stats([gleu_calculator.gleu(stats)
144 | for stats in iter_stats ]))
145 | else :
146 | print get_gleu_stats([gleu_calculator.gleu(stats)
147 | for stats in iter_stats ])[0]
148 |
149 |
--------------------------------------------------------------------------------
/gleu/scripts/gleu.py:
--------------------------------------------------------------------------------
1 | # Courtney Napoles
2 | #
3 | # 21 June 2015
4 | # ##
5 | # gleu.py
6 | #
7 | # This script calculates the GLEU score of a sentence, as described in
8 | # our ACL 2015 paper, Ground Truth for Grammatical Error Correction Metrics
9 | # by Courtney Napoles, Keisuke Sakaguchi, Matt Post, and Joel Tetreault.
10 | #
11 | # For instructions on how to get the GLEU score, call "compute_gleu -h"
12 | #
13 | # Updated 2 May 2016: This is an updated version of GLEU that has been
14 | # modified to handle multiple references more fairly.
15 | #
16 | # Updated 6 9 2017: Fixed inverse brevity penalty
17 | #
18 | # This script was adapted from bleu.py by Adam Lopez.
19 | #
20 |
21 | import math
22 | from collections import Counter
23 |
24 | class GLEU :
25 |
26 | def __init__(self,n=4) :
27 | self.order = 4
28 |
29 | def load_hypothesis_sentence(self,hypothesis) :
30 | self.hlen = len(hypothesis)
31 | self.this_h_ngrams = [ self.get_ngram_counts(hypothesis,n)
32 | for n in range(1,self.order+1) ]
33 |
34 | def load_sources(self,spath) :
35 | self.all_s_ngrams = [ [ self.get_ngram_counts(line.split(),n)
36 | for n in range(1,self.order+1) ]
37 | for line in open(spath) ]
38 |
39 | def load_references(self,rpaths) :
40 | self.refs = [ [] for i in range(len(self.all_s_ngrams)) ]
41 | self.rlens = [ [] for i in range(len(self.all_s_ngrams)) ]
42 | for rpath in rpaths :
43 | for i,line in enumerate(open(rpath)) :
44 | self.refs[i].append(line.split())
45 | self.rlens[i].append(len(line.split()))
46 |
47 | # count number of references each n-gram appear sin
48 | self.all_rngrams_freq = [ Counter() for i in range(self.order) ]
49 |
50 | self.all_r_ngrams = [ ]
51 | for refset in self.refs :
52 | all_ngrams = []
53 | self.all_r_ngrams.append(all_ngrams)
54 |
55 | for n in range(1,self.order+1) :
56 | ngrams = self.get_ngram_counts(refset[0],n)
57 | all_ngrams.append(ngrams)
58 |
59 | for k in ngrams.keys() :
60 | self.all_rngrams_freq[n-1][k]+=1
61 |
62 | for ref in refset[1:] :
63 | new_ngrams = self.get_ngram_counts(ref,n)
64 | for nn in new_ngrams.elements() :
65 | if new_ngrams[nn] > ngrams.get(nn,0) :
66 | ngrams[nn] = new_ngrams[nn]
67 |
68 | def get_ngram_counts(self,sentence,n) :
69 | return Counter([tuple(sentence[i:i+n])
70 | for i in xrange(len(sentence)+1-n)])
71 |
72 | # returns ngrams in a but not in b
73 | def get_ngram_diff(self,a,b) :
74 | diff = Counter(a)
75 | for k in (set(a) & set(b)) :
76 | del diff[k]
77 | return diff
78 |
79 | def normalization(self,ngram,n) :
80 | return 1.0*self.all_rngrams_freq[n-1][ngram]/len(self.rlens[0])
81 |
82 | # Collect BLEU-relevant statistics for a single hypothesis/reference pair.
83 | # Return value is a generator yielding:
84 | # (c, r, numerator1, denominator1, ... numerator4, denominator4)
85 | # Summing the columns across calls to this function on an entire corpus
86 | # will produce a vector of statistics that can be used to compute GLEU
87 | def gleu_stats(self,i,r_ind=None):
88 |
89 | hlen = self.hlen
90 | rlen = self.rlens[i][r_ind]
91 |
92 | yield hlen
93 | yield rlen
94 |
95 | for n in xrange(1,self.order+1):
96 | h_ngrams = self.this_h_ngrams[n-1]
97 | s_ngrams = self.all_s_ngrams[i][n-1]
98 | r_ngrams = self.get_ngram_counts(self.refs[i][r_ind],n)
99 |
100 | s_ngram_diff = self.get_ngram_diff(s_ngrams,r_ngrams)
101 |
102 | yield max([ sum( (h_ngrams & r_ngrams).values() ) - \
103 | sum( (h_ngrams & s_ngram_diff).values() ), 0 ])
104 |
105 | yield max([hlen+1-n, 0])
106 |
107 | # Compute GLEU from collected statistics obtained by call(s) to gleu_stats
108 | def gleu(self,stats,smooth=False):
109 | # smooth 0 counts for sentence-level scores
110 | if smooth :
111 | stats = [ s if s != 0 else 1 for s in stats ]
112 | if len(filter(lambda x: x==0, stats)) > 0:
113 | return 0
114 | (c, r) = stats[:2]
115 | log_gleu_prec = sum([math.log(float(x)/y)
116 | for x,y in zip(stats[2::2],stats[3::2])]) / 4
117 | return math.exp(min([0, 1-float(r)/c]) + log_gleu_prec)
118 |
--------------------------------------------------------------------------------
/gleu/scripts/original_gleu/compute_gleu:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Courtney Napoles
4 | #
5 | # 21 June 2015
6 | # ##
7 | # compute_gleu
8 | #
9 | # This script calls gleu.py to calculate the GLEU score of a sentence, as
10 | # described in our ACL 2015 paper, Ground Truth for Grammatical Error
11 | # Correction Metrics by Courtney Napoles, Keisuke Sakaguchi, Matt Post,
12 | # and Joel Tetreault.
13 | #
14 | # For instructions on how to get the GLEU score, call "compute_gleu -h"
15 | #
16 | # This script was adapted from compute-bleu by Adam Lopez.
17 | #
18 | #
19 | # THIS IS AN OLD VERSION OF GLEU. Please see the repository for the correct,
20 | # new version (https://github.com/cnap/gec-ranking)
21 |
22 | import argparse
23 | import sys
24 | import os
25 | from gleu import GLEU
26 |
27 | if __name__ == '__main__' :
28 |
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument("-r", "--reference",
31 | help="Target language reference sentences. Multiple files for "
32 | " multiple references.",
33 | nargs="*",
34 | dest="reference",
35 | default=["data/dev.ref"])
36 | parser.add_argument("-s", "--source",
37 | help="Source language source sentences",
38 | dest="source",
39 | default="data/dev.src")
40 | parser.add_argument("-o", "--hypothesis",
41 | help="Target language hypothesis sentences to evaluate (can "
42 | "be more than one file--the GLEU score of each file will be) "
43 | "output separately. Use '-o -' to read hypotheses from stdin.",
44 | nargs="*",
45 | dest="hypothesis",
46 | default=["data/dev.hyp"])
47 | parser.add_argument("-n",
48 | help="Maximum order of ngrams",
49 | type=int,
50 | default=4)
51 | parser.add_argument("-l",
52 | help="Lambda weight for penalizing incorrectly unchanged n-grams",
53 | nargs='*',
54 | default=[0])
55 | parser.add_argument("-d","--debug",
56 | help="Debug; print sentence-level scores",
57 | default=False,
58 | action="store_true")
59 |
60 | args = parser.parse_args()
61 |
62 | gleu_calculator = GLEU(args.n,args.l)
63 |
64 | gleu_calculator.load_sources(args.source)
65 | gleu_calculator.load_references(args.reference)
66 |
67 | for hpath in args.hypothesis :
68 | instream = sys.stdin if hpath == '-' else open(hpath)
69 | hyp = [line.split() for line in instream]
70 |
71 | for l in args.l :
72 | l = float(l)
73 | gleu_calculator.set_lambda(l)
74 | print os.path.basename(hpath),l,
75 |
76 | if args.debug :
77 | print
78 | print '===== Sentence-level scores ====='
79 | print 'SID\tGLEU'
80 |
81 | stats = [0 for i in xrange(2*args.n+2)]
82 | for i,h in enumerate(hyp):
83 | this_stats = [s for s in gleu_calculator.gleu_stats(h,i)]
84 | if args.debug :
85 | print '%d\t%f'%(i,gleu_calculator.gleu(this_stats))
86 | stats = [sum(scores) for scores in zip(stats, this_stats)]
87 | if args.debug :
88 | print '\n==== Overall score ====='
89 | print gleu_calculator.gleu(stats)
90 |
--------------------------------------------------------------------------------
/gleu/scripts/original_gleu/gleu.py:
--------------------------------------------------------------------------------
1 | # Courtney Napoles
2 | #
3 | # 21 June 2015
4 | # ##
5 | # gleu.py
6 | #
7 | # This script calculates the GLEU score of a sentence, as described in
8 | # our ACL 2015 paper, Ground Truth for Grammatical Error Correction Metrics
9 | # by Courtney Napoles, Keisuke Sakaguchi, Matt Post, and Joel Tetreault.
10 | #
11 | # For instructions on how to get the GLEU score, call "compute_gleu -h"
12 | #
13 | # This script was adapted from bleu.py by Adam Lopez.
14 | #
15 | #
16 | # THIS IS AN OLD VERSION OF GLEU. Please see the repository for the correct,
17 | # new version (https://github.com/cnap/gec-ranking)
18 |
19 | import math
20 | from collections import Counter
21 |
22 | class GLEU :
23 |
24 | def __init__(self,n=4,l=1) :
25 | self.order = 4
26 | self.weight = l
27 |
28 | def load_sources(self,spath) :
29 | self.all_s_ngrams = [ [ self.get_ngram_counts(line.split(),n) \
30 | for n in range(1,self.order+1) ] \
31 | for line in open(spath) ]
32 |
33 | def load_references(self,rpaths) :
34 | refs = [ [] for i in range(len(self.all_s_ngrams)) ]
35 | self.rlens = [ [] for i in range(len(self.all_s_ngrams)) ]
36 | for rpath in rpaths :
37 | for i,line in enumerate(open(rpath)) :
38 | refs[i].append(line.split())
39 | self.rlens[i].append(len(line.split()))
40 |
41 | self.all_r_ngrams = [ ]
42 | for refset in refs :
43 | all_ngrams = []
44 | self.all_r_ngrams.append(all_ngrams)
45 |
46 | for n in range(1,self.order+1) :
47 | ngrams = self.get_ngram_counts(refset[0],n)
48 | all_ngrams.append(ngrams)
49 | for ref in refset[1:] :
50 | new_ngrams = self.get_ngram_counts(ref,n)
51 | for nn in new_ngrams.elements() :
52 | if new_ngrams[nn] > ngrams.get(nn,0) :
53 | ngrams[nn] = new_ngrams[nn]
54 |
55 |
56 | def get_ngram_counts(self,sentence,n) :
57 | return Counter([tuple(sentence[i:i+n]) for i in xrange(len(sentence)+1-n)])
58 |
59 | def set_lambda(self,l) :
60 | self.weight = l
61 |
62 | # Collect BLEU-relevant statistics for a single hypothesis/reference pair.
63 | # Return value is a generator yielding:
64 | # (c, r, numerator1, denominator1, ... numerator4, denominator4)
65 | # Summing the columns across calls to this function on an entire corpus will
66 | # produce a vector of statistics that can be used to compute BLEU or GLEU
67 | def gleu_stats(self,hypothesis, i):
68 |
69 | hlen=len(hypothesis)
70 | rlen = self.rlens[i][0]
71 |
72 | # set the reference length to be the reference length closest to the hyp length
73 | for r in self.rlens[i][1:] :
74 | if abs(r - hlen) < abs(rlen - hlen) :
75 | rlen = r
76 |
77 | yield rlen
78 | yield hlen
79 |
80 | for n in xrange(1,self.order+1):
81 | h_ngrams = self.get_ngram_counts(hypothesis,n)
82 | s_ngrams = self.all_s_ngrams[i][n-1]
83 | r_ngrams = self.all_r_ngrams[i][n-1]
84 |
85 | r_ngram_diff = r_ngrams - s_ngrams
86 | # some n-grams may appear in both sets but have a higher count in the subtracted
87 | # one so these n-grams should be deleted so a single occurrence of one of those
88 | # n-grams doesn't penalize the precision
89 | for k in r_ngram_diff.keys() :
90 | if k in s_ngrams :
91 | del r_ngram_diff[k]
92 | s_ngram_diff = s_ngrams - r_ngrams
93 | for k in s_ngram_diff.keys() :
94 | if k in r_ngrams :
95 | del s_ngram_diff[k]
96 |
97 | yield sum( (h_ngrams & r_ngram_diff).values() ) + \
98 | max([ sum( (h_ngrams & r_ngrams).values() ) - \
99 | self.weight * sum( (h_ngrams & s_ngram_diff).values() ), 0 ])
100 |
101 | yield sum( (h_ngrams & r_ngram_diff).values() ) + max([hlen+1-n, 0])
102 |
103 | ## here is the original, erroneous way to calculate the denominator
104 | #yield max([sum(r_ngram_diff.values()), 0]) + max([hlen+1-n, 0])
105 |
106 | # Compute GLEU from collected statistics obtained by call(s) to gleu_stats
107 | def gleu(self,stats):
108 | if len(filter(lambda x: x==0, stats)) > 0:
109 | return 0
110 | (c, r) = stats[:2]
111 | log_gleu_prec = sum([math.log(float(x)/y) for x,y in zip(stats[2::2],stats[3::2])]) / 4.
112 |
113 | return math.exp(min([0, 1-float(r)/c]) + log_gleu_prec)
114 |
--------------------------------------------------------------------------------
/gnn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.utils
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
8 |
9 | from typing import List, Tuple, Dict, Sequence, Any
10 |
11 | # https://github.com/pcyin/pytorch-gated-graph-neural-network/blob/master/gnn.py
12 |
13 | class AdjacencyList:
14 | """represent the topology of a graph"""
15 | def __init__(self, node_num: int, adj_list: List, device: torch.device):
16 | self.node_num = node_num
17 | self.data = torch.tensor(adj_list, dtype=torch.long, device=device)
18 | self.edge_num = len(adj_list)
19 |
20 | @property
21 | def device(self):
22 | return self.data.device
23 |
24 | def __getitem__(self, item):
25 | return self.data[item]
26 |
27 |
28 | class GatedGraphNeuralNetwork(nn.Module):
29 | def __init__(self, hidden_size, num_edge_types, layer_timesteps,
30 | residual_connections,
31 | state_to_message_dropout=0.3,
32 | rnn_dropout=0.3,
33 | use_bias_for_message_linear=True):
34 |
35 | super(GatedGraphNeuralNetwork, self).__init__()
36 |
37 | self.hidden_size = hidden_size
38 | self.num_edge_types = num_edge_types
39 | self.layer_timesteps = layer_timesteps
40 | self.residual_connections = residual_connections
41 | self.state_to_message_dropout = state_to_message_dropout
42 | self.rnn_dropout = rnn_dropout
43 | self.use_bias_for_message_linear = use_bias_for_message_linear
44 |
45 | # Prepare linear transformations from node states to messages, for each layer and each edge type
46 | # Prepare rnn cells for each layer
47 | self.state_to_message_linears = []
48 | self.rnn_cells = []
49 | for layer_idx in range(len(self.layer_timesteps)):
50 | state_to_msg_linears_cur_layer = []
51 | # Initiate a linear transformation for each edge type
52 | for edge_type_j in range(self.num_edge_types):
53 | # TODO: glorot_init?
54 | state_to_msg_linear_layer_i_type_j = nn.Linear(self.hidden_size, self.hidden_size, bias=use_bias_for_message_linear)
55 | setattr(self,
56 | 'state_to_message_linear_layer%d_type%d' % (layer_idx, edge_type_j),
57 | state_to_msg_linear_layer_i_type_j)
58 |
59 | state_to_msg_linears_cur_layer.append(state_to_msg_linear_layer_i_type_j)
60 | self.state_to_message_linears.append(state_to_msg_linears_cur_layer)
61 |
62 | layer_residual_connections = self.residual_connections.get(layer_idx, [])
63 | rnn_cell_layer_i = nn.GRUCell(self.hidden_size * (1 + len(layer_residual_connections)), self.hidden_size)
64 | setattr(self, 'rnn_cell_layer%d' % layer_idx, rnn_cell_layer_i)
65 | self.rnn_cells.append(rnn_cell_layer_i)
66 |
67 | self.state_to_message_dropout_layer = nn.Dropout(self.state_to_message_dropout)
68 | self.rnn_dropout_layer = nn.Dropout(self.rnn_dropout)
69 |
70 | @property
71 | def device(self):
72 | return self.rnn_cells[0].weight_hh.device
73 |
74 | def forward(self,
75 | initial_node_representation: Variable,
76 | adjacency_lists: List[AdjacencyList],
77 | return_all_states=False) -> Variable:
78 | return self.compute_node_representations(initial_node_representation, adjacency_lists,
79 | return_all_states=return_all_states)
80 |
81 | def compute_node_representations(self,
82 | initial_node_representation: Variable,
83 | adjacency_lists: List[AdjacencyList],
84 | return_all_states=False) -> Variable:
85 | # If the dimension of initial node embedding is smaller, then perform padding first
86 | # one entry per layer (final state of that layer), shape: number of nodes in batch v x D
87 | init_node_repr_size = initial_node_representation.size(1)
88 | device = adjacency_lists[0].data.device
89 | if init_node_repr_size < self.hidden_size:
90 | pad_size = self.hidden_size - init_node_repr_size
91 | zero_pads = torch.zeros(initial_node_representation.size(0), pad_size, dtype=torch.float, device=device)
92 | initial_node_representation = torch.cat([initial_node_representation, zero_pads], dim=-1)
93 | node_states_per_layer = [initial_node_representation]
94 |
95 | node_num = initial_node_representation.size(0)
96 |
97 | message_targets = [] # list of tensors of message targets of shape [E]
98 | for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
99 | if adjacency_list_for_edge_type.edge_num > 0:
100 | edge_targets = adjacency_list_for_edge_type[:, 1]
101 | message_targets.append(edge_targets)
102 | message_targets = torch.cat(message_targets, dim=0) # Shape [M]
103 |
104 | # sparse matrix of shape [V, M]
105 | # incoming_msg_sparse_matrix = self.get_incoming_message_sparse_matrix(adjacency_lists).to(device)
106 | for layer_idx, num_timesteps in enumerate(self.layer_timesteps):
107 | # Used shape abbreviations:
108 | # V ~ number of nodes
109 | # D ~ state dimension
110 | # E ~ number of edges of current type
111 | # M ~ number of messages (sum of all E)
112 |
113 | # Extract residual messages, if any:
114 | layer_residual_connections = self.residual_connections.get(layer_idx, [])
115 | # List[(V, D)]
116 | layer_residual_states: List[torch.FloatTensor] = [node_states_per_layer[residual_layer_idx]
117 | for residual_layer_idx in layer_residual_connections]
118 |
119 | # Record new states for this layer. Initialised to last state, but will be updated below:
120 | node_states_for_this_layer = node_states_per_layer[-1]
121 | # For each message propagation step
122 | for t in range(num_timesteps):
123 | messages: List[torch.FloatTensor] = [] # list of tensors of messages of shape [E, D]
124 | message_source_states: List[torch.FloatTensor] = [] # list of tensors of edge source states of shape [E, D]
125 |
126 | # Collect incoming messages per edge type
127 | for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
128 | if adjacency_list_for_edge_type.edge_num > 0:
129 | # shape [E]
130 | edge_sources = adjacency_list_for_edge_type[:, 0]
131 | # shape [E, D]
132 | edge_source_states = node_states_for_this_layer[edge_sources]
133 |
134 | f_state_to_message = self.state_to_message_linears[layer_idx][edge_type_idx]
135 | # Shape [E, D]
136 | all_messages_for_edge_type = self.state_to_message_dropout_layer(f_state_to_message(edge_source_states))
137 |
138 | messages.append(all_messages_for_edge_type)
139 | message_source_states.append(edge_source_states)
140 |
141 | # shape [M, D]
142 | messages: torch.FloatTensor = torch.cat(messages, dim=0)
143 |
144 | # Sum up messages that go to the same target node
145 | # shape [V, D]
146 | incoming_messages = torch.zeros(node_num, messages.size(1), device=device)
147 | incoming_messages = incoming_messages.scatter_add_(0,
148 | message_targets.unsqueeze(-1).expand_as(messages),
149 | messages)
150 |
151 | # shape [V, D * (1 + num of residual connections)]
152 | incoming_information = torch.cat(layer_residual_states + [incoming_messages], dim=-1)
153 |
154 | # pass updated vertex features into RNN cell
155 | # Shape [V, D]
156 | updated_node_states = self.rnn_cells[layer_idx](incoming_information, node_states_for_this_layer)
157 | updated_node_states = self.rnn_dropout_layer(updated_node_states)
158 | node_states_for_this_layer = updated_node_states
159 |
160 | node_states_per_layer.append(node_states_for_this_layer)
161 |
162 | if return_all_states:
163 | return node_states_per_layer[1:]
164 | else:
165 | node_states_for_last_layer = node_states_per_layer[-1]
166 | return node_states_for_last_layer
167 |
168 |
169 | def main():
170 | gnn = GatedGraphNeuralNetwork(hidden_size=64, num_edge_types=2,
171 | layer_timesteps=[3, 5, 7, 2], residual_connections={2: [0], 3: [0, 1]})
172 |
173 | adj_list_type1 = AdjacencyList(node_num=4, adj_list=[(0, 2), (2, 1), (1, 3)], device=gnn.device)
174 | adj_list_type2 = AdjacencyList(node_num=4, adj_list=[(0, 0), (0, 1)], device=gnn.device)
175 |
176 | node_representations = gnn.compute_node_representations(initial_node_representation=torch.randn(4, 64),
177 | adjacency_lists=[adj_list_type1, adj_list_type2])
178 |
179 | print(node_representations)
180 |
181 |
182 | if __name__ == '__main__':
183 | main()
--------------------------------------------------------------------------------
/module_manager.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import Counter
3 | import numpy as np
4 | import os
5 | import random
6 | import sys
7 | import torch
8 | from torch import nn
9 |
10 | from dpu_utils.mlutils import Vocabulary
11 |
12 | from ast_graph_encoder import ASTGraphEncoder
13 | from constants import *
14 | from data_utils import *
15 | import diff_utils
16 | from embedding_store import EmbeddingStore
17 | from encoder import Encoder
18 | from external_cache import get_code_features, get_nl_features, get_num_code_features, get_num_nl_features
19 | from tensor_utils import *
20 |
21 |
22 | class ModuleManager(nn.Module):
23 | """Utility class which helps manage related attributes of the update and detection tasks."""
24 | def __init__(self, attend_code_sequence_states, attend_code_graph_states, features, posthoc, task):
25 | super(ModuleManager, self).__init__()
26 | self.attend_code_sequence_states = attend_code_sequence_states
27 | self.attend_code_graph_states = attend_code_graph_states
28 | self.features = features
29 | self.posthoc = posthoc
30 | self.task = task
31 |
32 | self.num_encoders = 0
33 | self.num_seq_encoders = 0
34 | self.out_dim = 0
35 | self.attention_state_size = 0
36 | self.update_encoder_state_size = 0
37 | self.max_ast_length = 0
38 | self.max_code_length = 0
39 | self.max_nl_length = 0
40 | self.generate = task in ['update', 'dual']
41 | self.classify = task in ['detect', 'dual']
42 |
43 | self.encode_code_sequence = self.generate or self.attend_code_sequence_states
44 |
45 | print('Attend code sequence states: {}'.format(self.attend_code_sequence_states))
46 | print('Attend code graph states: {}'.format(self.attend_code_graph_states))
47 | print('Features: {}'.format(self.features))
48 | print('Task: {}'.format(self.task))
49 | sys.stdout.flush()
50 |
51 | def get_code_representation(self, ex, data_type):
52 | if self.posthoc:
53 | if data_type == 'sequence':
54 | return ex.new_code_subtokens
55 | else:
56 | return ex.new_ast
57 | else:
58 | if data_type == 'sequence':
59 | return ex.span_diff_code_subtokens
60 | else:
61 | return ex.diff_ast
62 |
63 | def initialize(self, train_data):
64 | """Initializes model parameters from pre-defined hyperparameters and other hyperparameters
65 | that are computed based on statistics over the training data."""
66 | nl_lengths = []
67 | code_lengths = []
68 | ast_lengths = []
69 |
70 | nl_token_counter = Counter()
71 | code_token_counter = Counter()
72 |
73 | for ex in train_data:
74 | if self.generate:
75 | trg_sequence = [START] + ex.span_minimal_diff_comment_subtokens + [END]
76 | nl_token_counter.update(trg_sequence)
77 | nl_lengths.append(len(trg_sequence))
78 |
79 | old_nl_sequence = ex.old_comment_subtokens
80 | nl_token_counter.update(old_nl_sequence)
81 | nl_lengths.append(len(old_nl_sequence))
82 |
83 | if self.encode_code_sequence:
84 | code_sequence = self.get_code_representation(ex, 'sequence')
85 | code_token_counter.update(code_sequence)
86 | code_lengths.append(len(code_sequence))
87 |
88 | if self.attend_code_graph_states:
89 | code_sequence = [n.value for n in self.get_code_representation(ex, 'graph').nodes]
90 | code_token_counter.update(code_sequence)
91 | ast_lengths.append(len(code_sequence))
92 |
93 | self.max_nl_length = int(np.percentile(np.asarray(sorted(nl_lengths)),
94 | LENGTH_CUTOFF_PCT))
95 | self.max_vocab_extension = self.max_nl_length
96 |
97 | if self.encode_code_sequence:
98 | self.max_code_length = int(np.percentile(np.asarray(sorted(code_lengths)),
99 | LENGTH_CUTOFF_PCT))
100 | self.max_vocab_extension += self.max_code_length
101 |
102 | if self.attend_code_graph_states:
103 | self.max_ast_length = int(np.percentile(np.asarray(sorted(ast_lengths)),
104 | LENGTH_CUTOFF_PCT))
105 |
106 | nl_counts = np.asarray(sorted(nl_token_counter.values()))
107 | nl_threshold = int(np.percentile(nl_counts, VOCAB_CUTOFF_PCT)) + 1
108 | code_counts = np.asarray(sorted(code_token_counter.values()))
109 | code_threshold = int(np.percentile(nl_counts, VOCAB_CUTOFF_PCT)) + 1
110 |
111 | self.embedding_store = EmbeddingStore(nl_threshold, NL_EMBEDDING_SIZE, nl_token_counter,
112 | code_threshold, CODE_EMBEDDING_SIZE, code_token_counter,
113 | DROPOUT_RATE, len(SrcType), SRC_EMBEDDING_SIZE, CODE_EMBEDDING_SIZE, True)
114 |
115 | self.out_dim = 2*HIDDEN_SIZE
116 |
117 | # Accounting for the old NL encoder
118 | self.num_encoders = 1
119 | self.num_seq_encoders += 1
120 | self.attention_state_size += 2*HIDDEN_SIZE
121 | self.nl_encoder = Encoder(NL_EMBEDDING_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT_RATE)
122 | self.nl_attention_transform_matrix = nn.Parameter(torch.randn(
123 | self.out_dim, self.out_dim, dtype=torch.float, requires_grad=True))
124 | self.self_attention = nn.MultiheadAttention(self.out_dim, MULTI_HEADS, DROPOUT_RATE)
125 |
126 | if self.encode_code_sequence:
127 | self.sequence_code_encoder = Encoder(CODE_EMBEDDING_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT_RATE)
128 | self.num_encoders += 1
129 | self.num_seq_encoders += 1
130 |
131 | if self.attend_code_sequence_states:
132 | self.attention_state_size += 2*HIDDEN_SIZE
133 | self.sequence_attention_transform_matrix = nn.Parameter(torch.randn(
134 | self.out_dim, self.out_dim, dtype=torch.float, requires_grad=True))
135 | self.code_sequence_multihead_attention = nn.MultiheadAttention(self.out_dim, MULTI_HEADS, DROPOUT_RATE)
136 |
137 | if self.attend_code_graph_states:
138 | self.graph_code_encoder = ASTGraphEncoder(CODE_EMBEDDING_SIZE, len(DiffEdgeType))
139 | self.num_encoders += 1
140 | self.attention_state_size += 2*HIDDEN_SIZE
141 | self.graph_attention_transform_matrix = nn.Parameter(torch.randn(
142 | CODE_EMBEDDING_SIZE, self.out_dim, dtype=torch.float, requires_grad=True))
143 | self.graph_multihead_attention = nn.MultiheadAttention(self.out_dim, MULTI_HEADS, DROPOUT_RATE)
144 |
145 | if self.features:
146 | self.code_features_to_embedding = nn.Linear(CODE_EMBEDDING_SIZE + get_num_code_features(),
147 | CODE_EMBEDDING_SIZE, bias=False)
148 | self.nl_features_to_embedding = nn.Linear(
149 | NL_EMBEDDING_SIZE + get_num_nl_features(),
150 | NL_EMBEDDING_SIZE, bias=False)
151 |
152 | if self.generate:
153 | self.update_encoder_state_size = self.num_seq_encoders*self.out_dim
154 | self.encoder_final_to_decoder_initial = nn.Parameter(torch.randn(self.update_encoder_state_size,
155 | self.out_dim, dtype=torch.float, requires_grad=True))
156 |
157 | if self.classify:
158 | self.attended_nl_encoder = Encoder(self.out_dim, HIDDEN_SIZE, NUM_LAYERS, DROPOUT_RATE)
159 | self.attended_nl_encoder_output_layer = nn.Linear(self.attention_state_size, self.out_dim, bias=False)
160 |
161 | def get_batches(self, dataset, device, shuffle=False):
162 | """Divides the dataset into batches based on pre-defined BATCH_SIZE hyperparameter.
163 | Each batch is tensorized so that it can be directly passed into the network."""
164 | batches = []
165 | if shuffle:
166 | random.shuffle(dataset)
167 |
168 | curr_idx = 0
169 | while curr_idx < len(dataset):
170 | start_idx = curr_idx
171 | end_idx = min(start_idx + BATCH_SIZE, len(dataset))
172 |
173 | code_token_ids = []
174 | code_lengths = []
175 | old_nl_token_ids = []
176 | old_nl_lengths = []
177 | trg_token_ids = []
178 | trg_extended_token_ids = []
179 | trg_lengths = []
180 | invalid_copy_positions = []
181 | inp_str_reps = []
182 | inp_ids = []
183 | code_features = []
184 | nl_features = []
185 | labels = []
186 |
187 | graph_batch = initialize_graph_method_batch(len(DiffEdgeType))
188 |
189 | for i in range(start_idx, end_idx):
190 | if self.encode_code_sequence:
191 | code_sequence = self.get_code_representation(dataset[i], 'sequence')
192 | code_sequence_ids = self.embedding_store.get_padded_code_ids(
193 | code_sequence, self.max_code_length)
194 | code_length = min(len(code_sequence), self.max_code_length)
195 | code_token_ids.append(code_sequence_ids)
196 | code_lengths.append(code_length)
197 |
198 | if self.attend_code_graph_states:
199 | ast = self.get_code_representation(dataset[i], 'graph')
200 | ast_sequence = [n.value for n in ast.nodes]
201 | ast_length = min(len(ast_sequence), self.max_ast_length)
202 | ast.nodes = ast.nodes[:ast_length]
203 | graph_batch = insert_graph(graph_batch, dataset[i], ast,
204 | self.embedding_store.code_vocabulary, self.features, self.max_ast_length)
205 |
206 | old_nl_sequence = dataset[i].old_comment_subtokens
207 | old_nl_length = min(len(old_nl_sequence), self.max_nl_length)
208 | old_nl_sequence_ids = self.embedding_store.get_padded_nl_ids(
209 | old_nl_sequence, self.max_nl_length)
210 |
211 | old_nl_token_ids.append(old_nl_sequence_ids)
212 | old_nl_lengths.append(old_nl_length)
213 |
214 | if self.generate:
215 | ex_inp_str_reps = []
216 | ex_inp_ids = []
217 |
218 | extra_counter = len(self.embedding_store.nl_vocabulary)
219 | max_limit = len(self.embedding_store.nl_vocabulary) + self.max_vocab_extension
220 | out_ids = set()
221 |
222 | copy_inputs = []
223 | copy_inputs += code_sequence[:code_length]
224 |
225 | copy_inputs += old_nl_sequence[:old_nl_length]
226 | for c in copy_inputs:
227 | nl_id = self.embedding_store.get_nl_id(c)
228 | if self.embedding_store.is_nl_unk(nl_id) and extra_counter < max_limit:
229 | if c in ex_inp_str_reps:
230 | nl_id = ex_inp_ids[ex_inp_str_reps.index(c)]
231 | else:
232 | nl_id = extra_counter
233 | extra_counter += 1
234 |
235 | out_ids.add(nl_id)
236 | ex_inp_str_reps.append(c)
237 | ex_inp_ids.append(nl_id)
238 |
239 | trg_sequence = trg_sequence = [START] + dataset[i].span_minimal_diff_comment_subtokens + [END]
240 | trg_sequence_ids = self.embedding_store.get_padded_nl_ids(
241 | trg_sequence, self.max_nl_length)
242 | trg_extended_sequence_ids = self.embedding_store.get_extended_padded_nl_ids(
243 | trg_sequence, self.max_nl_length, ex_inp_ids, ex_inp_str_reps)
244 |
245 | trg_token_ids.append(trg_sequence_ids)
246 | trg_extended_token_ids.append(trg_extended_sequence_ids)
247 | trg_lengths.append(min(len(trg_sequence), self.max_nl_length))
248 | inp_str_reps.append(ex_inp_str_reps)
249 | inp_ids.append(self.embedding_store.pad_length(ex_inp_ids, self.max_vocab_extension))
250 |
251 | invalid_copy_positions.append(get_invalid_copy_locations(ex_inp_str_reps, self.max_vocab_extension,
252 | trg_sequence, self.max_nl_length))
253 |
254 | labels.append(dataset[i].label)
255 |
256 | if self.features:
257 | if self.encode_code_sequence:
258 | code_features.append(get_code_features(code_sequence, dataset[i], self.max_code_length))
259 | nl_features.append(get_nl_features(old_nl_sequence, dataset[i], self.max_nl_length))
260 |
261 | batches.append(UpdateBatchData(torch.tensor(code_token_ids, dtype=torch.int64, device=device),
262 | torch.tensor(code_lengths, dtype=torch.int64, device=device),
263 | torch.tensor(old_nl_token_ids, dtype=torch.int64, device=device),
264 | torch.tensor(old_nl_lengths, dtype=torch.int64, device=device),
265 | torch.tensor(trg_token_ids, dtype=torch.int64, device=device),
266 | torch.tensor(trg_extended_token_ids, dtype=torch.int64, device=device),
267 | torch.tensor(trg_lengths, dtype=torch.int64, device=device),
268 | torch.tensor(invalid_copy_positions, dtype=torch.uint8, device=device),
269 | inp_str_reps,
270 | torch.tensor(inp_ids, dtype=torch.int64, device=device),
271 | torch.tensor(code_features, dtype=torch.float32, device=device),
272 | torch.tensor(nl_features, dtype=torch.float32, device=device),
273 | torch.tensor(labels, dtype=torch.int64, device=device),
274 | tensorize_graph_method_batch(graph_batch, device, self.max_ast_length)))
275 | curr_idx = end_idx
276 | return batches
277 |
278 | def get_encoder_output(self, batch_data, device):
279 | """Gets hidden states, final state, and a length masks corresponding to each encoder."""
280 | encoder_hidden_states = None
281 | input_lengths = None
282 | final_states = None
283 | mask = None
284 |
285 | # Encode old NL
286 | old_nl_embedded_subtokens = self.embedding_store.get_nl_embeddings(batch_data.old_nl_ids)
287 | if self.features:
288 | old_nl_embedded_subtokens = self.nl_features_to_embedding(torch.cat(
289 | [old_nl_embedded_subtokens, batch_data.nl_features], dim=-1))
290 | old_nl_hidden_states, old_nl_final_state = self.nl_encoder.forward(old_nl_embedded_subtokens,
291 | batch_data.old_nl_lengths, device)
292 | old_nl_masks = (torch.arange(
293 | old_nl_hidden_states.shape[1], device=device).view(1, -1) >= batch_data.old_nl_lengths.view(-1, 1)).unsqueeze(1)
294 | attention_states = compute_attention_states(old_nl_hidden_states, old_nl_masks,
295 | old_nl_hidden_states, transformation_matrix=self.nl_attention_transform_matrix, multihead_attention=self.self_attention)
296 |
297 | # Encode code
298 | code_hidden_states = None
299 | code_masks = None
300 | code_final_state = None
301 |
302 | if self.encode_code_sequence:
303 | code_embedded_subtokens = self.embedding_store.get_code_embeddings(batch_data.code_ids)
304 | if self.features:
305 | code_embedded_subtokens = self.code_features_to_embedding(torch.cat(
306 | [code_embedded_subtokens, batch_data.code_features], dim=-1))
307 | code_hidden_states, code_final_state = self.sequence_code_encoder.forward(code_embedded_subtokens,
308 | batch_data.code_lengths, device)
309 | code_masks = (torch.arange(
310 | code_hidden_states.shape[1], device=device).view(1, -1) >= batch_data.code_lengths.view(-1, 1)).unsqueeze(1)
311 | encoder_hidden_states = code_hidden_states
312 | input_lengths = batch_data.code_lengths
313 | final_states = code_final_state
314 |
315 | if self.attend_code_sequence_states:
316 | attention_states = torch.cat([attention_states, compute_attention_states(
317 | code_hidden_states, code_masks, old_nl_hidden_states,
318 | transformation_matrix=self.sequence_attention_transform_matrix,
319 | multihead_attention=self.code_sequence_multihead_attention)], dim=-1)
320 |
321 | if self.attend_code_graph_states:
322 | embedded_nodes = self.embedding_store.get_node_embeddings(
323 | batch_data.graph_batch.value_lookup_ids, batch_data.graph_batch.src_type_ids)
324 |
325 | if self.features:
326 | embedded_nodes = self.code_features_to_embedding(torch.cat(
327 | [embedded_nodes, batch_data.graph_batch.node_features], dim=-1))
328 |
329 | graph_states = self.graph_code_encoder.forward(embedded_nodes, batch_data.graph_batch, device)
330 | graph_lengths = batch_data.graph_batch.num_nodes_per_graph
331 | graph_masks = (torch.arange(
332 | graph_states.shape[1], device=device).view(1, -1) >= graph_lengths.view(-1, 1)).unsqueeze(1)
333 |
334 | transformed_graph_states = torch.einsum('ijk,km->ijm', graph_states, self.graph_attention_transform_matrix)
335 | graph_attention_states = compute_attention_states(transformed_graph_states, graph_masks,
336 | old_nl_hidden_states, multihead_attention=self.graph_multihead_attention)
337 | attention_states = torch.cat([attention_states, graph_attention_states], dim=-1)
338 |
339 | if self.classify:
340 | nl_attended_states = torch.tanh(self.attended_nl_encoder_output_layer(attention_states))
341 | _, attended_old_nl_final_state = self.attended_nl_encoder.forward(nl_attended_states,
342 | batch_data.old_nl_lengths, device)
343 | else:
344 | attended_old_nl_final_state = None
345 |
346 | if self.generate:
347 | encoder_final_state = torch.einsum('ij,jk->ik',
348 | torch.cat([final_states, old_nl_final_state], dim=-1),
349 | self.encoder_final_to_decoder_initial)
350 | encoder_hidden_states, input_lengths = merge_encoder_outputs(encoder_hidden_states,
351 | input_lengths, old_nl_hidden_states, batch_data.old_nl_lengths, device)
352 | mask = (torch.arange(
353 | encoder_hidden_states.shape[1], device=device).view(1, -1) >= input_lengths.view(-1, 1)).unsqueeze(1)
354 | else:
355 | encoder_final_state = None
356 |
357 | return EncoderOutputs(encoder_hidden_states, mask, encoder_final_state, code_hidden_states, code_masks,
358 | old_nl_hidden_states, old_nl_masks, old_nl_final_state, attended_old_nl_final_state)
--------------------------------------------------------------------------------
/run_comment_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 | import os
4 | import sys
5 | import torch
6 |
7 | sys.path.append('comment_update')
8 | from comment_generation import CommentGenerationModel
9 | from update_module import UpdateModule
10 | from detection_module import DetectionModule
11 | from data_loader import get_data_splits
12 | from module_manager import ModuleManager
13 |
14 | def build_model(task, model_path, manager):
15 | """ Builds the appropriate model, with task-specific modules."""
16 | if task == 'dual':
17 | detection_module = DetectionModule(None, manager)
18 | model = UpdateModule(model_path, manager, detection_module)
19 | elif 'update' in task:
20 | model = UpdateModule(model_path, manager, None)
21 | else:
22 | model = DetectionModule(model_path, manager)
23 |
24 | return model
25 |
26 | def load_model(model_path, evaluate_detection=False):
27 | """Loads a pretrained model from model_path."""
28 | print('Loading model from: {}'.format(model_path))
29 | sys.stdout.flush()
30 | if torch.cuda.is_available() and evaluate_detection:
31 | model = torch.load(model_path)
32 | model.torch_device_name = 'gpu'
33 | model.cuda()
34 | for c in model.children():
35 | c.cuda()
36 | else:
37 | model = torch.load(model_path, map_location='cpu')
38 | model.torch_device_name = 'cpu'
39 | model.cpu()
40 | for c in model.children():
41 | c.cpu()
42 | return model
43 |
44 | def train(model, train_examples, valid_examples):
45 | """Trains a model."""
46 | print('Training with {} examples (validation {})'.format(len(train_examples), len(valid_examples)))
47 | sys.stdout.flush()
48 | if torch.cuda.is_available():
49 | model.torch_device_name = 'gpu'
50 | model.cuda()
51 | for c in model.children():
52 | c.cuda()
53 | else:
54 | model.torch_device_name = 'cpu'
55 | model.cpu()
56 | for c in model.children():
57 | c.cpu()
58 |
59 | model.run_train(train_examples, valid_examples)
60 |
61 | def evaluate(task, model, test_examples, model_name, rerank):
62 | """Runs evaluation over a given model."""
63 | print('Evaluating {} examples'.format(len(test_examples)))
64 | sys.stdout.flush()
65 | if task == 'detect':
66 | model.run_evaluation(test_examples, model_name)
67 | else:
68 | model.run_evaluation(test_examples, rerank, model_name)
69 |
70 | if __name__ == "__main__":
71 | parser = argparse.ArgumentParser()
72 | parser.add_argument('--task', help='detect, update, or dual')
73 | parser.add_argument('--attend_code_sequence_states', action='store_true', help='attend to sequence-based code hidden states for detection')
74 | parser.add_argument('--attend_code_graph_states', action='store_true', help='attend to graph-based code hidden states for detection')
75 | parser.add_argument('--features', action='store_true', help='concatenate lexical and linguistic feats to code/comment input embeddings')
76 | parser.add_argument('--posthoc', action='store_true', help='whether to run in posthoc mode where old code is not available')
77 | parser.add_argument('--positive_only', action='store_true', help='whether to train on only inconsistent examples')
78 | parser.add_argument('--test_mode', action='store_true', help='whether to run evaluation')
79 | parser.add_argument('--rerank', action='store_true', help='whether to use reranking in the update module (if task is update or dual)')
80 | parser.add_argument('--model_path', help='path to save model (training) or path to saved model (evaluation)')
81 | parser.add_argument('--model_name', help='name of model (used to save model output)')
82 | args = parser.parse_args()
83 |
84 | train_examples, valid_examples, test_examples, high_level_details = get_data_splits()
85 | if args.positive_only:
86 | train_examples = [ex for ex in train_examples if ex.label == 1]
87 | valid_examples = [ex for ex in valid_examples if ex.label == 1]
88 |
89 | print('Train: {}'.format(len(train_examples)))
90 | print('Valid: {}'.format(len(valid_examples)))
91 | print('Test: {}'.format(len(test_examples)))
92 |
93 | if args.task == 'detect' and (not args.attend_code_sequence_states and not args.attend_code_graph_states):
94 | raise ValueError('Please specify attention states for detection')
95 | if args.posthoc and (args.task != 'detect' or args.features):
96 | # Features and update rely on code changes
97 | raise ValueError('Posthoc setting not supported for given arguments')
98 |
99 | if args.test_mode:
100 | print('Starting evaluation: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
101 |
102 | model = load_model(args.model_path, args.task =='detect')
103 | evaluate(args.task, model, test_examples, args.model_name, args.rerank)
104 |
105 | print('Terminating evaluation: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
106 | else:
107 | print('Starting training: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
108 |
109 | manager = ModuleManager(args.attend_code_sequence_states, args.attend_code_graph_states, args.features, args.posthoc, args.task)
110 | manager.initialize(train_examples)
111 | model = build_model(args.task, args.model_path, manager)
112 |
113 | print('Model path: {}'.format(args.model_path))
114 | sys.stdout.flush()
115 |
116 | train(model, train_examples, valid_examples)
117 |
118 | print('Terminating training: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
--------------------------------------------------------------------------------