├── LICENSE ├── README.md ├── __init__.py ├── ast_graph_encoder.py ├── baselines └── codebert_bow.py ├── comment_update ├── SARI.py ├── comment_generation.py ├── decoder.py ├── embedding_store.py ├── external_cache.py ├── generation_decoder.py ├── tensor_utils.py ├── update_decoder.py └── update_evaluation_utils.py ├── constants.py ├── data_loader.py ├── data_processing ├── ast_diffing │ ├── code_samples │ │ ├── new.java │ │ └── old.java │ └── python │ │ └── xml_diff_parser.py ├── build_example.py ├── data_formatting_utils.py ├── high_level_feature_extractor.py └── tokenization_feature_extractor.py ├── data_utils.py ├── detection_evaluation_utils.py ├── detection_module.py ├── diff_utils.py ├── display_scores.py ├── encoder.py ├── gleu ├── README.md ├── data │ ├── all_judgments.csv │ └── all_judgments.xml ├── gleu_update_2016.pdf └── scripts │ ├── compute_gleu │ ├── gleu.py │ └── original_gleu │ ├── compute_gleu │ └── gleu.py ├── gnn.py ├── module_manager.py ├── run_comment_model.py └── update_module.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 panthap2 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Just-In-Time Inconsistency Detection Between Comments and Source Code 2 | 3 | **Code and datasets for our AAAI-2021 paper "Deep Just-In-Time Inconsistency Detection Between Comments and Source Code"** 4 | which can be found [here](https://arxiv.org/pdf/2010.01625.pdf). 5 | 6 | If you find this work useful, please consider citing our paper: 7 | 8 | ``` 9 | @inproceedings{PanthaplackelETAL21DeepJITInconsistency, 10 | author = {Panthaplackel, Sheena and Li, Junyi Jessy and Gligoric, Milos and Mooney, Raymond J.}, 11 | title = {Deep Just-In-Time Inconsistency Detection Between Comments and Source Code}, 12 | booktitle = {AAAI}, 13 | pages = {427--435}, 14 | year = {2021}, 15 | } 16 | ``` 17 | The code base shares components with our prior work called [Learning to Update Natural Language Comments Based on Code Changes](https://github.com/panthap2/LearningToUpdateNLComments). 18 | 19 | Download data from [here](https://drive.google.com/drive/folders/1heqEQGZHgO6gZzCjuQD1EyYertN4SAYZ?usp=sharing). Download additional model resources from [here](https://drive.google.com/drive/folders/1cutxr4rMDkT1g2BbmCAR2wqKTxeFH11K?usp=sharing). Edit configurations in `constants.py` to specify data, resource, and output locations. 20 | 21 | **Inconsistency Detection:** 22 | 23 | *SEQ(C, Medit) + features* 24 | ``` 25 | python3 run_comment_model.py --task=detect --attend_code_sequence_states --features --model_path=detect_attend_code_sequence_states_features.pkl.gz --model_name=detect_attend_code_sequence_states_features 26 | ``` 27 | 28 | *GRAPH(C, Tedit) + features* 29 | (The GGNN used for this approach is derived from [here](https://github.com/pcyin/pytorch-gated-graph-neural-network/blob/master/gnn.py).) 30 | ``` 31 | python3 run_comment_model.py --task=detect --attend_code_graph_states --features --model_path=detect_attend_code_graph_states_features.pkl.gz --model_name=detect_attend_code_graph_states_features 32 | ``` 33 | 34 | *HYBRID(C, Medit, Tedit) + features* 35 | ``` 36 | python3 run_comment_model.py --task=detect --attend_code_sequence_states --attend_code_graph_states --features --model_path=detect_attend_code_sequence_states_attend_code_graph_states_features.pkl.gz --model_name=detect_attend_code_sequence_states_attend_code_graph_states_features 37 | ``` 38 | 39 | To run inference on a detection model, add `--test_mode` to the command used to train the model. 40 | 41 | **Combined Detection + Update:** 42 | 43 | *Update w/ implicit detection* 44 | ``` 45 | python3 run_comment_model.py --task=update --features --model_path=update_features.pkl.gz --model_name=update_features 46 | ``` 47 | 48 | To run inference, add `--test_mode --rerank` to the command used to train the model. 49 | 50 | *Pretrained update + detection* 51 | ``` 52 | python3 run_comment_model.py --task=update --features --positive_only --model_path=update_features_positive_only.pkl.gz --model_name=update_features_positive_only 53 | ``` 54 | 55 | One of the detection models should also be trained, following instructions provided in the "Inconsistency Detection" section above. To run inference on the update model, add `--test_mode --rerank` to the command used to train the model. Inference on the detection model should also be done as instructed in the "Inconsistency Detection" section. 56 | 57 | *Jointly trained update + detection* 58 | 59 | To train, simply replace `--task=detect` with `--task=dual` in the configurations given for "Inconsistency Detection." For inference, additionally include `--test_mode --rerank`. 60 | 61 | **Displaying metrics:** 62 | 63 | To display metrics for the full test set as well as the cleaned test sample, run: 64 | 65 | ``` 66 | python3 display_scores.py --detection_output_file=[PATH TO DETECTION PREDICTIONS] --update_output_file=[PATH TO UPDATE PREDICTIONS] 67 | ``` 68 | 69 | For evaluating in the pretrained update + detection setting, both filepaths are required. For all other settings, only one should be specified. 70 | 71 | **AST Diffing:** 72 | 73 | The AST diffs were built using Java files provided by [Pengyu Nie](https://github.com/pengyunie). First, download `ast-diffing-1.6-jar-with-dependencies.jar` from [here](https://drive.google.com/file/d/1JVfIfJoDDSFBaFOhK18UsBOmC39z03am/view?usp=sharing). Then, go to `data_processing/ast_diffing/python` and run: 74 | 75 | ``` 76 | python3 xml_diff_parser.py --old_sample_path=[PATH TO OLD VERSION OF CODE] --new_sample_path=[PATH TO NEW VERSION OF CODE] --jar_path=[PATH TO DOWNLOADED JAR FILE] 77 | ``` 78 | 79 | You can see an example by running: 80 | 81 | ``` 82 | python3 xml_diff_parser.py --old_sample_path=../code_samples/old.java --new_sample_path=../code_samples/new.java --jar_path=[PATH TO DOWNLOADED JAR FILE] 83 | ``` -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panthap2/deep-jit-inconsistency-detection/dacf8513c155f35157eedc2bf630212bf815544c/__init__.py -------------------------------------------------------------------------------- /ast_graph_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from constants import * 6 | from gnn import GatedGraphNeuralNetwork, AdjacencyList 7 | 8 | class ASTGraphEncoder(nn.Module): 9 | """Encoder which learns a representation of a method's AST. The underlying network is a Gated Graph Neural Network.""" 10 | def __init__(self, hidden_size, num_edge_types): 11 | super(ASTGraphEncoder, self).__init__() 12 | self.hidden_size = hidden_size 13 | self.num_edge_types = num_edge_types 14 | self.gnn = GatedGraphNeuralNetwork(self.hidden_size, self.num_edge_types, 15 | [GNN_LAYER_TIMESTEPS], {}, GNN_DROPOUT_RATE, GNN_DROPOUT_RATE) 16 | 17 | def forward(self, initial_node_representation, graph_batch, device): 18 | adjacency_lists = [] 19 | for edge_type in range(self.num_edge_types): 20 | adjacency_lists.append(AdjacencyList(node_num=graph_batch.num_nodes, 21 | adj_list=graph_batch.edges[edge_type], device=device)) 22 | node_representations = self.gnn.compute_node_representations( 23 | initial_node_representation=initial_node_representation, adjacency_lists=adjacency_lists) 24 | hidden_states = node_representations[graph_batch.node_positions] 25 | return hidden_states -------------------------------------------------------------------------------- /baselines/codebert_bow.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | import numpy as np 4 | import os 5 | import random 6 | import torch 7 | from torch import nn 8 | from transformers import * 9 | import sys 10 | import json 11 | 12 | sys.path.append('../') 13 | sys.path.append('../comment_update') 14 | from constants import * 15 | from data_loader import get_data_splits 16 | from detection_evaluation_utils import compute_score 17 | 18 | BERT_HIDDEN_SIZE = 768 19 | DROPOUT_RATE = 0.6 20 | BATCH_SIZE = 100 21 | CLASSIFICATION_HIDDEN_SIZE = 256 22 | # TRANSFORMERS_CACHE='' # TODO: Fill in 23 | 24 | class BERTBatch(): 25 | def __init__(self, old_comment_ids, old_comment_lengths, 26 | new_code_ids, new_code_lengths, diff_code_ids, diff_code_lengths, labels): 27 | self.old_comment_ids = old_comment_ids 28 | self.old_comment_lengths = old_comment_lengths 29 | self.new_code_ids = new_code_ids 30 | self.new_code_lengths = new_code_lengths 31 | self.diff_code_ids = diff_code_ids 32 | self.diff_code_lengths = diff_code_lengths 33 | self.labels = labels 34 | 35 | class BERTClassifier(nn.Module): 36 | def __init__(self, model_path, new_code, diff_code): 37 | super(BERTClassifier, self).__init__() 38 | self.model_path = model_path 39 | self.new_code = new_code 40 | self.diff_code = diff_code 41 | 42 | self.code_tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base", cache_dir=TRANSFORMERS_CACHE) 43 | self.code_model = RobertaModel.from_pretrained("microsoft/codebert-base", cache_dir=TRANSFORMERS_CACHE) 44 | self.comment_tokenizer = self.code_tokenizer 45 | self.comment_model = self.code_model 46 | 47 | self.torch_device_name = 'cpu' 48 | self.max_nl_length = 0 49 | self.max_code_length = 0 50 | 51 | print('Model path: {}'.format(self.model_path)) 52 | print('New code: {}'.format(self.new_code)) 53 | print('Diff code: {}'.format(self.diff_code)) 54 | sys.stdout.flush() 55 | 56 | def initialize(self, train_examples): 57 | self.max_nl_length = 200 58 | self.max_code_length = 200 59 | 60 | output_size = BERT_HIDDEN_SIZE 61 | 62 | if self.new_code: 63 | output_size += BERT_HIDDEN_SIZE 64 | if self.diff_code: 65 | output_size += BERT_HIDDEN_SIZE 66 | 67 | self.classification_dropout_layer = nn.Dropout(p=DROPOUT_RATE) 68 | self.fc1 = nn.Linear(output_size, CLASSIFICATION_HIDDEN_SIZE) 69 | self.fc2 = nn.Linear(CLASSIFICATION_HIDDEN_SIZE, CLASSIFICATION_HIDDEN_SIZE) 70 | self.output_layer = nn.Linear(CLASSIFICATION_HIDDEN_SIZE, NUM_CLASSES) 71 | 72 | self.optimizer = torch.optim.Adam(self.parameters(), lr=LR) 73 | 74 | def get_code_inputs(self, input_text, max_length): 75 | tokens = self.code_tokenizer.tokenize(input_text) 76 | length = min(len(tokens), max_length) 77 | tokens = tokens[:length] 78 | token_ids = self.code_tokenizer.convert_tokens_to_ids(tokens) 79 | 80 | padding_length = max_length - len(tokens) 81 | token_ids += [self.code_tokenizer.pad_token_id]*padding_length 82 | return token_ids, length 83 | 84 | def get_comment_inputs(self, input_text, max_length): 85 | tokens = self.comment_tokenizer.tokenize(input_text) 86 | length = min(len(tokens), max_length) 87 | tokens = tokens[:length] 88 | token_ids = self.comment_tokenizer.convert_tokens_to_ids(tokens) 89 | 90 | padding_length = max_length - len(tokens) 91 | token_ids += [self.comment_tokenizer.pad_token_id]*padding_length 92 | return token_ids, length 93 | 94 | def get_batches(self, dataset, shuffle=False): 95 | batches = [] 96 | if shuffle: 97 | random.shuffle(dataset) 98 | 99 | curr_idx = 0 100 | while curr_idx < len(dataset): 101 | batch_idx = 0 102 | 103 | start_idx = curr_idx 104 | end_idx = min(start_idx + BATCH_SIZE, len(dataset)) 105 | labels = [] 106 | old_comment_ids = [] 107 | old_comment_lengths = [] 108 | new_code_ids = [] 109 | new_code_lengths = [] 110 | diff_code_ids = [] 111 | diff_code_lengths = [] 112 | 113 | for i in range(start_idx, end_idx): 114 | comment_ids, comment_length = self.get_comment_inputs(dataset[i].old_comment_raw, self.max_nl_length) 115 | old_comment_ids.append(comment_ids) 116 | old_comment_lengths.append(comment_length) 117 | 118 | if self.new_code: 119 | code_ids, code_length = self.get_code_inputs(dataset[i].new_code_raw, self.max_code_length) 120 | new_code_ids.append(code_ids) 121 | new_code_lengths.append(code_length) 122 | 123 | if self.diff_code: 124 | code_ids, code_length = self.get_code_inputs(' '.join(dataset[i].span_diff_code_tokens), self.max_code_length) 125 | diff_code_ids.append(code_ids) 126 | diff_code_lengths.append(code_length) 127 | 128 | labels.append(dataset[i].label) 129 | 130 | curr_idx = end_idx 131 | batches.append(BERTBatch( 132 | torch.tensor(old_comment_ids, dtype=torch.int64, device=self.get_device()), 133 | torch.tensor(old_comment_lengths, dtype=torch.int64, device=self.get_device()), 134 | torch.tensor(new_code_ids, dtype=torch.int64, device=self.get_device()), 135 | torch.tensor(new_code_lengths, dtype=torch.int64, device=self.get_device()), 136 | torch.tensor(diff_code_ids, dtype=torch.int64, device=self.get_device()), 137 | torch.tensor(diff_code_lengths, dtype=torch.int64, device=self.get_device()), 138 | torch.tensor(labels, dtype=torch.int64, device=self.get_device()) 139 | )) 140 | 141 | return batches 142 | 143 | def get_code_representation(self, input_ids, masks): 144 | embeddings = self.code_model.embeddings(input_ids) 145 | if self.torch_device_name == 'cpu': 146 | factor = masks.type(torch.FloatTensor).unsqueeze(-1) 147 | else: 148 | factor = masks.type(torch.FloatTensor).cuda(self.get_device()).unsqueeze(-1) 149 | embeddings = embeddings * factor 150 | vector = torch.sum(embeddings, dim=1)/torch.sum(factor, dim=1) 151 | return embeddings, vector 152 | 153 | def get_comment_representation(self, input_ids, masks): 154 | embeddings = self.comment_model.embeddings(input_ids) 155 | if self.torch_device_name == 'cpu': 156 | factor = masks.type(torch.FloatTensor).unsqueeze(-1) 157 | else: 158 | factor = masks.type(torch.FloatTensor).cuda(self.get_device()).unsqueeze(-1) 159 | embeddings = embeddings * factor 160 | vector = torch.sum(embeddings, dim=1)/torch.sum(factor, dim=1) 161 | return embeddings, vector 162 | 163 | def get_input_features(self, batch_data): 164 | old_comment_masks = (torch.arange( 165 | batch_data.old_comment_ids.shape[1], device=self.get_device()).view(1, -1) < batch_data.old_comment_lengths.view(-1, 1)) 166 | old_comment_hidden_states, old_comment_final_state = self.get_comment_representation(batch_data.old_comment_ids, old_comment_masks) 167 | final_state = old_comment_final_state 168 | 169 | if self.new_code: 170 | new_code_masks = (torch.arange( 171 | batch_data.new_code_ids.shape[1], device=self.get_device()).view(1, -1) < batch_data.new_code_lengths.view(-1, 1)) 172 | new_code_hidden_states, new_code_final_state = self.get_code_representation(batch_data.new_code_ids, new_code_masks) 173 | final_state = torch.cat([final_state, new_code_final_state], dim=-1) 174 | 175 | if self.diff_code: 176 | diff_code_masks = (torch.arange( 177 | batch_data.diff_code_ids.shape[1], device=self.get_device()).view(1, -1) < batch_data.diff_code_lengths.view(-1, 1)) 178 | diff_code_hidden_states, diff_code_final_state = self.get_code_representation(batch_data.diff_code_ids, diff_code_masks) 179 | final_state = torch.cat([final_state, diff_code_final_state], dim=-1) 180 | 181 | return final_state 182 | 183 | def get_logits(self, batch_data): 184 | all_features = self.get_input_features(batch_data) 185 | all_features = self.classification_dropout_layer(torch.nn.functional.relu(self.fc1(all_features))) 186 | all_features = self.classification_dropout_layer(torch.nn.functional.relu(self.fc2(all_features))) 187 | 188 | return self.output_layer(all_features) 189 | 190 | def get_logprobs(self, batch_data): 191 | logits = self.get_logits(batch_data) 192 | return torch.nn.functional.log_softmax(logits, dim=-1) 193 | 194 | def forward(self, batch_data, is_training=True): 195 | logprobs = self.get_logprobs(batch_data) 196 | loss = torch.nn.functional.nll_loss(logprobs, batch_data.labels) 197 | return loss, logprobs 198 | 199 | def run_train(self, train_examples, valid_examples): 200 | best_loss = float('inf') 201 | best_f1 = 0.0 202 | patience_tally = 0 203 | valid_batches = self.get_batches(valid_examples) 204 | 205 | for epoch in range(MAX_EPOCHS): 206 | if patience_tally > PATIENCE: 207 | print('Terminating') 208 | break 209 | 210 | self.train() 211 | train_batches = self.get_batches(train_examples, shuffle=True) 212 | 213 | train_loss = 0 214 | for batch_data in train_batches: 215 | train_loss += self.run_gradient_step(batch_data) 216 | 217 | self.eval() 218 | validation_loss = 0 219 | validation_predicted_labels = [] 220 | validation_gold_labels = [] 221 | with torch.no_grad(): 222 | for batch_data in valid_batches: 223 | b_loss, b_logprobs = self.forward(batch_data) 224 | validation_loss += float(b_loss.cpu()) 225 | validation_predicted_labels.extend(b_logprobs.argmax(-1).tolist()) 226 | validation_gold_labels.extend(batch_data.labels.tolist()) 227 | 228 | validation_loss = validation_loss/len(valid_batches) 229 | validation_precision, validation_recall, validation_f1 = compute_score( 230 | validation_predicted_labels, validation_gold_labels, verbose=False) 231 | 232 | if validation_f1 >= best_f1: 233 | best_f1 = validation_f1 234 | torch.save(self, self.model_path) 235 | saved = True 236 | patience_tally = 0 237 | else: 238 | saved = False 239 | patience_tally += 1 240 | 241 | print('Epoch: {}'.format(epoch)) 242 | print('Training loss: {:.3f}'.format(train_loss/len(train_batches))) 243 | print('Validation loss: {:.3f}'.format(validation_loss)) 244 | print('Validation precision: {:.3f}'.format(validation_precision)) 245 | print('Validation recall: {:.3f}'.format(validation_recall)) 246 | print('Validation f1: {:.3f}'.format(validation_f1)) 247 | if saved: 248 | print('Saved') 249 | print('-----------------------------------') 250 | sys.stdout.flush() 251 | 252 | def get_device(self): 253 | """Returns the proper device.""" 254 | if self.torch_device_name == 'gpu': 255 | return torch.device('cuda') 256 | else: 257 | return torch.device('cpu') 258 | 259 | def run_gradient_step(self, batch_data): 260 | """Performs gradient step.""" 261 | self.optimizer.zero_grad() 262 | loss, _ = self.forward(batch_data) 263 | loss.backward() 264 | self.optimizer.step() 265 | return float(loss.cpu()) 266 | 267 | def run_evaluation(self, test_examples, write_file): 268 | self.eval() 269 | 270 | test_batches = self.get_batches(test_examples) 271 | test_predictions = [] 272 | 273 | with torch.no_grad(): 274 | for b, batch in enumerate(test_batches): 275 | print('Testing batch {}/{}'.format(b, len(test_batches))) 276 | sys.stdout.flush() 277 | batch_logprobs = self.get_logprobs(batch) 278 | test_predictions.extend(batch_logprobs.argmax(dim=-1).tolist()) 279 | 280 | self.compute_metrics(test_predictions, test_examples, write_file) 281 | 282 | def compute_metrics(self, predicted_labels, test_examples, write_file): 283 | gold_labels = [] 284 | correct = 0 285 | 286 | print('Writing to: {}'.format(write_file)) 287 | with open(write_file, 'w+') as f: 288 | for e, ex in enumerate(test_examples): 289 | f.write('{} {}\n'.format(ex.id, predicted_labels[e])) 290 | gold_label = ex.label 291 | if gold_label == predicted_labels[e]: 292 | correct += 1 293 | gold_labels.append(gold_label) 294 | 295 | accuracy = float(correct)/len(test_examples) 296 | precision, recall, f1 = compute_score(predicted_labels, gold_labels, False) 297 | 298 | print('Precision: {}'.format(precision)) 299 | print('Recall: {}'.format(recall)) 300 | print('F1: {}'.format(f1)) 301 | print('Accuracy: {}'.format(accuracy)) 302 | 303 | if __name__ == "__main__": 304 | parser = argparse.ArgumentParser() 305 | parser.add_argument('--new_code', action='store_true') 306 | parser.add_argument('--diff_code', action='store_true') 307 | parser.add_argument('--comment_type') 308 | parser.add_argument('--trial') 309 | parser.add_argument('--test_mode', action='store_true') 310 | args = parser.parse_args() 311 | 312 | print('Starting') 313 | sys.stdout.flush() 314 | 315 | train_examples, valid_examples, test_examples, high_level_details = get_data_splits() 316 | 317 | print('Train: {}'.format(len(train_examples))) 318 | print('Valid: {}'.format(len(valid_examples))) 319 | print('Test: {}'.format(len(test_examples))) 320 | sys.stdout.flush() 321 | 322 | model_name = 'bert' 323 | 324 | if args.new_code: 325 | model_name += '-new_code' 326 | if args.diff_code: 327 | model_name += '-diff_code' 328 | 329 | if args.comment_type: 330 | model_name += '-{}'.format(args.comment_type) 331 | if args.trial: 332 | model_name += '-{}'.format(args.trial) 333 | 334 | # Assumes that saved_bert_models directory exists 335 | model_path = 'saved_bert_models/{}.pkl.gz'.format(model_name) 336 | sys.stdout.flush() 337 | 338 | if args.test_mode: 339 | print('Loading model from: {}'.format(model_path)) 340 | print('Starting evaluation: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S"))) 341 | sys.stdout.flush() 342 | model = torch.load(model_path) 343 | if torch.cuda.is_available(): 344 | model.torch_device_name = 'gpu' 345 | model.cuda() 346 | for c in model.children(): 347 | c.cuda() 348 | else: 349 | model.torch_device_name = 'cpu' 350 | model.cpu() 351 | for c in model.children(): 352 | c.cpu() 353 | 354 | # Assumes that bert_predictions directory exists 355 | write_file = os.path.join('bert_predictions', '{}.txt'.format(model_name)) 356 | model.run_evaluation(test_examples, write_file) 357 | print('Terminating evaluation: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S"))) 358 | else: 359 | print('Starting training: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S"))) 360 | sys.stdout.flush() 361 | model = BERTClassifier(model_path, args.new_code, args.diff_code) 362 | model.initialize(train_examples) 363 | 364 | if torch.cuda.is_available(): 365 | model.torch_device_name = 'gpu' 366 | model.cuda() 367 | for c in model.children(): 368 | c.cuda() 369 | else: 370 | model.torch_device_name = 'cpu' 371 | model.cpu() 372 | for c in model.children(): 373 | c.cpu() 374 | 375 | model.run_train(train_examples, valid_examples) 376 | print('Terminating training: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S"))) 377 | 378 | 379 | -------------------------------------------------------------------------------- /comment_update/SARI.py: -------------------------------------------------------------------------------- 1 | # ======================================================= 2 | # SARI -- Text Simplification Tunable Evaluation Metric 3 | # ======================================================= 4 | # 5 | # Author: Wei Xu (UPenn xwe@cis.upenn.edu) 6 | # 7 | # A Python implementation of the SARI metric for text simplification 8 | # evaluation in the following paper 9 | # 10 | # "Optimizing Statistical Machine Translation for Text Simplification" 11 | # Wei Xu, Courtney Napoles, Ellie Pavlick, Quanze Chen and Chris Callison-Burch 12 | # In Transactions of the Association for Computational Linguistics (TACL) 2015 13 | # 14 | # There is also a Java implementation of the SARI metric 15 | # that is integrated into the Joshua MT Decoder. It can 16 | # be used for tuning Joshua models for a real end-to-end 17 | # text simplification model. 18 | # 19 | 20 | from __future__ import division 21 | from collections import Counter 22 | import sys 23 | 24 | 25 | 26 | def ReadInFile (filename): 27 | 28 | with open(filename) as f: 29 | lines = f.readlines() 30 | lines = [x.strip() for x in lines] 31 | return lines 32 | 33 | 34 | def SARIngram(sgrams, cgrams, rgramslist, numref): 35 | rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams] 36 | rgramcounter = Counter(rgramsall) 37 | 38 | sgramcounter = Counter(sgrams) 39 | sgramcounter_rep = Counter() 40 | for sgram, scount in sgramcounter.items(): 41 | sgramcounter_rep[sgram] = scount * numref 42 | 43 | cgramcounter = Counter(cgrams) 44 | cgramcounter_rep = Counter() 45 | for cgram, ccount in cgramcounter.items(): 46 | cgramcounter_rep[cgram] = ccount * numref 47 | 48 | 49 | # KEEP 50 | keepgramcounter_rep = sgramcounter_rep & cgramcounter_rep 51 | keepgramcountergood_rep = keepgramcounter_rep & rgramcounter 52 | keepgramcounterall_rep = sgramcounter_rep & rgramcounter 53 | 54 | keeptmpscore1 = 0 55 | keeptmpscore2 = 0 56 | for keepgram in keepgramcountergood_rep: 57 | keeptmpscore1 += keepgramcountergood_rep[keepgram] / keepgramcounter_rep[keepgram] 58 | keeptmpscore2 += keepgramcountergood_rep[keepgram] / keepgramcounterall_rep[keepgram] 59 | #print "KEEP", keepgram, keepscore, cgramcounter[keepgram], sgramcounter[keepgram], rgramcounter[keepgram] 60 | keepscore_precision = 0 61 | if len(keepgramcounter_rep) > 0: 62 | keepscore_precision = keeptmpscore1 / len(keepgramcounter_rep) 63 | keepscore_recall = 0 64 | if len(keepgramcounterall_rep) > 0: 65 | keepscore_recall = keeptmpscore2 / len(keepgramcounterall_rep) 66 | keepscore = 0 67 | if keepscore_precision > 0 or keepscore_recall > 0: 68 | keepscore = 2 * keepscore_precision * keepscore_recall / (keepscore_precision + keepscore_recall) 69 | 70 | 71 | # DELETION 72 | delgramcounter_rep = sgramcounter_rep - cgramcounter_rep 73 | delgramcountergood_rep = delgramcounter_rep - rgramcounter 74 | delgramcounterall_rep = sgramcounter_rep - rgramcounter 75 | deltmpscore1 = 0 76 | deltmpscore2 = 0 77 | for delgram in delgramcountergood_rep: 78 | deltmpscore1 += delgramcountergood_rep[delgram] / delgramcounter_rep[delgram] 79 | deltmpscore2 += delgramcountergood_rep[delgram] / delgramcounterall_rep[delgram] 80 | delscore_precision = 0 81 | if len(delgramcounter_rep) > 0: 82 | delscore_precision = deltmpscore1 / len(delgramcounter_rep) 83 | delscore_recall = 0 84 | if len(delgramcounterall_rep) > 0: 85 | delscore_recall = deltmpscore1 / len(delgramcounterall_rep) 86 | delscore = 0 87 | if delscore_precision > 0 or delscore_recall > 0: 88 | delscore = 2 * delscore_precision * delscore_recall / (delscore_precision + delscore_recall) 89 | 90 | 91 | # ADDITION 92 | addgramcounter = set(cgramcounter) - set(sgramcounter) 93 | addgramcountergood = set(addgramcounter) & set(rgramcounter) 94 | addgramcounterall = set(rgramcounter) - set(sgramcounter) 95 | 96 | addtmpscore = 0 97 | for addgram in addgramcountergood: 98 | addtmpscore += 1 99 | 100 | addscore_precision = 0 101 | addscore_recall = 0 102 | if len(addgramcounter) > 0: 103 | addscore_precision = addtmpscore / len(addgramcounter) 104 | if len(addgramcounterall) > 0: 105 | addscore_recall = addtmpscore / len(addgramcounterall) 106 | addscore = 0 107 | if addscore_precision > 0 or addscore_recall > 0: 108 | addscore = 2 * addscore_precision * addscore_recall / (addscore_precision + addscore_recall) 109 | 110 | return (keepscore, delscore_precision, addscore) 111 | 112 | 113 | def SARIsent (ssent, csent, rsents) : 114 | numref = len(rsents) 115 | 116 | s1grams = ssent.lower().split(" ") 117 | c1grams = csent.lower().split(" ") 118 | s2grams = [] 119 | c2grams = [] 120 | s3grams = [] 121 | c3grams = [] 122 | s4grams = [] 123 | c4grams = [] 124 | 125 | r1gramslist = [] 126 | r2gramslist = [] 127 | r3gramslist = [] 128 | r4gramslist = [] 129 | for rsent in rsents: 130 | r1grams = rsent.lower().split(" ") 131 | r2grams = [] 132 | r3grams = [] 133 | r4grams = [] 134 | r1gramslist.append(r1grams) 135 | for i in range(0, len(r1grams)-1) : 136 | if i < len(r1grams) - 1: 137 | r2gram = r1grams[i] + " " + r1grams[i+1] 138 | r2grams.append(r2gram) 139 | if i < len(r1grams)-2: 140 | r3gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2] 141 | r3grams.append(r3gram) 142 | if i < len(r1grams)-3: 143 | r4gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2] + " " + r1grams[i+3] 144 | r4grams.append(r4gram) 145 | r2gramslist.append(r2grams) 146 | r3gramslist.append(r3grams) 147 | r4gramslist.append(r4grams) 148 | 149 | for i in range(0, len(s1grams)-1) : 150 | if i < len(s1grams) - 1: 151 | s2gram = s1grams[i] + " " + s1grams[i+1] 152 | s2grams.append(s2gram) 153 | if i < len(s1grams)-2: 154 | s3gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2] 155 | s3grams.append(s3gram) 156 | if i < len(s1grams)-3: 157 | s4gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2] + " " + s1grams[i+3] 158 | s4grams.append(s4gram) 159 | 160 | for i in range(0, len(c1grams)-1) : 161 | if i < len(c1grams) - 1: 162 | c2gram = c1grams[i] + " " + c1grams[i+1] 163 | c2grams.append(c2gram) 164 | if i < len(c1grams)-2: 165 | c3gram = c1grams[i] + " " + c1grams[i+1] + " " + c1grams[i+2] 166 | c3grams.append(c3gram) 167 | if i < len(c1grams)-3: 168 | c4gram = c1grams[i] + " " + c1grams[i+1] + " " + c1grams[i+2] + " " + c1grams[i+3] 169 | c4grams.append(c4gram) 170 | 171 | 172 | (keep1score, del1score, add1score) = SARIngram(s1grams, c1grams, r1gramslist, numref) 173 | (keep2score, del2score, add2score) = SARIngram(s2grams, c2grams, r2gramslist, numref) 174 | (keep3score, del3score, add3score) = SARIngram(s3grams, c3grams, r3gramslist, numref) 175 | (keep4score, del4score, add4score) = SARIngram(s4grams, c4grams, r4gramslist, numref) 176 | avgkeepscore = sum([keep1score,keep2score,keep3score,keep4score])/4 177 | avgdelscore = sum([del1score,del2score,del3score,del4score])/4 178 | avgaddscore = sum([add1score,add2score,add3score,add4score])/4 179 | finalscore = (avgkeepscore + avgdelscore + avgaddscore ) / 3 180 | 181 | return finalscore 182 | 183 | 184 | def main(): 185 | 186 | fnamenorm = "./turkcorpus/test.8turkers.tok.norm" 187 | fnamesimp = "./turkcorpus/test.8turkers.tok.simp" 188 | fnameturk = "./turkcorpus/test.8turkers.tok.turk." 189 | 190 | 191 | ssent = "About 95 species are currently accepted ." 192 | csent1 = "About 95 you now get in ." 193 | csent2 = "About 95 species are now agreed ." 194 | csent3 = "About 95 species are currently agreed ." 195 | rsents = ["About 95 species are currently known .", "About 95 species are now accepted .", "95 species are now accepted ."] 196 | 197 | print(SARIsent(ssent, csent1, rsents)) 198 | print(SARIsent(ssent, csent2, rsents)) 199 | print(SARIsent(ssent, csent3, rsents)) 200 | 201 | 202 | if __name__ == '__main__': 203 | main() -------------------------------------------------------------------------------- /comment_update/decoder.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class Decoder(nn.Module): 7 | def __init__(self, input_size, hidden_size, attention_state_size, embedding_store, 8 | embedding_size, dropout_rate): 9 | super(Decoder, self).__init__() 10 | self.input_size = input_size # Dimension of input into decoder cell 11 | self.hidden_size = hidden_size # Dimension of output from decoder cell 12 | self.attention_state_size = attention_state_size # Dimension of the encoder hidden states to attend to 13 | self.embedding_store = embedding_store 14 | self.gen_vocabulary_size = len(self.embedding_store.nl_vocabulary) 15 | self.embedding_size = embedding_size 16 | self.dropout_rate = dropout_rate 17 | 18 | self.gru = nn.GRU( 19 | input_size=self.input_size, 20 | hidden_size=self.hidden_size, 21 | batch_first=True 22 | ) 23 | 24 | # Parameters for attention 25 | self.attention_encoder_hidden_transform_matrix = nn.Parameter( 26 | torch.randn(self.attention_state_size, self.hidden_size, 27 | dtype=torch.float, requires_grad=True) 28 | ) 29 | self.attention_output_layer = nn.Linear(self.attention_state_size + self.hidden_size, 30 | self.hidden_size, bias=False) 31 | 32 | # Parameters for generating/copying 33 | self.generation_output_matrix = nn.Parameter( 34 | torch.randn(self.hidden_size, self.gen_vocabulary_size, 35 | dtype=torch.float, requires_grad=True) 36 | ) 37 | 38 | self.copy_encoder_hidden_transform_matrix = nn.Parameter( 39 | torch.randn(self.attention_state_size, self.hidden_size, 40 | dtype=torch.float, requires_grad=True) 41 | ) 42 | 43 | @abstractmethod 44 | def decode(self): 45 | return NotImplemented 46 | 47 | @abstractmethod 48 | def forward(self, initial_state, decoder_input_embeddings, encoder_hidden_states, masks): 49 | return NotImplemented -------------------------------------------------------------------------------- /comment_update/embedding_store.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import Counter, defaultdict 3 | from dpu_utils.mlutils import Vocabulary 4 | import heapq 5 | import json 6 | import logging 7 | import numpy as np 8 | import os 9 | import random 10 | import sys 11 | import torch 12 | from torch import nn 13 | 14 | from constants import START, END, NL_EMBEDDING_PATH, CODE_EMBEDDING_PATH, MAX_VOCAB_SIZE,\ 15 | NL_EMBEDDING_SIZE, CODE_EMBEDDING_SIZE 16 | from diff_utils import get_edit_keywords 17 | 18 | class EmbeddingStore(nn.Module): 19 | def __init__(self, nl_threshold, nl_embedding_size, nl_token_counter, 20 | code_threshold, code_embedding_size, code_token_counter, 21 | dropout_rate, num_src_embeddings, src_embedding_size, node_embedding_size, 22 | load_pretrained_embeddings=False): 23 | """Keeps track of the NL and code vocabularies and embeddings.""" 24 | super(EmbeddingStore, self).__init__() 25 | edit_keywords = get_edit_keywords() 26 | self.__nl_vocabulary = Vocabulary.create_vocabulary(tokens=edit_keywords, 27 | max_size=MAX_VOCAB_SIZE, 28 | count_threshold=1, 29 | add_pad=True) 30 | self.__nl_vocabulary.update(nl_token_counter, MAX_VOCAB_SIZE, nl_threshold) 31 | self.__nl_embedding_layer = nn.Embedding(num_embeddings=len(self.__nl_vocabulary), 32 | embedding_dim=nl_embedding_size, 33 | padding_idx=self.__nl_vocabulary.get_id_or_unk( 34 | Vocabulary.get_pad())) 35 | self.nl_embedding_dropout_layer = nn.Dropout(p=dropout_rate) 36 | 37 | 38 | self.__code_vocabulary = Vocabulary.create_vocabulary(tokens=edit_keywords, 39 | max_size=MAX_VOCAB_SIZE, 40 | count_threshold=1, 41 | add_pad=True) 42 | self.__code_vocabulary.update(code_token_counter, MAX_VOCAB_SIZE, code_threshold) 43 | self.__code_embedding_layer = nn.Embedding(num_embeddings=len(self.__code_vocabulary), 44 | embedding_dim=code_embedding_size, 45 | padding_idx=self.__code_vocabulary.get_id_or_unk( 46 | Vocabulary.get_pad())) 47 | self.code_embedding_dropout_layer = nn.Dropout(p=dropout_rate) 48 | 49 | self.src_embedding_layer = nn.Embedding(num_embeddings=num_src_embeddings, embedding_dim=src_embedding_size) 50 | self.src_embedding_dropout_layer = nn.Dropout(p=dropout_rate) 51 | self.node_synthesis_layer = nn.Linear(code_embedding_size+src_embedding_size, node_embedding_size, bias=False) 52 | 53 | print('NL vocabulary size: {}'.format(len(self.__nl_vocabulary))) 54 | print('Code vocabulary size: {}'.format(len(self.__code_vocabulary))) 55 | 56 | if load_pretrained_embeddings: 57 | self.initialize_embeddings() 58 | 59 | def initialize_embeddings(self): 60 | with open(NL_EMBEDDING_PATH) as f: 61 | nl_embeddings = json.load(f) 62 | 63 | nl_weights_matrix = np.zeros((len(self.__nl_vocabulary), NL_EMBEDDING_SIZE), dtype=np.float64) 64 | nl_word_count = 0 65 | for i, word in enumerate(self.__nl_vocabulary.id_to_token): 66 | try: 67 | nl_weights_matrix[i] = nl_embeddings[word] 68 | nl_word_count += 1 69 | except KeyError: 70 | nl_weights_matrix[i] = np.random.normal(scale=0.6, size=(NL_EMBEDDING_SIZE, )) 71 | 72 | self.__nl_embedding_layer.weight = torch.nn.Parameter(torch.FloatTensor(nl_weights_matrix), 73 | requires_grad=True) 74 | 75 | with open(CODE_EMBEDDING_PATH) as f: 76 | code_embeddings = json.load(f) 77 | 78 | code_weights_matrix = np.zeros((len(self.__code_vocabulary), CODE_EMBEDDING_SIZE)) 79 | code_word_count = 0 80 | for i, word in enumerate(self.__code_vocabulary.id_to_token): 81 | try: 82 | code_weights_matrix[i] = code_embeddings[word] 83 | code_word_count += 1 84 | except KeyError: 85 | code_weights_matrix[i] = np.random.normal(scale=0.6, size=(CODE_EMBEDDING_SIZE, )) 86 | 87 | self.__code_embedding_layer.weight = torch.nn.Parameter(torch.FloatTensor(code_weights_matrix), 88 | requires_grad=True) 89 | 90 | print('Using {} pre-trained NL embeddings'.format(nl_word_count)) 91 | print('Using {} pre-trained code embeddings'.format(code_word_count)) 92 | 93 | def get_nl_embeddings(self, token_ids): 94 | return self.nl_embedding_dropout_layer(self.__nl_embedding_layer(token_ids)) 95 | 96 | def get_code_embeddings(self, token_ids): 97 | return self.code_embedding_dropout_layer(self.__code_embedding_layer(token_ids)) 98 | 99 | def get_src_embeddings(self, src_ids): 100 | return self.src_embedding_dropout_layer(self.src_embedding_layer(src_ids)) 101 | 102 | def get_node_embeddings(self, lookup_ids, src_ids): 103 | lookup_embeddings = self.get_code_embeddings(lookup_ids) 104 | src_embeddings = self.get_src_embeddings(src_ids) 105 | 106 | embeddings = torch.cat([lookup_embeddings, src_embeddings], dim=-1) 107 | node_embeddings = self.node_synthesis_layer(embeddings) 108 | return node_embeddings 109 | 110 | @property 111 | def nl_vocabulary(self): 112 | return self.__nl_vocabulary 113 | 114 | @property 115 | def code_vocabulary(self): 116 | return self.__code_vocabulary 117 | 118 | @property 119 | def nl_embedding_layer(self): 120 | return self.__nl_embedding_layer 121 | 122 | @property 123 | def code_embedding_layer(self): 124 | return self.__code_embedding_layer 125 | 126 | def get_padded_code_ids(self, code_sequence, pad_length): 127 | return self.__code_vocabulary.get_id_or_unk_multiple(code_sequence, 128 | pad_to_size=pad_length, 129 | padding_element=self.__code_vocabulary.get_id_or_unk( 130 | Vocabulary.get_pad()), 131 | ) 132 | 133 | def get_padded_nl_ids(self, nl_sequence, pad_length): 134 | return self.__nl_vocabulary.get_id_or_unk_multiple(nl_sequence, 135 | pad_to_size=pad_length, 136 | padding_element=self.__nl_vocabulary.get_id_or_unk( 137 | Vocabulary.get_pad()), 138 | ) 139 | 140 | def get_extended_padded_nl_ids(self, nl_sequence, pad_length, inp_ids, inp_tokens): 141 | # Derived from: https://github.com/microsoft/dpu-utils/blob/master/python/dpu_utils/mlutils/vocabulary.py 142 | nl_ids = [] 143 | for token in nl_sequence: 144 | nl_id = self.get_nl_id(token) 145 | if self.is_nl_unk(nl_id) and token in inp_tokens: 146 | copy_idx = inp_tokens.index(token) 147 | nl_id = inp_ids[copy_idx] 148 | nl_ids.append(nl_id) 149 | 150 | if len(nl_ids) > pad_length: 151 | return nl_ids[:pad_length] 152 | else: 153 | padding = [self.__nl_vocabulary.get_id_or_unk(Vocabulary.get_pad())] * (pad_length - len(nl_ids)) 154 | return nl_ids + padding 155 | 156 | def pad_length(self, sequence, target_length): 157 | if len(sequence) >= target_length: 158 | return sequence[:target_length] 159 | else: 160 | return sequence + [self.__nl_vocabulary.get_id_or_unk(Vocabulary.get_pad()) for _ in range(target_length-len(sequence))] 161 | 162 | def get_code_id(self, token): 163 | return self.__code_vocabulary.get_id_or_unk(token) 164 | 165 | def is_code_unk(self, id): 166 | return id == self.__code_vocabulary.get_id_or_unk(Vocabulary.get_unk()) 167 | 168 | def get_code_token(self, token_id): 169 | return self.__code_vocabulary.get_name_for_id(token_id) 170 | 171 | def get_nl_id(self, token): 172 | return self.__nl_vocabulary.get_id_or_unk(token) 173 | 174 | def is_nl_unk(self, id): 175 | return id == self.__nl_vocabulary.get_id_or_unk(Vocabulary.get_unk()) 176 | 177 | def get_nl_token(self, token_id): 178 | return self.__nl_vocabulary.get_name_for_id(token_id) 179 | 180 | def get_vocab_extended_nl_token(self, token_id, inp_ids, inp_tokens): 181 | if token_id < len(self.__nl_vocabulary): 182 | return self.get_nl_token(token_id) 183 | elif token_id in inp_ids: 184 | copy_idx = inp_ids.index(token_id) 185 | return inp_tokens[copy_idx] 186 | else: 187 | return Vocabulary.get_unk() 188 | 189 | def get_nl_tokens(self, token_ids, inp_ids, inp_tokens): 190 | tokens = [self.get_vocab_extended_nl_token(t, inp_ids, inp_tokens) for t in token_ids] 191 | if END in tokens: 192 | return tokens[:tokens.index(END)] 193 | return tokens 194 | 195 | def get_end_id(self): 196 | return self.get_nl_id(END) 197 | 198 | def get_nl_pad_id(self): 199 | return self.__nl_vocabulary.get_id_or_unk(Vocabulary.get_pad()) 200 | 201 | def get_code_pad_id(self): 202 | return self.__code_vocabulary.get_id_or_unk(Vocabulary.get_pad()) -------------------------------------------------------------------------------- /comment_update/external_cache.py: -------------------------------------------------------------------------------- 1 | import json 2 | from nltk.corpus import stopwords 3 | from nltk.tokenize import word_tokenize 4 | from nltk import pos_tag 5 | import numpy as np 6 | import os 7 | import re 8 | 9 | from constants import * 10 | from diff_utils import * 11 | 12 | method_details = dict() 13 | tokenization_features = dict() 14 | for d in os.listdir(RESOURCES_PATH): 15 | try: 16 | with open(os.path.join(RESOURCES_PATH, d, 'high_level_details.json')) as f: 17 | method_details.update(json.load(f)) 18 | with open(os.path.join(RESOURCES_PATH, d, 'tokenization_features.json')) as f: 19 | tokenization_features.update(json.load(f)) 20 | except: 21 | print('Failed parsing: {}'.format(d)) 22 | 23 | stop_words = set(stopwords.words('english')) 24 | java_keywords = set(['abstract', 'assert', 'boolean', 'break', 'byte', 'case', 'catch', 'char', 'class', 25 | 'continue', 'default', 'do', 'double', 'else', 'enum', 'extends', 'final', 'finally', 26 | 'float', 'for', 'if', 'implements', 'import', 'instanceof', 'int', 'interface', 'long', 27 | 'native', 'new', 'null', 'package', 'private', 'protected', 'public', 'return', 'short', 28 | 'static', 'strictfp', 'super', 'switch', 'synchronized', 'this', 'throw', 'throws', 'transient', 29 | 'try', 'void', 'volatile', 'while']) 30 | 31 | tags = ['CC','CD','DT','EX','FW','IN','JJ','JJR','JJS','LS','MD','NN','NNS','NNP','NNPS','PDT', 32 | 'POS','PRP','PRP$','RB','RBR','RBS','RP','TO','UH','VB','VBD','VBG','VBN','VBP','VBZ','WDT','WP','WP$','WRB', 33 | 'OTHER'] 34 | 35 | NUM_CODE_FEATURES = 19 36 | NUM_NL_FEATURES = 17 + len(tags) 37 | 38 | def get_num_code_features(): 39 | return NUM_CODE_FEATURES 40 | 41 | def get_num_nl_features(): 42 | return NUM_NL_FEATURES 43 | 44 | def is_java_keyword(token): 45 | return token in java_keywords 46 | 47 | def is_operator(token): 48 | for s in token: 49 | if s.isalnum(): 50 | return False 51 | return True 52 | 53 | def get_return_type_subtokens(example): 54 | return method_details[example.id]['new']['subtoken']['return_type'] 55 | 56 | def get_old_return_type_subtokens(example): 57 | return method_details[example.id]['old']['subtoken']['return_type'] 58 | 59 | def get_method_name_subtokens(example): 60 | return method_details[example.id]['new']['subtoken']['method_name'] 61 | 62 | def get_new_return_sequence(example): 63 | return method_details[example.id]['new']['subtoken']['return_statement'] 64 | 65 | def get_old_return_sequence(example): 66 | return method_details[example.id]['old']['subtoken']['return_statement'] 67 | 68 | def get_old_argument_type_subtokens(example): 69 | return method_details[example.id]['old']['subtoken']['argument_type'] 70 | 71 | def get_new_argument_type_subtokens(example): 72 | return method_details[example.id]['new']['subtoken']['argument_type'] 73 | 74 | def get_old_argument_name_subtokens(example): 75 | return method_details[example.id]['old']['subtoken']['argument_name'] 76 | 77 | def get_new_argument_name_subtokens(example): 78 | return method_details[example.id]['new']['subtoken']['argument_name'] 79 | 80 | def get_old_code(example): 81 | return example.old_code_raw 82 | 83 | def get_new_code(example): 84 | return example.new_code_raw 85 | 86 | def get_edit_span_subtoken_tokenization_labels(example): 87 | return tokenization_features[example.id]['edit_span_subtoken_labels'] 88 | 89 | def get_edit_span_subtoken_tokenization_indices(example): 90 | return tokenization_features[example.id]['edit_span_subtoken_indices'] 91 | 92 | def get_nl_subtoken_tokenization_labels(example): 93 | return tokenization_features[example.id]['old_nl_subtoken_labels'] 94 | 95 | def get_nl_subtoken_tokenization_indices(example): 96 | return tokenization_features[example.id]['old_nl_subtoken_indices'] 97 | 98 | def get_node_features(nodes, example, max_ast_length): 99 | old_return_type_subtokens = get_old_return_type_subtokens(example) 100 | new_return_type_subtokens = get_return_type_subtokens(example) 101 | method_name_subtokens = get_method_name_subtokens(example) 102 | 103 | old_return_sequence = get_old_return_sequence(example) 104 | new_return_sequence = get_new_return_sequence(example) 105 | 106 | old_return_line_terms = set([t for t in old_return_sequence if not is_java_keyword(t) and not is_operator(t)]) 107 | new_return_line_terms = set([t for t in new_return_sequence if not is_java_keyword(t) and not is_operator(t)]) 108 | return_line_intersection = old_return_line_terms.intersection(new_return_line_terms) 109 | 110 | old_set = set(old_return_type_subtokens) 111 | new_set = set(new_return_type_subtokens) 112 | 113 | intersection = old_set.intersection(new_set) 114 | 115 | features = np.zeros((len(nodes), get_num_code_features()), dtype=np.int64) 116 | 117 | old_nl_tokens = set(example.old_comment_subtokens) 118 | last_command = None 119 | 120 | for i, node in enumerate(nodes): 121 | if not node.is_leaf: 122 | continue 123 | 124 | token = node.value 125 | 126 | if token in intersection: 127 | features[i][0] = True 128 | elif token in old_set: 129 | features[i][1] = True 130 | elif token in new_set: 131 | features[i][2] = True 132 | else: 133 | features[i][3] = True 134 | 135 | if token in return_line_intersection: 136 | features[i][4] = True 137 | elif token in old_return_line_terms: 138 | features[i][5] = True 139 | elif token in new_return_line_terms: 140 | features[i][6] = True 141 | else: 142 | features[i][7] = True 143 | 144 | if is_edit_keyword(token): 145 | features[i][8] = True 146 | if is_java_keyword(token): 147 | features[i][9] = True 148 | if is_operator(token): 149 | features[i][10] = True 150 | if token in old_nl_tokens: 151 | features[i][11] = True 152 | 153 | if not is_edit_keyword(token): 154 | if last_command == KEEP: 155 | features[i][12] = 1 156 | elif last_command == INSERT: 157 | features[i][13] = 1 158 | elif last_command == DELETE: 159 | features[i][14] = 1 160 | elif last_command == REPLACE_NEW: 161 | features[i][15] = 1 162 | else: 163 | features[i][16] = 1 164 | else: 165 | last_command = token 166 | 167 | if len(node.subtoken_children) > 0 or len(node.subtoken_parents) > 0: 168 | features[i][17] = True 169 | 170 | if len(node.subtoken_parents) == 1: 171 | features[i][18] = node.subtoken_parents[0].subtoken_children.index(node) 172 | 173 | return features.astype(np.float32) 174 | 175 | def get_code_features(code_sequence, example, max_code_length): 176 | old_return_type_subtokens = get_old_return_type_subtokens(example) 177 | new_return_type_subtokens = get_return_type_subtokens(example) 178 | method_name_subtokens = get_method_name_subtokens(example) 179 | 180 | old_return_sequence = get_old_return_sequence(example) 181 | new_return_sequence = get_new_return_sequence(example) 182 | 183 | old_return_line_terms = set([t for t in old_return_sequence if not is_java_keyword(t) and not is_operator(t)]) 184 | new_return_line_terms = set([t for t in new_return_sequence if not is_java_keyword(t) and not is_operator(t)]) 185 | return_line_intersection = old_return_line_terms.intersection(new_return_line_terms) 186 | 187 | old_set = set(old_return_type_subtokens) 188 | new_set = set(new_return_type_subtokens) 189 | 190 | intersection = old_set.intersection(new_set) 191 | 192 | features = np.zeros((max_code_length, get_num_code_features()), dtype=np.int64) 193 | 194 | old_nl_tokens = set(example.old_comment_subtokens) 195 | last_command = None 196 | 197 | subtoken_labels = get_edit_span_subtoken_tokenization_labels(example) 198 | subtoken_indices = get_edit_span_subtoken_tokenization_indices(example) 199 | 200 | for i, token in enumerate(code_sequence): 201 | if i >= max_code_length: 202 | break 203 | if token in intersection: 204 | features[i][0] = True 205 | elif token in old_set: 206 | features[i][1] = True 207 | elif token in new_set: 208 | features[i][2] = True 209 | else: 210 | features[i][3] = True 211 | 212 | if token in return_line_intersection: 213 | features[i][4] = True 214 | elif token in old_return_line_terms: 215 | features[i][5] = True 216 | elif token in new_return_line_terms: 217 | features[i][6] = True 218 | else: 219 | features[i][7] = True 220 | 221 | if is_edit_keyword(token): 222 | features[i][8] = True 223 | if is_java_keyword(token): 224 | features[i][9] = True 225 | if is_operator(token): 226 | features[i][10] = True 227 | if token in old_nl_tokens: 228 | features[i][11] = True 229 | 230 | if not is_edit_keyword(token): 231 | if last_command == KEEP: 232 | features[i][12] = 1 233 | elif last_command == INSERT: 234 | features[i][13] = 1 235 | elif last_command == DELETE: 236 | features[i][14] = 1 237 | elif last_command == REPLACE_NEW: 238 | features[i][15] = 1 239 | else: 240 | features[i][16] = 1 241 | else: 242 | last_command = token 243 | 244 | features[i][17] = subtoken_labels[i] 245 | features[i][18] = subtoken_indices[i] 246 | 247 | return features.astype(np.float32) 248 | 249 | def get_nl_features(old_nl_sequence, example, max_nl_length): 250 | insert_code_tokens = set() 251 | keep_code_tokens = set() 252 | delete_code_tokens = set() 253 | replace_old_code_tokens = set() 254 | replace_new_code_tokens = set() 255 | 256 | frequency_map = dict() 257 | for tok in old_nl_sequence: 258 | if tok not in frequency_map: 259 | frequency_map[tok] = 0 260 | frequency_map[tok] += 1 261 | 262 | pos_tags = pos_tag(word_tokenize(' '.join(old_nl_sequence))) 263 | pos_tag_indices = [] 264 | for _, t in pos_tags: 265 | if t in tags: 266 | pos_tag_indices.append(tags.index(t)) 267 | else: 268 | pos_tag_indices.append(tags.index('OTHER')) 269 | 270 | i = 0 271 | code_tokens = example.token_diff_code_subtokens 272 | 273 | while i < len(code_tokens): 274 | if code_tokens[i] == INSERT: 275 | insert_code_tokens.add(code_tokens[i+1].lower()) 276 | i += 2 277 | elif code_tokens[i] == KEEP: 278 | keep_code_tokens.add(code_tokens[i+1].lower()) 279 | i += 2 280 | elif code_tokens[i] == DELETE: 281 | delete_code_tokens.add(code_tokens[i+1].lower()) 282 | i += 2 283 | elif code_tokens[i] == REPLACE_OLD: 284 | replace_old_code_tokens.add(code_tokens[i+1].lower()) 285 | i += 2 286 | elif code_tokens[i] == REPLACE_NEW: 287 | replace_new_code_tokens.add(code_tokens[i+1].lower()) 288 | i += 2 289 | 290 | old_return_type_subtokens = get_old_return_type_subtokens(example) 291 | new_return_type_subtokens = get_return_type_subtokens(example) 292 | 293 | old_return_sequence = get_old_return_sequence(example) 294 | new_return_sequence = get_new_return_sequence(example) 295 | 296 | old_return_line_terms = set([t for t in old_return_sequence if not is_java_keyword(t) and not is_operator(t)]) 297 | new_return_line_terms = set([t for t in new_return_sequence if not is_java_keyword(t) and not is_operator(t)]) 298 | return_line_intersection = old_return_line_terms.intersection(new_return_line_terms) 299 | 300 | old_set = set(old_return_type_subtokens) 301 | new_set = set(new_return_type_subtokens) 302 | 303 | intersection = old_set.intersection(new_set) 304 | 305 | method_name_subtokens = method_name_subtokens = get_method_name_subtokens(example) 306 | 307 | nl_subtoken_labels = get_nl_subtoken_tokenization_labels(example) 308 | nl_subtoken_indices = get_nl_subtoken_tokenization_indices(example) 309 | 310 | features = np.zeros((max_nl_length, get_num_nl_features()), dtype=np.int64) 311 | for i in range(len(old_nl_sequence)): 312 | if i >= max_nl_length: 313 | break 314 | token = old_nl_sequence[i].lower() 315 | if token in intersection: 316 | features[i][0] = True 317 | elif token in old_set: 318 | features[i][1] = True 319 | elif token in new_set: 320 | features[i][2] = True 321 | else: 322 | features[i][3] = True 323 | 324 | if token in return_line_intersection: 325 | features[i][4] = True 326 | elif token in old_return_line_terms: 327 | features[i][5] = True 328 | elif token in new_return_line_terms: 329 | features[i][6] = True 330 | else: 331 | features[i][7] = True 332 | 333 | features[i][8] = token in insert_code_tokens 334 | features[i][9] = token in keep_code_tokens 335 | features[i][10] = token in delete_code_tokens 336 | features[i][11] = token in replace_old_code_tokens 337 | features[i][12] = token in replace_new_code_tokens 338 | features[i][13] = token in stop_words 339 | features[i][14] = frequency_map[token] > 1 340 | 341 | features[i][15] = nl_subtoken_labels[i] 342 | features[i][16] = nl_subtoken_indices[i] 343 | features[i][17 + pos_tag_indices[i]] = 1 344 | 345 | return features.astype(np.float32) 346 | -------------------------------------------------------------------------------- /comment_update/generation_decoder.py: -------------------------------------------------------------------------------- 1 | from dpu_utils.mlutils import Vocabulary 2 | import logging 3 | import numpy as np 4 | import os 5 | import random 6 | import sys 7 | import torch 8 | from torch import nn 9 | 10 | from constants import START, BEAM_SIZE 11 | from decoder import Decoder 12 | 13 | class GenerationDecoder(Decoder): 14 | def __init__(self, input_size, hidden_size, attention_state_size, embedding_store, 15 | embedding_size, dropout_rate): 16 | """Decoder for the generation model which generates a comment based on a 17 | learned representation of a method.""" 18 | super(GenerationDecoder, self).__init__(input_size, hidden_size, attention_state_size, 19 | embedding_store, embedding_size, dropout_rate) 20 | 21 | def decode(self, initial_state, decoder_input_embeddings, encoder_hidden_states, masks): 22 | """Decoding with attention and copy.""" 23 | decoder_states, decoder_final_state = self.gru.forward(decoder_input_embeddings, 24 | initial_state.unsqueeze(0)) 25 | 26 | # https://stackoverflow.com/questions/50571991/implementing-luong-attention-in-pytorch 27 | attn_alignment = torch.einsum('ijk,km,inm->inj', encoder_hidden_states, 28 | self.attention_encoder_hidden_transform_matrix, decoder_states) 29 | attn_alignment.masked_fill_(masks, float('-inf')) 30 | attention_scores = nn.functional.softmax(attn_alignment, dim=-1) 31 | contexts = torch.einsum('ijk,ikm->ijm', attention_scores, encoder_hidden_states) 32 | decoder_states = torch.tanh(self.attention_output_layer(torch.cat([contexts, decoder_states], dim=-1))) 33 | 34 | generation_scores = torch.einsum('ijk,km->ijm', decoder_states, self.generation_output_matrix) 35 | copy_scores = torch.einsum('ijk,km,inm->inj', encoder_hidden_states, 36 | self.copy_encoder_hidden_transform_matrix, decoder_states) 37 | copy_scores.masked_fill_(masks, float('-inf')) 38 | 39 | combined_logprobs = nn.functional.log_softmax(torch.cat([generation_scores, copy_scores], dim=-1), dim=-1) 40 | generation_logprobs = combined_logprobs[:,:,:len(self.embedding_store.nl_vocabulary)] 41 | copy_logprobs = combined_logprobs[:, :,len(self.embedding_store.nl_vocabulary):] 42 | 43 | return decoder_states, decoder_final_state, generation_logprobs, copy_logprobs 44 | 45 | def forward(self, initial_state, decoder_input_embeddings, encoder_hidden_states, masks): 46 | """Runs decoding.""" 47 | return self.decode(initial_state, decoder_input_embeddings, encoder_hidden_states, masks) 48 | 49 | def greedy_decode(self, initial_state, encoder_hidden_states, masks, max_out_len, batch_data, device): 50 | """Greedily generates the output sequence.""" 51 | # Derived from https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/9f6b66f43d2e05175dabcc024f79e1d37a667070/decode_beam.py#L163 52 | batch_size = initial_state.shape[0] 53 | decoder_state = initial_state 54 | decoder_input = torch.tensor( 55 | [[self.embedding_store.get_nl_id(START)]] * batch_size, 56 | device=device 57 | ) 58 | 59 | decoded_batch = np.zeros([batch_size, max_out_len], dtype=np.int64) 60 | decoded_batch_scores = np.zeros([batch_size, max_out_len]) 61 | 62 | for i in range(max_out_len): 63 | decoder_input_embeddings = self.embedding_store.get_nl_embeddings(decoder_input) 64 | decoder_attention_states, decoder_state, generation_logprobs, copy_logprobs = self.decode(decoder_state, 65 | decoder_input_embeddings, encoder_hidden_states, masks) 66 | 67 | generation_logprobs = generation_logprobs.squeeze(1) 68 | copy_logprobs = copy_logprobs.squeeze(1) 69 | 70 | prob_scores = torch.zeros([generation_logprobs.shape[0], 71 | generation_logprobs.shape[-1] + copy_logprobs.shape[-1]], dtype=torch.float32, device=device) 72 | prob_scores[:, :generation_logprobs.shape[-1]] = torch.exp(generation_logprobs) 73 | for b in range(generation_logprobs.shape[0]): 74 | for c, inp_id in enumerate(batch_data.input_ids[b]): 75 | prob_scores[b, inp_id] = prob_scores[b, inp_id] + torch.exp(copy_logprobs[b,c]) 76 | 77 | predicted_ids = torch.argmax(prob_scores, dim=-1) 78 | decoded_batch_scores[:, i] = prob_scores[torch.arange(prob_scores.shape[0]), predicted_ids].cpu() 79 | decoded_batch[:, i] = predicted_ids.cpu() 80 | 81 | unks = torch.ones( 82 | predicted_ids.shape[0], dtype=torch.int64, device=device) * self.embedding_store.get_nl_id(Vocabulary.get_unk()) 83 | decoder_input = torch.where(predicted_ids < len(self.embedding_store.nl_vocabulary), predicted_ids, unks).unsqueeze(1) 84 | decoder_state = decoder_state.squeeze(0) 85 | 86 | return decoded_batch, decoded_batch_scores 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /comment_update/tensor_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def merge_encoder_outputs(a_states, a_lengths, b_states, b_lengths, device): 5 | a_max_len = a_states.size(1) 6 | b_max_len = b_states.size(1) 7 | combined_len = a_max_len + b_max_len 8 | padded_b_states = torch.zeros([b_states.size(0), combined_len, b_states.size(-1)], device=device) 9 | padded_b_states[:, :b_max_len, :] = b_states 10 | full_matrix = torch.cat([a_states, padded_b_states], dim=1) 11 | a_idxs = torch.arange(combined_len, dtype=torch.long, device=device).view(-1, 1) 12 | b_idxs = torch.arange(combined_len, dtype=torch.long, 13 | device=device).view(-1,1) - a_lengths.view(1, -1) + a_max_len 14 | idxs = torch.where(b_idxs < a_max_len, a_idxs, b_idxs).permute(1, 0) 15 | offset = torch.arange(0, full_matrix.size(0) * full_matrix.size(1), full_matrix.size(1), device=device) 16 | idxs = idxs + offset.unsqueeze(1) 17 | combined_states = full_matrix.reshape(-1, full_matrix.shape[-1])[idxs] 18 | combined_lengths = a_lengths + b_lengths 19 | 20 | return combined_states, combined_lengths 21 | 22 | def get_invalid_copy_locations(input_sequence, max_input_length, output_sequence, max_output_length): 23 | input_length = min(len(input_sequence), max_input_length) 24 | output_length = min(len(output_sequence), max_output_length) 25 | 26 | invalid_copy_locations = np.ones([max_output_length, max_input_length], dtype=np.bool) 27 | for o in range(output_length): 28 | for i in range(input_length): 29 | invalid_copy_locations[o,i] = output_sequence[o] != input_sequence[i] 30 | 31 | return invalid_copy_locations 32 | 33 | def compute_attention_states(key_states, masks, query_states, transformation_matrix=None, multihead_attention=None): 34 | if multihead_attention is not None: 35 | if transformation_matrix is not None: 36 | key = torch.einsum('bsh,hd->sbd', key_states, transformation_matrix) # S x B x D 37 | else: 38 | key = key_states.permute(1,0,2) # S x B x D 39 | 40 | query = query_states.permute(1,0,2) # T x B x D 41 | value = key 42 | attn_output, attn_output_weights = multihead_attention(query, key, value, key_padding_mask=masks.squeeze(1)) 43 | return attn_output.permute(1,0,2) 44 | else: 45 | if transformation_matrix is not None: 46 | alignment = torch.einsum('bsh,hd,btd->bts', key_states, transformation_matrix, query_states) 47 | else: 48 | alignment = torch.einsum('bsh,bth->bts', key_states, query_states) 49 | alignment.masked_fill_(masks, float('-inf')) 50 | attention_scores = torch.nn.functional.softmax(alignment, dim=-1) 51 | return torch.einsum('ijk,ikm->ijm', attention_scores, key_states) -------------------------------------------------------------------------------- /comment_update/update_decoder.py: -------------------------------------------------------------------------------- 1 | from dpu_utils.mlutils import Vocabulary 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch_scatter import scatter_add 6 | 7 | from constants import START, BEAM_SIZE 8 | from decoder import Decoder 9 | from tensor_utils import compute_attention_states 10 | 11 | class UpdateDecoder(Decoder): 12 | def __init__(self, input_size, hidden_size, attention_state_size, embedding_store, 13 | embedding_size, dropout_rate, attn_input_size): 14 | """Decoder for the edit model which generates a sequence of NL edits based on learned representations of 15 | the old comment and code edits.""" 16 | super(UpdateDecoder, self).__init__(input_size, hidden_size, attention_state_size, 17 | embedding_store, embedding_size, dropout_rate) 18 | 19 | self.sequence_attention_code_transform_matrix = nn.Parameter( 20 | torch.randn(self.attention_state_size, self.hidden_size, 21 | dtype=torch.float, requires_grad=True) 22 | ) 23 | self.attention_old_nl_hidden_transform_matrix = nn.Parameter( 24 | torch.randn(self.attention_state_size, self.hidden_size, 25 | dtype=torch.float, requires_grad=True) 26 | ) 27 | 28 | self.attention_output_layer = nn.Linear(attn_input_size + self.hidden_size, 29 | self.hidden_size, bias=False) 30 | 31 | def decode(self, initial_state, decoder_input_embeddings, encoder_hidden_states, 32 | code_hidden_states, old_nl_hidden_states, masks, code_masks, old_nl_masks): 33 | """Decoding with attention and copy. Attention is computed separately for each set of encoder hidden states.""" 34 | decoder_states, decoder_final_state = self.gru.forward(decoder_input_embeddings, initial_state.unsqueeze(0)) 35 | 36 | attention_context_states = compute_attention_states(old_nl_hidden_states, old_nl_masks, 37 | decoder_states, self.attention_old_nl_hidden_transform_matrix, None) 38 | 39 | code_contexts = compute_attention_states(code_hidden_states, code_masks, 40 | decoder_states, self.sequence_attention_code_transform_matrix, None) 41 | attention_context_states = torch.cat([attention_context_states, code_contexts], dim=-1) 42 | 43 | decoder_states = torch.tanh(self.attention_output_layer( 44 | torch.cat([attention_context_states, decoder_states], dim=-1))) 45 | 46 | generation_scores = torch.einsum('ijk,km->ijm', decoder_states, self.generation_output_matrix) 47 | copy_scores = torch.einsum('ijk,km,inm->inj', encoder_hidden_states, 48 | self.copy_encoder_hidden_transform_matrix, decoder_states) 49 | copy_scores.masked_fill_(masks, float('-inf')) 50 | 51 | combined_logprobs = nn.functional.log_softmax(torch.cat([generation_scores, copy_scores], dim=-1), dim=-1) 52 | generation_logprobs = combined_logprobs[:,:,:len(self.embedding_store.nl_vocabulary)] 53 | copy_logprobs = combined_logprobs[:, :,len(self.embedding_store.nl_vocabulary):] 54 | 55 | return decoder_states, decoder_final_state, generation_logprobs, copy_logprobs 56 | 57 | def forward(self, initial_state, decoder_input_embeddings, encoder_hidden_states, 58 | code_hidden_states, old_nl_hidden_states, masks, code_masks, old_nl_masks): 59 | """Runs decoding.""" 60 | return self.decode(initial_state, decoder_input_embeddings, encoder_hidden_states, 61 | code_hidden_states, old_nl_hidden_states, masks, code_masks, old_nl_masks) 62 | 63 | def beam_decode(self, initial_state, encoder_hidden_states, code_hidden_states, old_nl_hidden_states, 64 | masks, max_out_len, batch_data, code_masks, old_nl_masks, device): 65 | """Beam search. Generates the top K candidate predictions.""" 66 | batch_size = initial_state.shape[0] 67 | decoded_batch = [list() for _ in range(batch_size)] 68 | decoded_batch_scores = np.zeros([batch_size, BEAM_SIZE]) 69 | 70 | decoder_input = torch.tensor( 71 | [[self.embedding_store.get_nl_id(START)]] * batch_size, device=device) 72 | decoder_input = decoder_input.unsqueeze(1) 73 | decoder_state = initial_state.unsqueeze(1).expand( 74 | -1, decoder_input.shape[1], -1).reshape(-1, initial_state.shape[-1]) 75 | 76 | beam_scores = torch.ones([batch_size, 1], dtype=torch.float32, device=device) 77 | beam_status = torch.zeros([batch_size, 1], dtype=torch.uint8, device=device) 78 | beam_predicted_ids = torch.full([batch_size, 1, max_out_len], self.embedding_store.get_end_id(), 79 | dtype=torch.int64, device=device) 80 | 81 | for i in range(max_out_len): 82 | beam_size = decoder_input.shape[1] 83 | if beam_status[:,0].sum() == batch_size: 84 | break 85 | 86 | tiled_encoder_states = encoder_hidden_states.unsqueeze(1).expand(-1, beam_size, -1, -1) 87 | tiled_masks = masks.unsqueeze(1).expand(-1, beam_size, -1, -1) 88 | tiled_code_hidden_states = code_hidden_states.unsqueeze(1).expand(-1, beam_size, -1, -1) 89 | tiled_code_masks = code_masks.unsqueeze(1).expand(-1, beam_size, -1, -1) 90 | tiled_old_nl_hidden_states = old_nl_hidden_states.unsqueeze(1).expand(-1, beam_size, -1, -1) 91 | tiled_old_nl_masks = old_nl_masks.unsqueeze(1).expand(-1, beam_size, -1, -1) 92 | 93 | flat_decoder_input = decoder_input.reshape(-1, decoder_input.shape[-1]) 94 | flat_encoder_states = tiled_encoder_states.reshape(-1, tiled_encoder_states.shape[-2], tiled_encoder_states.shape[-1]) 95 | flat_masks = tiled_masks.reshape(-1, tiled_masks.shape[-2], tiled_masks.shape[-1]) 96 | flat_code_hidden_states = tiled_code_hidden_states.reshape(-1, tiled_code_hidden_states.shape[-2], tiled_code_hidden_states.shape[-1]) 97 | flat_code_masks = tiled_code_masks.reshape(-1, tiled_code_masks.shape[-2], tiled_code_masks.shape[-1]) 98 | flat_old_nl_hidden_states = tiled_old_nl_hidden_states.reshape(-1, tiled_old_nl_hidden_states.shape[-2], tiled_old_nl_hidden_states.shape[-1]) 99 | flat_old_nl_masks = tiled_old_nl_masks.reshape(-1, tiled_old_nl_masks.shape[-2], tiled_old_nl_masks.shape[-1]) 100 | 101 | decoder_input_embeddings = self.embedding_store.get_nl_embeddings(flat_decoder_input) 102 | decoder_attention_states, flat_decoder_state, generation_logprobs, copy_logprobs = self.decode( 103 | decoder_state, decoder_input_embeddings, flat_encoder_states, flat_code_hidden_states, 104 | flat_old_nl_hidden_states, flat_masks, flat_code_masks, flat_old_nl_masks) 105 | 106 | generation_logprobs = generation_logprobs.squeeze(1) 107 | copy_logprobs = copy_logprobs.squeeze(1) 108 | 109 | generation_logprobs = generation_logprobs.reshape(batch_size, beam_size, generation_logprobs.shape[-1]) 110 | copy_logprobs = copy_logprobs.reshape(batch_size, beam_size, copy_logprobs.shape[-1]) 111 | 112 | prob_scores = torch.zeros([batch_size, beam_size, 113 | generation_logprobs.shape[-1] + copy_logprobs.shape[-1]], dtype=torch.float32, device=device) 114 | prob_scores[:, :, :generation_logprobs.shape[-1]] = torch.exp(generation_logprobs) 115 | 116 | # Factoring in the copy scores 117 | expanded_token_ids = batch_data.input_ids.unsqueeze(1).expand(-1, beam_size, -1) 118 | prob_scores += scatter_add(src=torch.exp(copy_logprobs), index=expanded_token_ids, out=torch.zeros_like(prob_scores)) 119 | 120 | top_scores_per_beam, top_indices_per_beam = torch.topk(prob_scores, k=BEAM_SIZE, dim=-1) 121 | 122 | updated_scores = torch.einsum('eb,ebm->ebm', beam_scores, top_scores_per_beam) 123 | retained_scores = beam_scores.unsqueeze(-1).expand(-1, -1, top_scores_per_beam.shape[-1]) 124 | 125 | # Trying to keep at most one ray corresponding to completed beams 126 | end_mask = (torch.arange(beam_size) == 0).type(torch.float32).to(device) 127 | end_scores = torch.einsum('b,ebm->ebm', end_mask, retained_scores) 128 | 129 | possible_next_scores = torch.where(beam_status.unsqueeze(-1) == 1, end_scores, updated_scores) 130 | possible_next_status = torch.where(top_indices_per_beam == self.embedding_store.get_end_id(), 131 | torch.ones([batch_size, beam_size, top_scores_per_beam.shape[-1]], dtype=torch.uint8, device=device), 132 | beam_status.unsqueeze(-1).expand(-1,-1,top_scores_per_beam.shape[-1])) 133 | 134 | possible_beam_predicted_ids = beam_predicted_ids.unsqueeze(2).expand(-1, -1, top_scores_per_beam.shape[-1], -1) 135 | pool_next_scores = possible_next_scores.reshape(batch_size, -1) 136 | pool_next_status = possible_next_status.reshape(batch_size, -1) 137 | pool_next_ids = top_indices_per_beam.reshape(batch_size, -1) 138 | pool_predicted_ids = possible_beam_predicted_ids.reshape(batch_size, -1, beam_predicted_ids.shape[-1]) 139 | 140 | possible_decoder_state = flat_decoder_state.reshape(batch_size, beam_size, flat_decoder_state.shape[-1]) 141 | possible_decoder_state = possible_decoder_state.unsqueeze(2).expand(-1, -1, top_scores_per_beam.shape[-1], -1) 142 | pool_decoder_state = possible_decoder_state.reshape(batch_size, -1, possible_decoder_state.shape[-1]) 143 | 144 | top_scores, top_indices = torch.topk(pool_next_scores, k=BEAM_SIZE, dim=-1) 145 | next_step_ids = torch.gather(pool_next_ids, -1, top_indices) 146 | 147 | decoder_state = torch.gather(pool_decoder_state, 1, top_indices.unsqueeze(-1).expand(-1,-1, pool_decoder_state.shape[-1])) 148 | decoder_state = decoder_state.reshape(-1, decoder_state.shape[-1]) 149 | beam_status = torch.gather(pool_next_status, -1, top_indices) 150 | beam_scores = torch.gather(pool_next_scores, -1, top_indices) 151 | 152 | end_tags = torch.full_like(next_step_ids, self.embedding_store.get_end_id()) 153 | next_step_ids = torch.where(beam_status == 1, end_tags, next_step_ids) 154 | 155 | beam_predicted_ids = torch.gather(pool_predicted_ids, 1, top_indices.unsqueeze(-1).expand(-1, -1, pool_predicted_ids.shape[-1])) 156 | beam_predicted_ids[:,:,i] = next_step_ids 157 | 158 | unks = torch.full_like(next_step_ids, self.embedding_store.get_nl_id(Vocabulary.get_unk())) 159 | decoder_input = torch.where(next_step_ids < len(self.embedding_store.nl_vocabulary), next_step_ids, unks).unsqueeze(-1) 160 | 161 | return beam_predicted_ids, beam_scores -------------------------------------------------------------------------------- /comment_update/update_evaluation_utils.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | import logging 3 | import os 4 | import numpy as np 5 | from typing import List, NamedTuple 6 | import subprocess 7 | 8 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 9 | from pycocoevalcap.meteor.meteor import Meteor 10 | from SARI import SARIsent 11 | 12 | from data_utils import get_processed_comment_str 13 | 14 | def compute_accuracy(reference_strings, predicted_strings): 15 | assert(len(reference_strings) == len(predicted_strings)) 16 | correct = 0.0 17 | for i in range(len(reference_strings)): 18 | if reference_strings[i] == predicted_strings[i]: 19 | correct += 1 20 | return 100 * correct/float(len(reference_strings)) 21 | 22 | def compute_bleu(references, hypotheses): 23 | bleu_4_sentence_scores = [] 24 | for ref, hyp in zip(references, hypotheses): 25 | bleu_4_sentence_scores.append(sentence_bleu(ref, hyp, 26 | smoothing_function=SmoothingFunction().method2)) 27 | return 100*sum(bleu_4_sentence_scores)/float(len(bleu_4_sentence_scores)) 28 | 29 | def compute_sentence_bleu(ref, hyp): 30 | return sentence_bleu(ref, hyp, smoothing_function=SmoothingFunction().method2) 31 | 32 | def compute_sentence_meteor(reference_list, sentences): 33 | preds = dict() 34 | refs = dict() 35 | 36 | for i in range(len(sentences)): 37 | preds[i] = [' '.join([s for s in sentences[i]])] 38 | refs[i] = [' '.join(l) for l in reference_list[i]] 39 | 40 | final_scores = dict() 41 | 42 | scorers = [ 43 | (Meteor(),"METEOR") 44 | ] 45 | 46 | for scorer, method in scorers: 47 | score, scores = scorer.compute_score(refs, preds) 48 | if type(method) == list: 49 | for sc, scs, m in zip(score, scores, method): 50 | final_scores[m] = scs 51 | else: 52 | final_scores[method] = scores 53 | 54 | meteor_scores = final_scores["METEOR"] 55 | return meteor_scores 56 | 57 | def compute_meteor(reference_list, sentences): 58 | meteor_scores = compute_sentence_meteor(reference_list, sentences) 59 | return 100 * sum(meteor_scores)/len(meteor_scores) 60 | 61 | def compute_unchanged(test_data, predictions): 62 | source_sentences = [get_processed_comment_str(ex.old_comment_subtokens) for ex in test_data] 63 | predicted_sentences = [' '.join(p) for p in predictions] 64 | unchanged = 0 65 | 66 | for source, predicted in zip(source_sentences, predicted_sentences): 67 | if source == predicted: 68 | unchanged += 1 69 | 70 | return 100*(unchanged)/len(test_data) 71 | 72 | def compute_sari(test_data, predictions): 73 | source_sentences = [get_processed_comment_str(ex.old_comment_subtokens) for ex in test_data] 74 | target_sentences = [[get_processed_comment_str(ex.new_comment_subtokens)] for ex in test_data] 75 | predicted_sentences = [' '.join(p) for p in predictions] 76 | 77 | inp = zip(source_sentences, target_sentences, predicted_sentences) 78 | scores = [] 79 | 80 | for source, target, predicted in inp: 81 | scores.append(SARIsent(source, predicted, target)) 82 | 83 | return 100*sum(scores)/float(len(scores)) 84 | 85 | def compute_gleu(test_data, orig_file, ref_file, pred_file): 86 | command = 'python2.7 gleu/scripts/compute_gleu -s {} -r {} -o {} -d'.format(orig_file, ref_file, pred_file) 87 | output = subprocess.check_output(command.split()) 88 | 89 | output_lines = [l.strip() for l in output.decode("utf-8").split('\n') if len(l.strip()) > 0] 90 | l = 0 91 | while l < len(output_lines): 92 | if output_lines[l][0] == '0': 93 | break 94 | l += 1 95 | 96 | scores = np.zeros(len(test_data), dtype=np.float32) 97 | while l < len(test_data): 98 | terms = output_lines[l].split() 99 | idx = int(terms[0]) 100 | val = float(terms[1]) 101 | scores[idx] = val 102 | l += 1 103 | scores = np.ndarray.tolist(scores) 104 | return 100*sum(scores)/float(len(scores)) 105 | 106 | def write_predictions(predicted_strings, write_file): 107 | os.makedirs(os.path.dirname(write_file), exist_ok=True) 108 | with open(write_file, 'w+') as f: 109 | for p in predicted_strings: 110 | f.write('{}\n'.format(p)) -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | START = '' 4 | END = '' 5 | NL_EMBEDDING_SIZE = 64 6 | CODE_EMBEDDING_SIZE = 64 7 | HIDDEN_SIZE = 64 8 | DROPOUT_RATE = 0.6 9 | NUM_LAYERS = 2 10 | LR = 0.001 11 | BATCH_SIZE = 100 12 | MAX_EPOCHS = 100 13 | PATIENCE = 10 14 | VOCAB_CUTOFF_PCT = 5 15 | LENGTH_CUTOFF_PCT = 95 16 | MAX_VOCAB_EXTENSION = 50 17 | BEAM_SIZE = 20 18 | MAX_VOCAB_SIZE = 10000 19 | FEATURE_DIMENSION = 128 20 | NUM_CLASSES = 2 21 | 22 | GNN_HIDDEN_SIZE = 64 23 | GNN_LAYER_TIMESTEPS = 8 24 | GNN_DROPOUT_RATE = 0.0 25 | SRC_EMBEDDING_SIZE = 8 26 | NODE_EMBEDDING_SIZE = 64 27 | 28 | MODEL_LAMBDA = 0.5 29 | LIKELIHOOD_LAMBDA = 0.3 30 | OLD_METEOR_LAMBDA = 0.2 31 | GEN_MODEL_LAMBDA = 0.5 32 | GEN_OLD_BLEU_LAMBDA = 0.5 33 | DECODER_HIDDEN_SIZE = 128 34 | MULTI_HEADS = 4 35 | NUM_TRANSFORMER_LAYERS = 2 36 | 37 | # Download data from here: https://drive.google.com/drive/folders/1heqEQGZHgO6gZzCjuQD1EyYertN4SAYZ?usp=sharing 38 | # DATA_PATH should point to the location in which the above data is saved locally 39 | DATA_PATH = '[PATH TO DOWNLOADED DATA]' # TODO 40 | RESOURCES_PATH = os.path.join(DATA_PATH, 'resources') 41 | 42 | # Download model resources from here: https://drive.google.com/drive/folders/1cutxr4rMDkT1g2BbmCAR2wqKTxeFH11K?usp=sharing 43 | # MODEL_RESOURCES_PATH should point to the location in which the above resources are saved locally. 44 | MODEL_RESOURCES_PATH = '[PATH TO DOWNLOADED MODEL RESOURCES]' # TODO 45 | NL_EMBEDDING_PATH = os.path.join(MODEL_RESOURCES_PATH, 'nl_embeddings.json') 46 | CODE_EMBEDDING_PATH = os.path.join(MODEL_RESOURCES_PATH, 'code_embeddings.json') 47 | FULL_GENERATION_MODEL_PATH = os.path.join(MODEL_RESOURCES_PATH, 'generation-model.pkl.gz') 48 | 49 | # Should point to where the output is to be saved 50 | PREDICTION_DIR = '[ROOT DIR TO STORE PREDICTED OUTPUT FOR UPDATE AND DUAL MODELS]' # TODO 51 | DETECTION_DIR = '[ROOT DIR TO STORE PREDICTED OUTPUT FOR DETECTION MODELS]' # TODO -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from constants import DATA_PATH 5 | from data_utils import DiffAST, DiffExample, DiffASTExample, CommentCategory 6 | 7 | PARTITIONS = ['train', 'valid', 'test'] 8 | 9 | def get_data_splits(comment_type_str=None, ignore_ast=False): 10 | """Retrieves train/validation/test sets for the given comment_type_str. 11 | comment_type_str -- Return, Param, Summary, or None (if None, uses all comment types) 12 | ignore_ast -- Skip loading ASTs (they take a long time)""" 13 | dataset, high_level_details = load_processed_data(comment_type_str, ignore_ast) 14 | train_examples = dataset['train'] 15 | valid_examples = dataset['valid'] 16 | test_examples = dataset['test'] 17 | return train_examples, valid_examples, test_examples, high_level_details 18 | 19 | def load_cleaned_test_set(comment_type_str=None): 20 | """Retrieves the ids corresponding to clean examples, for the given comment_type_str. 21 | comment_type_str -- Return, Param, Summary, or None (if None, uses all comment types)""" 22 | if not comment_type_str: 23 | comment_types = [CommentCategory(category).name for category in CommentCategory] 24 | else: 25 | comment_types = [comment_type_str] 26 | 27 | test_ids = [] 28 | for comment_type in comment_types: 29 | resources_path = os.path.join(DATA_PATH, 'resources', comment_type, 'clean_test_ids.json') 30 | with open(resources_path) as f: 31 | test_ids.extend(json.load(f)) 32 | return test_ids 33 | 34 | def load_processed_data(comment_type_str, ignore_ast): 35 | """Processes saved data for the given comment_type_str. 36 | comment_type_str -- Return, Param, Summary, or None (if None, uses all comment types) 37 | ignore_ast -- Skip loading ASTs (they take a long time)""" 38 | if not comment_type_str: 39 | comment_types = [CommentCategory(category).name for category in CommentCategory] 40 | else: 41 | comment_types = [comment_type_str] 42 | 43 | print('Loading data from: {}'.format(comment_types)) 44 | 45 | dataset = dict() 46 | high_level_details = dict() 47 | for comment_type in comment_types: 48 | path = os.path.join(DATA_PATH, comment_type) 49 | loaded = load_raw_data_from_path(path) 50 | category_high_level_details_path = os.path.join(DATA_PATH, 'resources', comment_type, 'high_level_details.json') 51 | 52 | with open(category_high_level_details_path) as f: 53 | category_high_level_details = json.load(f) 54 | high_level_details.update(category_high_level_details) 55 | 56 | if not ignore_ast: 57 | ast_path = os.path.join(DATA_PATH, 'resources', comment_type, 'ast_objs.json') 58 | with open(ast_path) as f: 59 | ast_details = json.load(f) 60 | 61 | for partition, examples in loaded.items(): 62 | if partition not in dataset: 63 | dataset[partition] = [] 64 | 65 | if ignore_ast: 66 | dataset[partition].extend(examples) 67 | else: 68 | for ex in examples: 69 | ex_ast_info = ast_details[ex.id] 70 | old_ast = DiffAST.from_json(ex_ast_info['old_ast']) 71 | new_ast = DiffAST.from_json(ex_ast_info['new_ast']) 72 | diff_ast = DiffAST.from_json(ex_ast_info['diff_ast']) 73 | 74 | ast_ex = DiffASTExample(ex.id, ex.label, ex.comment_type, ex.old_comment_raw, 75 | ex.old_comment_subtokens, ex.new_comment_raw, ex.new_comment_subtokens, ex.span_minimal_diff_comment_subtokens, 76 | ex.old_code_raw, ex.old_code_subtokens, ex.new_code_raw, ex.new_code_subtokens, 77 | ex.span_diff_code_subtokens, ex.token_diff_code_subtokens, old_ast, new_ast, diff_ast) 78 | 79 | dataset[partition].append(ast_ex) 80 | 81 | return dataset, high_level_details 82 | 83 | def load_raw_data_from_path(path): 84 | """Reads saved partition-level data from a directory path""" 85 | dataset = dict() 86 | 87 | for partition in PARTITIONS: 88 | dataset[partition] = [] 89 | dataset[partition].extend(read_diff_examples_from_file(os.path.join(path, '{}.json'.format(partition)))) 90 | 91 | return dataset 92 | 93 | def read_diff_examples_from_file(filename): 94 | """Reads saved data from filename""" 95 | with open(filename) as f: 96 | data = json.load(f) 97 | return [DiffExample(**d) for d in data] -------------------------------------------------------------------------------- /data_processing/ast_diffing/code_samples/new.java: -------------------------------------------------------------------------------- 1 | /**Computes the highest value from the list of scores.*/ 2 | public double getBestScore() { 3 | return Collections.max(scores); 4 | } -------------------------------------------------------------------------------- /data_processing/ast_diffing/code_samples/old.java: -------------------------------------------------------------------------------- 1 | /**Computes the lowest value from the list of scores.*/ 2 | public int getBestScore() { 3 | return Collections.min(scores); 4 | } -------------------------------------------------------------------------------- /data_processing/ast_diffing/python/xml_diff_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import subprocess 6 | import sys 7 | 8 | import xml.etree.ElementTree as ET 9 | 10 | sys.path.append('../../../') 11 | sys.path.append('../../../comment_update') 12 | from data_utils import DiffTreeNode, DiffAST 13 | 14 | 15 | class Indexer: 16 | def __init__ (self): 17 | self.count = 0 18 | 19 | def generate(self): 20 | new_id = self.count 21 | self.count += 1 22 | return new_id 23 | 24 | class XMLNode: 25 | def __init__(self, value, node_id, parent, attribute, 26 | alignment_id, location_id, src, is_leaf=True): 27 | self.value = value 28 | self.node_id = node_id 29 | self.parent = parent 30 | self.attribute = attribute 31 | self.alignment_id = alignment_id 32 | self.location_id = location_id 33 | self.src = src 34 | self.is_leaf = is_leaf 35 | self.children = [] 36 | self.pseudo_children = [] 37 | self.prev_sibling = None 38 | self.next_sibling = None 39 | 40 | def print_node(self): 41 | parent_value = None 42 | if self.parent: 43 | parent_value = self.parent.value 44 | 45 | print('{}: {} ({}, {})'.format(self.node_id, self.value, parent_value, len(self.children))) 46 | for c in self.children: 47 | c.print_node() 48 | 49 | class AST: 50 | def __init__(self, ast_root): 51 | self.root = ast_root 52 | self.nodes = [] 53 | self.traverse(ast_root) 54 | 55 | def traverse(self, curr_node): 56 | self.nodes.append(curr_node) 57 | for c, child_node in enumerate(curr_node.children): 58 | if c > 0: 59 | child_node.prev_sibling = curr_node.children[c-1] 60 | if c < len(curr_node.children) - 1: 61 | child_node.next_sibling = curr_node.children[c+1] 62 | self.traverse(child_node) 63 | 64 | @property 65 | def leaves(self): 66 | return [n for n in self.nodes if n.is_leaf] 67 | 68 | def parse_xml_obj(xml_obj, indexer, parent, src): 69 | fields = xml_obj.attrib 70 | attribute = fields['typeLabel'] 71 | is_leaf = False 72 | 73 | if 'label' in fields: 74 | is_leaf = True 75 | value = fields['label'] 76 | else: 77 | value = attribute 78 | 79 | alignment_id = None 80 | location_id = '{}-{}-{}-{}'.format(fields['type'], value, fields['pos'], fields['length']) 81 | 82 | if 'other_pos' in fields: 83 | if src == 'old': 84 | alignment_id = '{}-{}-{}-{}'.format(fields['pos'], fields['length'], fields['other_pos'], fields['other_length']) 85 | else: 86 | alignment_id = '{}-{}-{}-{}'.format(fields['other_pos'], fields['other_length'], fields['pos'], fields['length']) 87 | 88 | node = XMLNode(value, indexer.generate(), parent, 89 | attribute, alignment_id, location_id, src, is_leaf) 90 | 91 | for child_obj in xml_obj: 92 | node.children.append(parse_xml_obj(child_obj, indexer, node, src)) 93 | return node 94 | 95 | def set_id(diff_node, indexer): 96 | diff_node.node_id = indexer.generate() 97 | for node in diff_node.children: 98 | set_id(node, indexer) 99 | 100 | def print_diff_node(diff_node): 101 | print('{} ({}-{}): {}, {}'.format(diff_node.value, diff_node.src, diff_node.node_id, 102 | [c.value for c in diff_node.children], [p.node_id for p in diff_node.parents])) 103 | for child in diff_node.children: 104 | print_diff_node(child) 105 | 106 | def get_individual_ast_objs(old_sample_path, new_sample_path, actions_json, jar_path): 107 | old_xml_path = os.path.join(XML_DIR, 'old.xml') 108 | new_xml_path = os.path.join(XML_DIR, 'new.xml') 109 | 110 | output = subprocess.check_output(['java', '-jar', jar_path, old_sample_path, 111 | new_sample_path, old_xml_path, new_xml_path, actions_json]) 112 | 113 | xml_obj = ET.parse(old_xml_path) 114 | old_root = parse_xml_obj(xml_obj.getroot()[1], Indexer(), None, 'old') 115 | old_ast = AST(old_root) 116 | 117 | xml_obj = ET.parse(new_xml_path) 118 | new_root = parse_xml_obj(xml_obj.getroot()[1], Indexer(), None, 'new') 119 | new_ast = AST(new_root) 120 | 121 | old_nodes = old_ast.nodes 122 | old_diff_nodes = [DiffTreeNode(n.value, n.attribute, n.src, n.is_leaf) for n in old_nodes] 123 | 124 | old_diff_nodes_by_alignment = dict() 125 | for n, old_node in enumerate(old_nodes): 126 | old_diff_node = old_diff_nodes[n] 127 | if old_node.parent: 128 | old_diff_node.parents.append(old_diff_nodes[old_node.parent.node_id]) 129 | 130 | for c in old_node.children: 131 | old_diff_node.children.append(old_diff_nodes[c.node_id]) 132 | 133 | if old_node.prev_sibling: 134 | old_diff_node.prev_siblings.append(old_diff_nodes[old_node.prev_sibling.node_id]) 135 | 136 | if old_node.next_sibling: 137 | old_diff_node.next_siblings.append(old_diff_nodes[old_node.next_sibling.node_id]) 138 | 139 | if old_node.alignment_id: 140 | old_diff_nodes_by_alignment[old_node.alignment_id] = old_diff_node 141 | 142 | new_nodes = new_ast.nodes 143 | new_diff_nodes = [DiffTreeNode(n.value, n.attribute, n.src, n.is_leaf) for n in new_nodes] 144 | 145 | for n, new_node in enumerate(new_nodes): 146 | new_diff_node = new_diff_nodes[n] 147 | if new_node.parent: 148 | new_diff_node.parents.append(new_diff_nodes[new_node.parent.node_id]) 149 | 150 | for c in new_node.children: 151 | new_diff_node.children.append(new_diff_nodes[c.node_id]) 152 | 153 | if new_node.prev_sibling: 154 | new_diff_node.prev_siblings.append(new_diff_nodes[new_node.prev_sibling.node_id]) 155 | 156 | if new_node.next_sibling: 157 | new_diff_node.next_siblings.append(new_diff_nodes[new_node.next_sibling.node_id]) 158 | 159 | old_diff_ast = DiffAST(old_diff_nodes[0]) 160 | new_diff_ast = DiffAST(new_diff_nodes[0]) 161 | 162 | return old_diff_ast, new_diff_ast 163 | 164 | def get_diff_ast(old_sample_path, new_sample_path, actions_json, jar_path): 165 | old_xml_path = os.path.join(XML_DIR, 'old.xml') 166 | new_xml_path = os.path.join(XML_DIR, 'new.xml') 167 | output = subprocess.check_output(['java', '-jar', jar_path, old_sample_path, 168 | new_sample_path, old_xml_path, new_xml_path, actions_json]) 169 | 170 | xml_obj = ET.parse(old_xml_path) 171 | old_root = parse_xml_obj(xml_obj.getroot()[1], Indexer(), None, 'old') 172 | old_ast = AST(old_root) 173 | 174 | xml_obj = ET.parse(new_xml_path) 175 | new_root = parse_xml_obj(xml_obj.getroot()[1], Indexer(), None, 'new') 176 | new_ast = AST(new_root) 177 | 178 | with open(actions_json) as f: 179 | actions = json.load(f) 180 | 181 | old_actions = dict() 182 | new_actions = dict() 183 | 184 | for action in actions: 185 | location_id = '{}-{}-{}-{}'.format(action['type'], action['label'], action['position'], action['length']) 186 | if action['action'] == 'Insert': 187 | new_actions[location_id] = action['action'] 188 | else: 189 | old_actions[location_id] = action['action'] 190 | 191 | old_nodes = old_ast.nodes 192 | old_diff_nodes = [] 193 | for n in old_nodes: 194 | old_diff_node = DiffTreeNode(n.value, n.attribute, n.src, n.is_leaf) 195 | if n.location_id in old_actions: 196 | old_diff_node.action_type = old_actions[n.location_id] 197 | old_diff_nodes.append(old_diff_node) 198 | 199 | old_diff_nodes_by_alignment = dict() 200 | for n, old_node in enumerate(old_nodes): 201 | old_diff_node = old_diff_nodes[n] 202 | if old_node.parent: 203 | old_diff_node.parents.append(old_diff_nodes[old_node.parent.node_id]) 204 | 205 | for c in old_node.children: 206 | old_diff_node.children.append(old_diff_nodes[c.node_id]) 207 | 208 | if old_node.prev_sibling: 209 | old_diff_node.prev_siblings.append(old_diff_nodes[old_node.prev_sibling.node_id]) 210 | 211 | if old_node.next_sibling: 212 | old_diff_node.next_siblings.append(old_diff_nodes[old_node.next_sibling.node_id]) 213 | 214 | if old_node.alignment_id: 215 | if old_node.alignment_id not in old_diff_nodes_by_alignment: 216 | old_diff_nodes_by_alignment[old_node.alignment_id] = [] 217 | old_diff_nodes_by_alignment[old_node.alignment_id].append(old_diff_node) 218 | 219 | new_nodes = new_ast.nodes 220 | new_diff_nodes = [] 221 | 222 | for n, new_node in enumerate(new_nodes): 223 | if new_node.alignment_id in old_diff_nodes_by_alignment and len(old_diff_nodes_by_alignment[new_node.alignment_id]) > 0: 224 | old_diff_node = old_diff_nodes_by_alignment[new_node.alignment_id].pop(0) 225 | if new_node.value == old_diff_node.value: 226 | new_diff_node = old_diff_node 227 | new_diff_node.src = 'both' 228 | new_diff_nodes.append(new_diff_node) 229 | else: 230 | new_diff_node = DiffTreeNode(new_node.value, new_node.attribute, new_node.src, new_node.is_leaf) 231 | new_diff_node.aligned_neighbors.append(old_diff_node) 232 | old_diff_node.aligned_neighbors.append(new_diff_node) 233 | new_diff_node.action_type = old_diff_node.action_type 234 | 235 | if new_node.location_id in new_actions: 236 | new_diff_node.action_type = new_actions[new_node.location_id] 237 | 238 | new_diff_nodes.append(new_diff_node) 239 | else: 240 | new_diff_node = DiffTreeNode(new_node.value, new_node.attribute, new_node.src, new_node.is_leaf) 241 | if new_node.location_id in new_actions: 242 | new_diff_node.action_type = new_actions[new_node.location_id] 243 | new_diff_nodes.append(new_diff_node) 244 | 245 | for n, new_node in enumerate(new_nodes): 246 | new_diff_node = new_diff_nodes[n] 247 | if new_node.parent and new_diff_nodes[new_node.parent.node_id] not in new_diff_node.parents: 248 | new_diff_node.parents.append(new_diff_nodes[new_node.parent.node_id]) 249 | 250 | for c in new_node.children: 251 | if new_diff_nodes[c.node_id] not in new_diff_node.children: 252 | new_diff_node.children.append(new_diff_nodes[c.node_id]) 253 | 254 | if new_node.prev_sibling and new_diff_nodes[new_node.prev_sibling.node_id] not in new_diff_node.prev_siblings: 255 | new_diff_node.prev_siblings.append(new_diff_nodes[new_node.prev_sibling.node_id]) 256 | 257 | if new_node.next_sibling and new_diff_nodes[new_node.next_sibling.node_id] not in new_diff_node.next_siblings: 258 | new_diff_node.next_siblings.append(new_diff_nodes[new_node.next_sibling.node_id]) 259 | 260 | super_root = DiffTreeNode('SuperRoot', 'SuperRoot', 'both', False) 261 | super_root.children.append(old_diff_nodes[0]) 262 | old_diff_nodes[0].parents.append(super_root) 263 | 264 | if old_diff_nodes[0] != new_diff_nodes[0]: 265 | super_root.children.append(new_diff_nodes[0]) 266 | new_diff_nodes[0].parents.append(super_root) 267 | 268 | diff_ast = DiffAST(super_root) 269 | return diff_ast 270 | 271 | if __name__ == "__main__": 272 | parser = argparse.ArgumentParser() 273 | parser.add_argument('--old_sample_path', help='path to java file containing old version of method') 274 | parser.add_argument('--new_sample_path', help='path to java file containing new version of method') 275 | parser.add_argument('--jar_path', help='path to downloaded jar file') 276 | args = parser.parse_args() 277 | 278 | logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s') 279 | logging.basicConfig(level=logging.ERROR, format='%(asctime)-15s %(message)s') 280 | 281 | XML_DIR = 'xml_files/' 282 | os.makedirs(XML_DIR, exist_ok=True) 283 | 284 | old_ast, new_ast = get_individual_ast_objs(args.old_sample_path, args.new_sample_path, 'old_new_ast_actions.json', args.jar_path) 285 | diff_ast = get_diff_ast(args.old_sample_path, args.new_sample_path, 'diff_ast_actions.json', args.jar_path) 286 | 287 | print(diff_ast.to_json()) -------------------------------------------------------------------------------- /data_processing/build_example.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from data_formatting_utils import subtokenize_code, subtokenize_comment 4 | 5 | sys.path.append('../') 6 | sys.path.append('../comment_update') 7 | from data_utils import DiffASTExample 8 | from diff_utils import compute_minimal_comment_diffs, compute_code_diffs 9 | 10 | # NOTE: Javalang will need to installed for this 11 | def build_test_example(): 12 | example_id = 'test-id-0' 13 | label = 1 14 | comment_type = 'Return' 15 | old_comment_raw = '@return the highest score' 16 | old_comment_subtokens = subtokenize_comment(old_comment_raw).split() 17 | new_comment_raw = '@return the lowest score' 18 | new_comment_subtokens = subtokenize_comment(new_comment_raw).split() 19 | span_minimal_diff_comment_subtokens, _, _ = compute_minimal_comment_diffs( 20 | old_comment_subtokens, new_comment_subtokens) 21 | old_code_raw = 'public int getBestScore()\n{\n\treturn Collections.max(scores);\n}' 22 | old_code_subtokens = subtokenize_code(old_code_raw).split() 23 | new_code_raw = 'public int getBestScore()\n{\n\treturn Collections.min(scores);\n}' 24 | new_code_subtokens = subtokenize_code(new_code_raw).split() 25 | span_diff_code_subtokens, token_diff_code_subtokens, _ = compute_code_diffs(old_code_subtokens, new_code_subtokens) 26 | 27 | # TODO: Add code for parsing ASTs 28 | old_ast = None 29 | new_ast = None 30 | diff_ast = None 31 | 32 | return DiffASTExample(example_id, label, comment_type, old_comment_raw, old_comment_subtokens, new_comment_raw, 33 | new_comment_subtokens, span_minimal_diff_comment_subtokens, old_code_raw, old_code_subtokens, new_code_raw, 34 | new_code_subtokens, span_diff_code_subtokens, token_diff_code_subtokens, old_ast, new_ast, diff_ast) -------------------------------------------------------------------------------- /data_processing/data_formatting_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import javalang 3 | import json 4 | import numpy as np 5 | import os 6 | import random 7 | import re 8 | import string 9 | 10 | SPECIAL_TAGS = ['{', '}', '@code', '@docRoot', '@inheritDoc', '@link', '@linkplain', '@value'] 11 | 12 | def remove_html_tag(line): 13 | clean = re.compile('<.*?>') 14 | line = re.sub(clean, '', line) 15 | 16 | for tag in SPECIAL_TAGS: 17 | line = line.replace(tag, '') 18 | 19 | return line 20 | 21 | def remove_tag_string(line): 22 | search_strings = ['@return', '@ return', '@param', '@ param', '@throws', '@ throws'] 23 | for s in search_strings: 24 | line = line.replace(s, '').strip() 25 | return line 26 | 27 | def tokenize_comment(comment_line, remove_tag=True): 28 | if remove_tag: 29 | comment_line = remove_tag_string(comment_line) 30 | comment_line = remove_html_tag(comment_line) 31 | comment_line = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", comment_line.strip()) 32 | comment_line = ' '.join(comment_line) 33 | comment_line = comment_line.replace('\n', ' ').strip() 34 | 35 | return comment_line 36 | 37 | def subtokenize_comment(comment_line, remove_tag=True): 38 | if remove_tag: 39 | comment_line = remove_tag_string(comment_line) 40 | comment_line = remove_html_tag(comment_line.replace('/**', '').replace('**/', '').replace('/*', '').replace('*/', '').replace('*', '').strip()) 41 | comment_line = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", comment_line.strip()) 42 | comment_line = ' '.join(comment_line) 43 | comment_line = comment_line.replace('\n', ' ').strip() 44 | 45 | tokens = comment_line.split(' ') 46 | subtokens = [] 47 | for token in tokens: 48 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split() 49 | try: 50 | new_curr = [] 51 | for c in curr: 52 | by_symbol = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", c.strip()) 53 | new_curr = new_curr + by_symbol 54 | 55 | curr = new_curr 56 | except: 57 | curr = [] 58 | subtokens = subtokens + [c.lower() for c in curr] 59 | 60 | comment_line = ' '.join(subtokens) 61 | return comment_line.lower() 62 | 63 | def subtokenize_code(line): 64 | try: 65 | tokens = get_clean_code(list(javalang.tokenizer.tokenize(line))) 66 | except: 67 | tokens = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", line.strip()) 68 | subtokens = [] 69 | for token in tokens: 70 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split() 71 | subtokens = subtokens + [c.lower() for c in curr] 72 | 73 | return ' '.join(subtokens) 74 | 75 | def tokenize_code(line): 76 | try: 77 | tokens = [t.value for t in list(javalang.tokenizer.tokenize(line))] 78 | return ' '.join(tokens) 79 | except: 80 | return tokenize_clean_code(line) 81 | 82 | def tokenize_clean_code(line): 83 | try: 84 | return ' '.join(get_clean_code(list(javalang.tokenizer.tokenize(line)))) 85 | except: 86 | return ' '.join(re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", line.strip())) 87 | 88 | def get_clean_code(tokenized_code): 89 | token_vals = [t.value for t in tokenized_code] 90 | new_token_vals = [] 91 | for t in token_vals: 92 | n = [c for c in re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", t.encode('ascii', errors='ignore').decode().strip()) if len(c) > 0] 93 | new_token_vals = new_token_vals + n 94 | 95 | token_vals = new_token_vals 96 | cleaned_code_tokens = [] 97 | 98 | for c in token_vals: 99 | try: 100 | cleaned_code_tokens.append(str(c)) 101 | except: 102 | pass 103 | 104 | return cleaned_code_tokens -------------------------------------------------------------------------------- /data_processing/high_level_feature_extractor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | 6 | from build_example import build_test_example 7 | from data_formatting_utils import tokenize_clean_code, subtokenize_code 8 | 9 | sys.path.append('../') 10 | from diff_utils import is_edit_keyword, KEEP, DELETE, INSERT, REPLACE_OLD, REPLACE_NEW 11 | 12 | EDIT_INDICES = [KEEP, DELETE, INSERT, REPLACE_OLD, REPLACE_NEW] 13 | 14 | def extract_arguments(code_block): 15 | i = 0 16 | while i < len(code_block): 17 | line = code_block[i].strip() 18 | if len(line.strip()) == 0: 19 | i += 1 20 | continue 21 | if line[0] == '@' and ' ' not in line: 22 | i += 1 23 | continue 24 | if '//' in line or '*' in line: 25 | i += 1 26 | continue 27 | else: 28 | break 29 | 30 | argument_string = line[line.index('(')+1:] 31 | 32 | if argument_string.count('(') + 1 == argument_string.count(')'): 33 | argument_string = argument_string[:argument_string.rfind(')')] 34 | else: 35 | curr_open_count = argument_string.count('(') + 1 36 | curr_close_count = argument_string.count(')') 37 | i += 1 38 | extension = '' 39 | while i < len(code_block): 40 | for w in code_block[i].strip(): 41 | extension += w 42 | if w == '(': 43 | curr_open_count += 1 44 | elif w == ')': 45 | curr_close_count += 1 46 | if curr_open_count == curr_close_count: 47 | break 48 | if curr_open_count == curr_close_count: 49 | break 50 | i += 1 51 | 52 | if curr_open_count != curr_close_count: 53 | raise ValueError('Invalid arguments') 54 | 55 | argument_string = argument_string + extension[:-1] 56 | 57 | argument_types = [] 58 | argument_names = [] 59 | 60 | argument_string = ' '.join([a for a in argument_string.split() if '@' not in a]) 61 | terms = [] 62 | a = 0 63 | curr_term = [] 64 | 65 | open_count = 0 66 | close_count = 0 67 | 68 | while a < len(argument_string): 69 | t = argument_string[a] 70 | if t == ' ' and open_count == close_count: 71 | terms.append(''.join(curr_term).strip()) 72 | curr_term = [] 73 | a += 1 74 | continue 75 | if t == ',' and open_count == close_count: 76 | curr_term.append(t) 77 | terms.append(''.join(curr_term).strip()) 78 | curr_term = [] 79 | a += 1 80 | continue 81 | 82 | if t == ',' and open_count != close_count: 83 | a += 1 84 | continue 85 | 86 | if t == '<': 87 | open_count += 1 88 | 89 | if t == '>': 90 | close_count += 1 91 | 92 | curr_term.append(t) 93 | a += 1 94 | 95 | if len(curr_term) > 0: 96 | terms.append(''.join(curr_term).strip()) 97 | 98 | terms = [t for t in terms if t not in ['private', 'protected', 'public', 'final', 'static']] 99 | arguments = ' '.join(terms).split(',') 100 | arguments = [a.strip() for a in arguments if len(a.strip()) > 0] 101 | for argument in arguments: 102 | argument_tokens = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", argument.strip()) 103 | argument_types.append(argument_tokens[0]) 104 | argument_names.append(argument_tokens[-1]) 105 | 106 | return argument_names, argument_types 107 | 108 | def strip_comment(s): 109 | """Checks whether a single line follows the structure of a comment.""" 110 | new_s = re.sub(r'\"(.+?)\"', '', s) 111 | matched_obj = re.findall("(?:/\\*(?:[^*]|(?:\\*+[^*/]))*\\*+/)|(?://.*)", new_s) 112 | url_match = re.findall('https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+', new_s) 113 | file_match = re.findall('^(.*/)?(?:$|(.+?)(?:(\.[^.]*$)|$))', new_s) 114 | 115 | if matched_obj and not url_match: 116 | for m in matched_obj: 117 | s = s.replace(m, ' ') 118 | return s.strip() 119 | 120 | def extract_return_statements(code_block): 121 | cleaned_lines = [] 122 | for l in code_block: 123 | cleaned_l = strip_comment(l) 124 | if len(cleaned_l) > 0: 125 | cleaned_lines.append(cleaned_l) 126 | 127 | combined_block = ' '.join(cleaned_lines) 128 | if 'return' not in combined_block: 129 | return [] 130 | indices = [m.start() for m in re.finditer('return ', combined_block)] 131 | return_statements = [] 132 | for idx in indices: 133 | s_idx = idx + len('return ') 134 | e_idx = s_idx + combined_block[s_idx:].index(';') 135 | statement = combined_block[s_idx:e_idx].strip() 136 | if len(statement) > 0: 137 | return_statements.append(statement) 138 | 139 | return return_statements 140 | 141 | 142 | def is_operator(token): 143 | for s in token: 144 | if s.isalnum(): 145 | return False 146 | return True 147 | 148 | def extract_method_name(code_block): 149 | i = 0 150 | while i < len(code_block): 151 | line = code_block[i].strip() 152 | if len(line.strip()) == 0: 153 | i += 1 154 | continue 155 | if line[0] == '@' and ' ' not in line: 156 | i += 1 157 | continue 158 | if '//' in line or '*' in line: 159 | i += 1 160 | continue 161 | else: 162 | break 163 | 164 | try: 165 | method_components = line.strip().split('(')[0].split(' ') 166 | method_components = [m for m in method_components if len(m) > 0] 167 | method_name = method_components[-1].strip() 168 | except: 169 | method_name = '' 170 | 171 | return method_name 172 | 173 | def extract_return_type(code_block): 174 | i = 0 175 | while i < len(code_block): 176 | line = code_block[i].strip() 177 | if len(line.strip()) == 0: 178 | i += 1 179 | continue 180 | if line[0] == '@': 181 | i += 1 182 | continue 183 | if '//' in line or '*' in line: 184 | i += 1 185 | continue 186 | else: 187 | break 188 | 189 | before_method_name_tokens = line.split('(')[0].split(' ')[:-1] 190 | return_type_tokens = [] 191 | for tok in before_method_name_tokens: 192 | if tok not in ['private', 'protected', 'public', 'final', 'static']: 193 | return_type_tokens.append(tok) 194 | return ' '.join(return_type_tokens) 195 | 196 | def get_change_labels(tokens): 197 | cache = dict() 198 | for label in EDIT_INDICES: 199 | cache[label] = set() 200 | 201 | label = None 202 | for t in tokens: 203 | if is_edit_keyword(t): 204 | label = t 205 | elif is_operator(t): 206 | continue 207 | else: 208 | cache[label].add(t) 209 | 210 | for label, label_set in cache.items(): 211 | cache[label] = list(label_set) 212 | return cache 213 | 214 | def extract_throwable_exceptions(code_block): 215 | i = 0 216 | while i < len(code_block): 217 | line = code_block[i].strip() 218 | if 'throws' in line: 219 | break 220 | i += 1 221 | 222 | if 'throws' not in line: 223 | return [] 224 | 225 | throws_string = line[line.index('throws') + len('throws'):] 226 | if '{' in throws_string: 227 | throws_string = throws_string[:throws_string.index('{')] 228 | else: 229 | extension = '' 230 | i += 1 231 | while i < len(code_block): 232 | line = code_block[i].strip() 233 | if len(line) == 0: 234 | i += 1 235 | continue 236 | for w in line: 237 | if w == '{': 238 | break 239 | else: 240 | extension += w 241 | if w == '{': 242 | break 243 | else: 244 | i += 1 245 | 246 | throws_string += extension 247 | 248 | exception_tokens = [t for t in tokenize_clean_code(throws_string).split() if not is_operator(t)] 249 | return exception_tokens 250 | 251 | def extract_throw_statements(code_block): 252 | cleaned_lines = [] 253 | for l in code_block: 254 | cleaned_l = strip_comment(l) 255 | if len(cleaned_l) > 0: 256 | cleaned_lines.append(cleaned_l) 257 | 258 | combined_block = ' '.join(cleaned_lines) 259 | if 'throw' not in combined_block: 260 | return [] 261 | indices = [m.start() for m in re.finditer('throw ', combined_block)] 262 | throw_statements = [] 263 | for idx in indices: 264 | s_idx = idx + len('throw ') 265 | e_idx = s_idx + combined_block[s_idx:].index(';') 266 | statement = combined_block[s_idx:e_idx].strip() 267 | if len(statement) > 0: 268 | throw_statements.append(statement) 269 | 270 | return throw_statements 271 | 272 | def get_method_elements(code_block): 273 | argument_names, argument_types = extract_arguments(code_block) 274 | return_statements = extract_return_statements(code_block) 275 | return_type = extract_return_type(code_block) 276 | 277 | throwable_exception_tokens = extract_throwable_exceptions(code_block) 278 | throwable_exception_subtokens = [] 279 | for throwable_exception in throwable_exception_tokens: 280 | throwable_exception_subtokens.extend(subtokenize_code(throwable_exception).split()) 281 | 282 | throw_statements = extract_throw_statements(code_block) 283 | throw_statement_tokens = [] 284 | throw_statement_subtokens = [] 285 | for throw_statement in throw_statements: 286 | throw_statement_tokens.extend([t for t in tokenize_clean_code(throw_statement).split() if not is_operator(t)]) 287 | throw_statement_subtokens.extend([t for t in subtokenize_code(throw_statement).split() if not is_operator(t)]) 288 | 289 | argument_name_tokens = [] 290 | argument_name_subtokens = [] 291 | argument_type_tokens = [] 292 | argument_type_subtokens = [] 293 | 294 | for argument_name in argument_names: 295 | argument_name_tokens.extend([t for t in tokenize_clean_code(argument_name).split() if not is_operator(t)]) 296 | argument_name_subtokens.extend([t for t in subtokenize_code(argument_name).split() if not is_operator(t)]) 297 | 298 | for argument_type in argument_types: 299 | argument_type_tokens.extend([t for t in tokenize_clean_code(argument_type).split() if not is_operator(t)]) 300 | argument_type_subtokens.extend([t for t in subtokenize_code(argument_type).split() if not is_operator(t)]) 301 | 302 | return_statement_tokens = [] 303 | return_statement_subtokens = [] 304 | for return_statement in return_statements: 305 | return_statement_tokens.extend([t for t in tokenize_clean_code(return_statement).split() if not is_operator(t)]) 306 | return_statement_subtokens.extend([t for t in subtokenize_code(return_statement).split() if not is_operator(t)]) 307 | 308 | return_type_tokens = [t for t in tokenize_clean_code(return_type).split() if not is_operator(t)] 309 | return_type_subtokens = [t for t in subtokenize_code(return_type).split() if not is_operator(t)] 310 | 311 | method_name = extract_method_name(code_block) 312 | method_name_tokens = [method_name] 313 | method_name_subtokens = subtokenize_code(method_name).split() 314 | 315 | token_elements = { 316 | 'argument_name': argument_name_tokens, 317 | 'argument_type': argument_type_tokens, 318 | 'return_type': return_type_tokens, 319 | 'return_statement': return_statement_tokens, 320 | 'throwable_exception': throwable_exception_tokens, 321 | 'throw_statement': throw_statement_tokens, 322 | 'method_name': method_name_tokens 323 | } 324 | 325 | subtoken_elements = { 326 | 'argument_name': argument_name_subtokens, 327 | 'argument_type': argument_type_subtokens, 328 | 'return_type': return_type_subtokens, 329 | 'return_statement': return_statement_subtokens, 330 | 'throwable_exception': throwable_exception_subtokens, 331 | 'throw_statement': throw_statement_subtokens, 332 | 'method_name': method_name_subtokens 333 | } 334 | 335 | return { 336 | 'token': token_elements, 337 | 'subtoken': subtoken_elements 338 | } 339 | 340 | if __name__ == "__main__": 341 | # Demo for extracting high level features for one example 342 | # Corresponds to what is written in high_level_features.json files 343 | 344 | ex = build_test_example() 345 | cache = dict() 346 | cache[ex.id] = { 347 | 'old': get_method_elements(ex.old_code_raw.split('\n')), 348 | 'new': get_method_elements(ex.new_code_raw.split('\n')), 349 | 'code_change_labels': {'subtoken': get_change_labels(ex.token_diff_code_subtokens)} 350 | } -------------------------------------------------------------------------------- /data_processing/tokenization_feature_extractor.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | import javalang 3 | import json 4 | import os 5 | import re 6 | import sys 7 | 8 | from build_example import build_test_example 9 | from data_formatting_utils import subtokenize_code, tokenize_clean_code, get_clean_code,\ 10 | subtokenize_comment, tokenize_comment 11 | 12 | sys.path.append('../') 13 | from diff_utils import is_edit_keyword, KEEP, KEEP_END, REPLACE_OLD, REPLACE_NEW,\ 14 | REPLACE_END, INSERT, INSERT_END, DELETE, DELETE_END, compute_code_diffs 15 | 16 | def subtokenize_token(token, parse_comment=False): 17 | if parse_comment and token in ['@return', '@param', '@throws']: 18 | return [token] 19 | if is_edit_keyword(token): 20 | return [token] 21 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split() 22 | 23 | try: 24 | new_curr = [] 25 | for t in curr: 26 | new_curr.extend([c for c in re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", t.encode('ascii', errors='ignore').decode().strip()) if len(c) > 0]) 27 | curr = new_curr 28 | except: 29 | pass 30 | try: 31 | new_curr = [] 32 | for c in curr: 33 | by_symbol = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", c.strip()) 34 | new_curr = new_curr + by_symbol 35 | 36 | curr = new_curr 37 | except: 38 | curr = [] 39 | subtokens = [c.lower() for c in curr] 40 | 41 | return subtokens 42 | 43 | def get_subtoken_labels(gold_subtokens, tokens, parse_comment=False): 44 | labels = [] 45 | indices = [] 46 | all_subtokens = [] 47 | 48 | token_map = [] 49 | subtoken_map = [] 50 | 51 | gold_idx = 0 52 | 53 | for token in tokens: 54 | subtokens = subtokenize_token(token, parse_comment) 55 | all_subtokens.extend(subtokens) 56 | token_map.append(subtokens) 57 | if len(subtokens) == 1: 58 | label = 0 59 | labels.append(label) 60 | indices.append(0) 61 | subtoken_map.append([token]) 62 | else: 63 | label = 1 64 | for s, subtoken in enumerate(subtokens): 65 | labels.append(label) 66 | indices.append(s) 67 | subtoken_map.append([token]) 68 | try: 69 | assert len(labels) == len(gold_subtokens) 70 | assert len(indices) == len(gold_subtokens) 71 | assert len(token_map) == len(tokens) 72 | assert len(subtoken_map) == len(gold_subtokens) 73 | except: 74 | print(tokens) 75 | print('\n') 76 | print(gold_subtokens) 77 | print('\n') 78 | for s, subtoken in enumerate(all_subtokens): 79 | print('Parsed: {}'.format(subtoken)) 80 | print('True: {}'.format(gold_subtokens[s])) 81 | print('---------------------------------') 82 | if subtoken != gold_subtokens[s]: 83 | break 84 | print(len(labels)) 85 | print(len(gold_subtokens)) 86 | raise ValueError('stop') 87 | return labels, indices, token_map, subtoken_map 88 | 89 | def get_code_subtoken_labels(gold_subtokens, tokens, raw_code): 90 | labels = [] 91 | indices = [] 92 | all_subtokens = [] 93 | 94 | token_map = [] 95 | subtoken_map = [] 96 | 97 | for token in tokens: 98 | if is_edit_keyword(token): 99 | token_map.append([token]) 100 | else: 101 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split() 102 | new_curr = [] 103 | for c in curr: 104 | by_symbol = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", c.strip()) 105 | new_curr = new_curr + by_symbol 106 | token_map.append([s.lower() for s in new_curr]) 107 | 108 | try: 109 | parsed_tokens = get_clean_code(list(javalang.tokenizer.tokenize(raw_code))) 110 | except: 111 | parsed_tokens = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", raw_code.strip()) 112 | 113 | subtokens = [] 114 | for t, token in enumerate(parsed_tokens): 115 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split() 116 | subtokens = [c.lower() for c in curr] 117 | all_subtokens.extend(subtokens) 118 | if len(subtokens) == 1: 119 | label = 0 120 | labels.append(label) 121 | indices.append(0) 122 | subtoken_map.append([token]) 123 | else: 124 | label = 1 125 | for s, subtoken in enumerate(subtokens): 126 | labels.append(label) 127 | indices.append(s) 128 | subtoken_map.append([token]) 129 | try: 130 | assert len(labels) == len(gold_subtokens) 131 | assert len(indices) == len(gold_subtokens) 132 | assert len(token_map) == len(tokens) 133 | assert len(subtoken_map) == len(gold_subtokens) 134 | except: 135 | print(tokens) 136 | print('\n') 137 | print(gold_subtokens) 138 | print('\n') 139 | for s, subtoken in enumerate(all_subtokens): 140 | print('Parsed: {}'.format(subtoken)) 141 | print('True: {}'.format(gold_subtokens[s])) 142 | print('---------------------------------') 143 | if subtoken != gold_subtokens[s]: 144 | break 145 | print(len(labels)) 146 | print(len(gold_subtokens)) 147 | raise ValueError('stop') 148 | return labels, indices, token_map, subtoken_map 149 | 150 | def get_diff_subtoken_labels(diff_subtokens, old_subtokens, old_tokens, new_subtokens, new_tokens, diff_tokens, old_code_raw, new_code_raw): 151 | old_labels, old_indices, old_token_map, old_subtoken_map = get_code_subtoken_labels(old_subtokens, old_tokens, old_code_raw) 152 | new_labels, new_indices, new_token_map, new_subtoken_map = get_code_subtoken_labels(new_subtokens, new_tokens, new_code_raw) 153 | 154 | diff_labels = [] 155 | diff_indices = [] 156 | 157 | diff_token_map = [] 158 | diff_subtoken_map = [] 159 | 160 | for token in diff_tokens: 161 | if is_edit_keyword(token): 162 | diff_token_map.append([token]) 163 | else: 164 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', token).split() 165 | new_curr = [] 166 | for c in curr: 167 | by_symbol = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", c.strip()) 168 | new_curr = new_curr + by_symbol 169 | diff_token_map.append([s.lower() for s in new_curr]) 170 | 171 | for edit_type, o_start, o_end, n_start, n_end in difflib.SequenceMatcher(None, old_subtokens, new_subtokens).get_opcodes(): 172 | if edit_type == 'equal': 173 | diff_labels.extend([0] + old_labels[o_start:o_end] + [0]) 174 | diff_indices.extend([0] + old_indices[o_start:o_end] + [0]) 175 | diff_subtoken_map.append([KEEP]) 176 | diff_subtoken_map.extend(old_subtoken_map[o_start:o_end]) 177 | diff_subtoken_map.append([KEEP_END]) 178 | elif edit_type == 'replace': 179 | diff_labels.extend([0] + old_labels[o_start:o_end] + [0] + new_labels[n_start:n_end] + [0]) 180 | diff_indices.extend([0] + old_indices[o_start:o_end] + [0] + new_indices[n_start:n_end] + [0]) 181 | diff_subtoken_map.append([REPLACE_OLD]) 182 | diff_subtoken_map.extend(old_subtoken_map[o_start:o_end]) 183 | diff_subtoken_map.append([REPLACE_NEW]) 184 | diff_subtoken_map.extend(new_subtoken_map[n_start:n_end]) 185 | diff_subtoken_map.append([REPLACE_END]) 186 | elif edit_type == 'insert': 187 | diff_labels.extend([0] + new_labels[n_start:n_end] + [0]) 188 | diff_indices.extend([0] + new_indices[n_start:n_end] + [0]) 189 | diff_subtoken_map.append([INSERT]) 190 | diff_subtoken_map.extend(new_subtoken_map[n_start:n_end]) 191 | diff_subtoken_map.append([INSERT_END]) 192 | else: 193 | diff_labels.extend([0] + old_labels[o_start:o_end] + [0]) 194 | diff_indices.extend([0] + old_indices[o_start:o_end] + [0]) 195 | diff_subtoken_map.append([DELETE]) 196 | diff_subtoken_map.extend(old_subtoken_map[o_start:o_end]) 197 | diff_subtoken_map.append([DELETE_END]) 198 | 199 | assert len(diff_labels) == len(diff_subtokens) 200 | assert len(diff_indices) == len(diff_subtokens) 201 | assert len(diff_subtoken_map) == len(diff_subtokens) 202 | assert len(diff_token_map) == len(diff_tokens) 203 | return diff_labels, diff_indices, diff_token_map, diff_subtoken_map 204 | 205 | if __name__ == "__main__": 206 | # Demo for extracting tokenization features for one example 207 | # Corresponds to what is written in tokenization_features.json files 208 | ex = build_test_example() 209 | 210 | old_code_tokens = tokenize_clean_code(ex.old_code_raw).split() 211 | new_code_tokens = tokenize_clean_code(ex.new_code_raw).split() 212 | span_diff_code_tokens, _, _ = compute_code_diffs(old_code_tokens, new_code_tokens) 213 | 214 | edit_span_subtoken_labels, edit_span_subtoken_indices, edit_span_token_map, edit_span_subtoken_map = get_diff_subtoken_labels( 215 | ex.span_diff_code_subtokens, ex.old_code_subtokens, old_code_tokens, ex.new_code_subtokens, new_code_tokens, 216 | span_diff_code_tokens, ex.old_code_raw, ex.new_code_raw) 217 | 218 | old_comment_tokens = tokenize_comment(ex.old_comment_raw).split() 219 | 220 | prefix = [] 221 | if ex.comment_type == 'Return': 222 | prefix = ['@return'] 223 | elif ex.comment_type == 'Param': 224 | prefix = ['@param'] 225 | 226 | old_nl_subtoken_labels, old_nl_subtoken_indices, old_nl_token_map, old_nl_subtoken_map = get_subtoken_labels( 227 | prefix + ex.old_comment_subtokens, prefix + old_comment_tokens, parse_comment=True) 228 | 229 | cache = dict() 230 | cache[ex.id] = { 231 | 'old_nl_subtoken_labels': old_nl_subtoken_labels, 232 | 'old_nl_subtoken_indices': old_nl_subtoken_indices, 233 | 'edit_span_subtoken_labels': edit_span_subtoken_labels, 234 | 'edit_span_subtoken_indices': edit_span_subtoken_indices, 235 | 'old_nl_token_map': old_nl_token_map, 236 | 'old_nl_subtoken_map': old_nl_subtoken_map, 237 | 'edit_span_token_map': edit_span_token_map, 238 | 'edit_span_subtoken_map': edit_span_subtoken_map 239 | } -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from enum import Enum 3 | import json 4 | import numpy as np 5 | import re 6 | import torch 7 | from typing import List, NamedTuple 8 | 9 | from external_cache import get_node_features 10 | 11 | @enum.unique 12 | class CommentCategory(Enum): 13 | Return = 0 14 | Param = 1 15 | Summary = 2 16 | 17 | @enum.unique 18 | class DiffEdgeType(Enum): 19 | PARENT = 0 20 | CHILD = 1 21 | SUBTOKEN_CHILD = 2 22 | SUBTOKEN_PARENT = 3 23 | PREV_SUBTOKEN = 4 24 | NEXT_SUBTOKEN = 5 25 | ALIGNED_NEIGHBOR = 6 26 | 27 | @enum.unique 28 | class SrcType(Enum): 29 | KEEP = 0 30 | INSERT = 1 31 | DELETE = 2 32 | REPLACE_OLD = 3 33 | REPLACE_NEW = 4 34 | MOVE = 5 35 | 36 | class DiffTreeNode: 37 | def __init__(self, value, attribute, src, is_leaf): 38 | self.value = value 39 | self.node_id = -1 40 | self.parents = [] 41 | self.attribute = attribute 42 | self.src = src 43 | self.is_leaf = is_leaf 44 | self.children = [] 45 | self.prev_siblings = [] 46 | self.next_siblings = [] 47 | self.aligned_neighbors = [] 48 | self.action_type = None 49 | self.prev_tokens = [] 50 | self.next_tokens = [] 51 | self.subtokens = [] 52 | 53 | self.subtoken_children = [] 54 | self.subtoken_parents = [] 55 | self.prev_subtokens = [] 56 | self.next_subtokens = [] 57 | 58 | def to_json(self): 59 | return { 60 | 'value': self.value, 61 | 'node_id': self.node_id, 62 | 'parent_ids': [p.node_id for p in self.parents], 63 | 'attribute': self.attribute, 64 | 'src': self.src, 65 | 'is_leaf': self.is_leaf, 66 | 'children_ids': [c.node_id for c in self.children], 67 | 'prev_sibling_ids': [p.node_id for p in self.prev_siblings], 68 | 'next_sibling_ids': [n.node_id for n in self.next_siblings], 69 | 'aligned_neighbor_ids': [n.node_id for n in self.aligned_neighbors], 70 | 'action_type': self.action_type, 71 | } 72 | 73 | @property 74 | def is_identifier(self): 75 | return self.is_leaf and self.attribute == 'SimpleName' 76 | 77 | class DiffAST: 78 | def __init__(self, ast_root): 79 | self.node_cache = set() 80 | self.root = ast_root 81 | self.nodes = [] 82 | self.traverse(self.root) 83 | 84 | def traverse(self, curr_node): 85 | if curr_node not in self.node_cache: 86 | self.node_cache.add(curr_node) 87 | curr_node.node_id = len(self.nodes) 88 | self.nodes.append(curr_node) 89 | for child in curr_node.subtoken_children: 90 | self.traverse(child) 91 | for child in curr_node.children: 92 | self.traverse(child) 93 | 94 | def to_json(self): 95 | return [n.to_json() for n in self.nodes] 96 | 97 | @property 98 | def leaves(self): 99 | return [n for n in self.nodes if n.is_leaf] 100 | 101 | @classmethod 102 | def from_json(cls, obj): 103 | nodes = [] 104 | for node_obj in obj: 105 | node = DiffTreeNode(node_obj['value'], node_obj['attribute'], node_obj['src'], False) 106 | if 'action_type' in node_obj: 107 | node.action_type = node_obj['action_type'] 108 | nodes.append(node) 109 | 110 | new_nodes = [] 111 | 112 | for n, node_obj in enumerate(obj): 113 | nodes[n].parents = [nodes[i] for i in node_obj['parent_ids']] 114 | nodes[n].children = [nodes[i] for i in node_obj['children_ids']] 115 | nodes[n].prev_siblings = [nodes[i] for i in node_obj['prev_sibling_ids']] 116 | nodes[n].next_siblings = [nodes[i] for i in node_obj['next_sibling_ids']] 117 | nodes[n].aligned_neighbors = [nodes[i] for i in node_obj['aligned_neighbor_ids']] 118 | new_nodes.append(nodes[n]) 119 | 120 | if len(nodes[n].children) == 0: 121 | nodes[n].is_leaf = True 122 | curr = re.sub('([a-z0-9])([A-Z])', r'\1 \2', nodes[n].value).split() 123 | new_curr = [] 124 | for c in curr: 125 | by_symbol = re.findall(r"[a-zA-Z0-9]+|[^\sa-zA-Z0-9]|[^_\sa-zA-Z0-9]", c.strip()) 126 | new_curr = new_curr + by_symbol 127 | nodes[n].subtokens = [s.lower() for s in new_curr] 128 | 129 | if len(nodes[n].subtokens) > 1: 130 | for s in nodes[n].subtokens: 131 | sub_node = DiffTreeNode(s, '', nodes[n].src, True) 132 | sub_node.action_type = nodes[n].action_type 133 | sub_node.subtoken_parents.append(nodes[n]) 134 | 135 | if len(nodes[n].subtoken_children) > 0: 136 | nodes[n].subtoken_children[-1].next_subtokens.append(sub_node) 137 | sub_node.prev_subtokens.append(nodes[n].subtoken_children[-1]) 138 | 139 | nodes[n].subtoken_children.append(sub_node) 140 | new_nodes.append(sub_node) 141 | 142 | nodes[n].value = nodes[n].value.lower() 143 | 144 | return cls(new_nodes[0]) 145 | 146 | def insert_graph(batch, ex, ast, vocabulary, use_features, max_ast_length): 147 | batch.root_ids.append(batch.num_nodes) 148 | graph_node_positions = [] 149 | for n, node in enumerate(ast.nodes): 150 | batch.graph_ids.append(batch.num_graphs) 151 | batch.is_internal.append(not node.is_leaf) 152 | batch.value_lookup_ids.append(vocabulary.get_id_or_unk(node.value)) 153 | 154 | if node.action_type == 'Insert': 155 | src_type = SrcType.INSERT 156 | elif node.action_type == 'Delete': 157 | src_type = SrcType.DELETE 158 | elif node.action_type == 'Move': 159 | src_type = SrcType.MOVE 160 | elif node.src == 'old' and node.action_type == 'Update': 161 | src_type = SrcType.REPLACE_OLD 162 | elif node.src == 'new' and node.action_type == 'Update': 163 | src_type = SrcType.REPLACE_NEW 164 | else: 165 | src_type = SrcType.KEEP 166 | 167 | batch.src_type_ids.append(src_type.value) 168 | graph_node_positions.append(batch.num_nodes + node.node_id) 169 | 170 | for parent in node.parents: 171 | if parent.node_id < len(ast.nodes): 172 | batch.edges[DiffEdgeType.PARENT.value].append( 173 | (batch.num_nodes + node.node_id, batch.num_nodes + parent.node_id)) 174 | 175 | for child in node.children: 176 | if child.node_id < len(ast.nodes): 177 | batch.edges[DiffEdgeType.CHILD.value].append( 178 | (batch.num_nodes + node.node_id, batch.num_nodes + child.node_id)) 179 | 180 | for subtoken_parent in node.subtoken_parents: 181 | if subtoken_parent.node_id < len(ast.nodes): 182 | batch.edges[DiffEdgeType.SUBTOKEN_PARENT.value].append( 183 | (batch.num_nodes + node.node_id, batch.num_nodes + subtoken_parent.node_id)) 184 | 185 | for subtoken_child in node.subtoken_children: 186 | if subtoken_child.node_id < len(ast.nodes): 187 | batch.edges[DiffEdgeType.SUBTOKEN_CHILD.value].append( 188 | (batch.num_nodes + node.node_id, batch.num_nodes + subtoken_child.node_id)) 189 | 190 | for next_subtoken in node.next_subtokens: 191 | if next_subtoken.node_id < len(ast.nodes): 192 | batch.edges[DiffEdgeType.NEXT_SUBTOKEN.value].append( 193 | (batch.num_nodes + node.node_id, batch.num_nodes + next_subtoken.node_id)) 194 | 195 | for prev_subtoken in node.prev_subtokens: 196 | if prev_subtoken.node_id < len(ast.nodes): 197 | batch.edges[DiffEdgeType.PREV_SUBTOKEN.value].append( 198 | (batch.num_nodes + node.node_id, batch.num_nodes + prev_subtoken.node_id)) 199 | 200 | if len(batch.edges) == len(DiffEdgeType): 201 | for aligned_neighbor in node.aligned_neighbors: 202 | if aligned_neighbor.node_id < len(ast.nodes): 203 | batch.edges[DiffEdgeType.ALIGNED_NEIGHBOR.value].append( 204 | (batch.num_nodes + node.node_id, batch.num_nodes + aligned_neighbor.node_id)) 205 | 206 | if use_features: 207 | node_features = get_node_features(ast.nodes, ex, max_ast_length) 208 | batch.node_features.extend(node_features) 209 | 210 | batch.node_positions.append(graph_node_positions) 211 | batch.num_nodes_per_graph.append(len(ast.nodes)) 212 | batch.num_nodes += len(ast.nodes) 213 | batch.num_graphs += 1 214 | return batch 215 | 216 | 217 | class GraphMethodBatch: 218 | def __init__(self, graph_ids, value_lookup_ids, src_type_ids, root_ids, is_internal, 219 | edges, num_graphs, num_nodes, node_features, node_positions, num_nodes_per_graph): 220 | self.graph_ids = graph_ids 221 | self.value_lookup_ids = value_lookup_ids 222 | self.src_type_ids = src_type_ids 223 | self.root_ids = root_ids 224 | self.is_internal = is_internal 225 | self.edges = edges 226 | self.num_graphs = num_graphs 227 | self.num_nodes = num_nodes 228 | self.node_features = node_features 229 | self.node_positions = node_positions 230 | self.num_nodes_per_graph = num_nodes_per_graph 231 | 232 | def initialize_graph_method_batch(num_edges): 233 | return GraphMethodBatch( 234 | graph_ids = [], 235 | value_lookup_ids = [], 236 | src_type_ids = [], 237 | root_ids = [], 238 | is_internal = [], 239 | edges = [[] for _ in range(num_edges)], 240 | num_graphs = 0, 241 | num_nodes = 0, 242 | node_features = [], 243 | node_positions = [], 244 | num_nodes_per_graph = [] 245 | ) 246 | 247 | def tensorize_graph_method_batch(batch, device, max_num_nodes_per_graph): 248 | node_positions = np.zeros([batch.num_graphs, max_num_nodes_per_graph], dtype=np.int64) 249 | for g in range(batch.num_graphs): 250 | graph_node_positions = batch.node_positions[g] 251 | node_positions[g,:len(graph_node_positions)] = graph_node_positions 252 | node_positions[g,len(graph_node_positions):] = batch.root_ids[g] 253 | 254 | return GraphMethodBatch( 255 | torch.tensor(batch.graph_ids, dtype=torch.int64, device=device), 256 | torch.tensor(batch.value_lookup_ids, dtype=torch.int64, device=device), 257 | torch.tensor(batch.src_type_ids, dtype=torch.int64, device=device), 258 | torch.tensor(batch.root_ids, dtype=torch.int64, device=device), 259 | torch.tensor(batch.is_internal, dtype=torch.uint8, device=device), 260 | batch.edges, batch.num_graphs, batch.num_nodes, 261 | torch.tensor(batch.node_features, dtype=torch.float32, device=device), 262 | torch.tensor(node_positions, dtype=torch.int64, device=device), 263 | torch.tensor(batch.num_nodes_per_graph, dtype=torch.int64, device=device)) 264 | 265 | class GenerationBatchData(NamedTuple): 266 | """Stores tensorized batch used in generation model.""" 267 | code_ids: torch.Tensor 268 | code_lengths: torch.Tensor 269 | trg_nl_ids: torch.Tensor 270 | trg_extended_nl_ids: torch.Tensor 271 | trg_nl_lengths: torch.Tensor 272 | invalid_copy_positions: torch.Tensor 273 | input_str_reps: List[List[str]] 274 | input_ids: List[List[str]] 275 | 276 | class UpdateBatchData(NamedTuple): 277 | """Stores tensorized batch used in edit model.""" 278 | code_ids: torch.Tensor 279 | code_lengths: torch.Tensor 280 | old_nl_ids: torch.Tensor 281 | old_nl_lengths: torch.Tensor 282 | trg_nl_ids: torch.Tensor 283 | trg_extended_nl_ids: torch.Tensor 284 | trg_nl_lengths: torch.Tensor 285 | invalid_copy_positions: torch.Tensor 286 | input_str_reps: List[List[str]] 287 | input_ids: List[List[str]] 288 | code_features: torch.Tensor 289 | nl_features: torch.Tensor 290 | labels: torch.Tensor 291 | graph_batch: GraphMethodBatch 292 | 293 | class EncoderOutputs(NamedTuple): 294 | """Stores tensorized batch used in edit model.""" 295 | encoder_hidden_states: torch.Tensor 296 | masks: torch.Tensor 297 | encoder_final_state: torch.Tensor 298 | code_hidden_states: torch.Tensor 299 | code_masks: torch.Tensor 300 | old_nl_hidden_states: torch.Tensor 301 | old_nl_masks: torch.Tensor 302 | old_nl_final_state: torch.Tensor 303 | attended_old_nl_final_state: torch.Tensor 304 | 305 | class Example(NamedTuple): 306 | """Data format for examples used in generation model.""" 307 | id: str 308 | old_comment: str 309 | old_comment_tokens: List[str] 310 | new_comment: str 311 | new_comment_tokens: List[str] 312 | old_code: str 313 | old_code_tokens: List[str] 314 | new_code: str 315 | new_code_tokens: List[str] 316 | 317 | class DiffExample(NamedTuple): 318 | id: str 319 | label: int 320 | comment_type: str 321 | old_comment_raw: str 322 | old_comment_subtokens: List[str] 323 | new_comment_raw: str 324 | new_comment_subtokens: List[str] 325 | span_minimal_diff_comment_subtokens: List[str] 326 | old_code_raw: str 327 | old_code_subtokens: List[str] 328 | new_code_raw: str 329 | new_code_subtokens: List[str] 330 | span_diff_code_subtokens: List[str] 331 | token_diff_code_subtokens: List[str] 332 | 333 | class DiffASTExample(NamedTuple): 334 | id: str 335 | label: int 336 | comment_type: str 337 | old_comment_raw: str 338 | old_comment_subtokens: List[str] 339 | new_comment_raw: str 340 | new_comment_subtokens: List[str] 341 | span_minimal_diff_comment_subtokens: List[str] 342 | old_code_raw: str 343 | old_code_subtokens: List[str] 344 | new_code_raw: str 345 | new_code_subtokens: List[str] 346 | span_diff_code_subtokens: List[str] 347 | token_diff_code_subtokens: List[str] 348 | old_ast: DiffAST 349 | new_ast: DiffAST 350 | diff_ast: DiffAST 351 | 352 | def get_processed_comment_sequence(comment_subtokens): 353 | """Returns sequence without tag string. Tag strings are excluded for evaluation purposes.""" 354 | if len(comment_subtokens) > 0 and comment_subtokens[0] in ['@param', '@return']: 355 | return comment_subtokens[1:] 356 | 357 | return comment_subtokens 358 | 359 | def get_processed_comment_str(comment_subtokens): 360 | """Returns string without tag string. Tag strings are excluded for evaluation purposes.""" 361 | return ' '.join(get_processed_comment_sequence(comment_subtokens)) 362 | 363 | def read_full_examples_from_file(filename): 364 | """Reads in data in the format used for generation model.""" 365 | with open(filename) as f: 366 | data = json.load(f) 367 | return [Example(**d) for d in data] -------------------------------------------------------------------------------- /detection_evaluation_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from sklearn.metrics import precision_recall_fscore_support 4 | 5 | def compute_average(values): 6 | return sum(values)/float(len(values)) 7 | 8 | def compute_score(predicted_labels, gold_labels, verbose=True): 9 | true_positives = 0.0 10 | true_negatives = 0.0 11 | false_positives = 0.0 12 | false_negatives = 0.0 13 | 14 | assert(len(predicted_labels) == len(gold_labels)) 15 | 16 | for i in range(len(gold_labels)): 17 | if gold_labels[i]: 18 | if predicted_labels[i]: 19 | true_positives += 1 20 | else: 21 | false_negatives += 1 22 | else: 23 | if predicted_labels[i]: 24 | false_positives += 1 25 | else: 26 | true_negatives += 1 27 | 28 | if verbose: 29 | print('True positives: {}'.format(true_positives)) 30 | print('False positives: {}'.format(false_positives)) 31 | print('True negatives: {}'.format(true_negatives)) 32 | print('False negatives: {}'.format(false_negatives)) 33 | 34 | try: 35 | precision = true_positives/(true_positives + false_positives) 36 | except: 37 | precision = 0.0 38 | try: 39 | recall = true_positives/(true_positives + false_negatives) 40 | except: 41 | recall = 0.0 42 | try: 43 | f1 = 2*((precision * recall)/(precision + recall)) 44 | except: 45 | f1 = 0.0 46 | return precision, recall, f1 -------------------------------------------------------------------------------- /detection_module.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import Counter 3 | import numpy as np 4 | import os 5 | import random 6 | import sys 7 | import torch 8 | from torch import nn 9 | 10 | from constants import * 11 | from detection_evaluation_utils import compute_score 12 | 13 | 14 | class DetectionModule(nn.Module): 15 | """Binary classification model for detecting inconsistent comments.""" 16 | def __init__(self, model_path, manager): 17 | super(DetectionModule, self).__init__() 18 | 19 | self.model_path = model_path 20 | self.manager = manager 21 | feature_input_dimension = self.manager.out_dim 22 | 23 | self.output_layer = nn.Linear(feature_input_dimension, NUM_CLASSES) 24 | self.optimizer = torch.optim.Adam(self.parameters(), lr=LR) 25 | 26 | def get_logprobs(self, encoder_outputs): 27 | """Computes the class-level log probabilities corresponding to the examples in the batch.""" 28 | logits = self.output_layer(encoder_outputs.attended_old_nl_final_state) 29 | return torch.nn.functional.log_softmax(logits, dim=-1) 30 | 31 | def compute_detection_loss(self, encoder_outputs, batch_data): 32 | """Computes the negative log likelihood loss against the gold labels corresponding to the examples in the batch.""" 33 | logprobs = self.get_logprobs(encoder_outputs) 34 | return torch.nn.functional.nll_loss(logprobs, batch_data.labels), logprobs 35 | 36 | def forward(self, batch_data): 37 | """Computes prediction loss for given batch.""" 38 | encoder_outputs = self.manager.get_encoder_output(batch_data, self.get_device()) 39 | loss, logprobs = self.compute_detection_loss(encoder_outputs, batch_data) 40 | return loss, logprobs 41 | 42 | def run_train(self, train_examples, valid_examples): 43 | """Runs training over the entire training set across several epochs. Following each epoch, 44 | F1 on the validation data is computed. If the validation F1 has improved, save the model. 45 | Early-stopping is employed to stop training if validation hasn't improved for a certain number 46 | of epochs.""" 47 | valid_batches = self.manager.get_batches(valid_examples, self.get_device()) 48 | best_loss = float('inf') 49 | best_f1 = 0.0 50 | patience_tally = 0 51 | 52 | for epoch in range(MAX_EPOCHS): 53 | if patience_tally > PATIENCE: 54 | print('Terminating: {}'.format(epoch)) 55 | break 56 | 57 | self.train() 58 | train_batches = self.manager.get_batches(train_examples, self.get_device(), shuffle=True) 59 | 60 | train_loss = 0 61 | for batch_data in train_batches: 62 | train_loss += self.run_gradient_step(batch_data) 63 | 64 | self.eval() 65 | validation_loss = 0 66 | validation_predicted_labels = [] 67 | validation_gold_labels = [] 68 | with torch.no_grad(): 69 | for batch_data in valid_batches: 70 | b_loss, b_logprobs = self.forward(batch_data) 71 | validation_loss += float(b_loss.cpu()) 72 | validation_predicted_labels.extend(b_logprobs.argmax(-1).tolist()) 73 | validation_gold_labels.extend(batch_data.labels.tolist()) 74 | 75 | validation_loss = validation_loss/len(valid_batches) 76 | validation_precision, validation_recall, validation_f1 = compute_score( 77 | validation_predicted_labels, validation_gold_labels, verbose=False) 78 | 79 | if validation_f1 >= best_f1: 80 | best_f1 = validation_f1 81 | torch.save(self, self.model_path) 82 | saved = True 83 | patience_tally = 0 84 | else: 85 | saved = False 86 | patience_tally += 1 87 | 88 | print('Epoch: {}'.format(epoch)) 89 | print('Training loss: {:.3f}'.format(train_loss/len(train_batches))) 90 | print('Validation loss: {:.3f}'.format(validation_loss)) 91 | print('Validation precision: {:.3f}'.format(validation_precision)) 92 | print('Validation recall: {:.3f}'.format(validation_recall)) 93 | print('Validation f1: {:.3f}'.format(validation_f1)) 94 | if saved: 95 | print('Saved') 96 | print('-----------------------------------') 97 | sys.stdout.flush() 98 | 99 | def get_device(self): 100 | """Returns the proper device.""" 101 | if self.torch_device_name == 'gpu': 102 | return torch.device('cuda') 103 | else: 104 | return torch.device('cpu') 105 | 106 | def run_gradient_step(self, batch_data): 107 | """Performs gradient step.""" 108 | self.optimizer.zero_grad() 109 | loss, _ = self.forward(batch_data) 110 | loss.backward() 111 | self.optimizer.step() 112 | return float(loss.cpu()) 113 | 114 | def run_evaluation(self, test_examples, model_name): 115 | """Predicts labels for all comments in the test set and computes evaluation metrics.""" 116 | self.eval() 117 | 118 | test_batches = self.manager.get_batches(test_examples, self.get_device()) 119 | test_predictions = [] 120 | 121 | with torch.no_grad(): 122 | for b, batch in enumerate(test_batches): 123 | print('Testing batch {}/{}'.format(b, len(test_batches))) 124 | sys.stdout.flush() 125 | encoder_outputs = self.manager.get_encoder_output(batch, self.get_device()) 126 | batch_logprobs = self.get_logprobs(encoder_outputs) 127 | test_predictions.extend(batch_logprobs.argmax(dim=-1).tolist()) 128 | 129 | self.compute_metrics(test_predictions, test_examples, model_name) 130 | 131 | def compute_metrics(self, predicted_labels, test_examples, model_name): 132 | """Computes evaluation metrics.""" 133 | gold_labels = [] 134 | correct = 0 135 | for e, ex in enumerate(test_examples): 136 | if ex.label == predicted_labels[e]: 137 | correct += 1 138 | gold_labels.append(ex.label) 139 | 140 | accuracy = float(correct)/len(test_examples) 141 | precision, recall, f1 = compute_score(predicted_labels, gold_labels) 142 | 143 | print('Precision: {}'.format(precision)) 144 | print('Recall: {}'.format(recall)) 145 | print('F1: {}'.format(f1)) 146 | print('Accuracy: {}\n'.format(accuracy)) 147 | 148 | write_file = os.path.join(DETECTION_DIR, '{}_detection.txt'.format(model_name)) 149 | with open(write_file, 'w+') as f: 150 | for e, ex in enumerate(test_examples): 151 | f.write('{} {}\n'.format(ex.id, predicted_labels[e])) 152 | -------------------------------------------------------------------------------- /display_scores.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | sys.path.append('comment_update') 6 | from data_loader import get_data_splits, load_cleaned_test_set 7 | from data_utils import get_processed_comment_str 8 | from detection_evaluation_utils import compute_score 9 | from update_evaluation_utils import write_predictions, compute_accuracy, compute_bleu,\ 10 | compute_meteor, compute_sari, compute_gleu 11 | 12 | """Script for printing update or detection metrics for output, on full and clean test sets.""" 13 | 14 | def load_predicted_detection_labels(filepath, selected_positions): 15 | with open(filepath) as f: 16 | lines = f.readlines() 17 | 18 | selected_labels = [] 19 | for s in selected_positions: 20 | selected_labels.append(int(lines[s].strip().split()[-1])) 21 | return selected_labels 22 | 23 | def load_predicted_generation_sequences(filepath, selected_positions): 24 | with open(filepath) as f: 25 | lines = f.readlines() 26 | 27 | selected_sequences = [] 28 | for s in selected_positions: 29 | selected_sequences.append(lines[s].strip()) 30 | return selected_sequences 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--detection_output_file', help='path to detection output file') 35 | parser.add_argument('--update_output_file', help='path to update output file') 36 | args = parser.parse_args() 37 | 38 | # NOTE: To evaluate the pretrained approach, detection_output_file and 39 | # update_output_file must be both specified. For all other approaches, 40 | # only one should be specified. 41 | 42 | _, _, test_examples, _ = get_data_splits(ignore_ast=True) 43 | positions = list(range(len(test_examples))) 44 | 45 | clean_ids = load_cleaned_test_set() 46 | clean_positions = [] 47 | for e, example in enumerate(test_examples): 48 | if example.id in clean_ids: 49 | clean_positions.append(e) 50 | clean_test_examples = [test_examples[pos] for pos in clean_positions] 51 | 52 | eval_tuples = [(test_examples, positions, 'full'), (clean_test_examples, clean_positions, 'clean')] 53 | 54 | for (examples, indices, test_type) in eval_tuples: 55 | if args.detection_output_file: 56 | predicted_labels = load_predicted_detection_labels(args.detection_output_file, indices) 57 | gold_labels = [ex.label for ex in examples] 58 | 59 | precision, recall, f1 = compute_score(predicted_labels, gold_labels, verbose=False) 60 | 61 | num_correct = 0 62 | for p, p_label in enumerate(predicted_labels): 63 | if p_label == gold_labels[p]: 64 | num_correct += 1 65 | 66 | print('Detection Precision: {}'.format(precision)) 67 | print('Detection Recall: {}'.format(recall)) 68 | print('Detection F1: {}'.format(f1)) 69 | print('Detection Accuracy: {}\n'.format(float(num_correct)/len(predicted_labels))) 70 | 71 | if args.update_output_file: 72 | update_strs = load_predicted_generation_sequences(args.update_output_file, indices) 73 | 74 | references = [] 75 | pred_instances = [] 76 | src_strs = [] 77 | gold_strs = [] 78 | pred_strs = [] 79 | 80 | for i in range(len(examples)): 81 | src_str = get_processed_comment_str(examples[i].old_comment_subtokens) 82 | src_strs.append(src_str) 83 | 84 | gold_str = get_processed_comment_str(examples[i].new_comment_subtokens) 85 | gold_strs.append(gold_str) 86 | references.append([gold_str.split()]) 87 | 88 | if args.detection_output_file and predicted_labels[i] == 0: 89 | pred_instances.append(src_str.split()) 90 | pred_strs.append(src_str) 91 | else: 92 | pred_instances.append(update_strs[i].split()) 93 | pred_strs.append(update_strs[i]) 94 | 95 | prediction_file = os.path.join(os.getcwd(), 'pred.txt') 96 | src_file = os.path.join(os.getcwd(), 'src.txt') 97 | ref_file = os.path.join(os.getcwd(), 'ref.txt') 98 | 99 | write_predictions(pred_strs, prediction_file) 100 | write_predictions(src_strs, src_file) 101 | write_predictions(gold_strs, ref_file) 102 | 103 | predicted_accuracy = compute_accuracy(gold_strs, pred_strs) 104 | predicted_bleu = compute_bleu(references, pred_instances) 105 | predicted_meteor = compute_meteor(references, pred_instances) 106 | predicted_sari = compute_sari(examples, pred_instances) 107 | predicted_gleu = compute_gleu(examples, src_file, ref_file, prediction_file) 108 | 109 | print('Update Accuracy: {}'.format(predicted_accuracy)) 110 | print('Update BLEU: {}'.format(predicted_bleu)) 111 | print('Update Meteor: {}'.format(predicted_meteor)) 112 | print('Update SARI: {}'.format(predicted_sari)) 113 | print('Update GLEU: {}\n'.format(predicted_gleu)) 114 | 115 | print('Test type: {}'.format(test_type)) 116 | print('Detection file: {}'.format(args.detection_output_file)) 117 | print('Update file: {}'.format(args.update_output_file)) 118 | print('Total: {}'.format(len(examples))) 119 | print('--------------------------------------') 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class Encoder(nn.Module): 5 | def __init__(self, embedding_size, hidden_size, num_layers, dropout, bidirectional=True): 6 | super(Encoder, self).__init__() 7 | self.__rnn = nn.GRU(input_size=embedding_size, 8 | hidden_size=hidden_size, 9 | dropout=dropout, 10 | num_layers=num_layers, 11 | batch_first=True, 12 | bidirectional=bidirectional) 13 | 14 | def forward(self, src_embedded_tokens, src_lengths, device): 15 | encoder_hidden_states, _ = self.__rnn.forward(src_embedded_tokens) 16 | encoder_final_state = encoder_hidden_states[torch.arange( 17 | src_embedded_tokens.size()[0], dtype=torch.int64, device=device), src_lengths-1] 18 | # encoder_final_state, _ = torch.max(encoder_hidden_states, dim=1) 19 | return encoder_hidden_states, encoder_final_state -------------------------------------------------------------------------------- /gleu/README.md: -------------------------------------------------------------------------------- 1 | # Ground Truth for Grammatical Error Correction Metrics 2 | 3 | 4 | This repository contains a python implementation of the GLEU metric 5 | (**G**eneral **L**anguage **E**valuation **U**nderstanding), which 6 | can be used for any monolingual "translation" task. It also contains 7 | human rankings of the CoNLL-14 Shared Task system output as well as 8 | scripts to evaluate the rankings to extract an absolute system 9 | ranking. 10 | 11 | These results were described in the ACL 2015 paper: 12 | 13 | > [*Ground Truth for Grammatical Error Correction Metrics*](http://www.aclweb.org/anthology/P/P15/P15-2097.pdf) 14 | by Courtney Napoles, Keisuke Sakaguchi, Joel Tetreault, and Matt Post 15 | 16 | Please cite this work when using this data or the GLEU metric. 17 | 18 | @InProceedings{napoles-EtAl:2015:ACL-IJCNLP, 19 | author = {Napoles, Courtney and Sakaguchi, Keisuke and Post, Matt and Tetreault, Joel}, 20 | title = {Ground Truth for Grammatical Error Correction Metrics}, 21 | booktitle = {Proceedings of the 53rd Annual Meeting of the Association for Computational Linguistics and the 7th International Joint Conference on Natural Language Processing (Volume 2: Short Papers)}, 22 | month = {July}, 23 | year = {2015}, 24 | address = {Beijing, China}, 25 | publisher = {Association for Computational Linguistics}, 26 | pages = {588--593}, 27 | url = {http://www.aclweb.org/anthology/P15-2097} 28 | } 29 | 30 | --- 31 | 32 | # GLEU Update 33 | 34 | As of May 2, 2016, we have identified a problem with the GLEU metric as the number of references increases. 35 | To resolve this issue, we made a minor adjustment to the metric so that it no longer has a tunable weight and is reliable using any number of reference sets. 36 | This update to GLEU is reflected in `scripts/compute_gleu` and `scripts/gleu.py`. 37 | The original GLEU scripts can be found in `scripts/original_gleu/`. 38 | We do not recommend using the original GLEU code. The new GLEU should be used instead. 39 | 40 | The changes to GLEU and updated results to our ACL 2015 paper are described in the eprint, [*GLEU Without Tuning*](http://arxiv.org/abs/1605.02592). 41 | The citation for the updated metric is 42 | 43 | @Article{napoles2016gleu, 44 | author = {Napoles, Courtney and Sakaguchi, Keisuke and Post, Matt and Tetreault, Joel}, 45 | title = {{GLEU} Without Tuning}, 46 | journal = {eprint arXiv:1605.02592 [cs.CL]}, 47 | year = {2016}, 48 | url = {http://arxiv.org/abs/1605.02592} 49 | } 50 | 51 | --- 52 | 53 | ## Instructions 54 | 55 | ### 1. Obtain the raw system output 56 | 57 | The rankings found in the gec-ranking-data correspond to the 12 system outputs 58 | from the CoNLL-14 Shared Task on Grammatical Error Correction, which can be 59 | downloaded from . 60 | 61 | Human judgments are located in `gec-ranking/data`. 62 | 63 | ### 2. Run TrueSkill 64 | 65 | To get the human rankings, run TrueSkill (which can be downloaded from 66 | ) on `all_judgments.csv`, following 67 | the instructions in the TrueSkill readme. 68 | 69 | ### 3. Calculate metric scores 70 | 71 | GLEU is included in `gec-ranking/scripts`. To obtain the GLEU scores for 72 | system output, run the following command: 73 | 74 | ``` 75 | ./compute_gleu -s source_sentences -r reference [reference ...] \ 76 | -o system_output [system_output ...] -n 4 -l 0.0 77 | ``` 78 | 79 | where each file contains one sentence per line. GLEU can be run with multiple 80 | references. To get the GLEU scores of multiple outputs, include the path to 81 | each system output file. GLEU was developed using Python 2.7. 82 | 83 | I-measure scores were taken from Felice and Briscoe's 2015 NAACL paper, 84 | *Towards a standard evaluation method for grammatical error detection and 85 | correction*. The I-measure scorer can be downloaded from 86 | . 87 | 88 | M2 scores were calculated using the official scorer (3.2) of the CoNLL-2014 Shared Task (). 89 | 90 | --- 91 | 92 | ## Errata 93 | 94 | There was an error in the calculation of the GLEU denominator, which was corrected in the 10 March 2016 commit. 95 | 96 | --- 97 | 98 | Please contact Courtney Napoles (courtneyn[at]jhu[dot]edu) or Keisuke Sakaguchi (keisuke[at]cs[dot]jhu[dot]edu) with any questions. 99 | 100 | Last updated 10 May 2016 101 | -------------------------------------------------------------------------------- /gleu/gleu_update_2016.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/panthap2/deep-jit-inconsistency-detection/dacf8513c155f35157eedc2bf630212bf815544c/gleu/gleu_update_2016.pdf -------------------------------------------------------------------------------- /gleu/scripts/compute_gleu: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Courtney Napoles 4 | # 5 | # 21 June 2015 6 | # ## 7 | # compute_gleu 8 | # 9 | # This script calls gleu.py to calculate the GLEU score of a sentence, as 10 | # described in our ACL 2015 paper, Ground Truth for Grammatical Error 11 | # Correction Metrics by Courtney Napoles, Keisuke Sakaguchi, Matt Post, 12 | # and Joel Tetreault. 13 | # 14 | # For instructions on how to get the GLEU score, call "compute_gleu -h" 15 | # 16 | # Updated 2 May 2016: This is an updated version of GLEU that has been 17 | # modified to handle multiple references more fairly. 18 | # 19 | # This script was adapted from compute-bleu by Adam Lopez. 20 | # 21 | 22 | import argparse 23 | import sys 24 | import os 25 | from gleu import GLEU 26 | import scipy.stats 27 | import numpy as np 28 | import random 29 | 30 | def get_gleu_stats(scores) : 31 | mean = np.mean(scores) 32 | std = np.std(scores) 33 | ci = scipy.stats.norm.interval(0.95,loc=mean,scale=std) 34 | return ['%f'%mean, 35 | '%f'%std, 36 | '(%.3f,%.3f)'%(ci[0],ci[1])] 37 | 38 | if __name__ == '__main__' : 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("-r", "--reference", 42 | help="Target language reference sentences. Multiple " 43 | "files for multiple references.", 44 | nargs="*", 45 | dest="reference", 46 | required=True) 47 | parser.add_argument("-s", "--source", 48 | help="Source language source sentences", 49 | dest="source", 50 | required=True) 51 | parser.add_argument("-o", "--hypothesis", 52 | help="Target language hypothesis sentences to evaluate " 53 | "(can be more than one file--the GLEU score of each " 54 | "file will be output separately). Use '-o -' to read " 55 | "hypotheses from stdin.", 56 | nargs="*", 57 | dest="hypothesis", 58 | required=True) 59 | parser.add_argument("-n", 60 | help="Maximum order of ngrams", 61 | type=int, 62 | default=4) 63 | parser.add_argument("-d","--debug", 64 | help="Debug; print sentence-level scores", 65 | default=False, 66 | action="store_true") 67 | parser.add_argument('--iter', 68 | type=int, 69 | default=500, 70 | help='the number of iterations to run') 71 | 72 | args = parser.parse_args() 73 | 74 | num_iterations = args.iter 75 | 76 | # if there is only one reference, just do one iteration 77 | if len(args.reference) == 1 : 78 | num_iterations = 1 79 | 80 | gleu_calculator = GLEU(args.n) 81 | 82 | gleu_calculator.load_sources(args.source) 83 | gleu_calculator.load_references(args.reference) 84 | 85 | for hpath in args.hypothesis : 86 | instream = sys.stdin if hpath == '-' else open(hpath) 87 | hyp = [line.split() for line in instream] 88 | 89 | if not args.debug : 90 | print os.path.basename(hpath), 91 | 92 | # first generate a random list of indices, using a different seed 93 | # for each iteration 94 | indices = [] 95 | for j in range(num_iterations) : 96 | random.seed(j*101) 97 | indices.append([random.randint(0,len(args.reference)-1) 98 | for i in range(len(hyp))]) 99 | 100 | if args.debug : 101 | print 102 | print '===== Sentence-level scores =====' 103 | print 'SID Mean Stdev 95%CI GLEU' 104 | 105 | iter_stats = [ [0 for i in xrange(2*args.n+2)] 106 | for j in range(num_iterations) ] 107 | 108 | for i,h in enumerate(hyp) : 109 | 110 | gleu_calculator.load_hypothesis_sentence(h) 111 | # we are going to store the score of this sentence for each ref 112 | # so we don't have to recalculate them 500 times 113 | 114 | stats_by_ref = [ None for r in range(len(args.reference)) ] 115 | 116 | for j in range(num_iterations) : 117 | ref = indices[j][i] 118 | this_stats = stats_by_ref[ref] 119 | 120 | if this_stats is None : 121 | this_stats = [ s for s in gleu_calculator.gleu_stats( 122 | i,r_ind=ref) ] 123 | stats_by_ref[ref] = this_stats 124 | 125 | iter_stats[j] = [ sum(scores) 126 | for scores in zip(iter_stats[j], this_stats)] 127 | 128 | if args.debug : 129 | # sentence-level GLEU is the mean GLEU of the hypothesis 130 | # compared to each reference 131 | for r in range(len(args.reference)) : 132 | if stats_by_ref[r] is None : 133 | stats_by_ref[r] = [s for s in gleu_calculator.gleu_stats( 134 | i,r_ind=r) ] 135 | 136 | print i, 137 | print ' '.join(get_gleu_stats([gleu_calculator.gleu(stats,smooth=True) 138 | for stats in stats_by_ref])) 139 | 140 | if args.debug : 141 | print '\n==== Overall score =====' 142 | print 'Mean Stdev 95%CI GLEU' 143 | print ' '.join(get_gleu_stats([gleu_calculator.gleu(stats) 144 | for stats in iter_stats ])) 145 | else : 146 | print get_gleu_stats([gleu_calculator.gleu(stats) 147 | for stats in iter_stats ])[0] 148 | 149 | -------------------------------------------------------------------------------- /gleu/scripts/gleu.py: -------------------------------------------------------------------------------- 1 | # Courtney Napoles 2 | # 3 | # 21 June 2015 4 | # ## 5 | # gleu.py 6 | # 7 | # This script calculates the GLEU score of a sentence, as described in 8 | # our ACL 2015 paper, Ground Truth for Grammatical Error Correction Metrics 9 | # by Courtney Napoles, Keisuke Sakaguchi, Matt Post, and Joel Tetreault. 10 | # 11 | # For instructions on how to get the GLEU score, call "compute_gleu -h" 12 | # 13 | # Updated 2 May 2016: This is an updated version of GLEU that has been 14 | # modified to handle multiple references more fairly. 15 | # 16 | # Updated 6 9 2017: Fixed inverse brevity penalty 17 | # 18 | # This script was adapted from bleu.py by Adam Lopez. 19 | # 20 | 21 | import math 22 | from collections import Counter 23 | 24 | class GLEU : 25 | 26 | def __init__(self,n=4) : 27 | self.order = 4 28 | 29 | def load_hypothesis_sentence(self,hypothesis) : 30 | self.hlen = len(hypothesis) 31 | self.this_h_ngrams = [ self.get_ngram_counts(hypothesis,n) 32 | for n in range(1,self.order+1) ] 33 | 34 | def load_sources(self,spath) : 35 | self.all_s_ngrams = [ [ self.get_ngram_counts(line.split(),n) 36 | for n in range(1,self.order+1) ] 37 | for line in open(spath) ] 38 | 39 | def load_references(self,rpaths) : 40 | self.refs = [ [] for i in range(len(self.all_s_ngrams)) ] 41 | self.rlens = [ [] for i in range(len(self.all_s_ngrams)) ] 42 | for rpath in rpaths : 43 | for i,line in enumerate(open(rpath)) : 44 | self.refs[i].append(line.split()) 45 | self.rlens[i].append(len(line.split())) 46 | 47 | # count number of references each n-gram appear sin 48 | self.all_rngrams_freq = [ Counter() for i in range(self.order) ] 49 | 50 | self.all_r_ngrams = [ ] 51 | for refset in self.refs : 52 | all_ngrams = [] 53 | self.all_r_ngrams.append(all_ngrams) 54 | 55 | for n in range(1,self.order+1) : 56 | ngrams = self.get_ngram_counts(refset[0],n) 57 | all_ngrams.append(ngrams) 58 | 59 | for k in ngrams.keys() : 60 | self.all_rngrams_freq[n-1][k]+=1 61 | 62 | for ref in refset[1:] : 63 | new_ngrams = self.get_ngram_counts(ref,n) 64 | for nn in new_ngrams.elements() : 65 | if new_ngrams[nn] > ngrams.get(nn,0) : 66 | ngrams[nn] = new_ngrams[nn] 67 | 68 | def get_ngram_counts(self,sentence,n) : 69 | return Counter([tuple(sentence[i:i+n]) 70 | for i in xrange(len(sentence)+1-n)]) 71 | 72 | # returns ngrams in a but not in b 73 | def get_ngram_diff(self,a,b) : 74 | diff = Counter(a) 75 | for k in (set(a) & set(b)) : 76 | del diff[k] 77 | return diff 78 | 79 | def normalization(self,ngram,n) : 80 | return 1.0*self.all_rngrams_freq[n-1][ngram]/len(self.rlens[0]) 81 | 82 | # Collect BLEU-relevant statistics for a single hypothesis/reference pair. 83 | # Return value is a generator yielding: 84 | # (c, r, numerator1, denominator1, ... numerator4, denominator4) 85 | # Summing the columns across calls to this function on an entire corpus 86 | # will produce a vector of statistics that can be used to compute GLEU 87 | def gleu_stats(self,i,r_ind=None): 88 | 89 | hlen = self.hlen 90 | rlen = self.rlens[i][r_ind] 91 | 92 | yield hlen 93 | yield rlen 94 | 95 | for n in xrange(1,self.order+1): 96 | h_ngrams = self.this_h_ngrams[n-1] 97 | s_ngrams = self.all_s_ngrams[i][n-1] 98 | r_ngrams = self.get_ngram_counts(self.refs[i][r_ind],n) 99 | 100 | s_ngram_diff = self.get_ngram_diff(s_ngrams,r_ngrams) 101 | 102 | yield max([ sum( (h_ngrams & r_ngrams).values() ) - \ 103 | sum( (h_ngrams & s_ngram_diff).values() ), 0 ]) 104 | 105 | yield max([hlen+1-n, 0]) 106 | 107 | # Compute GLEU from collected statistics obtained by call(s) to gleu_stats 108 | def gleu(self,stats,smooth=False): 109 | # smooth 0 counts for sentence-level scores 110 | if smooth : 111 | stats = [ s if s != 0 else 1 for s in stats ] 112 | if len(filter(lambda x: x==0, stats)) > 0: 113 | return 0 114 | (c, r) = stats[:2] 115 | log_gleu_prec = sum([math.log(float(x)/y) 116 | for x,y in zip(stats[2::2],stats[3::2])]) / 4 117 | return math.exp(min([0, 1-float(r)/c]) + log_gleu_prec) 118 | -------------------------------------------------------------------------------- /gleu/scripts/original_gleu/compute_gleu: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Courtney Napoles 4 | # 5 | # 21 June 2015 6 | # ## 7 | # compute_gleu 8 | # 9 | # This script calls gleu.py to calculate the GLEU score of a sentence, as 10 | # described in our ACL 2015 paper, Ground Truth for Grammatical Error 11 | # Correction Metrics by Courtney Napoles, Keisuke Sakaguchi, Matt Post, 12 | # and Joel Tetreault. 13 | # 14 | # For instructions on how to get the GLEU score, call "compute_gleu -h" 15 | # 16 | # This script was adapted from compute-bleu by Adam Lopez. 17 | # 18 | # 19 | # THIS IS AN OLD VERSION OF GLEU. Please see the repository for the correct, 20 | # new version (https://github.com/cnap/gec-ranking) 21 | 22 | import argparse 23 | import sys 24 | import os 25 | from gleu import GLEU 26 | 27 | if __name__ == '__main__' : 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("-r", "--reference", 31 | help="Target language reference sentences. Multiple files for " 32 | " multiple references.", 33 | nargs="*", 34 | dest="reference", 35 | default=["data/dev.ref"]) 36 | parser.add_argument("-s", "--source", 37 | help="Source language source sentences", 38 | dest="source", 39 | default="data/dev.src") 40 | parser.add_argument("-o", "--hypothesis", 41 | help="Target language hypothesis sentences to evaluate (can " 42 | "be more than one file--the GLEU score of each file will be) " 43 | "output separately. Use '-o -' to read hypotheses from stdin.", 44 | nargs="*", 45 | dest="hypothesis", 46 | default=["data/dev.hyp"]) 47 | parser.add_argument("-n", 48 | help="Maximum order of ngrams", 49 | type=int, 50 | default=4) 51 | parser.add_argument("-l", 52 | help="Lambda weight for penalizing incorrectly unchanged n-grams", 53 | nargs='*', 54 | default=[0]) 55 | parser.add_argument("-d","--debug", 56 | help="Debug; print sentence-level scores", 57 | default=False, 58 | action="store_true") 59 | 60 | args = parser.parse_args() 61 | 62 | gleu_calculator = GLEU(args.n,args.l) 63 | 64 | gleu_calculator.load_sources(args.source) 65 | gleu_calculator.load_references(args.reference) 66 | 67 | for hpath in args.hypothesis : 68 | instream = sys.stdin if hpath == '-' else open(hpath) 69 | hyp = [line.split() for line in instream] 70 | 71 | for l in args.l : 72 | l = float(l) 73 | gleu_calculator.set_lambda(l) 74 | print os.path.basename(hpath),l, 75 | 76 | if args.debug : 77 | print 78 | print '===== Sentence-level scores =====' 79 | print 'SID\tGLEU' 80 | 81 | stats = [0 for i in xrange(2*args.n+2)] 82 | for i,h in enumerate(hyp): 83 | this_stats = [s for s in gleu_calculator.gleu_stats(h,i)] 84 | if args.debug : 85 | print '%d\t%f'%(i,gleu_calculator.gleu(this_stats)) 86 | stats = [sum(scores) for scores in zip(stats, this_stats)] 87 | if args.debug : 88 | print '\n==== Overall score =====' 89 | print gleu_calculator.gleu(stats) 90 | -------------------------------------------------------------------------------- /gleu/scripts/original_gleu/gleu.py: -------------------------------------------------------------------------------- 1 | # Courtney Napoles 2 | # 3 | # 21 June 2015 4 | # ## 5 | # gleu.py 6 | # 7 | # This script calculates the GLEU score of a sentence, as described in 8 | # our ACL 2015 paper, Ground Truth for Grammatical Error Correction Metrics 9 | # by Courtney Napoles, Keisuke Sakaguchi, Matt Post, and Joel Tetreault. 10 | # 11 | # For instructions on how to get the GLEU score, call "compute_gleu -h" 12 | # 13 | # This script was adapted from bleu.py by Adam Lopez. 14 | # 15 | # 16 | # THIS IS AN OLD VERSION OF GLEU. Please see the repository for the correct, 17 | # new version (https://github.com/cnap/gec-ranking) 18 | 19 | import math 20 | from collections import Counter 21 | 22 | class GLEU : 23 | 24 | def __init__(self,n=4,l=1) : 25 | self.order = 4 26 | self.weight = l 27 | 28 | def load_sources(self,spath) : 29 | self.all_s_ngrams = [ [ self.get_ngram_counts(line.split(),n) \ 30 | for n in range(1,self.order+1) ] \ 31 | for line in open(spath) ] 32 | 33 | def load_references(self,rpaths) : 34 | refs = [ [] for i in range(len(self.all_s_ngrams)) ] 35 | self.rlens = [ [] for i in range(len(self.all_s_ngrams)) ] 36 | for rpath in rpaths : 37 | for i,line in enumerate(open(rpath)) : 38 | refs[i].append(line.split()) 39 | self.rlens[i].append(len(line.split())) 40 | 41 | self.all_r_ngrams = [ ] 42 | for refset in refs : 43 | all_ngrams = [] 44 | self.all_r_ngrams.append(all_ngrams) 45 | 46 | for n in range(1,self.order+1) : 47 | ngrams = self.get_ngram_counts(refset[0],n) 48 | all_ngrams.append(ngrams) 49 | for ref in refset[1:] : 50 | new_ngrams = self.get_ngram_counts(ref,n) 51 | for nn in new_ngrams.elements() : 52 | if new_ngrams[nn] > ngrams.get(nn,0) : 53 | ngrams[nn] = new_ngrams[nn] 54 | 55 | 56 | def get_ngram_counts(self,sentence,n) : 57 | return Counter([tuple(sentence[i:i+n]) for i in xrange(len(sentence)+1-n)]) 58 | 59 | def set_lambda(self,l) : 60 | self.weight = l 61 | 62 | # Collect BLEU-relevant statistics for a single hypothesis/reference pair. 63 | # Return value is a generator yielding: 64 | # (c, r, numerator1, denominator1, ... numerator4, denominator4) 65 | # Summing the columns across calls to this function on an entire corpus will 66 | # produce a vector of statistics that can be used to compute BLEU or GLEU 67 | def gleu_stats(self,hypothesis, i): 68 | 69 | hlen=len(hypothesis) 70 | rlen = self.rlens[i][0] 71 | 72 | # set the reference length to be the reference length closest to the hyp length 73 | for r in self.rlens[i][1:] : 74 | if abs(r - hlen) < abs(rlen - hlen) : 75 | rlen = r 76 | 77 | yield rlen 78 | yield hlen 79 | 80 | for n in xrange(1,self.order+1): 81 | h_ngrams = self.get_ngram_counts(hypothesis,n) 82 | s_ngrams = self.all_s_ngrams[i][n-1] 83 | r_ngrams = self.all_r_ngrams[i][n-1] 84 | 85 | r_ngram_diff = r_ngrams - s_ngrams 86 | # some n-grams may appear in both sets but have a higher count in the subtracted 87 | # one so these n-grams should be deleted so a single occurrence of one of those 88 | # n-grams doesn't penalize the precision 89 | for k in r_ngram_diff.keys() : 90 | if k in s_ngrams : 91 | del r_ngram_diff[k] 92 | s_ngram_diff = s_ngrams - r_ngrams 93 | for k in s_ngram_diff.keys() : 94 | if k in r_ngrams : 95 | del s_ngram_diff[k] 96 | 97 | yield sum( (h_ngrams & r_ngram_diff).values() ) + \ 98 | max([ sum( (h_ngrams & r_ngrams).values() ) - \ 99 | self.weight * sum( (h_ngrams & s_ngram_diff).values() ), 0 ]) 100 | 101 | yield sum( (h_ngrams & r_ngram_diff).values() ) + max([hlen+1-n, 0]) 102 | 103 | ## here is the original, erroneous way to calculate the denominator 104 | #yield max([sum(r_ngram_diff.values()), 0]) + max([hlen+1-n, 0]) 105 | 106 | # Compute GLEU from collected statistics obtained by call(s) to gleu_stats 107 | def gleu(self,stats): 108 | if len(filter(lambda x: x==0, stats)) > 0: 109 | return 0 110 | (c, r) = stats[:2] 111 | log_gleu_prec = sum([math.log(float(x)/y) for x,y in zip(stats[2::2],stats[3::2])]) / 4. 112 | 113 | return math.exp(min([0, 1-float(r)/c]) + log_gleu_prec) 114 | -------------------------------------------------------------------------------- /gnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.utils 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 8 | 9 | from typing import List, Tuple, Dict, Sequence, Any 10 | 11 | # https://github.com/pcyin/pytorch-gated-graph-neural-network/blob/master/gnn.py 12 | 13 | class AdjacencyList: 14 | """represent the topology of a graph""" 15 | def __init__(self, node_num: int, adj_list: List, device: torch.device): 16 | self.node_num = node_num 17 | self.data = torch.tensor(adj_list, dtype=torch.long, device=device) 18 | self.edge_num = len(adj_list) 19 | 20 | @property 21 | def device(self): 22 | return self.data.device 23 | 24 | def __getitem__(self, item): 25 | return self.data[item] 26 | 27 | 28 | class GatedGraphNeuralNetwork(nn.Module): 29 | def __init__(self, hidden_size, num_edge_types, layer_timesteps, 30 | residual_connections, 31 | state_to_message_dropout=0.3, 32 | rnn_dropout=0.3, 33 | use_bias_for_message_linear=True): 34 | 35 | super(GatedGraphNeuralNetwork, self).__init__() 36 | 37 | self.hidden_size = hidden_size 38 | self.num_edge_types = num_edge_types 39 | self.layer_timesteps = layer_timesteps 40 | self.residual_connections = residual_connections 41 | self.state_to_message_dropout = state_to_message_dropout 42 | self.rnn_dropout = rnn_dropout 43 | self.use_bias_for_message_linear = use_bias_for_message_linear 44 | 45 | # Prepare linear transformations from node states to messages, for each layer and each edge type 46 | # Prepare rnn cells for each layer 47 | self.state_to_message_linears = [] 48 | self.rnn_cells = [] 49 | for layer_idx in range(len(self.layer_timesteps)): 50 | state_to_msg_linears_cur_layer = [] 51 | # Initiate a linear transformation for each edge type 52 | for edge_type_j in range(self.num_edge_types): 53 | # TODO: glorot_init? 54 | state_to_msg_linear_layer_i_type_j = nn.Linear(self.hidden_size, self.hidden_size, bias=use_bias_for_message_linear) 55 | setattr(self, 56 | 'state_to_message_linear_layer%d_type%d' % (layer_idx, edge_type_j), 57 | state_to_msg_linear_layer_i_type_j) 58 | 59 | state_to_msg_linears_cur_layer.append(state_to_msg_linear_layer_i_type_j) 60 | self.state_to_message_linears.append(state_to_msg_linears_cur_layer) 61 | 62 | layer_residual_connections = self.residual_connections.get(layer_idx, []) 63 | rnn_cell_layer_i = nn.GRUCell(self.hidden_size * (1 + len(layer_residual_connections)), self.hidden_size) 64 | setattr(self, 'rnn_cell_layer%d' % layer_idx, rnn_cell_layer_i) 65 | self.rnn_cells.append(rnn_cell_layer_i) 66 | 67 | self.state_to_message_dropout_layer = nn.Dropout(self.state_to_message_dropout) 68 | self.rnn_dropout_layer = nn.Dropout(self.rnn_dropout) 69 | 70 | @property 71 | def device(self): 72 | return self.rnn_cells[0].weight_hh.device 73 | 74 | def forward(self, 75 | initial_node_representation: Variable, 76 | adjacency_lists: List[AdjacencyList], 77 | return_all_states=False) -> Variable: 78 | return self.compute_node_representations(initial_node_representation, adjacency_lists, 79 | return_all_states=return_all_states) 80 | 81 | def compute_node_representations(self, 82 | initial_node_representation: Variable, 83 | adjacency_lists: List[AdjacencyList], 84 | return_all_states=False) -> Variable: 85 | # If the dimension of initial node embedding is smaller, then perform padding first 86 | # one entry per layer (final state of that layer), shape: number of nodes in batch v x D 87 | init_node_repr_size = initial_node_representation.size(1) 88 | device = adjacency_lists[0].data.device 89 | if init_node_repr_size < self.hidden_size: 90 | pad_size = self.hidden_size - init_node_repr_size 91 | zero_pads = torch.zeros(initial_node_representation.size(0), pad_size, dtype=torch.float, device=device) 92 | initial_node_representation = torch.cat([initial_node_representation, zero_pads], dim=-1) 93 | node_states_per_layer = [initial_node_representation] 94 | 95 | node_num = initial_node_representation.size(0) 96 | 97 | message_targets = [] # list of tensors of message targets of shape [E] 98 | for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists): 99 | if adjacency_list_for_edge_type.edge_num > 0: 100 | edge_targets = adjacency_list_for_edge_type[:, 1] 101 | message_targets.append(edge_targets) 102 | message_targets = torch.cat(message_targets, dim=0) # Shape [M] 103 | 104 | # sparse matrix of shape [V, M] 105 | # incoming_msg_sparse_matrix = self.get_incoming_message_sparse_matrix(adjacency_lists).to(device) 106 | for layer_idx, num_timesteps in enumerate(self.layer_timesteps): 107 | # Used shape abbreviations: 108 | # V ~ number of nodes 109 | # D ~ state dimension 110 | # E ~ number of edges of current type 111 | # M ~ number of messages (sum of all E) 112 | 113 | # Extract residual messages, if any: 114 | layer_residual_connections = self.residual_connections.get(layer_idx, []) 115 | # List[(V, D)] 116 | layer_residual_states: List[torch.FloatTensor] = [node_states_per_layer[residual_layer_idx] 117 | for residual_layer_idx in layer_residual_connections] 118 | 119 | # Record new states for this layer. Initialised to last state, but will be updated below: 120 | node_states_for_this_layer = node_states_per_layer[-1] 121 | # For each message propagation step 122 | for t in range(num_timesteps): 123 | messages: List[torch.FloatTensor] = [] # list of tensors of messages of shape [E, D] 124 | message_source_states: List[torch.FloatTensor] = [] # list of tensors of edge source states of shape [E, D] 125 | 126 | # Collect incoming messages per edge type 127 | for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists): 128 | if adjacency_list_for_edge_type.edge_num > 0: 129 | # shape [E] 130 | edge_sources = adjacency_list_for_edge_type[:, 0] 131 | # shape [E, D] 132 | edge_source_states = node_states_for_this_layer[edge_sources] 133 | 134 | f_state_to_message = self.state_to_message_linears[layer_idx][edge_type_idx] 135 | # Shape [E, D] 136 | all_messages_for_edge_type = self.state_to_message_dropout_layer(f_state_to_message(edge_source_states)) 137 | 138 | messages.append(all_messages_for_edge_type) 139 | message_source_states.append(edge_source_states) 140 | 141 | # shape [M, D] 142 | messages: torch.FloatTensor = torch.cat(messages, dim=0) 143 | 144 | # Sum up messages that go to the same target node 145 | # shape [V, D] 146 | incoming_messages = torch.zeros(node_num, messages.size(1), device=device) 147 | incoming_messages = incoming_messages.scatter_add_(0, 148 | message_targets.unsqueeze(-1).expand_as(messages), 149 | messages) 150 | 151 | # shape [V, D * (1 + num of residual connections)] 152 | incoming_information = torch.cat(layer_residual_states + [incoming_messages], dim=-1) 153 | 154 | # pass updated vertex features into RNN cell 155 | # Shape [V, D] 156 | updated_node_states = self.rnn_cells[layer_idx](incoming_information, node_states_for_this_layer) 157 | updated_node_states = self.rnn_dropout_layer(updated_node_states) 158 | node_states_for_this_layer = updated_node_states 159 | 160 | node_states_per_layer.append(node_states_for_this_layer) 161 | 162 | if return_all_states: 163 | return node_states_per_layer[1:] 164 | else: 165 | node_states_for_last_layer = node_states_per_layer[-1] 166 | return node_states_for_last_layer 167 | 168 | 169 | def main(): 170 | gnn = GatedGraphNeuralNetwork(hidden_size=64, num_edge_types=2, 171 | layer_timesteps=[3, 5, 7, 2], residual_connections={2: [0], 3: [0, 1]}) 172 | 173 | adj_list_type1 = AdjacencyList(node_num=4, adj_list=[(0, 2), (2, 1), (1, 3)], device=gnn.device) 174 | adj_list_type2 = AdjacencyList(node_num=4, adj_list=[(0, 0), (0, 1)], device=gnn.device) 175 | 176 | node_representations = gnn.compute_node_representations(initial_node_representation=torch.randn(4, 64), 177 | adjacency_lists=[adj_list_type1, adj_list_type2]) 178 | 179 | print(node_representations) 180 | 181 | 182 | if __name__ == '__main__': 183 | main() -------------------------------------------------------------------------------- /module_manager.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import Counter 3 | import numpy as np 4 | import os 5 | import random 6 | import sys 7 | import torch 8 | from torch import nn 9 | 10 | from dpu_utils.mlutils import Vocabulary 11 | 12 | from ast_graph_encoder import ASTGraphEncoder 13 | from constants import * 14 | from data_utils import * 15 | import diff_utils 16 | from embedding_store import EmbeddingStore 17 | from encoder import Encoder 18 | from external_cache import get_code_features, get_nl_features, get_num_code_features, get_num_nl_features 19 | from tensor_utils import * 20 | 21 | 22 | class ModuleManager(nn.Module): 23 | """Utility class which helps manage related attributes of the update and detection tasks.""" 24 | def __init__(self, attend_code_sequence_states, attend_code_graph_states, features, posthoc, task): 25 | super(ModuleManager, self).__init__() 26 | self.attend_code_sequence_states = attend_code_sequence_states 27 | self.attend_code_graph_states = attend_code_graph_states 28 | self.features = features 29 | self.posthoc = posthoc 30 | self.task = task 31 | 32 | self.num_encoders = 0 33 | self.num_seq_encoders = 0 34 | self.out_dim = 0 35 | self.attention_state_size = 0 36 | self.update_encoder_state_size = 0 37 | self.max_ast_length = 0 38 | self.max_code_length = 0 39 | self.max_nl_length = 0 40 | self.generate = task in ['update', 'dual'] 41 | self.classify = task in ['detect', 'dual'] 42 | 43 | self.encode_code_sequence = self.generate or self.attend_code_sequence_states 44 | 45 | print('Attend code sequence states: {}'.format(self.attend_code_sequence_states)) 46 | print('Attend code graph states: {}'.format(self.attend_code_graph_states)) 47 | print('Features: {}'.format(self.features)) 48 | print('Task: {}'.format(self.task)) 49 | sys.stdout.flush() 50 | 51 | def get_code_representation(self, ex, data_type): 52 | if self.posthoc: 53 | if data_type == 'sequence': 54 | return ex.new_code_subtokens 55 | else: 56 | return ex.new_ast 57 | else: 58 | if data_type == 'sequence': 59 | return ex.span_diff_code_subtokens 60 | else: 61 | return ex.diff_ast 62 | 63 | def initialize(self, train_data): 64 | """Initializes model parameters from pre-defined hyperparameters and other hyperparameters 65 | that are computed based on statistics over the training data.""" 66 | nl_lengths = [] 67 | code_lengths = [] 68 | ast_lengths = [] 69 | 70 | nl_token_counter = Counter() 71 | code_token_counter = Counter() 72 | 73 | for ex in train_data: 74 | if self.generate: 75 | trg_sequence = [START] + ex.span_minimal_diff_comment_subtokens + [END] 76 | nl_token_counter.update(trg_sequence) 77 | nl_lengths.append(len(trg_sequence)) 78 | 79 | old_nl_sequence = ex.old_comment_subtokens 80 | nl_token_counter.update(old_nl_sequence) 81 | nl_lengths.append(len(old_nl_sequence)) 82 | 83 | if self.encode_code_sequence: 84 | code_sequence = self.get_code_representation(ex, 'sequence') 85 | code_token_counter.update(code_sequence) 86 | code_lengths.append(len(code_sequence)) 87 | 88 | if self.attend_code_graph_states: 89 | code_sequence = [n.value for n in self.get_code_representation(ex, 'graph').nodes] 90 | code_token_counter.update(code_sequence) 91 | ast_lengths.append(len(code_sequence)) 92 | 93 | self.max_nl_length = int(np.percentile(np.asarray(sorted(nl_lengths)), 94 | LENGTH_CUTOFF_PCT)) 95 | self.max_vocab_extension = self.max_nl_length 96 | 97 | if self.encode_code_sequence: 98 | self.max_code_length = int(np.percentile(np.asarray(sorted(code_lengths)), 99 | LENGTH_CUTOFF_PCT)) 100 | self.max_vocab_extension += self.max_code_length 101 | 102 | if self.attend_code_graph_states: 103 | self.max_ast_length = int(np.percentile(np.asarray(sorted(ast_lengths)), 104 | LENGTH_CUTOFF_PCT)) 105 | 106 | nl_counts = np.asarray(sorted(nl_token_counter.values())) 107 | nl_threshold = int(np.percentile(nl_counts, VOCAB_CUTOFF_PCT)) + 1 108 | code_counts = np.asarray(sorted(code_token_counter.values())) 109 | code_threshold = int(np.percentile(nl_counts, VOCAB_CUTOFF_PCT)) + 1 110 | 111 | self.embedding_store = EmbeddingStore(nl_threshold, NL_EMBEDDING_SIZE, nl_token_counter, 112 | code_threshold, CODE_EMBEDDING_SIZE, code_token_counter, 113 | DROPOUT_RATE, len(SrcType), SRC_EMBEDDING_SIZE, CODE_EMBEDDING_SIZE, True) 114 | 115 | self.out_dim = 2*HIDDEN_SIZE 116 | 117 | # Accounting for the old NL encoder 118 | self.num_encoders = 1 119 | self.num_seq_encoders += 1 120 | self.attention_state_size += 2*HIDDEN_SIZE 121 | self.nl_encoder = Encoder(NL_EMBEDDING_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT_RATE) 122 | self.nl_attention_transform_matrix = nn.Parameter(torch.randn( 123 | self.out_dim, self.out_dim, dtype=torch.float, requires_grad=True)) 124 | self.self_attention = nn.MultiheadAttention(self.out_dim, MULTI_HEADS, DROPOUT_RATE) 125 | 126 | if self.encode_code_sequence: 127 | self.sequence_code_encoder = Encoder(CODE_EMBEDDING_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT_RATE) 128 | self.num_encoders += 1 129 | self.num_seq_encoders += 1 130 | 131 | if self.attend_code_sequence_states: 132 | self.attention_state_size += 2*HIDDEN_SIZE 133 | self.sequence_attention_transform_matrix = nn.Parameter(torch.randn( 134 | self.out_dim, self.out_dim, dtype=torch.float, requires_grad=True)) 135 | self.code_sequence_multihead_attention = nn.MultiheadAttention(self.out_dim, MULTI_HEADS, DROPOUT_RATE) 136 | 137 | if self.attend_code_graph_states: 138 | self.graph_code_encoder = ASTGraphEncoder(CODE_EMBEDDING_SIZE, len(DiffEdgeType)) 139 | self.num_encoders += 1 140 | self.attention_state_size += 2*HIDDEN_SIZE 141 | self.graph_attention_transform_matrix = nn.Parameter(torch.randn( 142 | CODE_EMBEDDING_SIZE, self.out_dim, dtype=torch.float, requires_grad=True)) 143 | self.graph_multihead_attention = nn.MultiheadAttention(self.out_dim, MULTI_HEADS, DROPOUT_RATE) 144 | 145 | if self.features: 146 | self.code_features_to_embedding = nn.Linear(CODE_EMBEDDING_SIZE + get_num_code_features(), 147 | CODE_EMBEDDING_SIZE, bias=False) 148 | self.nl_features_to_embedding = nn.Linear( 149 | NL_EMBEDDING_SIZE + get_num_nl_features(), 150 | NL_EMBEDDING_SIZE, bias=False) 151 | 152 | if self.generate: 153 | self.update_encoder_state_size = self.num_seq_encoders*self.out_dim 154 | self.encoder_final_to_decoder_initial = nn.Parameter(torch.randn(self.update_encoder_state_size, 155 | self.out_dim, dtype=torch.float, requires_grad=True)) 156 | 157 | if self.classify: 158 | self.attended_nl_encoder = Encoder(self.out_dim, HIDDEN_SIZE, NUM_LAYERS, DROPOUT_RATE) 159 | self.attended_nl_encoder_output_layer = nn.Linear(self.attention_state_size, self.out_dim, bias=False) 160 | 161 | def get_batches(self, dataset, device, shuffle=False): 162 | """Divides the dataset into batches based on pre-defined BATCH_SIZE hyperparameter. 163 | Each batch is tensorized so that it can be directly passed into the network.""" 164 | batches = [] 165 | if shuffle: 166 | random.shuffle(dataset) 167 | 168 | curr_idx = 0 169 | while curr_idx < len(dataset): 170 | start_idx = curr_idx 171 | end_idx = min(start_idx + BATCH_SIZE, len(dataset)) 172 | 173 | code_token_ids = [] 174 | code_lengths = [] 175 | old_nl_token_ids = [] 176 | old_nl_lengths = [] 177 | trg_token_ids = [] 178 | trg_extended_token_ids = [] 179 | trg_lengths = [] 180 | invalid_copy_positions = [] 181 | inp_str_reps = [] 182 | inp_ids = [] 183 | code_features = [] 184 | nl_features = [] 185 | labels = [] 186 | 187 | graph_batch = initialize_graph_method_batch(len(DiffEdgeType)) 188 | 189 | for i in range(start_idx, end_idx): 190 | if self.encode_code_sequence: 191 | code_sequence = self.get_code_representation(dataset[i], 'sequence') 192 | code_sequence_ids = self.embedding_store.get_padded_code_ids( 193 | code_sequence, self.max_code_length) 194 | code_length = min(len(code_sequence), self.max_code_length) 195 | code_token_ids.append(code_sequence_ids) 196 | code_lengths.append(code_length) 197 | 198 | if self.attend_code_graph_states: 199 | ast = self.get_code_representation(dataset[i], 'graph') 200 | ast_sequence = [n.value for n in ast.nodes] 201 | ast_length = min(len(ast_sequence), self.max_ast_length) 202 | ast.nodes = ast.nodes[:ast_length] 203 | graph_batch = insert_graph(graph_batch, dataset[i], ast, 204 | self.embedding_store.code_vocabulary, self.features, self.max_ast_length) 205 | 206 | old_nl_sequence = dataset[i].old_comment_subtokens 207 | old_nl_length = min(len(old_nl_sequence), self.max_nl_length) 208 | old_nl_sequence_ids = self.embedding_store.get_padded_nl_ids( 209 | old_nl_sequence, self.max_nl_length) 210 | 211 | old_nl_token_ids.append(old_nl_sequence_ids) 212 | old_nl_lengths.append(old_nl_length) 213 | 214 | if self.generate: 215 | ex_inp_str_reps = [] 216 | ex_inp_ids = [] 217 | 218 | extra_counter = len(self.embedding_store.nl_vocabulary) 219 | max_limit = len(self.embedding_store.nl_vocabulary) + self.max_vocab_extension 220 | out_ids = set() 221 | 222 | copy_inputs = [] 223 | copy_inputs += code_sequence[:code_length] 224 | 225 | copy_inputs += old_nl_sequence[:old_nl_length] 226 | for c in copy_inputs: 227 | nl_id = self.embedding_store.get_nl_id(c) 228 | if self.embedding_store.is_nl_unk(nl_id) and extra_counter < max_limit: 229 | if c in ex_inp_str_reps: 230 | nl_id = ex_inp_ids[ex_inp_str_reps.index(c)] 231 | else: 232 | nl_id = extra_counter 233 | extra_counter += 1 234 | 235 | out_ids.add(nl_id) 236 | ex_inp_str_reps.append(c) 237 | ex_inp_ids.append(nl_id) 238 | 239 | trg_sequence = trg_sequence = [START] + dataset[i].span_minimal_diff_comment_subtokens + [END] 240 | trg_sequence_ids = self.embedding_store.get_padded_nl_ids( 241 | trg_sequence, self.max_nl_length) 242 | trg_extended_sequence_ids = self.embedding_store.get_extended_padded_nl_ids( 243 | trg_sequence, self.max_nl_length, ex_inp_ids, ex_inp_str_reps) 244 | 245 | trg_token_ids.append(trg_sequence_ids) 246 | trg_extended_token_ids.append(trg_extended_sequence_ids) 247 | trg_lengths.append(min(len(trg_sequence), self.max_nl_length)) 248 | inp_str_reps.append(ex_inp_str_reps) 249 | inp_ids.append(self.embedding_store.pad_length(ex_inp_ids, self.max_vocab_extension)) 250 | 251 | invalid_copy_positions.append(get_invalid_copy_locations(ex_inp_str_reps, self.max_vocab_extension, 252 | trg_sequence, self.max_nl_length)) 253 | 254 | labels.append(dataset[i].label) 255 | 256 | if self.features: 257 | if self.encode_code_sequence: 258 | code_features.append(get_code_features(code_sequence, dataset[i], self.max_code_length)) 259 | nl_features.append(get_nl_features(old_nl_sequence, dataset[i], self.max_nl_length)) 260 | 261 | batches.append(UpdateBatchData(torch.tensor(code_token_ids, dtype=torch.int64, device=device), 262 | torch.tensor(code_lengths, dtype=torch.int64, device=device), 263 | torch.tensor(old_nl_token_ids, dtype=torch.int64, device=device), 264 | torch.tensor(old_nl_lengths, dtype=torch.int64, device=device), 265 | torch.tensor(trg_token_ids, dtype=torch.int64, device=device), 266 | torch.tensor(trg_extended_token_ids, dtype=torch.int64, device=device), 267 | torch.tensor(trg_lengths, dtype=torch.int64, device=device), 268 | torch.tensor(invalid_copy_positions, dtype=torch.uint8, device=device), 269 | inp_str_reps, 270 | torch.tensor(inp_ids, dtype=torch.int64, device=device), 271 | torch.tensor(code_features, dtype=torch.float32, device=device), 272 | torch.tensor(nl_features, dtype=torch.float32, device=device), 273 | torch.tensor(labels, dtype=torch.int64, device=device), 274 | tensorize_graph_method_batch(graph_batch, device, self.max_ast_length))) 275 | curr_idx = end_idx 276 | return batches 277 | 278 | def get_encoder_output(self, batch_data, device): 279 | """Gets hidden states, final state, and a length masks corresponding to each encoder.""" 280 | encoder_hidden_states = None 281 | input_lengths = None 282 | final_states = None 283 | mask = None 284 | 285 | # Encode old NL 286 | old_nl_embedded_subtokens = self.embedding_store.get_nl_embeddings(batch_data.old_nl_ids) 287 | if self.features: 288 | old_nl_embedded_subtokens = self.nl_features_to_embedding(torch.cat( 289 | [old_nl_embedded_subtokens, batch_data.nl_features], dim=-1)) 290 | old_nl_hidden_states, old_nl_final_state = self.nl_encoder.forward(old_nl_embedded_subtokens, 291 | batch_data.old_nl_lengths, device) 292 | old_nl_masks = (torch.arange( 293 | old_nl_hidden_states.shape[1], device=device).view(1, -1) >= batch_data.old_nl_lengths.view(-1, 1)).unsqueeze(1) 294 | attention_states = compute_attention_states(old_nl_hidden_states, old_nl_masks, 295 | old_nl_hidden_states, transformation_matrix=self.nl_attention_transform_matrix, multihead_attention=self.self_attention) 296 | 297 | # Encode code 298 | code_hidden_states = None 299 | code_masks = None 300 | code_final_state = None 301 | 302 | if self.encode_code_sequence: 303 | code_embedded_subtokens = self.embedding_store.get_code_embeddings(batch_data.code_ids) 304 | if self.features: 305 | code_embedded_subtokens = self.code_features_to_embedding(torch.cat( 306 | [code_embedded_subtokens, batch_data.code_features], dim=-1)) 307 | code_hidden_states, code_final_state = self.sequence_code_encoder.forward(code_embedded_subtokens, 308 | batch_data.code_lengths, device) 309 | code_masks = (torch.arange( 310 | code_hidden_states.shape[1], device=device).view(1, -1) >= batch_data.code_lengths.view(-1, 1)).unsqueeze(1) 311 | encoder_hidden_states = code_hidden_states 312 | input_lengths = batch_data.code_lengths 313 | final_states = code_final_state 314 | 315 | if self.attend_code_sequence_states: 316 | attention_states = torch.cat([attention_states, compute_attention_states( 317 | code_hidden_states, code_masks, old_nl_hidden_states, 318 | transformation_matrix=self.sequence_attention_transform_matrix, 319 | multihead_attention=self.code_sequence_multihead_attention)], dim=-1) 320 | 321 | if self.attend_code_graph_states: 322 | embedded_nodes = self.embedding_store.get_node_embeddings( 323 | batch_data.graph_batch.value_lookup_ids, batch_data.graph_batch.src_type_ids) 324 | 325 | if self.features: 326 | embedded_nodes = self.code_features_to_embedding(torch.cat( 327 | [embedded_nodes, batch_data.graph_batch.node_features], dim=-1)) 328 | 329 | graph_states = self.graph_code_encoder.forward(embedded_nodes, batch_data.graph_batch, device) 330 | graph_lengths = batch_data.graph_batch.num_nodes_per_graph 331 | graph_masks = (torch.arange( 332 | graph_states.shape[1], device=device).view(1, -1) >= graph_lengths.view(-1, 1)).unsqueeze(1) 333 | 334 | transformed_graph_states = torch.einsum('ijk,km->ijm', graph_states, self.graph_attention_transform_matrix) 335 | graph_attention_states = compute_attention_states(transformed_graph_states, graph_masks, 336 | old_nl_hidden_states, multihead_attention=self.graph_multihead_attention) 337 | attention_states = torch.cat([attention_states, graph_attention_states], dim=-1) 338 | 339 | if self.classify: 340 | nl_attended_states = torch.tanh(self.attended_nl_encoder_output_layer(attention_states)) 341 | _, attended_old_nl_final_state = self.attended_nl_encoder.forward(nl_attended_states, 342 | batch_data.old_nl_lengths, device) 343 | else: 344 | attended_old_nl_final_state = None 345 | 346 | if self.generate: 347 | encoder_final_state = torch.einsum('ij,jk->ik', 348 | torch.cat([final_states, old_nl_final_state], dim=-1), 349 | self.encoder_final_to_decoder_initial) 350 | encoder_hidden_states, input_lengths = merge_encoder_outputs(encoder_hidden_states, 351 | input_lengths, old_nl_hidden_states, batch_data.old_nl_lengths, device) 352 | mask = (torch.arange( 353 | encoder_hidden_states.shape[1], device=device).view(1, -1) >= input_lengths.view(-1, 1)).unsqueeze(1) 354 | else: 355 | encoder_final_state = None 356 | 357 | return EncoderOutputs(encoder_hidden_states, mask, encoder_final_state, code_hidden_states, code_masks, 358 | old_nl_hidden_states, old_nl_masks, old_nl_final_state, attended_old_nl_final_state) -------------------------------------------------------------------------------- /run_comment_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | import os 4 | import sys 5 | import torch 6 | 7 | sys.path.append('comment_update') 8 | from comment_generation import CommentGenerationModel 9 | from update_module import UpdateModule 10 | from detection_module import DetectionModule 11 | from data_loader import get_data_splits 12 | from module_manager import ModuleManager 13 | 14 | def build_model(task, model_path, manager): 15 | """ Builds the appropriate model, with task-specific modules.""" 16 | if task == 'dual': 17 | detection_module = DetectionModule(None, manager) 18 | model = UpdateModule(model_path, manager, detection_module) 19 | elif 'update' in task: 20 | model = UpdateModule(model_path, manager, None) 21 | else: 22 | model = DetectionModule(model_path, manager) 23 | 24 | return model 25 | 26 | def load_model(model_path, evaluate_detection=False): 27 | """Loads a pretrained model from model_path.""" 28 | print('Loading model from: {}'.format(model_path)) 29 | sys.stdout.flush() 30 | if torch.cuda.is_available() and evaluate_detection: 31 | model = torch.load(model_path) 32 | model.torch_device_name = 'gpu' 33 | model.cuda() 34 | for c in model.children(): 35 | c.cuda() 36 | else: 37 | model = torch.load(model_path, map_location='cpu') 38 | model.torch_device_name = 'cpu' 39 | model.cpu() 40 | for c in model.children(): 41 | c.cpu() 42 | return model 43 | 44 | def train(model, train_examples, valid_examples): 45 | """Trains a model.""" 46 | print('Training with {} examples (validation {})'.format(len(train_examples), len(valid_examples))) 47 | sys.stdout.flush() 48 | if torch.cuda.is_available(): 49 | model.torch_device_name = 'gpu' 50 | model.cuda() 51 | for c in model.children(): 52 | c.cuda() 53 | else: 54 | model.torch_device_name = 'cpu' 55 | model.cpu() 56 | for c in model.children(): 57 | c.cpu() 58 | 59 | model.run_train(train_examples, valid_examples) 60 | 61 | def evaluate(task, model, test_examples, model_name, rerank): 62 | """Runs evaluation over a given model.""" 63 | print('Evaluating {} examples'.format(len(test_examples))) 64 | sys.stdout.flush() 65 | if task == 'detect': 66 | model.run_evaluation(test_examples, model_name) 67 | else: 68 | model.run_evaluation(test_examples, rerank, model_name) 69 | 70 | if __name__ == "__main__": 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--task', help='detect, update, or dual') 73 | parser.add_argument('--attend_code_sequence_states', action='store_true', help='attend to sequence-based code hidden states for detection') 74 | parser.add_argument('--attend_code_graph_states', action='store_true', help='attend to graph-based code hidden states for detection') 75 | parser.add_argument('--features', action='store_true', help='concatenate lexical and linguistic feats to code/comment input embeddings') 76 | parser.add_argument('--posthoc', action='store_true', help='whether to run in posthoc mode where old code is not available') 77 | parser.add_argument('--positive_only', action='store_true', help='whether to train on only inconsistent examples') 78 | parser.add_argument('--test_mode', action='store_true', help='whether to run evaluation') 79 | parser.add_argument('--rerank', action='store_true', help='whether to use reranking in the update module (if task is update or dual)') 80 | parser.add_argument('--model_path', help='path to save model (training) or path to saved model (evaluation)') 81 | parser.add_argument('--model_name', help='name of model (used to save model output)') 82 | args = parser.parse_args() 83 | 84 | train_examples, valid_examples, test_examples, high_level_details = get_data_splits() 85 | if args.positive_only: 86 | train_examples = [ex for ex in train_examples if ex.label == 1] 87 | valid_examples = [ex for ex in valid_examples if ex.label == 1] 88 | 89 | print('Train: {}'.format(len(train_examples))) 90 | print('Valid: {}'.format(len(valid_examples))) 91 | print('Test: {}'.format(len(test_examples))) 92 | 93 | if args.task == 'detect' and (not args.attend_code_sequence_states and not args.attend_code_graph_states): 94 | raise ValueError('Please specify attention states for detection') 95 | if args.posthoc and (args.task != 'detect' or args.features): 96 | # Features and update rely on code changes 97 | raise ValueError('Posthoc setting not supported for given arguments') 98 | 99 | if args.test_mode: 100 | print('Starting evaluation: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S"))) 101 | 102 | model = load_model(args.model_path, args.task =='detect') 103 | evaluate(args.task, model, test_examples, args.model_name, args.rerank) 104 | 105 | print('Terminating evaluation: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S"))) 106 | else: 107 | print('Starting training: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S"))) 108 | 109 | manager = ModuleManager(args.attend_code_sequence_states, args.attend_code_graph_states, args.features, args.posthoc, args.task) 110 | manager.initialize(train_examples) 111 | model = build_model(args.task, args.model_path, manager) 112 | 113 | print('Model path: {}'.format(args.model_path)) 114 | sys.stdout.flush() 115 | 116 | train(model, train_examples, valid_examples) 117 | 118 | print('Terminating training: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S"))) --------------------------------------------------------------------------------