├── CodeBERT ├── model.py └── run.py ├── GraphCodeBERT ├── model.py ├── parser │ ├── DFG.py │ ├── __init__.py │ ├── build.py │ ├── build.sh │ └── utils.py └── run.py ├── Mixup.py ├── README.md ├── Tool ├── Java_refactor │ ├── generate_refactoring.py │ ├── processing_source_code.py │ ├── refactoring_methods.py │ └── util.py └── Python_refactor │ ├── generate_refactoring.py │ ├── processing_source_code.py │ ├── refactoring_methods.py │ └── util.py ├── img └── overview.png └── model ├── BagOfToken.py ├── GAT_model.py ├── GCN_model.py ├── GGNN_model.py └── SeqofToken.py /CodeBERT/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import torch 4 | import torch.nn as nn 5 | import torch 6 | from torch.autograd import Variable 7 | import copy 8 | import torch.nn.functional as F 9 | from torch.nn import CrossEntropyLoss, MSELoss 10 | import numpy as np 11 | 12 | def mixup_data(x, y, alpha=0.1, runs, use_cuda=True): 13 | for i in range(runs): 14 | output_x = torch.Tensor(0) 15 | output_x= output_x.numpy().tolist() 16 | output_y = torch.Tensor(0) 17 | output_y = output_y.numpy().tolist() 18 | batch_size = x.size()[0] 19 | if alpha > 0.: 20 | lam = np.random.beta(alpha, alpha) 21 | else: 22 | lam = 1. 23 | 24 | if use_cuda: 25 | index = torch.randperm(batch_size).cuda() 26 | else: 27 | index = torch.randperm(batch_size) 28 | mixed_x = lam * x + (1 - lam) * x[index, :] 29 | mixed_y = lam * y + (1 - lam) * y[index, :] 30 | output_x.append(mixed_x) 31 | output_y.append(mixed_y) 32 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 33 | 34 | 35 | def mixup_data_refactor( x, y, x_refactor, y_refactor, alpha, runs, use_cuda=True): 36 | for i in range(runs): 37 | output_x = torch.Tensor(0) 38 | output_x= output_x.numpy().tolist() 39 | output_y = torch.Tensor(0) 40 | output_y = output_y.numpy().tolist() 41 | batch_size = x.size()[0] 42 | if alpha > 0.: 43 | lam = np.random.beta(alpha, alpha) 44 | else: 45 | lam = 1. 46 | if use_cuda: 47 | index = torch.randperm(batch_size).cuda() 48 | else: 49 | index = torch.randperm(batch_size) 50 | mixed_x = lam * x + (1 - lam) * x_refactor[index, :] 51 | mixed_y = lam * y + (1 - lam) * y_refactor[index, :] 52 | output_x.append(mixed_x) 53 | output_y.append(mixed_y) 54 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 55 | 56 | 57 | class Model(nn.Module): 58 | def __init__(self, encoder,config,tokenizer,args): 59 | super(Model, self).__init__() 60 | self.encoder = encoder 61 | self.config=config 62 | self.tokenizer=tokenizer 63 | self.args=args 64 | 65 | 66 | def forward(self, input_ids=None,labels=None): 67 | logits=self.encoder(input_ids,attention_mask=input_ids.ne(1))[0] 68 | logits, labels = mixup_data(logits,labels) # Mixup Data 69 | prob=torch.nn.functional.log_softmax(logits,-1) 70 | if labels is not None: 71 | loss = -torch.sum(prob*labels) 72 | return loss,prob 73 | else: 74 | return prob 75 | 76 | 77 | -------------------------------------------------------------------------------- /CodeBERT/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). 18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned 19 | using a masked language modeling (MLM) loss. 20 | """ 21 | 22 | from __future__ import absolute_import, division, print_function 23 | 24 | import argparse 25 | import glob 26 | import logging 27 | import os 28 | import pickle 29 | import random 30 | import re 31 | import shutil 32 | 33 | import numpy as np 34 | import torch 35 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset 36 | from torch.utils.data.distributed import DistributedSampler 37 | import json 38 | 39 | 40 | from tqdm import tqdm, trange 41 | import multiprocessing 42 | from model import Model 43 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 44 | RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer) 45 | 46 | logger = logging.getLogger(__name__) 47 | 48 | class InputFeatures(object): 49 | """A single training/test features for a example.""" 50 | def __init__(self, 51 | input_tokens, 52 | input_ids, 53 | label, 54 | 55 | ): 56 | self.input_tokens = input_tokens 57 | self.input_ids = input_ids 58 | self.label=label 59 | 60 | 61 | def convert_examples_to_features(js,tokenizer,args): 62 | #source 63 | code=' '.join(js['code'].split()) 64 | code_tokens=tokenizer.tokenize(code)[:args.block_size-2] 65 | source_tokens =[tokenizer.cls_token]+code_tokens+[tokenizer.sep_token] 66 | source_ids = tokenizer.convert_tokens_to_ids(source_tokens) 67 | padding_length = args.block_size - len(source_ids) 68 | source_ids+=[tokenizer.pad_token_id]*padding_length 69 | return InputFeatures(source_tokens,source_ids,js['label']) 70 | 71 | class TextDataset(Dataset): 72 | def __init__(self, tokenizer, args, file_path=None): 73 | self.examples = [] 74 | with open(file_path) as f: 75 | for line in f: 76 | js=json.loads(line.strip()) 77 | self.examples.append(convert_examples_to_features(js,tokenizer,args)) 78 | if 'train' in file_path: 79 | for idx, example in enumerate(self.examples[:3]): 80 | logger.info("*** Example ***") 81 | logger.info("label: {}".format(example.label)) 82 | logger.info("input_tokens: {}".format([x.replace('\u0120','_') for x in example.input_tokens])) 83 | logger.info("input_ids: {}".format(' '.join(map(str, example.input_ids)))) 84 | 85 | def __len__(self): 86 | return len(self.examples) 87 | 88 | def __getitem__(self, i): 89 | return torch.tensor(self.examples[i].input_ids),torch.tensor(self.examples[i].label) 90 | 91 | 92 | def set_seed(seed=42): 93 | random.seed(seed) 94 | os.environ['PYHTONHASHSEED'] = str(seed) 95 | np.random.seed(seed) 96 | torch.manual_seed(seed) 97 | torch.cuda.manual_seed(seed) 98 | torch.backends.cudnn.deterministic = True 99 | 100 | def train(args, train_dataset, model, tokenizer): 101 | """ Train the model """ 102 | train_sampler = RandomSampler(train_dataset) 103 | 104 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, 105 | batch_size=args.train_batch_size,num_workers=4,pin_memory=True) 106 | 107 | 108 | 109 | # Prepare optimizer and schedule (linear warmup and decay) 110 | no_decay = ['bias', 'LayerNorm.weight'] 111 | optimizer_grouped_parameters = [ 112 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 113 | 'weight_decay': args.weight_decay}, 114 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 115 | ] 116 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 117 | max_steps = len(train_dataloader) * args.num_train_epochs 118 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=max_steps*0.1, 119 | num_training_steps=max_steps) 120 | 121 | # Train! 122 | logger.info("***** Running training *****") 123 | logger.info(" Num examples = %d", len(train_dataset)) 124 | logger.info(" Num Epochs = %d", args.num_train_epochs) 125 | logger.info(" batch size = %d", args.train_batch_size) 126 | logger.info(" Total optimization steps = %d", max_steps) 127 | best_acc=0.0 128 | model.zero_grad() 129 | 130 | for idx in range(args.num_train_epochs): 131 | bar = tqdm(train_dataloader,total=len(train_dataloader)) 132 | losses=[] 133 | for step, batch in enumerate(bar): 134 | inputs = batch[0].to(args.device) 135 | labels=batch[1].to(args.device) 136 | #print(inputs.size()) 137 | #print(labels.size()) 138 | labels = torch.nn.functional.one_hot(labels, args.num_labels) 139 | # print(labels.size()) 140 | model.train() 141 | loss,logits = model(inputs,labels) 142 | 143 | if args.n_gpu > 1: 144 | loss = loss.mean() # mean() to average on multi-gpu parallel training 145 | 146 | loss.backward() 147 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 148 | 149 | losses.append(loss.item()) 150 | bar.set_description("epoch {} loss {}".format(idx,round(np.mean(losses),3))) 151 | optimizer.step() 152 | optimizer.zero_grad() 153 | scheduler.step() 154 | 155 | results = evaluate(args, model, tokenizer) 156 | for key, value in results.items(): 157 | logger.info(" %s = %s", key, round(value,4)) 158 | 159 | # Save model checkpoint 160 | if results['eval_acc']>best_acc: 161 | best_acc=results['eval_acc'] 162 | logger.info(" "+"*"*20) 163 | logger.info(" Best acc:%s",round(best_acc,4)) 164 | logger.info(" "+"*"*20) 165 | 166 | checkpoint_prefix = 'checkpoint-best-acc' 167 | output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) 168 | if not os.path.exists(output_dir): 169 | os.makedirs(output_dir) 170 | model_to_save = model.module if hasattr(model,'module') else model 171 | output_dir = os.path.join(output_dir, '{}'.format('model.bin')) 172 | torch.save(model_to_save.state_dict(), output_dir) 173 | logger.info("Saving model checkpoint to %s", output_dir) 174 | 175 | def evaluate(args, model, tokenizer): 176 | eval_output_dir = args.output_dir 177 | 178 | eval_dataset = TextDataset(tokenizer, args,args.eval_data_file) 179 | 180 | if not os.path.exists(eval_output_dir): 181 | os.makedirs(eval_output_dir) 182 | 183 | 184 | eval_sampler = SequentialSampler(eval_dataset) 185 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size,num_workers=4,pin_memory=True) 186 | 187 | # Eval! 188 | logger.info("***** Running evaluation *****") 189 | logger.info(" Num examples = %d", len(eval_dataset)) 190 | logger.info(" Batch size = %d", args.eval_batch_size) 191 | eval_loss = 0.0 192 | nb_eval_steps = 0 193 | model.eval() 194 | logits=[] 195 | labels=[] 196 | for batch in eval_dataloader: 197 | inputs = batch[0].to(args.device) 198 | label=batch[1].to(args.device) 199 | label = torch.nn.functional.one_hot(label, args.num_labels) 200 | 201 | with torch.no_grad(): 202 | lm_loss,logit = model(inputs,label) 203 | eval_loss += lm_loss.mean().item() 204 | logits.append(logit.cpu().numpy()) 205 | labels.append(label.cpu().numpy()) 206 | nb_eval_steps += 1 207 | logits=np.concatenate(logits,0) 208 | labels=np.concatenate(labels,0) 209 | #print('logits:',logits) 210 | preds=logits.argmax(-1) 211 | #print('preds:',preds) 212 | labels=labels.argmax(-1) 213 | eval_acc=np.mean(labels==preds) 214 | eval_loss = eval_loss / nb_eval_steps 215 | perplexity = torch.tensor(eval_loss) 216 | 217 | result = { 218 | "eval_loss": float(perplexity), 219 | "eval_acc":round(eval_acc,4), 220 | } 221 | return result 222 | 223 | def test(args, model, tokenizer): 224 | # Note that DistributedSampler samples randomly 225 | eval_dataset = TextDataset(tokenizer, args,args.test_data_file) 226 | eval_sampler = SequentialSampler(eval_dataset) 227 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 228 | 229 | # Eval! 230 | logger.info("***** Running Test *****") 231 | logger.info(" Num examples = %d", len(eval_dataset)) 232 | logger.info(" Batch size = %d", args.eval_batch_size) 233 | eval_loss = 0.0 234 | nb_eval_steps = 0 235 | model.eval() 236 | logits=[] 237 | labels=[] 238 | for batch in tqdm(eval_dataloader,total=len(eval_dataloader)): 239 | inputs = batch[0].to(args.device) 240 | label=batch[1].to(args.device) 241 | label = torch.nn.functional.one_hot(label, args.num_labels) 242 | with torch.no_grad(): 243 | _, logit = model(inputs,label) 244 | logits.append(logit.cpu().numpy()) 245 | labels.append(label.cpu().numpy()) 246 | nb_eval_steps += 1 247 | 248 | logits=np.concatenate(logits,0) 249 | labels=np.concatenate(labels,0) 250 | preds=logits.argmax(-1) 251 | labels= logits.argmax(-1) 252 | test_acc = np.mean(labels==preds) 253 | 254 | result = { 255 | "test_acc":round(test_acc,4), 256 | } 257 | return result 258 | 259 | 260 | 261 | #with open(os.path.join(args.output_dir,"predictions.txt"),'w') as f: 262 | # for example,pred in zip(eval_dataset.examples,preds): 263 | # f.write(str(pred)+'\n') 264 | 265 | def main(): 266 | parser = argparse.ArgumentParser() 267 | 268 | ## Required parameters 269 | parser.add_argument("--train_data_file", default=None, type=str, required=True, 270 | help="The input training data file (a text file).") 271 | parser.add_argument("--output_dir", default=None, type=str, required=True, 272 | help="The output directory where the model predictions and checkpoints will be written.") 273 | 274 | ## Other parameters 275 | parser.add_argument("--eval_data_file", default=None, type=str, 276 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).") 277 | parser.add_argument("--test_data_file", default=None, type=str, 278 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).") 279 | parser.add_argument("--model_name_or_path", default=None, type=str, 280 | help="The model checkpoint for weights initialization.") 281 | parser.add_argument("--tokenizer_name", default="", type=str, 282 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") 283 | parser.add_argument("--block_size", default=-1, type=int, 284 | help="Optional input sequence length after tokenization.") 285 | parser.add_argument("--do_train", action='store_true', 286 | help="Whether to run training.") 287 | parser.add_argument("--do_eval", action='store_true', 288 | help="Whether to run eval on the dev set.") 289 | parser.add_argument("--do_test", action='store_true', 290 | help="Whether to run eval on the dev set.") 291 | parser.add_argument("--train_batch_size", default=4, type=int, 292 | help="Batch size per GPU/CPU for training.") 293 | parser.add_argument("--eval_batch_size", default=4, type=int, 294 | help="Batch size per GPU/CPU for evaluation.") 295 | parser.add_argument("--learning_rate", default=5e-5, type=float, 296 | help="The initial learning rate for Adam.") 297 | parser.add_argument("--weight_decay", default=0.0, type=float, 298 | help="Weight deay if we apply some.") 299 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 300 | help="Epsilon for Adam optimizer.") 301 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 302 | help="Max gradient norm.") 303 | parser.add_argument("--warmup_steps", default=0, type=int, 304 | help="Linear warmup over warmup_steps.") 305 | parser.add_argument('--seed', type=int, default=42, 306 | help="random seed for initialization") 307 | parser.add_argument('--num_train_epochs', type=int, default=42, 308 | help="num_train_epochs") 309 | parser.add_argument('--num_labels', type=int, default=None, 310 | help = 'num_labels') 311 | 312 | args = parser.parse_args() 313 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 314 | args.n_gpu = torch.cuda.device_count() 315 | 316 | args.device = device 317 | # Setup logging 318 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 319 | datefmt='%m/%d/%Y %H:%M:%S', 320 | level=logging.INFO) 321 | logger.warning("device: %s, n_gpu: %s", device, args.n_gpu) 322 | 323 | # Set seed 324 | set_seed(args.seed) 325 | 326 | config = RobertaConfig.from_pretrained(args.model_name_or_path) 327 | config.num_labels=args.num_labels 328 | tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) 329 | model = RobertaForSequenceClassification.from_pretrained(args.model_name_or_path,config=config) 330 | 331 | model=Model(model,config,tokenizer,args) 332 | 333 | # multi-gpu training (should be after apex fp16 initialization) 334 | model.to(args.device) 335 | if args.n_gpu > 1: 336 | model = torch.nn.DataParallel(model) 337 | 338 | logger.info("Training/evaluation parameters %s", args) 339 | 340 | # Training 341 | if args.do_train: 342 | train_dataset = TextDataset(tokenizer, args,args.train_data_file) 343 | train(args, train_dataset, model, tokenizer) 344 | 345 | # Evaluation 346 | results = {} 347 | if args.do_eval: 348 | checkpoint_prefix = 'checkpoint-best-acc/model.bin' 349 | output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) 350 | model.load_state_dict(torch.load(output_dir)) 351 | model.to(args.device) 352 | result=evaluate(args, model, tokenizer) 353 | logger.info("***** Final Eval results *****") 354 | for key in sorted(result.keys()): 355 | logger.info(" %s = %s", key, str(round(result[key],4))) 356 | 357 | if args.do_test: 358 | checkpoint_prefix = 'checkpoint-best-acc/model.bin' 359 | output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) 360 | model.load_state_dict(torch.load(output_dir)) 361 | model.to(args.device) 362 | result = test(args, model, tokenizer) 363 | logger.info("***** Final Test results *****") 364 | for key in sorted(result.keys()): 365 | logger.info(" %s = %s", key, str(round(result[key],4))) 366 | 367 | 368 | return results 369 | 370 | 371 | if __name__ == "__main__": 372 | main() 373 | -------------------------------------------------------------------------------- /GraphCodeBERT/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import torch 4 | import torch.nn as nn 5 | import torch 6 | from torch.autograd import Variable 7 | import copy 8 | import numpy as np 9 | import torch.nn.functional as F 10 | from torch.nn import CrossEntropyLoss, MSELoss 11 | from torch.utils.data import SequentialSampler, DataLoader 12 | import pdb 13 | 14 | 15 | def mixup_data(x, y, alpha=0.1, runs, use_cuda=True): 16 | for i in range(runs): 17 | output_x = torch.Tensor(0) 18 | output_x= output_x.numpy().tolist() 19 | output_y = torch.Tensor(0) 20 | output_y = output_y.numpy().tolist() 21 | batch_size = x.size()[0] 22 | if alpha > 0.: 23 | lam = np.random.beta(alpha, alpha) 24 | else: 25 | lam = 1. 26 | 27 | if use_cuda: 28 | index = torch.randperm(batch_size).cuda() 29 | else: 30 | index = torch.randperm(batch_size) 31 | mixed_x = lam * x + (1 - lam) * x[index, :] 32 | mixed_y = lam * y + (1 - lam) * y[index, :] 33 | output_x.append(mixed_x) 34 | output_y.append(mixed_y) 35 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 36 | 37 | 38 | def mixup_data_refactor( x, y, x_refactor, y_refactor, alpha, runs, use_cuda=True): 39 | for i in range(runs): 40 | output_x = torch.Tensor(0) 41 | output_x= output_x.numpy().tolist() 42 | output_y = torch.Tensor(0) 43 | output_y = output_y.numpy().tolist() 44 | batch_size = x.size()[0] 45 | if alpha > 0.: 46 | lam = np.random.beta(alpha, alpha) 47 | else: 48 | lam = 1. 49 | if use_cuda: 50 | index = torch.randperm(batch_size).cuda() 51 | else: 52 | index = torch.randperm(batch_size) 53 | mixed_x = lam * x + (1 - lam) * x_refactor[index, :] 54 | mixed_y = lam * y + (1 - lam) * y_refactor[index, :] 55 | output_x.append(mixed_x) 56 | output_y.append(mixed_y) 57 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 58 | 59 | class RobertaClassificationHead(nn.Module): 60 | """Head for sentence-level classification tasks.""" 61 | 62 | def __init__(self, config): 63 | super().__init__() 64 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 65 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 66 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 67 | 68 | def forward(self, features, **kwargs): 69 | x = features[:, 0, :] 70 | x = self.dropout(x) 71 | x = self.dense(x) 72 | x = torch.tanh(x) 73 | x = self.dropout(x) 74 | x = self.out_proj(x) 75 | return x 76 | 77 | class Model(nn.Module): 78 | def __init__(self, encoder,config,tokenizer,args): 79 | super(Model, self).__init__() 80 | self.encoder = encoder 81 | self.config=config 82 | self.tokenizer=tokenizer 83 | self.classifier=RobertaClassificationHead(config) 84 | self.args=args 85 | self.query = 0 86 | 87 | 88 | def forward(self, inputs_ids=None, attn_mask=None, position_idx=None, labels=None): 89 | 90 | nodes_mask=position_idx.eq(0) 91 | token_mask=position_idx.ge(2) 92 | 93 | inputs_embeddings=self.encoder.roberta.embeddings.word_embeddings(inputs_ids) 94 | nodes_to_token_mask=nodes_mask[:,:,None]&token_mask[:,None,:]&attn_mask 95 | nodes_to_token_mask=nodes_to_token_mask/(nodes_to_token_mask.sum(-1)+1e-10)[:,:,None] 96 | avg_embeddings=torch.einsum("abc,acd->abd",nodes_to_token_mask,inputs_embeddings) 97 | inputs_embeddings=inputs_embeddings*(~nodes_mask)[:,:,None]+avg_embeddings*nodes_mask[:,:,None] 98 | outputs = self.encoder.roberta(inputs_embeds=inputs_embeddings, attention_mask=attn_mask, position_ids=position_idx, token_type_ids=position_idx.eq(-1).long())[0] 99 | 100 | logits=self.classifier(outputs) 101 | logits, labels = mixup_data(logits,labels) # Mixup Data 102 | prob=torch.nn.functional.log_softmax(logits,-1) 103 | 104 | if labels is not None: 105 | loss = -torch.sum(prob*labels) 106 | return loss,prob 107 | else: 108 | return prob 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /GraphCodeBERT/parser/DFG.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from tree_sitter import Language, Parser 5 | from .utils import (remove_comments_and_docstrings, 6 | tree_to_token_index, 7 | index_to_code_token, 8 | tree_to_variable_index) 9 | 10 | 11 | def DFG_python(root_node,index_to_code,states): 12 | assignment=['assignment','augmented_assignment','for_in_clause'] 13 | if_statement=['if_statement'] 14 | for_statement=['for_statement'] 15 | while_statement=['while_statement'] 16 | do_first_statement=['for_in_clause'] 17 | def_statement=['default_parameter'] 18 | states=states.copy() 19 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 20 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 21 | if root_node.type==code: 22 | return [],states 23 | elif code in states: 24 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 25 | else: 26 | if root_node.type=='identifier': 27 | states[code]=[idx] 28 | return [(code,idx,'comesFrom',[],[])],states 29 | elif root_node.type in def_statement: 30 | name=root_node.child_by_field_name('name') 31 | value=root_node.child_by_field_name('value') 32 | DFG=[] 33 | if value is None: 34 | indexs=tree_to_variable_index(name,index_to_code) 35 | for index in indexs: 36 | idx,code=index_to_code[index] 37 | DFG.append((code,idx,'comesFrom',[],[])) 38 | states[code]=[idx] 39 | return sorted(DFG,key=lambda x:x[1]),states 40 | else: 41 | name_indexs=tree_to_variable_index(name,index_to_code) 42 | value_indexs=tree_to_variable_index(value,index_to_code) 43 | temp,states=DFG_python(value,index_to_code,states) 44 | DFG+=temp 45 | for index1 in name_indexs: 46 | idx1,code1=index_to_code[index1] 47 | for index2 in value_indexs: 48 | idx2,code2=index_to_code[index2] 49 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 50 | states[code1]=[idx1] 51 | return sorted(DFG,key=lambda x:x[1]),states 52 | elif root_node.type in assignment: 53 | if root_node.type=='for_in_clause': 54 | right_nodes=[root_node.children[-1]] 55 | left_nodes=[root_node.child_by_field_name('left')] 56 | else: 57 | if root_node.child_by_field_name('right') is None: 58 | return [],states 59 | left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=','] 60 | right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=','] 61 | if len(right_nodes)!=len(left_nodes): 62 | left_nodes=[root_node.child_by_field_name('left')] 63 | right_nodes=[root_node.child_by_field_name('right')] 64 | if len(left_nodes)==0: 65 | left_nodes=[root_node.child_by_field_name('left')] 66 | if len(right_nodes)==0: 67 | right_nodes=[root_node.child_by_field_name('right')] 68 | DFG=[] 69 | for node in right_nodes: 70 | temp,states=DFG_python(node,index_to_code,states) 71 | DFG+=temp 72 | 73 | for left_node,right_node in zip(left_nodes,right_nodes): 74 | left_tokens_index=tree_to_variable_index(left_node,index_to_code) 75 | right_tokens_index=tree_to_variable_index(right_node,index_to_code) 76 | temp=[] 77 | for token1_index in left_tokens_index: 78 | idx1,code1=index_to_code[token1_index] 79 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], 80 | [index_to_code[x][0] for x in right_tokens_index])) 81 | states[code1]=[idx1] 82 | DFG+=temp 83 | return sorted(DFG,key=lambda x:x[1]),states 84 | elif root_node.type in if_statement: 85 | DFG=[] 86 | current_states=states.copy() 87 | others_states=[] 88 | tag=False 89 | if 'else' in root_node.type: 90 | tag=True 91 | for child in root_node.children: 92 | if 'else' in child.type: 93 | tag=True 94 | if child.type not in ['elif_clause','else_clause']: 95 | temp,current_states=DFG_python(child,index_to_code,current_states) 96 | DFG+=temp 97 | else: 98 | temp,new_states=DFG_python(child,index_to_code,states) 99 | DFG+=temp 100 | others_states.append(new_states) 101 | others_states.append(current_states) 102 | if tag is False: 103 | others_states.append(states) 104 | new_states={} 105 | for dic in others_states: 106 | for key in dic: 107 | if key not in new_states: 108 | new_states[key]=dic[key].copy() 109 | else: 110 | new_states[key]+=dic[key] 111 | for key in new_states: 112 | new_states[key]=sorted(list(set(new_states[key]))) 113 | return sorted(DFG,key=lambda x:x[1]),new_states 114 | elif root_node.type in for_statement: 115 | DFG=[] 116 | for i in range(2): 117 | right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=','] 118 | left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=','] 119 | if len(right_nodes)!=len(left_nodes): 120 | left_nodes=[root_node.child_by_field_name('left')] 121 | right_nodes=[root_node.child_by_field_name('right')] 122 | if len(left_nodes)==0: 123 | left_nodes=[root_node.child_by_field_name('left')] 124 | if len(right_nodes)==0: 125 | right_nodes=[root_node.child_by_field_name('right')] 126 | for node in right_nodes: 127 | temp,states=DFG_python(node,index_to_code,states) 128 | DFG+=temp 129 | for left_node,right_node in zip(left_nodes,right_nodes): 130 | left_tokens_index=tree_to_variable_index(left_node,index_to_code) 131 | right_tokens_index=tree_to_variable_index(right_node,index_to_code) 132 | temp=[] 133 | for token1_index in left_tokens_index: 134 | idx1,code1=index_to_code[token1_index] 135 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], 136 | [index_to_code[x][0] for x in right_tokens_index])) 137 | states[code1]=[idx1] 138 | DFG+=temp 139 | if root_node.children[-1].type=="block": 140 | temp,states=DFG_python(root_node.children[-1],index_to_code,states) 141 | DFG+=temp 142 | dic={} 143 | for x in DFG: 144 | if (x[0],x[1],x[2]) not in dic: 145 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 146 | else: 147 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 148 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 149 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 150 | return sorted(DFG,key=lambda x:x[1]),states 151 | elif root_node.type in while_statement: 152 | DFG=[] 153 | for i in range(2): 154 | for child in root_node.children: 155 | temp,states=DFG_python(child,index_to_code,states) 156 | DFG+=temp 157 | dic={} 158 | for x in DFG: 159 | if (x[0],x[1],x[2]) not in dic: 160 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 161 | else: 162 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 163 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 164 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 165 | return sorted(DFG,key=lambda x:x[1]),states 166 | else: 167 | DFG=[] 168 | for child in root_node.children: 169 | if child.type in do_first_statement: 170 | temp,states=DFG_python(child,index_to_code,states) 171 | DFG+=temp 172 | for child in root_node.children: 173 | if child.type not in do_first_statement: 174 | temp,states=DFG_python(child,index_to_code,states) 175 | DFG+=temp 176 | 177 | return sorted(DFG,key=lambda x:x[1]),states 178 | 179 | 180 | def DFG_java(root_node,index_to_code,states): 181 | assignment=['assignment_expression'] 182 | def_statement=['variable_declarator'] 183 | increment_statement=['update_expression'] 184 | if_statement=['if_statement','else'] 185 | for_statement=['for_statement'] 186 | enhanced_for_statement=['enhanced_for_statement'] 187 | while_statement=['while_statement'] 188 | do_first_statement=[] 189 | states=states.copy() 190 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 191 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 192 | if root_node.type==code: 193 | return [],states 194 | elif code in states: 195 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 196 | else: 197 | if root_node.type=='identifier': 198 | states[code]=[idx] 199 | return [(code,idx,'comesFrom',[],[])],states 200 | elif root_node.type in def_statement: 201 | name=root_node.child_by_field_name('name') 202 | value=root_node.child_by_field_name('value') 203 | DFG=[] 204 | if value is None: 205 | indexs=tree_to_variable_index(name,index_to_code) 206 | for index in indexs: 207 | idx,code=index_to_code[index] 208 | DFG.append((code,idx,'comesFrom',[],[])) 209 | states[code]=[idx] 210 | return sorted(DFG,key=lambda x:x[1]),states 211 | else: 212 | name_indexs=tree_to_variable_index(name,index_to_code) 213 | value_indexs=tree_to_variable_index(value,index_to_code) 214 | temp,states=DFG_java(value,index_to_code,states) 215 | DFG+=temp 216 | for index1 in name_indexs: 217 | idx1,code1=index_to_code[index1] 218 | for index2 in value_indexs: 219 | idx2,code2=index_to_code[index2] 220 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 221 | states[code1]=[idx1] 222 | return sorted(DFG,key=lambda x:x[1]),states 223 | elif root_node.type in assignment: 224 | left_nodes=root_node.child_by_field_name('left') 225 | right_nodes=root_node.child_by_field_name('right') 226 | DFG=[] 227 | temp,states=DFG_java(right_nodes,index_to_code,states) 228 | DFG+=temp 229 | name_indexs=tree_to_variable_index(left_nodes,index_to_code) 230 | value_indexs=tree_to_variable_index(right_nodes,index_to_code) 231 | for index1 in name_indexs: 232 | idx1,code1=index_to_code[index1] 233 | for index2 in value_indexs: 234 | idx2,code2=index_to_code[index2] 235 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 236 | states[code1]=[idx1] 237 | return sorted(DFG,key=lambda x:x[1]),states 238 | elif root_node.type in increment_statement: 239 | DFG=[] 240 | indexs=tree_to_variable_index(root_node,index_to_code) 241 | for index1 in indexs: 242 | idx1,code1=index_to_code[index1] 243 | for index2 in indexs: 244 | idx2,code2=index_to_code[index2] 245 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 246 | states[code1]=[idx1] 247 | return sorted(DFG,key=lambda x:x[1]),states 248 | elif root_node.type in if_statement: 249 | DFG=[] 250 | current_states=states.copy() 251 | others_states=[] 252 | flag=False 253 | tag=False 254 | if 'else' in root_node.type: 255 | tag=True 256 | for child in root_node.children: 257 | if 'else' in child.type: 258 | tag=True 259 | if child.type not in if_statement and flag is False: 260 | temp,current_states=DFG_java(child,index_to_code,current_states) 261 | DFG+=temp 262 | else: 263 | flag=True 264 | temp,new_states=DFG_java(child,index_to_code,states) 265 | DFG+=temp 266 | others_states.append(new_states) 267 | others_states.append(current_states) 268 | if tag is False: 269 | others_states.append(states) 270 | new_states={} 271 | for dic in others_states: 272 | for key in dic: 273 | if key not in new_states: 274 | new_states[key]=dic[key].copy() 275 | else: 276 | new_states[key]+=dic[key] 277 | for key in new_states: 278 | new_states[key]=sorted(list(set(new_states[key]))) 279 | return sorted(DFG,key=lambda x:x[1]),new_states 280 | elif root_node.type in for_statement: 281 | DFG=[] 282 | for child in root_node.children: 283 | temp,states=DFG_java(child,index_to_code,states) 284 | DFG+=temp 285 | flag=False 286 | for child in root_node.children: 287 | if flag: 288 | temp,states=DFG_java(child,index_to_code,states) 289 | DFG+=temp 290 | elif child.type=="local_variable_declaration": 291 | flag=True 292 | dic={} 293 | for x in DFG: 294 | if (x[0],x[1],x[2]) not in dic: 295 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 296 | else: 297 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 298 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 299 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 300 | return sorted(DFG,key=lambda x:x[1]),states 301 | elif root_node.type in enhanced_for_statement: 302 | name=root_node.child_by_field_name('name') 303 | value=root_node.child_by_field_name('value') 304 | body=root_node.child_by_field_name('body') 305 | DFG=[] 306 | for i in range(2): 307 | temp,states=DFG_java(value,index_to_code,states) 308 | DFG+=temp 309 | name_indexs=tree_to_variable_index(name,index_to_code) 310 | value_indexs=tree_to_variable_index(value,index_to_code) 311 | for index1 in name_indexs: 312 | idx1,code1=index_to_code[index1] 313 | for index2 in value_indexs: 314 | idx2,code2=index_to_code[index2] 315 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 316 | states[code1]=[idx1] 317 | temp,states=DFG_java(body,index_to_code,states) 318 | DFG+=temp 319 | dic={} 320 | for x in DFG: 321 | if (x[0],x[1],x[2]) not in dic: 322 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 323 | else: 324 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 325 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 326 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 327 | return sorted(DFG,key=lambda x:x[1]),states 328 | elif root_node.type in while_statement: 329 | DFG=[] 330 | for i in range(2): 331 | for child in root_node.children: 332 | temp,states=DFG_java(child,index_to_code,states) 333 | DFG+=temp 334 | dic={} 335 | for x in DFG: 336 | if (x[0],x[1],x[2]) not in dic: 337 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 338 | else: 339 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 340 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 341 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 342 | return sorted(DFG,key=lambda x:x[1]),states 343 | else: 344 | DFG=[] 345 | for child in root_node.children: 346 | if child.type in do_first_statement: 347 | temp,states=DFG_java(child,index_to_code,states) 348 | DFG+=temp 349 | for child in root_node.children: 350 | if child.type not in do_first_statement: 351 | temp,states=DFG_java(child,index_to_code,states) 352 | DFG+=temp 353 | 354 | return sorted(DFG,key=lambda x:x[1]),states 355 | 356 | def DFG_csharp(root_node,index_to_code,states): 357 | assignment=['assignment_expression'] 358 | def_statement=['variable_declarator'] 359 | increment_statement=['postfix_unary_expression'] 360 | if_statement=['if_statement','else'] 361 | for_statement=['for_statement'] 362 | enhanced_for_statement=['for_each_statement'] 363 | while_statement=['while_statement'] 364 | do_first_statement=[] 365 | states=states.copy() 366 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 367 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 368 | if root_node.type==code: 369 | return [],states 370 | elif code in states: 371 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 372 | else: 373 | if root_node.type=='identifier': 374 | states[code]=[idx] 375 | return [(code,idx,'comesFrom',[],[])],states 376 | elif root_node.type in def_statement: 377 | if len(root_node.children)==2: 378 | name=root_node.children[0] 379 | value=root_node.children[1] 380 | else: 381 | name=root_node.children[0] 382 | value=None 383 | DFG=[] 384 | if value is None: 385 | indexs=tree_to_variable_index(name,index_to_code) 386 | for index in indexs: 387 | idx,code=index_to_code[index] 388 | DFG.append((code,idx,'comesFrom',[],[])) 389 | states[code]=[idx] 390 | return sorted(DFG,key=lambda x:x[1]),states 391 | else: 392 | name_indexs=tree_to_variable_index(name,index_to_code) 393 | value_indexs=tree_to_variable_index(value,index_to_code) 394 | temp,states=DFG_csharp(value,index_to_code,states) 395 | DFG+=temp 396 | for index1 in name_indexs: 397 | idx1,code1=index_to_code[index1] 398 | for index2 in value_indexs: 399 | idx2,code2=index_to_code[index2] 400 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 401 | states[code1]=[idx1] 402 | return sorted(DFG,key=lambda x:x[1]),states 403 | elif root_node.type in assignment: 404 | left_nodes=root_node.child_by_field_name('left') 405 | right_nodes=root_node.child_by_field_name('right') 406 | DFG=[] 407 | temp,states=DFG_csharp(right_nodes,index_to_code,states) 408 | DFG+=temp 409 | name_indexs=tree_to_variable_index(left_nodes,index_to_code) 410 | value_indexs=tree_to_variable_index(right_nodes,index_to_code) 411 | for index1 in name_indexs: 412 | idx1,code1=index_to_code[index1] 413 | for index2 in value_indexs: 414 | idx2,code2=index_to_code[index2] 415 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 416 | states[code1]=[idx1] 417 | return sorted(DFG,key=lambda x:x[1]),states 418 | elif root_node.type in increment_statement: 419 | DFG=[] 420 | indexs=tree_to_variable_index(root_node,index_to_code) 421 | for index1 in indexs: 422 | idx1,code1=index_to_code[index1] 423 | for index2 in indexs: 424 | idx2,code2=index_to_code[index2] 425 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 426 | states[code1]=[idx1] 427 | return sorted(DFG,key=lambda x:x[1]),states 428 | elif root_node.type in if_statement: 429 | DFG=[] 430 | current_states=states.copy() 431 | others_states=[] 432 | flag=False 433 | tag=False 434 | if 'else' in root_node.type: 435 | tag=True 436 | for child in root_node.children: 437 | if 'else' in child.type: 438 | tag=True 439 | if child.type not in if_statement and flag is False: 440 | temp,current_states=DFG_csharp(child,index_to_code,current_states) 441 | DFG+=temp 442 | else: 443 | flag=True 444 | temp,new_states=DFG_csharp(child,index_to_code,states) 445 | DFG+=temp 446 | others_states.append(new_states) 447 | others_states.append(current_states) 448 | if tag is False: 449 | others_states.append(states) 450 | new_states={} 451 | for dic in others_states: 452 | for key in dic: 453 | if key not in new_states: 454 | new_states[key]=dic[key].copy() 455 | else: 456 | new_states[key]+=dic[key] 457 | for key in new_states: 458 | new_states[key]=sorted(list(set(new_states[key]))) 459 | return sorted(DFG,key=lambda x:x[1]),new_states 460 | elif root_node.type in for_statement: 461 | DFG=[] 462 | for child in root_node.children: 463 | temp,states=DFG_csharp(child,index_to_code,states) 464 | DFG+=temp 465 | flag=False 466 | for child in root_node.children: 467 | if flag: 468 | temp,states=DFG_csharp(child,index_to_code,states) 469 | DFG+=temp 470 | elif child.type=="local_variable_declaration": 471 | flag=True 472 | dic={} 473 | for x in DFG: 474 | if (x[0],x[1],x[2]) not in dic: 475 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 476 | else: 477 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 478 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 479 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 480 | return sorted(DFG,key=lambda x:x[1]),states 481 | elif root_node.type in enhanced_for_statement: 482 | name=root_node.child_by_field_name('left') 483 | value=root_node.child_by_field_name('right') 484 | body=root_node.child_by_field_name('body') 485 | DFG=[] 486 | for i in range(2): 487 | temp,states=DFG_csharp(value,index_to_code,states) 488 | DFG+=temp 489 | name_indexs=tree_to_variable_index(name,index_to_code) 490 | value_indexs=tree_to_variable_index(value,index_to_code) 491 | for index1 in name_indexs: 492 | idx1,code1=index_to_code[index1] 493 | for index2 in value_indexs: 494 | idx2,code2=index_to_code[index2] 495 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 496 | states[code1]=[idx1] 497 | temp,states=DFG_csharp(body,index_to_code,states) 498 | DFG+=temp 499 | dic={} 500 | for x in DFG: 501 | if (x[0],x[1],x[2]) not in dic: 502 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 503 | else: 504 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 505 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 506 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 507 | return sorted(DFG,key=lambda x:x[1]),states 508 | elif root_node.type in while_statement: 509 | DFG=[] 510 | for i in range(2): 511 | for child in root_node.children: 512 | temp,states=DFG_csharp(child,index_to_code,states) 513 | DFG+=temp 514 | dic={} 515 | for x in DFG: 516 | if (x[0],x[1],x[2]) not in dic: 517 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 518 | else: 519 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 520 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 521 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 522 | return sorted(DFG,key=lambda x:x[1]),states 523 | else: 524 | DFG=[] 525 | for child in root_node.children: 526 | if child.type in do_first_statement: 527 | temp,states=DFG_csharp(child,index_to_code,states) 528 | DFG+=temp 529 | for child in root_node.children: 530 | if child.type not in do_first_statement: 531 | temp,states=DFG_csharp(child,index_to_code,states) 532 | DFG+=temp 533 | 534 | return sorted(DFG,key=lambda x:x[1]),states 535 | 536 | 537 | 538 | 539 | def DFG_ruby(root_node,index_to_code,states): 540 | assignment=['assignment','operator_assignment'] 541 | if_statement=['if','elsif','else','unless','when'] 542 | for_statement=['for'] 543 | while_statement=['while_modifier','until'] 544 | do_first_statement=[] 545 | def_statement=['keyword_parameter'] 546 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 547 | states=states.copy() 548 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 549 | if root_node.type==code: 550 | return [],states 551 | elif code in states: 552 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 553 | else: 554 | if root_node.type=='identifier': 555 | states[code]=[idx] 556 | return [(code,idx,'comesFrom',[],[])],states 557 | elif root_node.type in def_statement: 558 | name=root_node.child_by_field_name('name') 559 | value=root_node.child_by_field_name('value') 560 | DFG=[] 561 | if value is None: 562 | indexs=tree_to_variable_index(name,index_to_code) 563 | for index in indexs: 564 | idx,code=index_to_code[index] 565 | DFG.append((code,idx,'comesFrom',[],[])) 566 | states[code]=[idx] 567 | return sorted(DFG,key=lambda x:x[1]),states 568 | else: 569 | name_indexs=tree_to_variable_index(name,index_to_code) 570 | value_indexs=tree_to_variable_index(value,index_to_code) 571 | temp,states=DFG_ruby(value,index_to_code,states) 572 | DFG+=temp 573 | for index1 in name_indexs: 574 | idx1,code1=index_to_code[index1] 575 | for index2 in value_indexs: 576 | idx2,code2=index_to_code[index2] 577 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 578 | states[code1]=[idx1] 579 | return sorted(DFG,key=lambda x:x[1]),states 580 | elif root_node.type in assignment: 581 | left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=','] 582 | right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=','] 583 | if len(right_nodes)!=len(left_nodes): 584 | left_nodes=[root_node.child_by_field_name('left')] 585 | right_nodes=[root_node.child_by_field_name('right')] 586 | if len(left_nodes)==0: 587 | left_nodes=[root_node.child_by_field_name('left')] 588 | if len(right_nodes)==0: 589 | right_nodes=[root_node.child_by_field_name('right')] 590 | if root_node.type=="operator_assignment": 591 | left_nodes=[root_node.children[0]] 592 | right_nodes=[root_node.children[-1]] 593 | 594 | DFG=[] 595 | for node in right_nodes: 596 | temp,states=DFG_ruby(node,index_to_code,states) 597 | DFG+=temp 598 | 599 | for left_node,right_node in zip(left_nodes,right_nodes): 600 | left_tokens_index=tree_to_variable_index(left_node,index_to_code) 601 | right_tokens_index=tree_to_variable_index(right_node,index_to_code) 602 | temp=[] 603 | for token1_index in left_tokens_index: 604 | idx1,code1=index_to_code[token1_index] 605 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], 606 | [index_to_code[x][0] for x in right_tokens_index])) 607 | states[code1]=[idx1] 608 | DFG+=temp 609 | return sorted(DFG,key=lambda x:x[1]),states 610 | elif root_node.type in if_statement: 611 | DFG=[] 612 | current_states=states.copy() 613 | others_states=[] 614 | tag=False 615 | if 'else' in root_node.type: 616 | tag=True 617 | for child in root_node.children: 618 | if 'else' in child.type: 619 | tag=True 620 | if child.type not in if_statement: 621 | temp,current_states=DFG_ruby(child,index_to_code,current_states) 622 | DFG+=temp 623 | else: 624 | temp,new_states=DFG_ruby(child,index_to_code,states) 625 | DFG+=temp 626 | others_states.append(new_states) 627 | others_states.append(current_states) 628 | if tag is False: 629 | others_states.append(states) 630 | new_states={} 631 | for dic in others_states: 632 | for key in dic: 633 | if key not in new_states: 634 | new_states[key]=dic[key].copy() 635 | else: 636 | new_states[key]+=dic[key] 637 | for key in new_states: 638 | new_states[key]=sorted(list(set(new_states[key]))) 639 | return sorted(DFG,key=lambda x:x[1]),new_states 640 | elif root_node.type in for_statement: 641 | DFG=[] 642 | for i in range(2): 643 | left_nodes=[root_node.child_by_field_name('pattern')] 644 | right_nodes=[root_node.child_by_field_name('value')] 645 | assert len(right_nodes)==len(left_nodes) 646 | for node in right_nodes: 647 | temp,states=DFG_ruby(node,index_to_code,states) 648 | DFG+=temp 649 | for left_node,right_node in zip(left_nodes,right_nodes): 650 | left_tokens_index=tree_to_variable_index(left_node,index_to_code) 651 | right_tokens_index=tree_to_variable_index(right_node,index_to_code) 652 | temp=[] 653 | for token1_index in left_tokens_index: 654 | idx1,code1=index_to_code[token1_index] 655 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], 656 | [index_to_code[x][0] for x in right_tokens_index])) 657 | states[code1]=[idx1] 658 | DFG+=temp 659 | temp,states=DFG_ruby(root_node.child_by_field_name('body'),index_to_code,states) 660 | DFG+=temp 661 | dic={} 662 | for x in DFG: 663 | if (x[0],x[1],x[2]) not in dic: 664 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 665 | else: 666 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 667 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 668 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 669 | return sorted(DFG,key=lambda x:x[1]),states 670 | elif root_node.type in while_statement: 671 | DFG=[] 672 | for i in range(2): 673 | for child in root_node.children: 674 | temp,states=DFG_ruby(child,index_to_code,states) 675 | DFG+=temp 676 | dic={} 677 | for x in DFG: 678 | if (x[0],x[1],x[2]) not in dic: 679 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 680 | else: 681 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 682 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 683 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 684 | return sorted(DFG,key=lambda x:x[1]),states 685 | else: 686 | DFG=[] 687 | for child in root_node.children: 688 | if child.type in do_first_statement: 689 | temp,states=DFG_ruby(child,index_to_code,states) 690 | DFG+=temp 691 | for child in root_node.children: 692 | if child.type not in do_first_statement: 693 | temp,states=DFG_ruby(child,index_to_code,states) 694 | DFG+=temp 695 | 696 | return sorted(DFG,key=lambda x:x[1]),states 697 | 698 | def DFG_go(root_node,index_to_code,states): 699 | assignment=['assignment_statement',] 700 | def_statement=['var_spec'] 701 | increment_statement=['inc_statement'] 702 | if_statement=['if_statement','else'] 703 | for_statement=['for_statement'] 704 | enhanced_for_statement=[] 705 | while_statement=[] 706 | do_first_statement=[] 707 | states=states.copy() 708 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 709 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 710 | if root_node.type==code: 711 | return [],states 712 | elif code in states: 713 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 714 | else: 715 | if root_node.type=='identifier': 716 | states[code]=[idx] 717 | return [(code,idx,'comesFrom',[],[])],states 718 | elif root_node.type in def_statement: 719 | name=root_node.child_by_field_name('name') 720 | value=root_node.child_by_field_name('value') 721 | DFG=[] 722 | if value is None: 723 | indexs=tree_to_variable_index(name,index_to_code) 724 | for index in indexs: 725 | idx,code=index_to_code[index] 726 | DFG.append((code,idx,'comesFrom',[],[])) 727 | states[code]=[idx] 728 | return sorted(DFG,key=lambda x:x[1]),states 729 | else: 730 | name_indexs=tree_to_variable_index(name,index_to_code) 731 | value_indexs=tree_to_variable_index(value,index_to_code) 732 | temp,states=DFG_go(value,index_to_code,states) 733 | DFG+=temp 734 | for index1 in name_indexs: 735 | idx1,code1=index_to_code[index1] 736 | for index2 in value_indexs: 737 | idx2,code2=index_to_code[index2] 738 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 739 | states[code1]=[idx1] 740 | return sorted(DFG,key=lambda x:x[1]),states 741 | elif root_node.type in assignment: 742 | left_nodes=root_node.child_by_field_name('left') 743 | right_nodes=root_node.child_by_field_name('right') 744 | DFG=[] 745 | temp,states=DFG_go(right_nodes,index_to_code,states) 746 | DFG+=temp 747 | name_indexs=tree_to_variable_index(left_nodes,index_to_code) 748 | value_indexs=tree_to_variable_index(right_nodes,index_to_code) 749 | for index1 in name_indexs: 750 | idx1,code1=index_to_code[index1] 751 | for index2 in value_indexs: 752 | idx2,code2=index_to_code[index2] 753 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 754 | states[code1]=[idx1] 755 | return sorted(DFG,key=lambda x:x[1]),states 756 | elif root_node.type in increment_statement: 757 | DFG=[] 758 | indexs=tree_to_variable_index(root_node,index_to_code) 759 | for index1 in indexs: 760 | idx1,code1=index_to_code[index1] 761 | for index2 in indexs: 762 | idx2,code2=index_to_code[index2] 763 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 764 | states[code1]=[idx1] 765 | return sorted(DFG,key=lambda x:x[1]),states 766 | elif root_node.type in if_statement: 767 | DFG=[] 768 | current_states=states.copy() 769 | others_states=[] 770 | flag=False 771 | tag=False 772 | if 'else' in root_node.type: 773 | tag=True 774 | for child in root_node.children: 775 | if 'else' in child.type: 776 | tag=True 777 | if child.type not in if_statement and flag is False: 778 | temp,current_states=DFG_go(child,index_to_code,current_states) 779 | DFG+=temp 780 | else: 781 | flag=True 782 | temp,new_states=DFG_go(child,index_to_code,states) 783 | DFG+=temp 784 | others_states.append(new_states) 785 | others_states.append(current_states) 786 | if tag is False: 787 | others_states.append(states) 788 | new_states={} 789 | for dic in others_states: 790 | for key in dic: 791 | if key not in new_states: 792 | new_states[key]=dic[key].copy() 793 | else: 794 | new_states[key]+=dic[key] 795 | for key in states: 796 | if key not in new_states: 797 | new_states[key]=states[key] 798 | else: 799 | new_states[key]+=states[key] 800 | for key in new_states: 801 | new_states[key]=sorted(list(set(new_states[key]))) 802 | return sorted(DFG,key=lambda x:x[1]),new_states 803 | elif root_node.type in for_statement: 804 | DFG=[] 805 | for child in root_node.children: 806 | temp,states=DFG_go(child,index_to_code,states) 807 | DFG+=temp 808 | flag=False 809 | for child in root_node.children: 810 | if flag: 811 | temp,states=DFG_go(child,index_to_code,states) 812 | DFG+=temp 813 | elif child.type=="for_clause": 814 | if child.child_by_field_name('update') is not None: 815 | temp,states=DFG_go(child.child_by_field_name('update'),index_to_code,states) 816 | DFG+=temp 817 | flag=True 818 | dic={} 819 | for x in DFG: 820 | if (x[0],x[1],x[2]) not in dic: 821 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 822 | else: 823 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 824 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 825 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 826 | return sorted(DFG,key=lambda x:x[1]),states 827 | else: 828 | DFG=[] 829 | for child in root_node.children: 830 | if child.type in do_first_statement: 831 | temp,states=DFG_go(child,index_to_code,states) 832 | DFG+=temp 833 | for child in root_node.children: 834 | if child.type not in do_first_statement: 835 | temp,states=DFG_go(child,index_to_code,states) 836 | DFG+=temp 837 | 838 | return sorted(DFG,key=lambda x:x[1]),states 839 | 840 | 841 | 842 | 843 | def DFG_php(root_node,index_to_code,states): 844 | assignment=['assignment_expression','augmented_assignment_expression'] 845 | def_statement=['simple_parameter'] 846 | increment_statement=['update_expression'] 847 | if_statement=['if_statement','else_clause'] 848 | for_statement=['for_statement'] 849 | enhanced_for_statement=['foreach_statement'] 850 | while_statement=['while_statement'] 851 | do_first_statement=[] 852 | states=states.copy() 853 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 854 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 855 | if root_node.type==code: 856 | return [],states 857 | elif code in states: 858 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 859 | else: 860 | if root_node.type=='identifier': 861 | states[code]=[idx] 862 | return [(code,idx,'comesFrom',[],[])],states 863 | elif root_node.type in def_statement: 864 | name=root_node.child_by_field_name('name') 865 | value=root_node.child_by_field_name('default_value') 866 | DFG=[] 867 | if value is None: 868 | indexs=tree_to_variable_index(name,index_to_code) 869 | for index in indexs: 870 | idx,code=index_to_code[index] 871 | DFG.append((code,idx,'comesFrom',[],[])) 872 | states[code]=[idx] 873 | return sorted(DFG,key=lambda x:x[1]),states 874 | else: 875 | name_indexs=tree_to_variable_index(name,index_to_code) 876 | value_indexs=tree_to_variable_index(value,index_to_code) 877 | temp,states=DFG_php(value,index_to_code,states) 878 | DFG+=temp 879 | for index1 in name_indexs: 880 | idx1,code1=index_to_code[index1] 881 | for index2 in value_indexs: 882 | idx2,code2=index_to_code[index2] 883 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 884 | states[code1]=[idx1] 885 | return sorted(DFG,key=lambda x:x[1]),states 886 | elif root_node.type in assignment: 887 | left_nodes=root_node.child_by_field_name('left') 888 | right_nodes=root_node.child_by_field_name('right') 889 | DFG=[] 890 | temp,states=DFG_php(right_nodes,index_to_code,states) 891 | DFG+=temp 892 | name_indexs=tree_to_variable_index(left_nodes,index_to_code) 893 | value_indexs=tree_to_variable_index(right_nodes,index_to_code) 894 | for index1 in name_indexs: 895 | idx1,code1=index_to_code[index1] 896 | for index2 in value_indexs: 897 | idx2,code2=index_to_code[index2] 898 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 899 | states[code1]=[idx1] 900 | return sorted(DFG,key=lambda x:x[1]),states 901 | elif root_node.type in increment_statement: 902 | DFG=[] 903 | indexs=tree_to_variable_index(root_node,index_to_code) 904 | for index1 in indexs: 905 | idx1,code1=index_to_code[index1] 906 | for index2 in indexs: 907 | idx2,code2=index_to_code[index2] 908 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 909 | states[code1]=[idx1] 910 | return sorted(DFG,key=lambda x:x[1]),states 911 | elif root_node.type in if_statement: 912 | DFG=[] 913 | current_states=states.copy() 914 | others_states=[] 915 | flag=False 916 | tag=False 917 | if 'else' in root_node.type: 918 | tag=True 919 | for child in root_node.children: 920 | if 'else' in child.type: 921 | tag=True 922 | if child.type not in if_statement and flag is False: 923 | temp,current_states=DFG_php(child,index_to_code,current_states) 924 | DFG+=temp 925 | else: 926 | flag=True 927 | temp,new_states=DFG_php(child,index_to_code,states) 928 | DFG+=temp 929 | others_states.append(new_states) 930 | others_states.append(current_states) 931 | new_states={} 932 | for dic in others_states: 933 | for key in dic: 934 | if key not in new_states: 935 | new_states[key]=dic[key].copy() 936 | else: 937 | new_states[key]+=dic[key] 938 | for key in states: 939 | if key not in new_states: 940 | new_states[key]=states[key] 941 | else: 942 | new_states[key]+=states[key] 943 | for key in new_states: 944 | new_states[key]=sorted(list(set(new_states[key]))) 945 | return sorted(DFG,key=lambda x:x[1]),new_states 946 | elif root_node.type in for_statement: 947 | DFG=[] 948 | for child in root_node.children: 949 | temp,states=DFG_php(child,index_to_code,states) 950 | DFG+=temp 951 | flag=False 952 | for child in root_node.children: 953 | if flag: 954 | temp,states=DFG_php(child,index_to_code,states) 955 | DFG+=temp 956 | elif child.type=="assignment_expression": 957 | flag=True 958 | dic={} 959 | for x in DFG: 960 | if (x[0],x[1],x[2]) not in dic: 961 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 962 | else: 963 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 964 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 965 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 966 | return sorted(DFG,key=lambda x:x[1]),states 967 | elif root_node.type in enhanced_for_statement: 968 | name=None 969 | value=None 970 | for child in root_node.children: 971 | if child.type=='variable_name' and value is None: 972 | value=child 973 | elif child.type=='variable_name' and name is None: 974 | name=child 975 | break 976 | body=root_node.child_by_field_name('body') 977 | DFG=[] 978 | for i in range(2): 979 | temp,states=DFG_php(value,index_to_code,states) 980 | DFG+=temp 981 | name_indexs=tree_to_variable_index(name,index_to_code) 982 | value_indexs=tree_to_variable_index(value,index_to_code) 983 | for index1 in name_indexs: 984 | idx1,code1=index_to_code[index1] 985 | for index2 in value_indexs: 986 | idx2,code2=index_to_code[index2] 987 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 988 | states[code1]=[idx1] 989 | temp,states=DFG_php(body,index_to_code,states) 990 | DFG+=temp 991 | dic={} 992 | for x in DFG: 993 | if (x[0],x[1],x[2]) not in dic: 994 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 995 | else: 996 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 997 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 998 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 999 | return sorted(DFG,key=lambda x:x[1]),states 1000 | elif root_node.type in while_statement: 1001 | DFG=[] 1002 | for i in range(2): 1003 | for child in root_node.children: 1004 | temp,states=DFG_php(child,index_to_code,states) 1005 | DFG+=temp 1006 | dic={} 1007 | for x in DFG: 1008 | if (x[0],x[1],x[2]) not in dic: 1009 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 1010 | else: 1011 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 1012 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 1013 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 1014 | return sorted(DFG,key=lambda x:x[1]),states 1015 | else: 1016 | DFG=[] 1017 | for child in root_node.children: 1018 | if child.type in do_first_statement: 1019 | temp,states=DFG_php(child,index_to_code,states) 1020 | DFG+=temp 1021 | for child in root_node.children: 1022 | if child.type not in do_first_statement: 1023 | temp,states=DFG_php(child,index_to_code,states) 1024 | DFG+=temp 1025 | 1026 | return sorted(DFG,key=lambda x:x[1]),states 1027 | 1028 | 1029 | def DFG_javascript(root_node,index_to_code,states): 1030 | assignment=['assignment_pattern','augmented_assignment_expression'] 1031 | def_statement=['variable_declarator'] 1032 | increment_statement=['update_expression'] 1033 | if_statement=['if_statement','else'] 1034 | for_statement=['for_statement'] 1035 | enhanced_for_statement=[] 1036 | while_statement=['while_statement'] 1037 | do_first_statement=[] 1038 | states=states.copy() 1039 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 1040 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 1041 | if root_node.type==code: 1042 | return [],states 1043 | elif code in states: 1044 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 1045 | else: 1046 | if root_node.type=='identifier': 1047 | states[code]=[idx] 1048 | return [(code,idx,'comesFrom',[],[])],states 1049 | elif root_node.type in def_statement: 1050 | name=root_node.child_by_field_name('name') 1051 | value=root_node.child_by_field_name('value') 1052 | DFG=[] 1053 | if value is None: 1054 | indexs=tree_to_variable_index(name,index_to_code) 1055 | for index in indexs: 1056 | idx,code=index_to_code[index] 1057 | DFG.append((code,idx,'comesFrom',[],[])) 1058 | states[code]=[idx] 1059 | return sorted(DFG,key=lambda x:x[1]),states 1060 | else: 1061 | name_indexs=tree_to_variable_index(name,index_to_code) 1062 | value_indexs=tree_to_variable_index(value,index_to_code) 1063 | temp,states=DFG_javascript(value,index_to_code,states) 1064 | DFG+=temp 1065 | for index1 in name_indexs: 1066 | idx1,code1=index_to_code[index1] 1067 | for index2 in value_indexs: 1068 | idx2,code2=index_to_code[index2] 1069 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 1070 | states[code1]=[idx1] 1071 | return sorted(DFG,key=lambda x:x[1]),states 1072 | elif root_node.type in assignment: 1073 | left_nodes=root_node.child_by_field_name('left') 1074 | right_nodes=root_node.child_by_field_name('right') 1075 | DFG=[] 1076 | temp,states=DFG_javascript(right_nodes,index_to_code,states) 1077 | DFG+=temp 1078 | name_indexs=tree_to_variable_index(left_nodes,index_to_code) 1079 | value_indexs=tree_to_variable_index(right_nodes,index_to_code) 1080 | for index1 in name_indexs: 1081 | idx1,code1=index_to_code[index1] 1082 | for index2 in value_indexs: 1083 | idx2,code2=index_to_code[index2] 1084 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 1085 | states[code1]=[idx1] 1086 | return sorted(DFG,key=lambda x:x[1]),states 1087 | elif root_node.type in increment_statement: 1088 | DFG=[] 1089 | indexs=tree_to_variable_index(root_node,index_to_code) 1090 | for index1 in indexs: 1091 | idx1,code1=index_to_code[index1] 1092 | for index2 in indexs: 1093 | idx2,code2=index_to_code[index2] 1094 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 1095 | states[code1]=[idx1] 1096 | return sorted(DFG,key=lambda x:x[1]),states 1097 | elif root_node.type in if_statement: 1098 | DFG=[] 1099 | current_states=states.copy() 1100 | others_states=[] 1101 | flag=False 1102 | tag=False 1103 | if 'else' in root_node.type: 1104 | tag=True 1105 | for child in root_node.children: 1106 | if 'else' in child.type: 1107 | tag=True 1108 | if child.type not in if_statement and flag is False: 1109 | temp,current_states=DFG_javascript(child,index_to_code,current_states) 1110 | DFG+=temp 1111 | else: 1112 | flag=True 1113 | temp,new_states=DFG_javascript(child,index_to_code,states) 1114 | DFG+=temp 1115 | others_states.append(new_states) 1116 | others_states.append(current_states) 1117 | if tag is False: 1118 | others_states.append(states) 1119 | new_states={} 1120 | for dic in others_states: 1121 | for key in dic: 1122 | if key not in new_states: 1123 | new_states[key]=dic[key].copy() 1124 | else: 1125 | new_states[key]+=dic[key] 1126 | for key in states: 1127 | if key not in new_states: 1128 | new_states[key]=states[key] 1129 | else: 1130 | new_states[key]+=states[key] 1131 | for key in new_states: 1132 | new_states[key]=sorted(list(set(new_states[key]))) 1133 | return sorted(DFG,key=lambda x:x[1]),new_states 1134 | elif root_node.type in for_statement: 1135 | DFG=[] 1136 | for child in root_node.children: 1137 | temp,states=DFG_javascript(child,index_to_code,states) 1138 | DFG+=temp 1139 | flag=False 1140 | for child in root_node.children: 1141 | if flag: 1142 | temp,states=DFG_javascript(child,index_to_code,states) 1143 | DFG+=temp 1144 | elif child.type=="variable_declaration": 1145 | flag=True 1146 | dic={} 1147 | for x in DFG: 1148 | if (x[0],x[1],x[2]) not in dic: 1149 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 1150 | else: 1151 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 1152 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 1153 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 1154 | return sorted(DFG,key=lambda x:x[1]),states 1155 | elif root_node.type in while_statement: 1156 | DFG=[] 1157 | for i in range(2): 1158 | for child in root_node.children: 1159 | temp,states=DFG_javascript(child,index_to_code,states) 1160 | DFG+=temp 1161 | dic={} 1162 | for x in DFG: 1163 | if (x[0],x[1],x[2]) not in dic: 1164 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 1165 | else: 1166 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 1167 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 1168 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 1169 | return sorted(DFG,key=lambda x:x[1]),states 1170 | else: 1171 | DFG=[] 1172 | for child in root_node.children: 1173 | if child.type in do_first_statement: 1174 | temp,states=DFG_javascript(child,index_to_code,states) 1175 | DFG+=temp 1176 | for child in root_node.children: 1177 | if child.type not in do_first_statement: 1178 | temp,states=DFG_javascript(child,index_to_code,states) 1179 | DFG+=temp 1180 | 1181 | return sorted(DFG,key=lambda x:x[1]),states 1182 | 1183 | 1184 | 1185 | -------------------------------------------------------------------------------- /GraphCodeBERT/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import (remove_comments_and_docstrings, 2 | tree_to_token_index, 3 | index_to_code_token, 4 | tree_to_variable_index) 5 | from .DFG import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp -------------------------------------------------------------------------------- /GraphCodeBERT/parser/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from tree_sitter import Language, Parser 5 | 6 | Language.build_library( 7 | # Store the library in the `build` directory 8 | 'my-languages.so', 9 | 10 | # Include one or more languages 11 | [ 12 | 'tree-sitter-go', 13 | 'tree-sitter-javascript', 14 | 'tree-sitter-python', 15 | 'tree-sitter-php', 16 | 'tree-sitter-java', 17 | 'tree-sitter-ruby', 18 | 'tree-sitter-c-sharp', 19 | ] 20 | ) 21 | 22 | -------------------------------------------------------------------------------- /GraphCodeBERT/parser/build.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/tree-sitter/tree-sitter-go 2 | git clone https://github.com/tree-sitter/tree-sitter-javascript 3 | git clone https://github.com/tree-sitter/tree-sitter-python 4 | git clone https://github.com/tree-sitter/tree-sitter-ruby 5 | git clone https://github.com/tree-sitter/tree-sitter-php 6 | git clone https://github.com/tree-sitter/tree-sitter-java 7 | git clone https://github.com/tree-sitter/tree-sitter-c-sharp 8 | python build.py 9 | -------------------------------------------------------------------------------- /GraphCodeBERT/parser/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from io import StringIO 3 | import tokenize 4 | def remove_comments_and_docstrings(source,lang): 5 | if lang in ['python']: 6 | """ 7 | Returns 'source' minus comments and docstrings. 8 | """ 9 | io_obj = StringIO(source) 10 | out = "" 11 | prev_toktype = tokenize.INDENT 12 | last_lineno = -1 13 | last_col = 0 14 | for tok in tokenize.generate_tokens(io_obj.readline): 15 | token_type = tok[0] 16 | token_string = tok[1] 17 | start_line, start_col = tok[2] 18 | end_line, end_col = tok[3] 19 | ltext = tok[4] 20 | if start_line > last_lineno: 21 | last_col = 0 22 | if start_col > last_col: 23 | out += (" " * (start_col - last_col)) 24 | # Remove comments: 25 | if token_type == tokenize.COMMENT: 26 | pass 27 | # This series of conditionals removes docstrings: 28 | elif token_type == tokenize.STRING: 29 | if prev_toktype != tokenize.INDENT: 30 | # This is likely a docstring; double-check we're not inside an operator: 31 | if prev_toktype != tokenize.NEWLINE: 32 | if start_col > 0: 33 | out += token_string 34 | else: 35 | out += token_string 36 | prev_toktype = token_type 37 | last_col = end_col 38 | last_lineno = end_line 39 | temp=[] 40 | for x in out.split('\n'): 41 | if x.strip()!="": 42 | temp.append(x) 43 | return '\n'.join(temp) 44 | elif lang in ['ruby']: 45 | return source 46 | else: 47 | def replacer(match): 48 | s = match.group(0) 49 | if s.startswith('/'): 50 | return " " # note: a space and not an empty string 51 | else: 52 | return s 53 | pattern = re.compile( 54 | r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', 55 | re.DOTALL | re.MULTILINE 56 | ) 57 | temp=[] 58 | for x in re.sub(pattern, replacer, source).split('\n'): 59 | if x.strip()!="": 60 | temp.append(x) 61 | return '\n'.join(temp) 62 | 63 | def tree_to_token_index(root_node): 64 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 65 | return [(root_node.start_point,root_node.end_point)] 66 | else: 67 | code_tokens=[] 68 | for child in root_node.children: 69 | code_tokens+=tree_to_token_index(child) 70 | return code_tokens 71 | 72 | def tree_to_variable_index(root_node,index_to_code): 73 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 74 | index=(root_node.start_point,root_node.end_point) 75 | _,code=index_to_code[index] 76 | if root_node.type!=code: 77 | return [(root_node.start_point,root_node.end_point)] 78 | else: 79 | return [] 80 | else: 81 | code_tokens=[] 82 | for child in root_node.children: 83 | code_tokens+=tree_to_variable_index(child,index_to_code) 84 | return code_tokens 85 | 86 | def index_to_code_token(index,code): 87 | start_point=index[0] 88 | end_point=index[1] 89 | if start_point[0]==end_point[0]: 90 | s=code[start_point[0]][start_point[1]:end_point[1]] 91 | else: 92 | s="" 93 | s+=code[start_point[0]][start_point[1]:] 94 | for i in range(start_point[0]+1,end_point[0]): 95 | s+=code[i] 96 | s+=code[end_point[0]][:end_point[1]] 97 | return s 98 | -------------------------------------------------------------------------------- /GraphCodeBERT/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). 18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned 19 | using a masked language modeling (MLM) loss. 20 | """ 21 | 22 | from __future__ import absolute_import, division, print_function 23 | 24 | import argparse 25 | import glob 26 | import logging 27 | import os 28 | import pickle 29 | import random 30 | import re 31 | import shutil 32 | import json 33 | import numpy as np 34 | import torch 35 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset 36 | from torch.utils.data.distributed import DistributedSampler 37 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 38 | RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer) 39 | from tqdm import tqdm, trange 40 | import multiprocessing 41 | from model import Model 42 | import pdb 43 | cpu_cont = 16 44 | logger = logging.getLogger(__name__) 45 | 46 | from parser import DFG_python, DFG_java, DFG_ruby, DFG_go, DFG_php, DFG_javascript 47 | from parser import (remove_comments_and_docstrings, 48 | tree_to_token_index, 49 | index_to_code_token, 50 | tree_to_variable_index) 51 | from tree_sitter import Language, Parser 52 | 53 | root_path = "tg" 54 | 55 | dfg_function = { 56 | 'python': DFG_python, 57 | 'java': DFG_java, 58 | 'ruby': DFG_ruby, 59 | 'go': DFG_go, 60 | 'php': DFG_php, 61 | 'javascript': DFG_javascript 62 | } 63 | 64 | # load parsers 65 | parsers = {} 66 | for lang in dfg_function: 67 | LANGUAGE = Language('parser/my-languages.so', lang) 68 | parser = Parser() 69 | parser.set_language(LANGUAGE) 70 | parser = [parser, dfg_function[lang]] 71 | parsers[lang] = parser 72 | 73 | 74 | # remove comments, tokenize code and extract dataflow 75 | def extract_dataflow(code, parser, lang): 76 | # remove comments 77 | code = code.replace("\\n", "\n") 78 | try: 79 | code = remove_comments_and_docstrings(code, lang) 80 | except: 81 | pass 82 | # obtain dataflow 83 | if lang == "php": 84 | code = "" 85 | try: 86 | tree = parser[0].parse(bytes(code, 'utf8')) 87 | root_node = tree.root_node 88 | tokens_index = tree_to_token_index(root_node) 89 | code = code.split('\n') 90 | code_tokens = [index_to_code_token(x, code) for x in tokens_index] 91 | index_to_code = {} 92 | for idx, (index, code) in enumerate(zip(tokens_index, code_tokens)): 93 | index_to_code[index] = (idx, code) 94 | try: 95 | DFG, _ = parser[1](root_node, index_to_code, {}) 96 | except: 97 | DFG = [] 98 | DFG = sorted(DFG, key=lambda x: x[1]) 99 | indexs = set() 100 | for d in DFG: 101 | if len(d[-1]) != 0: 102 | indexs.add(d[1]) 103 | for x in d[-1]: 104 | indexs.add(x) 105 | new_DFG = [] 106 | for d in DFG: 107 | if d[1] in indexs: 108 | new_DFG.append(d) 109 | dfg = new_DFG 110 | except: 111 | dfg = [] 112 | return code_tokens, dfg 113 | 114 | 115 | class InputFeatures(object): 116 | """A single training/test features for a example.""" 117 | 118 | def __init__(self, 119 | input_tokens, 120 | input_ids, 121 | position_idx, 122 | dfg_to_code, 123 | dfg_to_dfg, 124 | label 125 | 126 | ): 127 | self.input_tokens = input_tokens 128 | self.input_ids = input_ids 129 | self.position_idx = position_idx 130 | self.dfg_to_code = dfg_to_code 131 | self.dfg_to_dfg = dfg_to_dfg 132 | self.label = label 133 | 134 | 135 | def convert_examples_to_features(code, label, tokenizer, args): 136 | # source 137 | # code=' '.join(js['func'].split()) 138 | parser = parsers["java"] 139 | code_tokens, dfg = extract_dataflow(code, parser, "java") 140 | 141 | code_tokens = [tokenizer.tokenize('@ ' + x)[1:] if idx != 0 else tokenizer.tokenize(x) for idx, x in enumerate(code_tokens)] 142 | ori2cur_pos = {} 143 | ori2cur_pos[-1] = (0, 0) 144 | for i in range(len(code_tokens)): 145 | ori2cur_pos[i] = (ori2cur_pos[i - 1][1], ori2cur_pos[i - 1][1] + len(code_tokens[i])) 146 | code_tokens = [y for x in code_tokens for y in x] 147 | 148 | code_tokens = code_tokens[:args.code_length + args.data_flow_length - 2 - min(len(dfg), args.data_flow_length)] 149 | source_tokens = [tokenizer.cls_token] + code_tokens + [tokenizer.sep_token] 150 | source_ids = tokenizer.convert_tokens_to_ids(source_tokens) 151 | position_idx = [i + tokenizer.pad_token_id + 1 for i in range(len(source_tokens))] 152 | dfg = dfg[:args.code_length + args.data_flow_length - len(source_tokens)] 153 | source_tokens += [x[0] for x in dfg] 154 | position_idx += [0 for x in dfg] 155 | source_ids += [tokenizer.unk_token_id for x in dfg] 156 | padding_length = args.code_length + args.data_flow_length - len(source_ids) 157 | position_idx += [tokenizer.pad_token_id] * padding_length 158 | source_ids += [tokenizer.pad_token_id] * padding_length 159 | 160 | reverse_index = {} 161 | for idx, x in enumerate(dfg): 162 | reverse_index[x[1]] = idx 163 | for idx, x in enumerate(dfg): 164 | dfg[idx] = x[:-1] + ([reverse_index[i] for i in x[-1] if i in reverse_index],) 165 | dfg_to_dfg = [x[-1] for x in dfg] 166 | dfg_to_code = [ori2cur_pos[x[1]] for x in dfg] 167 | length = len([tokenizer.cls_token]) 168 | dfg_to_code = [(x[0] + length, x[1] + length) for x in dfg_to_code] 169 | 170 | return InputFeatures(source_tokens, source_ids, position_idx, dfg_to_code, dfg_to_dfg, label) 171 | 172 | 173 | class TextDataset(Dataset): 174 | def __init__(self, tokenizer, args, file_path=None): 175 | self.examples = [] 176 | self.args = args 177 | file_type = file_path.split('/')[-1].split('.')[0] 178 | folder = '/'.join(file_path.split('/')[:-1]) 179 | 180 | cache_file_path = os.path.join(folder, 'cached_{}'.format(file_type)) 181 | code_pairs_file_path = os.path.join(folder, 'cached_{}.pkl'.format(file_type)) 182 | 183 | print('\n cached_features_file: ', cache_file_path) 184 | try: 185 | self.examples = torch.load(cache_file_path) 186 | with open(code_pairs_file_path, 'rb') as f: 187 | code_files = pickle.load(f) 188 | logger.info("Loading features from cached file %s", cache_file_path) 189 | 190 | except: 191 | logger.info("Creating features from dataset file at %s", file_path) 192 | code_files = [] 193 | with open(file_path) as f: 194 | for line in f: 195 | js = json.loads(line.strip()) 196 | code = ' '.join(js['func'].split()) 197 | label = js['target'] 198 | #label = torch.nn.functional.one_hot(label, args.num_labels) 199 | self.examples.append(convert_examples_to_features(code, label, tokenizer, args)) 200 | code_files.append(code) 201 | assert (len(self.examples) == len(code_files)) 202 | with open(code_pairs_file_path, 'wb') as f: 203 | pickle.dump(code_files, f) 204 | logger.info("Saving features into cached file %s", cache_file_path) 205 | torch.save(self.examples, cache_file_path) 206 | 207 | def __len__(self): 208 | return len(self.examples) 209 | 210 | def __getitem__(self, item): 211 | # calculate graph-guided masked function 212 | attn_mask = np.zeros((self.args.code_length + self.args.data_flow_length, 213 | self.args.code_length + self.args.data_flow_length), dtype=np.bool) 214 | # calculate begin index of node and max length of input 215 | 216 | node_index = sum([i > 1 for i in self.examples[item].position_idx]) 217 | max_length = sum([i != 1 for i in self.examples[item].position_idx]) 218 | # sequence can attend to sequence 219 | attn_mask[:node_index, :node_index] = True 220 | # special tokens attend to all tokens 221 | 222 | for idx, i in enumerate(self.examples[item].input_ids): 223 | if i in [0, 2]: 224 | attn_mask[idx, :max_length] = True 225 | # nodes attend to code tokens that are identified from 226 | for idx, (a, b) in enumerate(self.examples[item].dfg_to_code): 227 | if a < node_index and b < node_index: 228 | attn_mask[idx + node_index, a:b] = True 229 | attn_mask[a:b, idx + node_index] = True 230 | # nodes attend to adjacent nodes 231 | for idx, nodes in enumerate(self.examples[item].dfg_to_dfg): 232 | for a in nodes: 233 | if a + node_index < len(self.examples[item].position_idx): 234 | attn_mask[idx + node_index, a + node_index] = True 235 | 236 | return (torch.tensor(self.examples[item].input_ids), 237 | torch.tensor(attn_mask), 238 | torch.tensor(self.examples[item].position_idx), 239 | torch.tensor(self.examples[item].label)) 240 | 241 | 242 | def set_seed(args): 243 | random.seed(args.seed) 244 | np.random.seed(args.seed) 245 | torch.manual_seed(args.seed) 246 | if args.n_gpu > 0: 247 | torch.cuda.manual_seed_all(args.seed) 248 | 249 | 250 | def train(args, train_dataset, model, tokenizer): 251 | """ Train the model """ 252 | 253 | # build dataloader 254 | train_sampler = RandomSampler(train_dataset) 255 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=4) 256 | 257 | args.max_steps = args.epochs * len(train_dataloader) 258 | args.save_steps = len(train_dataloader) 259 | args.warmup_steps = args.max_steps // 5 260 | model.to(args.device) 261 | 262 | # Prepare optimizer and schedule (linear warmup and decay) 263 | no_decay = ['bias', 'LayerNorm.weight'] 264 | optimizer_grouped_parameters = [ 265 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 266 | 'weight_decay': args.weight_decay}, 267 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 268 | ] 269 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 270 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, 271 | num_training_steps=args.max_steps) 272 | 273 | # multi-gpu training 274 | if args.n_gpu > 1: 275 | model = torch.nn.DataParallel(model) 276 | 277 | # Train! 278 | logger.info("***** Running training *****") 279 | logger.info(" Num examples = %d", len(train_dataset)) 280 | logger.info(" Num Epochs = %d", args.epochs) 281 | logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size // max(args.n_gpu, 1)) 282 | logger.info(" Total train batch size = %d", args.train_batch_size * args.gradient_accumulation_steps) 283 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 284 | logger.info(" Total optimization steps = %d", args.max_steps) 285 | 286 | global_step = 0 287 | tr_loss, logging_loss, avg_loss, tr_nb, tr_num, train_loss = 0.0, 0.0, 0.0, 0, 0, 0 288 | best_f1 = 0 289 | 290 | model.zero_grad() 291 | 292 | for idx in range(args.epochs): 293 | bar = tqdm(train_dataloader, total=len(train_dataloader)) 294 | tr_num = 0 295 | train_loss = 0 296 | for step, batch in enumerate(bar): 297 | inputs_ids = batch[0].to(args.device) 298 | attn_mask = batch[1].to(args.device) 299 | position_idx = batch[2].to(args.device) 300 | labels = batch[3].to(args.device) 301 | labels = torch.nn.functional.one_hot(labels, args.num_labels) 302 | model.train() 303 | loss, logits = model(inputs_ids, attn_mask, position_idx, labels) 304 | 305 | if args.n_gpu > 1: 306 | loss = loss.mean() 307 | 308 | if args.gradient_accumulation_steps > 1: 309 | loss = loss / args.gradient_accumulation_steps 310 | 311 | loss.backward() 312 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 313 | 314 | tr_loss += loss.item() 315 | tr_num += 1 316 | train_loss += loss.item() 317 | if avg_loss == 0: 318 | avg_loss = tr_loss 319 | 320 | avg_loss = round(train_loss / tr_num, 5) 321 | bar.set_description("epoch {} loss {}".format(idx, avg_loss)) 322 | 323 | if (step + 1) % args.gradient_accumulation_steps == 0: 324 | optimizer.step() 325 | optimizer.zero_grad() 326 | scheduler.step() 327 | global_step += 1 328 | output_flag = True 329 | avg_loss = round(np.exp((tr_loss - logging_loss) / (global_step - tr_nb)), 4) 330 | 331 | if global_step % args.save_steps == 0: 332 | results = evaluate(args, model, tokenizer, eval_when_training=True) 333 | 334 | # Save model checkpoint 335 | if results['eval_precision'] > best_f1: 336 | best_f1 = results['eval_precision'] 337 | logger.info(" " + "*" * 20) 338 | logger.info(" Best precision:%s", round(best_f1, 4)) 339 | logger.info(" " + "*" * 20) 340 | 341 | output_dir1 = f"{root_path}/{args.output_dir}" 342 | if not os.path.exists(output_dir1): 343 | os.makedirs(output_dir1) 344 | model_to_save = model.module if hasattr(model, 'module') else model 345 | output_dir = os.path.join(output_dir1, '{}'.format('model.bin')) 346 | torch.save(model_to_save.state_dict(), output_dir) 347 | logger.info("Saving model checkpoint to %s", output_dir) 348 | 349 | 350 | def evaluate(args, model, tokenizer, eval_when_training=False): 351 | # build dataloader 352 | eval_dataset = TextDataset(tokenizer, args, file_path=args.eval_data_file) 353 | eval_sampler = SequentialSampler(eval_dataset) 354 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=4) 355 | 356 | # multi-gpu evaluate 357 | if args.n_gpu > 1 and eval_when_training is False: 358 | model = torch.nn.DataParallel(model) 359 | 360 | # Eval! 361 | logger.info("***** Running evaluation *****") 362 | logger.info(" Num examples = %d", len(eval_dataset)) 363 | logger.info(" Batch size = %d", args.eval_batch_size) 364 | 365 | eval_loss = 0.0 366 | nb_eval_steps = 0 367 | model.eval() 368 | logits = [] 369 | y_trues = [] 370 | for batch in eval_dataloader: 371 | inputs_ids = batch[0].to(args.device) 372 | attn_mask = batch[1].to(args.device) 373 | position_idx = batch[2].to(args.device) 374 | label = batch[3].to(args.device) 375 | labels = torch.nn.functional.one_hot(labels, args.num_labels) 376 | with torch.no_grad(): 377 | lm_loss, logit = model(inputs_ids, attn_mask, position_idx, label) 378 | eval_loss += lm_loss.mean().item() 379 | logits.append(logit.cpu().numpy()) 380 | y_trues.append(label.cpu().numpy()) 381 | nb_eval_steps += 1 382 | logits = np.concatenate(logits, 0) 383 | y_trues = np.concatenate(y_trues, 0) 384 | 385 | y_preds = [] 386 | for logit in logits: 387 | y_preds.append(np.argmax(logit)) 388 | 389 | from sklearn.metrics import recall_score 390 | recall = recall_score(y_trues, y_preds, average='macro') 391 | from sklearn.metrics import precision_score 392 | precision = precision_score(y_trues, y_preds, average='macro') 393 | from sklearn.metrics import f1_score 394 | f1 = f1_score(y_trues, y_preds, average='macro') 395 | 396 | eval_acc = np.mean(y_trues == np.array(y_preds)) 397 | result = { 398 | "eval_precision": float(precision), 399 | "eval_acc": eval_acc, 400 | } 401 | 402 | # logger.info("***** Eval results {} *****".format(prefix)) 403 | for key in sorted(result.keys()): 404 | logger.info(" %s = %s", key, str(round(result[key], 4))) 405 | 406 | return result 407 | 408 | 409 | def main(): 410 | parser = argparse.ArgumentParser() 411 | 412 | ## Required parameters 413 | parser.add_argument("--data_dir", default="datasets/GraphCodeBERT/code_classification", type=str, 414 | help="The output directory where the model predictions and checkpoints will be written.") 415 | parser.add_argument("--output_dir", default="models/GraphCodeBERT/code_classification/model", type=str, 416 | help="The output directory where the model predictions and checkpoints will be written.") 417 | 418 | ## Other parameters 419 | parser.add_argument("--model_name_or_path", default="microsoft/graphcodebert-base", type=str, 420 | help="The model checkpoint for weights initialization.") 421 | parser.add_argument("--number_labels", type=int, default=250, 422 | help="The model checkpoint for weights initialization.") 423 | 424 | parser.add_argument("--config_name", default="microsoft/graphcodebert-base", type=str, 425 | help="Optional pretrained config name or path if not the same as model_name_or_path") 426 | parser.add_argument("--tokenizer_name", default="microsoft/graphcodebert-base", type=str, 427 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") 428 | 429 | parser.add_argument("--code_length", default=384, type=int, 430 | help="Optional Code input sequence length after tokenization.") 431 | parser.add_argument("--data_flow_length", default=128, type=int, 432 | help="Optional Data Flow input sequence length after tokenization.") 433 | parser.add_argument("--do_train", action='store_true', 434 | help="Whether to run training.") 435 | parser.add_argument("--do_eval", action='store_true', 436 | help="Whether to run eval on the dev set.") 437 | parser.add_argument("--do_test", action='store_true', 438 | help="Whether to run eval on the dev set.") 439 | parser.add_argument("--language_type", type=str, 440 | help="The programming language type of dataset") 441 | parser.add_argument("--evaluate_during_training", action='store_true', 442 | help="Run evaluation during training at each logging step.") 443 | 444 | parser.add_argument("--train_batch_size", default=32, type=int, 445 | help="Batch size per GPU/CPU for training.") 446 | parser.add_argument("--eval_batch_size", default=64, type=int, 447 | help="Batch size per GPU/CPU for evaluation.") 448 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 449 | help="Number of updates steps to accumulate before performing a backward/update pass.") 450 | parser.add_argument("--learning_rate", default=2e-5, type=float, 451 | help="The initial learning rate for Adam.") 452 | parser.add_argument("--weight_decay", default=0.0, type=float, 453 | help="Weight deay if we apply some.") 454 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 455 | help="Epsilon for Adam optimizer.") 456 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 457 | help="Max gradient norm.") 458 | parser.add_argument("--epochs", default=20, type=int, 459 | help="Total number of training epochs to perform.") 460 | parser.add_argument("--max_steps", default=-1, type=int, 461 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 462 | parser.add_argument("--warmup_steps", default=0, type=int, 463 | help="Linear warmup over warmup_steps.") 464 | parser.add_argument('--num_labels', type=int, default=None, 465 | help = 'num_labels') 466 | 467 | parser.add_argument('--seed', type=int, default=123456, 468 | help="random seed for initialization") 469 | 470 | args = parser.parse_args() 471 | 472 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 473 | args.n_gpu = torch.cuda.device_count() 474 | 475 | args.device = device 476 | args.train_data_file = f"{root_path}/{args.data_dir}/train.jsonl" 477 | args.eval_data_file = f"{root_path}/{args.data_dir}/test.jsonl" 478 | args.test_data_file = f"{root_path}/{args.data_dir}/test.jsonl" 479 | 480 | # Setup logging 481 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 482 | logger.warning("device: %s, n_gpu: %s", device, args.n_gpu) 483 | 484 | # Set seed 485 | set_seed(args) 486 | config = RobertaConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) 487 | config.num_labels = args.number_labels 488 | tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) 489 | model = RobertaForSequenceClassification.from_pretrained(args.model_name_or_path, config=config) 490 | model = Model(model, config, tokenizer, args) 491 | 492 | logger.info("Training/evaluation parameters %s", args) 493 | 494 | # Training 495 | if args.do_train: 496 | train_dataset = TextDataset(tokenizer, args, args.train_data_file) 497 | train(args, train_dataset, model, tokenizer) 498 | 499 | # Evaluation 500 | if args.do_eval: 501 | checkpoint_prefix = f"{root_path}/{args.output_dir}/model.bin" 502 | model.load_state_dict(torch.load(checkpoint_prefix)) 503 | model.to(args.device) 504 | result = evaluate(args, model, tokenizer) 505 | logger.info("***** Eval results *****") 506 | for key in sorted(result.keys()): 507 | logger.info(" %s = %s", key, str(round(result[key], 4))) 508 | 509 | 510 | 511 | if __name__ == "__main__": 512 | main() 513 | 514 | -------------------------------------------------------------------------------- /Mixup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | ##### Mixup-Tensorflow ################ 6 | def mixup_data( x, y, alpha, runs): 7 | if runs is None: 8 | runs = 1 9 | output_x = [] 10 | output_y = [] 11 | batch_size = x.shape[0] 12 | for i in range(runs): 13 | lam_vector = np.random.beta(alpha, alpha, batch_size) 14 | index = np.random.permutation(batch_size) 15 | mixed_x = (x.T * lam_vector).T + (x[index, :].T * (1.0 - lam_vector)).T 16 | output_x.append(mixed_x) 17 | if y is None: 18 | return np.concatenate(output_x, axis=0) 19 | mixed_y = (y.T * lam_vector).T + (y[index].T * (1.0 - lam_vector)).T 20 | output_y.append(mixed_y) 21 | return np.concatenate(output_x, axis=0), np.concatenate(output_y, axis=0) 22 | 23 | 24 | 25 | def mixup_data_refactor( x, y, x_refactor, y_refactor, alpha, runs): 26 | if runs is None: 27 | runs = 1 28 | output_x = [] 29 | output_y = [] 30 | batch_size = x.shape[0] 31 | for i in range(runs): 32 | lam_vector = np.random.beta(alpha, alpha, batch_size) 33 | index = np.random.permutation(batch_size) 34 | 35 | mixed_x = (x.T * lam_vector).T + (x_refactor[index, :].T * (1.0 - lam_vector)).T 36 | output_x.append(mixed_x) 37 | if y is None: 38 | return np.concatenate(output_x, axis=0) 39 | mixed_y = (y.T * lam_vector).T + (y_refactor[index].T * (1.0 - lam_vector)).T 40 | output_y.append(mixed_y) 41 | return np.concatenate(output_x, axis=0), np.concatenate(output_y, axis=0) 42 | 43 | 44 | 45 | 46 | ##### Mixup-Pytorch ################ 47 | def mixup_data(x, y, alpha=0.1, runs, use_cuda=True): 48 | for i in range(runs): 49 | output_x = torch.Tensor(0) 50 | output_x= output_x.numpy().tolist() 51 | output_y = torch.Tensor(0) 52 | output_y = output_y.numpy().tolist() 53 | batch_size = x.size()[0] 54 | if alpha > 0.: 55 | lam = np.random.beta(alpha, alpha) 56 | else: 57 | lam = 1. 58 | 59 | if use_cuda: 60 | index = torch.randperm(batch_size).cuda() 61 | else: 62 | index = torch.randperm(batch_size) 63 | mixed_x = lam * x + (1 - lam) * x[index, :] 64 | mixed_y = lam * y + (1 - lam) * y[index, :] 65 | output_x.append(mixed_x) 66 | output_y.append(mixed_y) 67 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 68 | 69 | 70 | def mixup_data_refactor( x, y, x_refactor, y_refactor, alpha, runs, use_cuda=True): 71 | for i in range(runs): 72 | output_x = torch.Tensor(0) 73 | output_x= output_x.numpy().tolist() 74 | output_y = torch.Tensor(0) 75 | output_y = output_y.numpy().tolist() 76 | batch_size = x.size()[0] 77 | if alpha > 0.: 78 | lam = np.random.beta(alpha, alpha) 79 | else: 80 | lam = 1. 81 | if use_cuda: 82 | index = torch.randperm(batch_size).cuda() 83 | else: 84 | index = torch.randperm(batch_size) 85 | mixed_x = lam * x + (1 - lam) * x_refactor[index, :] 86 | mixed_y = lam * y + (1 - lam) * y_refactor[index, :] 87 | output_x.append(mixed_x) 88 | output_y.append(mixed_y) 89 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MixCode: Enhancing Code Classification by Mixup-Based Data Augmentation 2 | Implementation of SANER2023 paper MixCode: Enhancing Code Classification by Mixup-Based Data Augmentation [[arxiv]](https://arxiv.org/abs/2210.03003). 3 | 4 | We build this project on the top of [Project_CodeNet](https://github.com/IBM/Project_CodeNet). Please refer to this project for more details. 5 | 6 | ## Introduction 7 | MIXCODE aims to effectively supplement valid training data without manually collecting or labeling new code, inspired by the recent advance named Mixup in computer vision. Specifically, 1) first utilize multiple code refactoring methods to generate transformed code that holds consistent labels with the original data; 2) adapt the Mixup technique to linearly mix the original code with the transformed code to augment the training data. 8 | 9 |
10 | 11 | 12 | ## Requirements 13 | On Ubuntu: 14 | 15 | - Task: Classification 16 | 17 | ```shell 18 | Python (>=3.6) 19 | TensorFlow (version 2.3.0) 20 | Keras (version 2.4.3) 21 | CUDA 10.1 22 | cuDNN (>=7.6) 23 | ``` 24 | 25 | - Task: Bug Detection 26 | ```shell 27 | Python (>=3.6) 28 | Pytorch (version 1.6.0) 29 | CUDA 10.1 30 | cuDNN (>=7.6) 31 | ``` 32 | 33 | ## CodeBERT/GraphCodeBERT for Classification Tasks 34 | 35 | - pip install torch==1.4.0 36 | - pip install transformers==2.5.0 37 | - pip install filelock 38 | 39 | ### Fine-Tune 40 | ```shell 41 | cd CodeBERT 42 | 43 | python run.py \ 44 | --output_dir=./saved_models \ 45 | --tokenizer_name=microsoft/codebert-base \ 46 | --model_name_or_path=microsoft/codebert-base \ 47 | --do_train \ 48 | --num_train_epochs 50 \ 49 | --block_size 256 \ 50 | --train_batch_size 8 \ 51 | --eval_batch_size 16 \ 52 | --learning_rate 2e-5 \ 53 | --max_grad_norm 1.0 \ 54 | --num_labels 250 \ # Number Classifications 55 | --seed 123456 2>&1 | tee train.log 56 | ``` 57 | 58 | ```shell 59 | cd GraphCodeBERT 60 | 61 | python run.py \ 62 | --tokenizer_name=microsoft/graphcodebert-base \ 63 | --model_name_or_path=microsoft/graphcodebert-base \ 64 | --config_name microsoft/graphcodebert-base \ 65 | --do_train \ 66 | --num_train_epochs 50 \ 67 | --code_length 384 \ 68 | --data_flow_length 384 \ 69 | --train_batch_size 8 \ 70 | --eval_batch_size 16 \ 71 | --learning_rate 2e-5 \ 72 | --max_grad_norm 1.0 \ 73 | --evaluate_during_training \ 74 | --num_labels 250 \ # Number Classifications 75 | --seed 123456 2>&1 | tee train.log 76 | ``` 77 | 78 | ## Dataset 79 | - Java250: https://developer.ibm.com/exchanges/data/all/project-codenet/ 80 | - Python800: https://developer.ibm.com/exchanges/data/all/project-codenet/ 81 | - Refactory: https://github.com/githubhuyang/refactory 82 | - CodRep: https://github.com/KTH/CodRep-competition 83 | 84 | ## Citation 85 | If you use the code in your research, please cite: 86 | ```bibtex 87 | @inproceedings{dong2023mixcode, 88 | title={MixCode: Enhancing Code Classification by Mixup-Based Data Augmentation}, 89 | author={Dong, Zeming and Hu, Qiang and Guo, Yuejun and Cordy, Maxime and Papadakis, Mike and Zhang, Zhenya and Le Traon, Yves and Zhao, Jianjun}, 90 | booktitle={2023 IEEE International Conference on Software Analysis, Evolution and Reengineering (SANER)}, 91 | pages={379--390}, 92 | year={2023}, 93 | organization={IEEE} 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /Tool/Java_refactor/generate_refactoring.py: -------------------------------------------------------------------------------- 1 | import os, random 2 | from shutil import copyfile 3 | from refactoring_methods import * 4 | 5 | 6 | def return_function_code(code, method_names): 7 | final_codes = [] 8 | final_names = [] 9 | 10 | Class_list, raw_code = extract_class(code) 11 | 12 | for class_name in Class_list: 13 | function_list, class_name = extract_function(class_name) 14 | 15 | for fun_code in function_list: 16 | 17 | for method_name in method_names: 18 | method_name_tem = method_name.replace('|', '') 19 | if method_name_tem.upper() in fun_code.split('\n')[0].upper(): 20 | 21 | final_codes.append(fun_code) 22 | final_names.append(method_name) 23 | 24 | return final_codes, final_names 25 | 26 | 27 | def generate_adversarial(k, code, method_names): 28 | 29 | method_name = method_names[0] 30 | function_list = [] 31 | class_name = '' 32 | 33 | Class_list, raw_code = extract_class(code) 34 | 35 | for class_name in Class_list: 36 | function_list, class_name = extract_function(class_name) 37 | 38 | refac = [] 39 | new_refactored_code = '' 40 | 41 | for code in function_list: 42 | if method_name not in code.split('\n')[0]: 43 | continue 44 | 45 | new_rf = code 46 | new_refactored_code = code 47 | 48 | # print(code) 49 | for t in range(k): 50 | refactors_list = [rename_argument, 51 | return_optimal, 52 | add_argumemts, 53 | rename_api, 54 | rename_local_variable, 55 | add_local_variable, 56 | rename_method_name, 57 | enhance_if, 58 | add_print, 59 | duplication, 60 | apply_plus_zero_math, 61 | dead_branch_if_else, 62 | dead_branch_if, 63 | dead_branch_while, 64 | dead_branch_for, 65 | dead_branch_switch 66 | ]# 67 | 68 | vv = 0 69 | 70 | while new_rf == new_refactored_code and vv <= 20: 71 | try: 72 | vv += 1 73 | 74 | refactor = random.choice(refactors_list) 75 | print('*'*50 , refactor , '*'*50) 76 | new_refactored_code = refactor(new_refactored_code) 77 | 78 | except Exception as error: 79 | print('error:\t', error) 80 | 81 | new_rf = new_refactored_code 82 | 83 | print('----------------------------OUT of WHILE----------------------------------', vv) 84 | print('----------------------------CHANGED THJIS TIME:----------------------------------', vv) 85 | 86 | refac.append(new_refactored_code) 87 | 88 | code_body = raw_code.strip() + ' ' + class_name.strip() 89 | for i in range(len(refac)): 90 | final_refactor = code_body.replace('vesal' + str(i), str(refac[i])) 91 | code_body = final_refactor 92 | 93 | 94 | return new_refactored_code 95 | 96 | 97 | def generate_adversarial_json(k, code): 98 | final_refactor = '' 99 | function_list = [] 100 | class_name = '' 101 | vv = 0 102 | if len(function_list) == 0: 103 | function_list.append(code) 104 | refac = [] 105 | for code in function_list: 106 | 107 | 108 | new_rf = code 109 | new_refactored_code = code 110 | 111 | for t in range(k): 112 | 113 | refactors_list = [rename_argument, 114 | return_optimal, 115 | add_argumemts, 116 | rename_api, 117 | rename_local_variable, 118 | add_local_variable, 119 | rename_method_name, 120 | enhance_if, 121 | add_print, 122 | duplication, 123 | apply_plus_zero_math, 124 | dead_branch_if_else, 125 | dead_branch_if, 126 | dead_branch_while, 127 | dead_branch_for, 128 | dead_branch_switch 129 | ] 130 | 131 | vv = 0 132 | 133 | while new_rf == new_refactored_code and vv <= 20: 134 | try: 135 | vv += 1 136 | refactor = random.choice(refactors_list) 137 | print('*' * 50, refactor, '*' * 50) 138 | new_refactored_code = refactor(new_refactored_code) 139 | 140 | except Exception as error: 141 | print('error:\t', error) 142 | 143 | new_rf = new_refactored_code 144 | 145 | refac.append(new_refactored_code) 146 | 147 | print("refactoring finished") 148 | return refac 149 | 150 | 151 | def generate_adversarial_file_level(k, code): 152 | new_refactored_code = '' 153 | new_rf = code 154 | new_refactored_code = code 155 | 156 | for t in range(k): 157 | refactors_list = [rename_argument, 158 | return_optimal, 159 | add_argumemts, 160 | rename_api, 161 | rename_local_variable, 162 | add_local_variable, 163 | rename_method_name, 164 | enhance_if, 165 | add_print, 166 | duplication, 167 | apply_plus_zero_math, 168 | dead_branch_if_else, 169 | dead_branch_if, 170 | dead_branch_while, 171 | dead_branch_for, 172 | dead_branch_switch 173 | ] 174 | 175 | vv = 0 176 | 177 | while new_rf == new_refactored_code and vv <= 20: 178 | try: 179 | vv += 1 180 | refactor = random.choice(refactors_list) 181 | print('*' * 50, refactor, '*' * 50) 182 | new_refactored_code = refactor(new_refactored_code) 183 | 184 | except Exception as error: 185 | print('error:\t', error) 186 | 187 | new_rf = new_refactored_code 188 | 189 | return new_refactored_code 190 | 191 | 192 | if __name__ == '__main__': 193 | K = 1 194 | filename = '**.py' 195 | open_file = open(filename, 'r', encoding='ISO-8859-1') 196 | code = open_file.read() 197 | new_code = generate_adversarial_file_level(K, code) 198 | print(new_code) 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /Tool/Java_refactor/processing_source_code.py: -------------------------------------------------------------------------------- 1 | import re, random 2 | from nltk.corpus import wordnet 3 | import wordninja 4 | 5 | from util import * 6 | 7 | reserved_kws = ["abstract", "assert", "boolean", 8 | "break", "byte", "case", "catch", "char", "class", "const", 9 | "continue", "default", "do", "double", "else", "extends", "false", 10 | "final", "finally", "float", "for", "goto", "if", "implements", 11 | "import", "instanceof", "int", "interface", "long", "native", 12 | "new", "null", "package", "private", "protected", "public", 13 | "return", "short", "static", "strictfp", "super", "switch", 14 | "synchronized", "this", "throw", "throws", "transient", "true", 15 | "try", "void", "volatile", "while"] 16 | 17 | reserved_cls = ["ArrayDeque", "ArrayList", "Arrays", "BitSet", "Calendar", "Collections", "Currency", 18 | "Date", "Dictionary", "EnumMap", "EnumSet", "Formatter", "GregorianCalendar", "HashMap", 19 | "HashSet", "Hashtable", "IdentityHashMap", "LinkedHashMap", "LinkedHashSet", 20 | "LinkedList", "ListResourceBundle", "Locale", "Observable", 21 | "PriorityQueue", "Properties", "PropertyPermission", 22 | "PropertyResourceBundle", "Random", "ResourceBundle", "ResourceBundle.Control", 23 | "Scanner", "ServiceLoader", "SimpleTimeZone", "Stack", 24 | "StringTokenizer", "Timer", "TimerTask", "TimeZone", 25 | "TreeMap", "TreeSet", "UUID", "Vector", "WeakHashMap" 26 | ] 27 | 28 | reserved_kws = reserved_kws + reserved_cls 29 | 30 | 31 | def word_synonym_replacement(word): 32 | if len(word) <= 3: 33 | return word + '_new' 34 | word_set = wordninja.split(word) 35 | while True: 36 | if word_set == []: 37 | return word + '_new' 38 | word_tar = random.choice(word_set) 39 | word_syn = wordnet.synsets(word_tar) 40 | if word_syn == []: 41 | word_set.remove(word_tar) 42 | else: 43 | break 44 | word_ret = [] 45 | for syn in word_syn: 46 | word_ret = word_ret + syn.lemma_names() 47 | if word_tar in word_ret: 48 | word_ret.remove(word_tar) 49 | try: 50 | word_new = random.choice(word_ret) 51 | except: 52 | word_new = word 53 | 54 | 55 | return word.replace(word_tar,word_new),word_ret 56 | 57 | 58 | def extract_method_name(string): 59 | match_ret = re.search('\w+\s*\(',string) 60 | if match_ret: 61 | method_name = match_ret.group()[:-1].strip() 62 | return method_name 63 | else: 64 | return None 65 | 66 | 67 | def extract_argument(string): 68 | end_pos = string.find('{') 69 | sta_pas = string.find('(') 70 | arguments = string[sta_pas + 1 :end_pos].strip()[:-1] 71 | arguments_list = arguments.split(',') 72 | if ' ' in arguments_list: 73 | arguments_list.remove(' ') 74 | if '' in arguments_list: 75 | arguments_list.remove('') 76 | return arguments_list 77 | 78 | 79 | def extract_brace(string,start_pos): 80 | length = 0 81 | brace_l_num = 0 82 | brace_r_num = 0 83 | for char in string[start_pos:]: 84 | if char == '{': 85 | brace_l_num += 1 86 | if char == '}': 87 | brace_r_num += 1 88 | if brace_l_num == brace_r_num and brace_l_num > 0: 89 | break; 90 | length += 1 91 | return string[start_pos: start_pos + length + 1] 92 | 93 | ''' 94 | def extract_import(string): 95 | import_list = re.findall('import .+;',string) 96 | return import_list,string 97 | ''' 98 | 99 | 100 | def extract_class(string): 101 | 102 | class_list = [] 103 | while ' class ' in string: 104 | start_pos = string.find(' class ') 105 | class_text = extract_brace(string, start_pos) 106 | class_list.append(class_text) 107 | string = string.replace(class_text,'') 108 | 109 | return class_list,string 110 | 111 | 112 | def extract_member_variable(string): 113 | 114 | variable_list = [] 115 | while True: 116 | match_ret = re.search('(private|public).+;', string) 117 | if match_ret: 118 | variable_text = match_ret.group() 119 | variable_list.append(variable_text) 120 | string = string.replace(variable_text,'') 121 | else: 122 | break 123 | return variable_list,string 124 | 125 | 126 | def extract_function(string): 127 | i = 0 128 | function_list = [] 129 | while True: 130 | match_ret = re.search('(protected|private|public).+\s*{', string) 131 | if match_ret: 132 | function_head = match_ret.group() 133 | start_pos = string.find(function_head) 134 | function_text = extract_brace(string, start_pos) 135 | function_list.append(function_text) 136 | string = string.replace(function_text, 'vesal'+ str(i)) 137 | i+=1 138 | else: 139 | break 140 | return function_list, string 141 | 142 | 143 | def extract_for_loop(string): 144 | 145 | for_list = [] 146 | while True: 147 | match_ret = re.search('for\s+\(', string) 148 | print(match_ret) 149 | if match_ret: 150 | for_head = match_ret.group() 151 | start_pos = string.find(for_head) 152 | for_text = extract_brace(string, start_pos) 153 | for_list.append(for_text) 154 | string = string.replace(for_text, '') 155 | else: 156 | break 157 | return for_list 158 | 159 | 160 | def extract_if(string): 161 | 162 | if_list = [] 163 | while True: 164 | match_ret = re.search('if\s+\(', string) 165 | if match_ret: 166 | if_head = match_ret.group() 167 | start_pos = string.find(if_head) 168 | if_text = extract_brace(string, start_pos) 169 | if_list.append(if_text) 170 | string = string.replace(if_text, '') 171 | else: 172 | break 173 | return if_list 174 | 175 | 176 | def extract_while_loop(string): 177 | 178 | while_list = [] 179 | while True: 180 | match_ret = re.search('while\s+\(', string) 181 | if match_ret: 182 | while_head = match_ret.group() 183 | start_pos = string.find(while_head) 184 | while_text = extract_brace(string, start_pos) 185 | while_list.append(while_text) 186 | string = string.replace(while_text, '') 187 | else: 188 | break 189 | return while_list, string 190 | 191 | 192 | def extract_local_variable(string): 193 | 194 | local_var_list = [] 195 | statement_list = string.split('\n') 196 | for line in statement_list: 197 | match_ret = re.search('[^\s]+\s+\w+\s+=', line) 198 | if match_ret: 199 | var_definition = match_ret.group() 200 | local_var_list.append(var_definition.split(' ')[1]) 201 | 202 | return local_var_list 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /Tool/Java_refactor/refactoring_methods.py: -------------------------------------------------------------------------------- 1 | import os, random, re 2 | 3 | from processing_source_code import * 4 | 5 | 6 | def rename_local_variable(method_string): 7 | local_var_list = extract_local_variable(method_string) 8 | if len(local_var_list) == 0: 9 | return method_string 10 | 11 | mutation_index = random.randint(0, len(local_var_list) - 1) 12 | return method_string.replace(local_var_list[mutation_index], word_synonym_replacement(local_var_list[mutation_index])[0]) 13 | 14 | 15 | def add_local_variable(method_string): 16 | local_var_list = extract_local_variable(method_string) 17 | if len(local_var_list) == 0: 18 | return method_string 19 | 20 | mutation_index = random.randint(0, len(local_var_list) - 1) 21 | match_ret = re.search('.+' + local_var_list[mutation_index] + '.+;', method_string) 22 | if match_ret: 23 | var_definition = match_ret.group() 24 | new_var_definition = var_definition.replace(local_var_list[mutation_index], word_synonym_replacement(local_var_list[mutation_index])[0]) 25 | method_string = method_string.replace(var_definition, var_definition + '\n' + new_var_definition) 26 | return method_string 27 | else: 28 | return method_string 29 | 30 | 31 | def duplication(method_string): 32 | local_var_list = extract_local_variable(method_string) 33 | if len(local_var_list) == 0: 34 | return method_string 35 | mutation_index = random.randint(0, len(local_var_list) - 1) 36 | match_ret = re.search('.+' + local_var_list[mutation_index] + '.+;', method_string) 37 | if match_ret: 38 | var_definition = match_ret.group() 39 | new_var_definition = var_definition 40 | method_string = method_string.replace(var_definition, var_definition + '\n' + new_var_definition) 41 | # print(method_string) 42 | return method_string 43 | else: 44 | # print(method_string) 45 | return method_string 46 | 47 | 48 | def rename_api(method_string): 49 | match_ret = re.findall('\.\s*\w+\s*\(', method_string) 50 | if match_ret != []: 51 | api_name = random.choice(match_ret)[1:-1] 52 | return method_string.replace(api_name,word_synonym_replacement(api_name)[0]) 53 | else: 54 | return method_string 55 | 56 | 57 | def rename_method_name(method_string): 58 | method_name = extract_method_name(method_string) 59 | if method_name: 60 | return method_string.replace(method_name, word_synonym_replacement(method_name)[0]) 61 | else: 62 | return method_string 63 | 64 | 65 | def rename_argument(method_string): 66 | arguments_list = extract_argument(method_string) 67 | if len(arguments_list) == 0: 68 | return method_string 69 | 70 | mutation_index = random.randint(0, len(arguments_list) - 1) 71 | # print(method_string.replace(arguments_list[mutation_index],word_synonym_replacement(arguments_list[mutation_index])[0])) 72 | return method_string.replace(arguments_list[mutation_index],word_synonym_replacement(arguments_list[mutation_index])[0]) 73 | 74 | 75 | def return_optimal(method_string): 76 | if 'return ' in method_string: 77 | return_statement = method_string[method_string.find('return ') : method_string.find(';', method_string.find('return ') + 1)] 78 | return_object = return_statement.replace('return ','') 79 | if return_object == 'null': 80 | return method_string 81 | optimal_statement = 'if (' + return_object + ' == null){\n\t\t\treturn 0;\n\t\t}\n' + return_statement 82 | method_string = method_string.replace(return_statement, optimal_statement) 83 | return method_string 84 | 85 | 86 | def enhance_for_loop(method_string): 87 | for_loop_list = extract_for_loop(method_string) 88 | if for_loop_list == []: 89 | return method_string 90 | mutation_index = random.randint(0, len(for_loop_list) - 1) 91 | for_text = for_loop_list[mutation_index] 92 | for_info = for_text[for_text.find('(') + 1 : for_text.find(')')] 93 | for_body = for_text[for_text.find('{') + 1 : for_text.rfind('}',-1,10)] 94 | if ':' in for_info: 95 | loop_bar = for_info.split(':')[-1].strip() 96 | loop_var = for_info.split(':')[0].strip().split(' ')[-1].strip() 97 | if loop_bar == None or loop_bar == '' or loop_var == None or loop_var == '': 98 | return method_string 99 | new_for_info = 'int i = 0; i < ' + loop_bar + '.size(); i ++' 100 | method_string = method_string.replace(for_info, new_for_info) 101 | method_string = method_string.replace(for_body,for_body.replace(loop_var, loop_bar + '.get(i)')) 102 | 103 | return method_string 104 | 105 | else: 106 | return method_string 107 | 108 | 109 | def add_print(method_string): 110 | statement_list = method_string.split(';') 111 | mutation_index = random.randint(1, len(statement_list) - 1) 112 | statement = statement_list[mutation_index] 113 | new_statement = '\t' + 'System.out.println("' + str(random.choice(word_synonym_replacement(statement)[1])) + '");' 114 | method_string = method_string.replace(statement, '\n' + new_statement + '\n' + statement) 115 | return method_string 116 | 117 | 118 | def enhance_if(method_string): 119 | if_list = extract_if(method_string) 120 | mutation_index = random.randint(0, len(if_list) - 1) 121 | if_text = if_list[mutation_index] 122 | if_info = if_text[if_text.find('(') + 1 :if_text.find('{')][:if_text.rfind(')',-1,5) -1] 123 | new_if_info = if_info 124 | if 'true' in if_info: 125 | new_if_info = if_info.replace('true','(0==0)') 126 | if 'flase' in if_info: 127 | new_if_info = if_info.replace('flase','(1==0)') 128 | if '!' in if_info and '!=' not in if_info and '(' not in if_info and '&&' not in if_info and '||' not in if_info: 129 | new_if_info = if_info.replace('!', 'flase == ') 130 | if '<' in if_info and '<=' not in if_info and '(' not in if_info and '&&' not in if_info and '||' not in if_info: 131 | new_if_info = if_info.split('<')[1] + ' > ' + if_info.split('<')[0] 132 | if '>' in if_info and '>=' not in if_info and '(' not in if_info and '&&' not in if_info and '||' not in if_info: 133 | new_if_info = if_info.split('>')[1] + ' < ' + if_info.split('>')[0] 134 | if '<=' in if_info and '(' not in if_info and '&&' not in if_info and '||' not in if_info: 135 | new_if_info = if_info.split('<=')[1] + ' >= ' + if_info.split('<=')[0] 136 | if '>=' in if_info and '(' not in if_info and '&&' not in if_info and '||' not in if_info: 137 | new_if_info = if_info.split('>=')[1] + ' <= ' + if_info.split('>=')[0] 138 | if '.equals(' in if_info: 139 | new_if_info = if_info.replace('.equals', '==') 140 | 141 | return method_string.replace(if_info,new_if_info) 142 | 143 | 144 | def add_argumemts(method_string): 145 | arguments_list = extract_argument(method_string) 146 | arguments_info = method_string[method_string.find('(') : method_string.find('{')] 147 | if len(arguments_list) == 0: 148 | arguments_info = 'String ' + word_synonym_replacement(extract_method_name(method_string))[0] 149 | return method_string[0 : method_string.find('()') + 1] + arguments_info + method_string[method_string.find('()') + 1 :] 150 | mutation_index = random.randint(0, len(arguments_list) - 1) 151 | org_argument = arguments_list[mutation_index] 152 | new_argument = word_synonym_replacement(arguments_list[mutation_index])[0] 153 | new_arguments_info = arguments_info.replace(org_argument,org_argument + ', ' + new_argument) 154 | method_string = method_string.replace(arguments_info,new_arguments_info) 155 | return method_string 156 | 157 | 158 | def enhance_filed(method_string): 159 | arguments_list = extract_argument(method_string) 160 | if len(arguments_list) == 0: 161 | return method_string 162 | mutation_index = random.randint(0, len(arguments_list) - 1) 163 | extra_info = "\n\tif (" + arguments_list[mutation_index].strip().split(' ')[-1] + " == null){\n\t\tSystem.out.println('please check your input');\n\t}" 164 | method_string = method_string[0 : method_string.find(';') + 1] + extra_info + method_string[method_string.find(';') + 1 : ] 165 | return method_string 166 | 167 | 168 | def apply_plus_zero_math(data): 169 | 170 | statement_list = data.split(';') 171 | mutation_index = random.randint(1, len(statement_list) - 1) 172 | statement = statement_list[mutation_index] 173 | 174 | tree = get_tree(data) 175 | var_list = get_local_vars(tree) 176 | var_list = [var for var, var_type in var_list if var_type in ( 177 | "int", "float", "double", "long")] 178 | if var_list==[]: 179 | return "" 180 | for var in var_list: 181 | mutant = ' ' + str(var) + ' = ' + str(var) + ' + ' + str(0) + ";" 182 | 183 | for idx, _ in enumerate(statement_list): 184 | if var in _ and idx < len(statement_list) - 1: 185 | insertion_index = idx + 1 186 | method_string = data.replace(statement, '\n' + mutant + '\n' + statement) 187 | return method_string 188 | 189 | 190 | 191 | 192 | def dead_branch_if_else(data): 193 | statement_list = data.split(';') 194 | mutation_index = random.randint(1, len(statement_list) - 1) 195 | statement = statement_list[mutation_index] 196 | new_statement = get_branch_if_else_mutant() 197 | method_string = data.replace(statement, '\n' + new_statement + '\n' + statement) 198 | # print(method_string) 199 | return method_string 200 | # return data 201 | 202 | 203 | def dead_branch_if(data): 204 | 205 | statement_list = data.split(';') 206 | mutation_index = random.randint(1, len(statement_list) - 1) 207 | statement = statement_list[mutation_index] 208 | new_statement = get_branch_if_mutant() 209 | method_string = data.replace(statement, '\n' + new_statement + '\n' + statement) 210 | return method_string 211 | 212 | 213 | def dead_branch_while(data): 214 | 215 | statement_list = data.split(';') 216 | mutation_index = random.randint(1, len(statement_list) - 1) 217 | statement = statement_list[mutation_index] 218 | new_statement = get_branch_while_mutant() 219 | method_string = data.replace(statement, '\n' + new_statement + '\n' + statement) 220 | return method_string 221 | 222 | 223 | def dead_branch_for(data): 224 | 225 | statement_list = data.split(';') 226 | mutation_index = random.randint(1, len(statement_list) - 1) 227 | statement = statement_list[mutation_index] 228 | new_statement = get_branch_for_mutant() 229 | method_string = data.replace(statement, '\n' + new_statement + '\n' + statement) 230 | return method_string 231 | 232 | 233 | def dead_branch_switch(data): 234 | 235 | statement_list = data.split(';') 236 | mutation_index = random.randint(1, len(statement_list) - 1) 237 | statement = statement_list[mutation_index] 238 | new_statement = get_branch_switch_mutant() 239 | method_string = data.replace(statement, '\n' + new_statement + '\n' + statement) 240 | return method_string 241 | 242 | 243 | 244 | if __name__ == '__main__': 245 | filename = '**.java' 246 | open_file = open(filename, 'r', encoding='ISO-8859-1') 247 | code = open_file.read() 248 | Class_list, raw_code = extract_class(code) 249 | for class_name in Class_list: 250 | function_list, class_name = extract_function(class_name) 251 | candidate_code = function_list[0] 252 | print(candidate_code) 253 | new_code = enhance_for_loop(candidate_code) 254 | print(new_code) 255 | 256 | -------------------------------------------------------------------------------- /Tool/Java_refactor/util.py: -------------------------------------------------------------------------------- 1 | import javalang 2 | import secrets 3 | import random 4 | import json 5 | 6 | from os import listdir 7 | from os.path import isfile, join 8 | 9 | 10 | def get_radom_var_name(): 11 | res_string = '' 12 | for x in range(8): 13 | res_string += random.choice('abcdefghijklmnopqrstuvwxyz') 14 | return res_string 15 | 16 | 17 | def get_dead_for_condition(): 18 | var = get_radom_var_name() 19 | return "int "+var+" = 0; "+var+" < 0; "+var+"++" 20 | 21 | 22 | def get_random_false_stmt(): 23 | res = [random.choice(["true", "false"]) for x in range(10)] 24 | res.append("false") 25 | res_str = " && ".join(res) 26 | return res_str 27 | 28 | 29 | def get_tree(data): 30 | tokens = javalang.tokenizer.tokenize(data) 31 | parser = javalang.parser.Parser(tokens) 32 | tree = parser.parse_member_declaration() 33 | return tree 34 | 35 | 36 | def verify_method_syntax(data): 37 | try: 38 | tokens = javalang.tokenizer.tokenize(data) 39 | parser = javalang.parser.Parser(tokens) 40 | tree = parser.parse_member_declaration() 41 | print("syantax check passed") 42 | except: 43 | print("syantax check failed") 44 | 45 | 46 | def get_random_type_name_and_value_statment(): 47 | datatype = random.choice( 48 | 'byte,short,int,long,float,double,boolean,char,String'.split(',')) 49 | var_name = get_radom_var_name() 50 | 51 | if datatype == "byte": 52 | var_value = get_random_int(-128, 127) 53 | elif datatype == "short": 54 | var_value = get_random_int(-10000, 10000) 55 | elif datatype == "boolean": 56 | var_value = random.choice(["true", "false"]) 57 | elif datatype == "char": 58 | var_value = str(random.choice( 59 | 'a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z'.split(','))) 60 | var_value = '"'+var_value+'"' 61 | elif datatype == "String": 62 | var_value = str(get_radom_var_name()) 63 | var_value = '"'+var_value+'"' 64 | else: 65 | var_value = get_random_int(-1000000000, 1000000000) 66 | 67 | mutant = str(datatype) + ' ' + str(var_name) + ' = ' + str(var_value)+";" 68 | return mutant 69 | 70 | 71 | def generate_file_name_list_file_from_dir(method_path): 72 | filenames = [f for f in listdir( 73 | method_path) if isfile(join(method_path, f))] 74 | with open(method_path+'\\'+'all_file_names.txt', 'w') as f: 75 | f.write(json.dumps(filenames)) 76 | print("done") 77 | 78 | 79 | def get_file_name_list(method_path): 80 | with open(method_path+'\\'+'all_file_names.txt') as f: 81 | data = json.load(f) 82 | return data 83 | 84 | 85 | def get_random_int(min, max): 86 | return random.randint(min, max) 87 | 88 | 89 | def format_code_chuncks(code_chuncks): 90 | for idx, c in enumerate(code_chuncks): 91 | c = c.replace(' . ', '.') 92 | c = c.replace(' ( ', '(') 93 | c = c.replace(' ) ', ')') 94 | c = c.replace(' ;', ';') 95 | c = c.replace('[ ]', '[]') 96 | code_chuncks[idx] = c 97 | return code_chuncks 98 | 99 | 100 | def format_code(c): 101 | c = c.replace(' . ', '.') 102 | c = c.replace(' ( ', '(') 103 | c = c.replace(' ) ', ')') 104 | c = c.replace(' ;', ';') 105 | c = c.replace('[ ]', '[]') 106 | return c 107 | 108 | 109 | def get_method_header(string): 110 | method_header = '' 111 | tree = get_tree(string) 112 | # print("tree") 113 | 114 | tokens = list(javalang.tokenizer.tokenize(string)) 115 | # print(tokens) 116 | chunck_start_poss = [s.position.column for s in tree.body] 117 | # print(chunck_start_poss) 118 | if len(chunck_start_poss) > 0: 119 | method_header = ' '.join([t.value for t in tokens 120 | if t.position.column < chunck_start_poss[0]]) 121 | 122 | method_header = format_code_chuncks([method_header])[0] 123 | return method_header 124 | 125 | 126 | def get_method_statement(string): 127 | code_chuncks = [] 128 | tree = get_tree(string) 129 | tokens = list(javalang.tokenizer.tokenize(string)) 130 | chunck_start_poss = [s.position.column for s in tree.body] 131 | 132 | if len(chunck_start_poss) > 1: 133 | 134 | for idx, statement in enumerate(chunck_start_poss[:-1]): 135 | statment = ' '.join([t.value for t in tokens 136 | if t.position.column >= chunck_start_poss[idx] 137 | and t.position.column < chunck_start_poss[idx+1]]) 138 | code_chuncks.append(statment) 139 | last_statment = ' '.join([t.value for t in tokens 140 | if t.position.column >= chunck_start_poss[-1]][:-1]) 141 | code_chuncks.append(last_statment) 142 | 143 | if len(chunck_start_poss) == 1: 144 | last_statment = ' '.join([t.value for t in tokens 145 | if t.position.column >= chunck_start_poss[0]][:-1]) 146 | code_chuncks.append(last_statment) 147 | code_chuncks = format_code_chuncks(code_chuncks) 148 | return code_chuncks 149 | 150 | 151 | def scan_tree(tree): 152 | for path, node in tree: 153 | print("=======================") 154 | print(node) 155 | 156 | 157 | def get_all_type(tree): 158 | res_list=[] 159 | for path, node in tree.filter(javalang.tree.ReferenceType): 160 | if node.name != None: 161 | res_list.append(node.name) 162 | return list(set(res_list)) 163 | 164 | 165 | def scan_local_vars(tree): 166 | for path, node in tree.filter(javalang.tree.LocalVariableDeclaration): 167 | print("name=========type=============") 168 | print(node.declarators[0].name, "\t", node.type.name) 169 | 170 | 171 | def get_local_vars(tree): 172 | var_list = [] 173 | for path, node in tree.filter(javalang.tree.LocalVariableDeclaration): 174 | var_list.append([node.declarators[0].name, node.type.name]) 175 | 176 | return var_list 177 | 178 | 179 | def get_local_assignments(tree): 180 | var_list = [] 181 | for path, node in tree.filter(javalang.tree.Assignment): 182 | var_list.append([node.declarators[0].name, node.type.name]) 183 | return var_list 184 | 185 | 186 | def get_branch_if_else_mutant(): 187 | mutant = 'if ('+get_random_false_stmt()+') {' + \ 188 | get_random_type_name_and_value_statment() + \ 189 | '}' + \ 190 | 'else{' + \ 191 | get_random_type_name_and_value_statment() + \ 192 | '}' 193 | return mutant 194 | 195 | 196 | def get_branch_if_mutant(): 197 | mutant = 'if ('+get_random_false_stmt()+') {' + \ 198 | get_random_type_name_and_value_statment() + \ 199 | '}' 200 | return mutant 201 | 202 | 203 | def get_branch_while_mutant(): 204 | mutant = 'while ('+get_random_false_stmt()+') {' + \ 205 | get_random_type_name_and_value_statment() + \ 206 | '}' 207 | return mutant 208 | 209 | 210 | def get_branch_for_mutant(): 211 | dead_for_condition = get_dead_for_condition() 212 | mutant = 'for ('+dead_for_condition+') {' + \ 213 | get_random_type_name_and_value_statment() + \ 214 | '}' 215 | return mutant 216 | 217 | def get_branch_switch_mutant(): 218 | var_name = get_radom_var_name() 219 | mutant = 'int ' + var_name+' = 0;' +\ 220 | 'switch ('+var_name+') {' + \ 221 | 'case 1:' + \ 222 | get_random_type_name_and_value_statment() + \ 223 | 'break;' +\ 224 | 'default:' + \ 225 | get_random_type_name_and_value_statment() + \ 226 | 'break;' +\ 227 | '}' 228 | return mutant 229 | -------------------------------------------------------------------------------- /Tool/Python_refactor/generate_refactoring.py: -------------------------------------------------------------------------------- 1 | import os, random 2 | from shutil import copyfile 3 | from refactoring_methods import * 4 | 5 | 6 | def return_function_code(code, method_names): 7 | final_codes = [] 8 | final_names = [] 9 | Class_list, raw_code = extract_class(code) 10 | for class_name in Class_list: 11 | function_list, class_name = extract_function(class_name) 12 | for fun_code in function_list: 13 | for method_name in method_names: 14 | method_name_tem = method_name.replace('|', '') 15 | if method_name_tem.upper() in fun_code.split('\n')[0].upper(): 16 | 17 | final_codes.append(fun_code) 18 | final_names.append(method_name) 19 | return final_codes, final_names 20 | 21 | 22 | def generate_adversarial(k, code, method_names): 23 | method_name = method_names[0] 24 | function_list = [] 25 | class_name = '' 26 | Class_list, raw_code = extract_class(code) 27 | for class_name in Class_list: 28 | function_list, class_name = extract_function(class_name) 29 | 30 | refac = [] 31 | new_refactored_code = '' 32 | for code in function_list: 33 | if method_name not in code.split('\n')[0]: 34 | continue 35 | new_rf = code 36 | new_refactored_code = code 37 | for t in range(k): 38 | refactors_list = [rename_argument, 39 | return_optimal, 40 | add_argumemts, 41 | rename_api, 42 | rename_local_variable, 43 | add_local_variable, 44 | rename_method_name, 45 | enhance_if, 46 | add_print, 47 | duplication, 48 | apply_plus_zero_math, 49 | dead_branch_if_else, 50 | dead_branch_if, 51 | dead_branch_while, 52 | dead_branch_for, 53 | # dead_branch_switch 54 | ]# 55 | vv = 0 56 | while new_rf == new_refactored_code and vv <= 20: 57 | try: 58 | vv += 1 59 | refactor = random.choice(refactors_list) 60 | print('*'*50 , refactor , '*'*50) 61 | new_refactored_code = refactor(new_refactored_code) 62 | 63 | except Exception as error: 64 | print('error:\t', error) 65 | 66 | new_rf = new_refactored_code 67 | print('----------------------------OUT of WHILE----------------------------------', vv) 68 | print('----------------------------CHANGED THJIS TIME:----------------------------------', vv) 69 | refac.append(new_refactored_code) 70 | code_body = raw_code.strip() + ' ' + class_name.strip() 71 | for i in range(len(refac)): 72 | final_refactor = code_body.replace('vesal' + str(i), str(refac[i])) 73 | code_body = final_refactor 74 | return new_refactored_code 75 | 76 | 77 | def generate_adversarial_json(k, code): 78 | final_refactor = '' 79 | function_list = [] 80 | class_name = '' 81 | vv = 0 82 | if len(function_list) == 0: 83 | function_list.append(code) 84 | refac = [] 85 | for code in function_list: 86 | new_rf = code 87 | new_refactored_code = code 88 | for t in range(k): 89 | refactors_list = [rename_argument, 90 | return_optimal, 91 | add_argumemts, 92 | rename_api, 93 | rename_local_variable, 94 | add_local_variable, 95 | rename_method_name, 96 | enhance_if, 97 | add_print, 98 | duplication, 99 | apply_plus_zero_math, 100 | dead_branch_if_else, 101 | dead_branch_if, 102 | dead_branch_while, 103 | dead_branch_for, 104 | # dead_branch_switch 105 | ] 106 | vv = 0 107 | while new_rf == new_refactored_code and vv <= 20: 108 | try: 109 | vv += 1 110 | refactor = random.choice(refactors_list) 111 | print('*' * 50, refactor, '*' * 50) 112 | new_refactored_code = refactor(new_refactored_code) 113 | 114 | except Exception as error: 115 | print('error:\t', error) 116 | 117 | new_rf = new_refactored_code 118 | refac.append(new_refactored_code) 119 | 120 | print("refactoring finished") 121 | return refac 122 | 123 | 124 | def generate_adversarial_file_level(k, code): 125 | new_refactored_code = '' 126 | new_rf = code 127 | new_refactored_code = code 128 | for t in range(k): 129 | refactors_list = [ 130 | rename_argument, 131 | return_optimal, 132 | add_argumemts, 133 | rename_api, 134 | rename_local_variable, 135 | add_local_variable, 136 | rename_method_name, 137 | enhance_if, 138 | add_print, 139 | duplication, 140 | apply_plus_zero_math, 141 | dead_branch_if_else, 142 | dead_branch_if, 143 | dead_branch_while, 144 | dead_branch_for 145 | ] 146 | vv = 0 147 | while new_rf == new_refactored_code and vv <= 20: 148 | try: 149 | vv += 1 150 | refactor = random.choice(refactors_list) 151 | print('*' * 50, refactor, '*' * 50) 152 | new_refactored_code = refactor(new_refactored_code) 153 | except Exception as error: 154 | print('error:\t', error) 155 | new_rf = new_refactored_code 156 | return new_refactored_code 157 | 158 | 159 | if __name__ == '__main__': 160 | K = 1 161 | filename = '**.py' 162 | open_file = open(filename, 'r', encoding='ISO-8859-1') 163 | code = open_file.read() 164 | new_code = generate_adversarial_file_level(K, code) 165 | print(new_code) 166 | -------------------------------------------------------------------------------- /Tool/Python_refactor/processing_source_code.py: -------------------------------------------------------------------------------- 1 | import re, random 2 | from nltk.corpus import wordnet 3 | import wordninja 4 | from util import * 5 | import ast 6 | 7 | 8 | def word_synonym_replacement(word): 9 | if len(word) <= 3: 10 | return word + '_new' 11 | word_set = wordninja.split(word) 12 | while True: 13 | if word_set == []: 14 | return word + '_new' 15 | word_tar = random.choice(word_set) 16 | word_syn = wordnet.synsets(word_tar) 17 | if word_syn == []: 18 | word_set.remove(word_tar) 19 | else: 20 | break 21 | word_ret = [] 22 | for syn in word_syn: 23 | word_ret = word_ret + syn.lemma_names() 24 | if word_tar in word_ret: 25 | word_ret.remove(word_tar) 26 | try: 27 | word_new = random.choice(word_ret) 28 | except: 29 | word_new = word 30 | 31 | return word.replace(word_tar, word_new), word_ret 32 | 33 | 34 | def extract_method_name(string): 35 | match_ret = re.search('\w+\s*\(',string) 36 | if match_ret: 37 | method_name = match_ret.group()[:-1].strip() 38 | return method_name 39 | else: 40 | return None 41 | 42 | 43 | def extract_argument(string): 44 | end_pos = string.find(')') 45 | sta_pas = string.find('(') 46 | arguments = string[sta_pas + 1: end_pos + 1].strip()[:-1] 47 | arguments_list = arguments.split(',') 48 | if ' ' in arguments_list: 49 | arguments_list.remove(' ') 50 | if '' in arguments_list: 51 | arguments_list.remove('') 52 | return arguments_list 53 | 54 | 55 | def extract_brace_python(string, start_pos): 56 | fragment = string[start_pos:] 57 | line_list = fragment.split('\n') 58 | return_string = '' 59 | return_string += line_list[0] + '\n' 60 | space_min = 0 61 | for _ in range(1, len(line_list)): 62 | space_count = 0 63 | for char in line_list[_]: 64 | if char == ' ': 65 | space_count += 1 66 | else: 67 | break 68 | if _ == 1: 69 | space_min = space_count 70 | return_string += line_list[_] + '\n' 71 | elif space_count < space_min and space_count != len(line_list[_]): 72 | break 73 | else: 74 | return_string += line_list[_] + '\n' 75 | return_string = return_string[:-1] 76 | return return_string 77 | 78 | 79 | def extract_class(string): 80 | 81 | class_list = [] 82 | while ' class ' in string: 83 | start_pos = string.find(' class ') 84 | class_text = extract_brace_python(string, start_pos) 85 | class_list.append(class_text) 86 | string = string.replace(class_text, '') 87 | 88 | while 'class ' in string: 89 | start_pos = string.find('class ') 90 | class_text = extract_brace_python(string, start_pos) 91 | class_list.append(class_text) 92 | string = string.replace(class_text, '') 93 | 94 | return class_list, string 95 | 96 | 97 | def extract_function_python(string): 98 | i = 0 99 | function_list = [] 100 | # print(string) 101 | while True: 102 | match_ret = re.search('(def).+\s*\(', string) 103 | # print(match_ret) 104 | if match_ret: 105 | function_head = match_ret.group() 106 | start_pos = string.find(function_head) 107 | function_text = extract_brace_python(string, start_pos) 108 | function_list.append(function_text) 109 | string = string.replace(function_text, 'vesal'+ str(i)) 110 | i+=1 111 | else: 112 | break 113 | return function_list, string 114 | 115 | 116 | def extract_for_loop(string): 117 | 118 | for_list = [] 119 | while True: 120 | # match_ret = re.search('for\s+\(', string) 121 | match_ret = re.search(' for ', string) 122 | if match_ret: 123 | for_head = match_ret.group() 124 | start_pos = string.find(for_head) 125 | for_text = extract_brace_python(string, start_pos) 126 | for_list.append(for_text) 127 | string = string.replace(for_text, '') 128 | else: 129 | break 130 | return for_list 131 | 132 | 133 | def extract_if(string): 134 | 135 | if_list = [] 136 | while True: 137 | match_ret = re.search(' if ', string) 138 | if match_ret: 139 | if_head = match_ret.group() 140 | start_pos = string.find(if_head) 141 | if_text = extract_brace_python(string, start_pos) 142 | if_list.append(if_text) 143 | string = string.replace(if_text, '') 144 | else: 145 | break 146 | return if_list 147 | 148 | 149 | def extract_while_loop(string): 150 | 151 | while_list = [] 152 | while True: 153 | match_ret = re.search(' while ', string) 154 | if match_ret: 155 | while_head = match_ret.group() 156 | start_pos = string.find(while_head) 157 | while_text = extract_brace_python(string, start_pos) 158 | while_list.append(while_text) 159 | string = string.replace(while_text, '') 160 | else: 161 | break 162 | return while_list, string 163 | 164 | 165 | def hack(source): 166 | root = ast.parse(source) 167 | 168 | for node in ast.walk(root): 169 | if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Store): 170 | yield node.id 171 | elif isinstance(node, ast.Attribute): 172 | yield node.attr 173 | elif isinstance(node, ast.FunctionDef): 174 | yield node.name 175 | 176 | 177 | def extract_local_variable(string): 178 | return list(hack(string)) 179 | 180 | 181 | if __name__ == '__main__': 182 | filename = 'test.py' 183 | open_file = open(filename, 'r', encoding='ISO-8859-1') 184 | code = open_file.read() 185 | Class_list, raw_code = extract_class(code) 186 | print(Class_list) 187 | for class_name in Class_list: 188 | function_list, class_name = extract_function_python(class_name) 189 | print(function_list) 190 | print(extract_local_variable(function_list[0])) 191 | -------------------------------------------------------------------------------- /Tool/Python_refactor/refactoring_methods.py: -------------------------------------------------------------------------------- 1 | import os, random, re 2 | 3 | from processing_source_code import * 4 | 5 | 6 | def rename_local_variable(method_string): 7 | local_var_list = extract_local_variable(method_string) 8 | if len(local_var_list) == 0: 9 | return method_string 10 | 11 | mutation_index = random.randint(0, len(local_var_list) - 1) 12 | return method_string.replace(local_var_list[mutation_index], word_synonym_replacement(local_var_list[mutation_index])[0]) 13 | 14 | 15 | def add_local_variable(method_string): 16 | local_var_list = extract_local_variable(method_string) 17 | if len(local_var_list) == 0: 18 | return method_string 19 | 20 | mutation_index = random.randint(0, len(local_var_list) - 1) 21 | match_ret = re.search(local_var_list[mutation_index] + '=\w', method_string) 22 | if match_ret is None: 23 | match_ret = re.search(local_var_list[mutation_index] + ' = ', method_string) 24 | if match_ret is None: 25 | match_ret = re.search(local_var_list[mutation_index] + '= ', method_string) 26 | if match_ret: 27 | var_definition = match_ret.group()[:-1] 28 | new_var_definition = var_definition.replace(local_var_list[mutation_index], word_synonym_replacement(local_var_list[mutation_index])[0]) 29 | method_string = method_string.replace(var_definition, var_definition + '' + new_var_definition) 30 | return method_string 31 | else: 32 | return method_string 33 | 34 | 35 | def duplication(method_string): 36 | local_var_list = extract_local_variable(method_string) 37 | if len(local_var_list) == 0: 38 | return method_string 39 | mutation_index = random.randint(0, len(local_var_list) - 1) 40 | match_ret = re.search(local_var_list[mutation_index] + '=\w', method_string) 41 | if match_ret is None: 42 | match_ret = re.search(local_var_list[mutation_index] + ' = ', method_string) 43 | if match_ret is None: 44 | match_ret = re.search(local_var_list[mutation_index] + '= ', method_string) 45 | if match_ret: 46 | var_definition = match_ret.group()[:-1] 47 | new_var_definition = var_definition 48 | method_string = method_string.replace(var_definition, var_definition + new_var_definition) 49 | # print(method_string) 50 | return method_string 51 | else: 52 | # print(method_string) 53 | return method_string 54 | 55 | 56 | def rename_api(method_string): 57 | match_ret = re.findall(' \s*\w+\s*\(', method_string) 58 | match_ret = match_ret[1:] 59 | if match_ret != []: 60 | api_name = random.choice(match_ret)[1:-1] 61 | return method_string.replace(api_name, word_synonym_replacement(api_name)[0]) 62 | else: 63 | return method_string 64 | 65 | 66 | def rename_method_name(method_string): 67 | method_name = extract_method_name(method_string) 68 | if method_name: 69 | return method_string.replace(method_name, word_synonym_replacement(method_name)[0]) 70 | else: 71 | return method_string 72 | 73 | 74 | def rename_argument(method_string): 75 | arguments_list = extract_argument(method_string) 76 | if len(arguments_list) == 0: 77 | return method_string 78 | 79 | mutation_index = random.randint(0, len(arguments_list) - 1) 80 | return method_string.replace(arguments_list[mutation_index], word_synonym_replacement(arguments_list[mutation_index])) 81 | 82 | 83 | def return_optimal(method_string): 84 | if 'return ' in method_string: 85 | return_statement = method_string[method_string.find('return ') : method_string.find('\n', method_string.find('return ') + 1)] 86 | return_object = return_statement.replace('return ', '') 87 | if return_object == 'null': 88 | return method_string 89 | optimal_statement = 'return 0 if (' + return_object + ' == None) else ' + return_object 90 | method_string = method_string.replace(return_statement, optimal_statement) 91 | return method_string 92 | 93 | 94 | def enhance_for_loop(method_string): 95 | for_loop_list = extract_for_loop(method_string) 96 | if for_loop_list == []: 97 | return method_string 98 | 99 | mutation_index = random.randint(0, len(for_loop_list) - 1) 100 | for_text = for_loop_list[mutation_index] 101 | for_info = for_text[for_text.find('(') + 1 : for_text.find(')')] 102 | if ' range(' in for_text: 103 | if ',' not in for_info: 104 | new_for_info = '0, ' + for_info 105 | method_string = method_string.replace(for_info, new_for_info) 106 | elif len(for_info.split(',')) == 2: 107 | new_for_info = for_info + ' ,1' 108 | method_string = method_string.replace(for_info, new_for_info) 109 | else: 110 | new_for_info = for_info + '+0' 111 | method_string = method_string.replace(for_info, new_for_info) 112 | return method_string 113 | 114 | else: 115 | return method_string 116 | 117 | 118 | def add_print(method_string): 119 | statement_list = method_string.split('\n') 120 | mutation_index = random.randint(1, len(statement_list) - 1) 121 | statement = statement_list[mutation_index] 122 | if statement == '': 123 | return method_string 124 | space_count = 0 125 | if mutation_index == len(statement_list) - 1: 126 | refer_line = statement_list[-1] 127 | for char in refer_line: 128 | if char == ' ': 129 | space_count += 1 130 | else: 131 | break 132 | else: 133 | refer_line = statement_list[mutation_index] 134 | for char in refer_line: 135 | if char == ' ': 136 | space_count += 1 137 | else: 138 | break 139 | new_statement = '' 140 | for _ in range(space_count): 141 | new_statement += ' ' 142 | new_statement += 'print("' + str(random.choice(word_synonym_replacement(statement)[1])) + '")' 143 | method_string = method_string.replace(statement, '\n' + new_statement + '\n' + statement) 144 | return method_string 145 | 146 | 147 | def enhance_if(method_string): 148 | if_list = extract_if(method_string) 149 | mutation_index = random.randint(0, len(if_list) - 1) 150 | if_text = if_list[mutation_index] 151 | if_info = if_text[if_text.find('if ') + 3: if_text.find(':')] 152 | new_if_info = if_info 153 | if 'true' in if_info: 154 | new_if_info = if_info.replace('true', ' (0==0) ') 155 | if 'flase' in if_info: 156 | new_if_info = if_info.replace('flase', ' (1==0) ') 157 | if '!=' in if_info and '(' not in if_info and 'and' not in if_info and 'or' not in if_info: 158 | new_if_info = if_info.replace('!=', ' is not ') 159 | if '<' in if_info and '<=' not in if_info and '(' not in if_info and 'and' not in if_info and 'or' not in if_info: 160 | new_if_info = if_info.split('<')[1] + ' > ' + if_info.split('<')[0] 161 | if '>' in if_info and '>=' not in if_info and '(' not in if_info and 'and' not in if_info and 'or' not in if_info: 162 | new_if_info = if_info.split('>')[1] + ' < ' + if_info.split('>')[0] 163 | if '<=' in if_info and '(' not in if_info and 'and' not in if_info and 'or' not in if_info: 164 | new_if_info = if_info.split('<=')[1] + ' >= ' + if_info.split('<=')[0] 165 | if '>=' in if_info and '(' not in if_info and 'and' not in if_info and 'or' not in if_info: 166 | new_if_info = if_info.split('>=')[1] + ' <= ' + if_info.split('>=')[0] 167 | if '==' in if_info: 168 | new_if_info = if_info.replace('==', ' is ') 169 | 170 | return method_string.replace(if_info, new_if_info) 171 | 172 | 173 | def add_argumemts(method_string): 174 | arguments_list = extract_argument(method_string) 175 | arguments_info = method_string[method_string.find('(') + 1: method_string.find(')')] 176 | if len(arguments_list) == 0: 177 | arguments_info = word_synonym_replacement(extract_method_name(method_string))[0] 178 | return method_string[0 : method_string.find('()') + 1] + arguments_info + method_string[method_string.find('()') + 1 :] 179 | mutation_index = random.randint(0, len(arguments_list) - 1) 180 | org_argument = arguments_list[mutation_index] 181 | new_argument = word_synonym_replacement(arguments_list[mutation_index]) 182 | new_arguments_info = arguments_info.replace(org_argument, org_argument + ', ' + new_argument) 183 | method_string = method_string.replace(arguments_info, new_arguments_info, 1) 184 | return method_string 185 | 186 | 187 | def enhance_filed(method_string): 188 | arguments_list = extract_argument(method_string) 189 | line_list = method_string.split('\n') 190 | refer_line = line_list[1] 191 | if len(arguments_list) == 0: 192 | return method_string 193 | space_count = 0 194 | for char in refer_line: 195 | if char == ' ': 196 | space_count += 1 197 | else: 198 | break 199 | mutation_index = random.randint(0, len(arguments_list) - 1) 200 | space_str = '' 201 | for _ in range(space_count): 202 | space_str += ' ' 203 | extra_info = "\n" + space_str + "if " + arguments_list[mutation_index].strip().split(' ')[-1] + " == None: print('please check your input')" 204 | method_string = method_string[0 : method_string.find(':') + 1] + extra_info + method_string[method_string.find(':') + 1 : ] 205 | return method_string 206 | 207 | 208 | def apply_plus_zero_math(data): 209 | variable_list = extract_local_variable(data) 210 | success_flag = 0 211 | for variable_name in variable_list: 212 | match_ret = re.findall(variable_name + '\s*=\s\w*\n', data) 213 | if len(match_ret) > 0: 214 | code_line = match_ret[0] 215 | value = code_line.split('\n')[0].split('=')[1] 216 | ori_value = value 217 | if '+' in value or '-' in value or '*' in value or '/' in value or '//' in value: 218 | value = value + ' + 0' 219 | success_flag = 1 220 | try: 221 | value_float = float(value) 222 | value = value + ' + 0' 223 | success_flag = 1 224 | except ValueError: 225 | continue 226 | if success_flag == 1: 227 | mutant = code_line.split(ori_value)[0] 228 | mutant = mutant + value + '\n' 229 | method_string = data.replace(code_line, mutant) 230 | return method_string 231 | if success_flag == 0: 232 | return data 233 | 234 | 235 | def dead_branch_if_else(data): 236 | statement_list = data.split('\n') 237 | mutation_index = random.randint(1, len(statement_list) - 1) 238 | statement = statement_list[mutation_index] 239 | space_count = 0 240 | if statement == '': 241 | return data 242 | if mutation_index == len(statement_list) - 1: 243 | refer_line = statement_list[-1] 244 | for char in refer_line: 245 | if char == ' ': 246 | space_count += 1 247 | else: 248 | break 249 | else: 250 | refer_line = statement_list[mutation_index] 251 | for char in refer_line: 252 | if char == ' ': 253 | space_count += 1 254 | else: 255 | break 256 | new_statement = '' 257 | for _ in range(space_count): 258 | new_statement += ' ' 259 | new_statement += get_branch_if_else_mutant() 260 | method_string = data.replace(statement, '\n' + new_statement + '\n' + statement) 261 | return method_string 262 | 263 | 264 | def dead_branch_if(data): 265 | statement_list = data.split('\n') 266 | mutation_index = random.randint(1, len(statement_list) - 1) 267 | statement = statement_list[mutation_index] 268 | space_count = 0 269 | if statement == '': 270 | return data 271 | if mutation_index == len(statement_list) - 1: 272 | refer_line = statement_list[-1] 273 | for char in refer_line: 274 | if char == ' ': 275 | space_count += 1 276 | else: 277 | break 278 | else: 279 | refer_line = statement_list[mutation_index] 280 | for char in refer_line: 281 | if char == ' ': 282 | space_count += 1 283 | else: 284 | break 285 | new_statement = '' 286 | for _ in range(space_count): 287 | new_statement += ' ' 288 | new_statement += get_branch_if_mutant() 289 | method_string = data.replace(statement, '\n' + new_statement + '\n' + statement) 290 | 291 | return method_string 292 | 293 | 294 | def dead_branch_while(data): 295 | statement_list = data.split('\n') 296 | mutation_index = random.randint(1, len(statement_list) - 1) 297 | statement = statement_list[mutation_index] 298 | space_count = 0 299 | if statement == '': 300 | return data 301 | if mutation_index == len(statement_list) - 1: 302 | refer_line = statement_list[-1] 303 | for char in refer_line: 304 | if char == ' ': 305 | space_count += 1 306 | else: 307 | break 308 | else: 309 | refer_line = statement_list[mutation_index] 310 | for char in refer_line: 311 | if char == ' ': 312 | space_count += 1 313 | else: 314 | break 315 | new_statement = '' 316 | print(space_count) 317 | for _ in range(space_count): 318 | new_statement += ' ' 319 | new_statement += get_branch_while_mutant() 320 | method_string = data.replace(statement, '\n' + new_statement + '\n' + statement) 321 | # print(method_string) 322 | return method_string 323 | 324 | 325 | def dead_branch_for(data): 326 | statement_list = data.split('\n') 327 | mutation_index = random.randint(1, len(statement_list) - 1) 328 | statement = statement_list[mutation_index] 329 | space_count = 0 330 | if statement == '': 331 | return data 332 | if mutation_index == len(statement_list) - 1: 333 | refer_line = statement_list[-1] 334 | for char in refer_line: 335 | if char == ' ': 336 | space_count += 1 337 | else: 338 | break 339 | else: 340 | refer_line = statement_list[mutation_index] 341 | for char in refer_line: 342 | if char == ' ': 343 | space_count += 1 344 | else: 345 | break 346 | new_statement = '' 347 | for _ in range(space_count): 348 | new_statement += ' ' 349 | new_statement += get_branch_for_mutant() 350 | method_string = data.replace(statement, '\n' + new_statement + '\n' + statement) 351 | return method_string 352 | 353 | 354 | if __name__ == '__main__': 355 | filename = 'test.py' 356 | open_file = open(filename, 'r', encoding='ISO-8859-1') 357 | code = open_file.read() 358 | Class_list, raw_code = extract_class(code) 359 | for class_name in Class_list: 360 | function_list, class_name = extract_function_python(class_name) 361 | candidate_code = function_list[0] 362 | mutated_code = apply_plus_zero_math(candidate_code) 363 | print(candidate_code) 364 | print(mutated_code) 365 | -------------------------------------------------------------------------------- /Tool/Python_refactor/util.py: -------------------------------------------------------------------------------- 1 | import javalang 2 | import secrets 3 | import random 4 | import json 5 | 6 | from os import listdir 7 | from os.path import isfile, join 8 | 9 | 10 | def get_radom_var_name(): 11 | res_string = '' 12 | for x in range(8): 13 | res_string += random.choice('abcdefghijklmnopqrstuvwxyz') 14 | return res_string 15 | 16 | 17 | def get_dead_for_condition(): 18 | var = get_radom_var_name() 19 | return "int "+var+" = 0; "+var+" < 0; "+var+"++" 20 | 21 | 22 | def get_random_false_stmt(): 23 | res = [random.choice(["True", "False"]) for x in range(10)] 24 | res.append("False") 25 | res_str = " and ".join(res) 26 | return res_str 27 | 28 | 29 | def get_tree(data): 30 | tokens = javalang.tokenizer.tokenize(data) 31 | parser = javalang.parser.Parser(tokens) 32 | tree = parser.parse_member_declaration() 33 | return tree 34 | 35 | 36 | def verify_method_syntax(data): 37 | try: 38 | tokens = javalang.tokenizer.tokenize(data) 39 | parser = javalang.parser.Parser(tokens) 40 | tree = parser.parse_member_declaration() 41 | print("syantax check passed") 42 | except: 43 | print("syantax check failed") 44 | 45 | 46 | def get_random_type_name_and_value_statment(): 47 | datatype = random.choice( 48 | 'byte,short,int,long,float,double,boolean,char,String'.split(',')) 49 | var_name = get_radom_var_name() 50 | 51 | if datatype == "byte": 52 | var_value = get_random_int(-128, 127) 53 | elif datatype == "short": 54 | var_value = get_random_int(-10000, 10000) 55 | elif datatype == "boolean": 56 | var_value = random.choice(["True", "False"]) 57 | elif datatype == "char": 58 | var_value = str(random.choice( 59 | 'a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z'.split(','))) 60 | var_value = '"'+var_value+'"' 61 | elif datatype == "String": 62 | var_value = str(get_radom_var_name()) 63 | var_value = '"'+var_value+'"' 64 | else: 65 | var_value = get_random_int(-1000000000, 1000000000) 66 | 67 | mutant = str(var_name) + ' = ' + str(var_value) 68 | return mutant 69 | 70 | 71 | def generate_file_name_list_file_from_dir(method_path): 72 | filenames = [f for f in listdir( 73 | method_path) if isfile(join(method_path, f))] 74 | with open(method_path+'\\'+'all_file_names.txt', 'w') as f: 75 | f.write(json.dumps(filenames)) 76 | print("done") 77 | 78 | 79 | def get_file_name_list(method_path): 80 | with open(method_path+'\\'+'all_file_names.txt') as f: 81 | data = json.load(f) 82 | return data 83 | 84 | 85 | def get_random_int(min, max): 86 | return random.randint(min, max) 87 | 88 | 89 | def format_code_chuncks(code_chuncks): 90 | for idx, c in enumerate(code_chuncks): 91 | c = c.replace(' . ', '.') 92 | c = c.replace(' ( ', '(') 93 | c = c.replace(' ) ', ')') 94 | c = c.replace(' ;', ';') 95 | c = c.replace('[ ]', '[]') 96 | code_chuncks[idx] = c 97 | return code_chuncks 98 | 99 | 100 | def format_code(c): 101 | c = c.replace(' . ', '.') 102 | c = c.replace(' ( ', '(') 103 | c = c.replace(' ) ', ')') 104 | c = c.replace(' ;', ';') 105 | c = c.replace('[ ]', '[]') 106 | return c 107 | 108 | 109 | def get_method_header(string): 110 | method_header = '' 111 | tree = get_tree(string) 112 | 113 | tokens = list(javalang.tokenizer.tokenize(string)) 114 | 115 | chunck_start_poss = [s.position.column for s in tree.body] 116 | if len(chunck_start_poss) > 0: 117 | method_header = ' '.join([t.value for t in tokens 118 | if t.position.column < chunck_start_poss[0]]) 119 | 120 | method_header = format_code_chuncks([method_header])[0] 121 | return method_header 122 | 123 | 124 | def get_method_statement(string): 125 | code_chuncks = [] 126 | tree = get_tree(string) 127 | tokens = list(javalang.tokenizer.tokenize(string)) 128 | chunck_start_poss = [s.position.column for s in tree.body] 129 | 130 | if len(chunck_start_poss) > 1: 131 | for idx, statement in enumerate(chunck_start_poss[:-1]): 132 | statment = ' '.join([t.value for t in tokens 133 | if t.position.column >= chunck_start_poss[idx] 134 | and t.position.column < chunck_start_poss[idx+1]]) 135 | code_chuncks.append(statment) 136 | 137 | last_statment = ' '.join([t.value for t in tokens 138 | if t.position.column >= chunck_start_poss[-1]][:-1]) 139 | code_chuncks.append(last_statment) 140 | 141 | if len(chunck_start_poss) == 1: 142 | last_statment = ' '.join([t.value for t in tokens 143 | if t.position.column >= chunck_start_poss[0]][:-1]) 144 | code_chuncks.append(last_statment) 145 | code_chuncks = format_code_chuncks(code_chuncks) 146 | return code_chuncks 147 | 148 | 149 | def scan_tree(tree): 150 | for path, node in tree: 151 | print("=======================") 152 | print(node) 153 | 154 | 155 | def get_all_type(tree): 156 | res_list=[] 157 | for path, node in tree.filter(javalang.tree.ReferenceType): 158 | if node.name != None: 159 | res_list.append(node.name) 160 | return list(set(res_list)) 161 | 162 | 163 | def scan_local_vars(tree): 164 | for path, node in tree.filter(javalang.tree.LocalVariableDeclaration): 165 | print("name=========type=============") 166 | print(node.declarators[0].name, "\t", node.type.name) 167 | 168 | 169 | def get_local_vars(tree): 170 | var_list = [] 171 | for path, node in tree.filter(javalang.tree.LocalVariableDeclaration): 172 | var_list.append([node.declarators[0].name, node.type.name]) 173 | return var_list 174 | 175 | 176 | def get_local_assignments(tree): 177 | var_list = [] 178 | for path, node in tree.filter(javalang.tree.Assignment): 179 | var_list.append([node.declarators[0].name, node.type.name]) 180 | return var_list 181 | 182 | 183 | def get_branch_if_else_mutant(): 184 | mutant = get_random_type_name_and_value_statment() + ' if '+get_random_false_stmt() + ' else ' + str(get_random_int(-1000000000, 1000000000)) 185 | return mutant 186 | 187 | 188 | def get_branch_if_mutant(): 189 | mutant = 'if '+get_random_false_stmt()+': ' + \ 190 | get_random_type_name_and_value_statment() 191 | return mutant 192 | 193 | 194 | def get_branch_while_mutant(): 195 | mutant = 'while '+get_random_false_stmt()+': ' + \ 196 | get_random_type_name_and_value_statment() 197 | return mutant 198 | 199 | 200 | def get_branch_for_mutant(): 201 | var = get_radom_var_name() 202 | mutant = 'for ' + var + ' in range(0): ' + get_random_type_name_and_value_statment() 203 | return mutant 204 | 205 | 206 | -------------------------------------------------------------------------------- /img/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zemingd/Mixup4Code/7cdf38e5968fb4e4db50ca077f1c3276280ca09c/img/overview.png -------------------------------------------------------------------------------- /model/BagOfToken.py: -------------------------------------------------------------------------------- 1 | """ 2 | Program for code classification by Bag or Tokens technique 3 | 4 | History of traing is pickled to be plotted with other application 5 | 6 | Implements several simple classifiers of sorce code: 7 | - Linear (1 layer neural network) 8 | - 2 layer neural network giving ccuracy for 5 classes 9 | - General N layer network 10 | 11 | The programm processes all files in the given directory 12 | Each file there contains all samples of one class of samples, 13 | i.e. tokenised source code 14 | 15 | Number of classes (lables) is defined automatically 16 | """ 17 | import sys 18 | import os 19 | import argparse 20 | import pickle 21 | from keras.utils.np_utils import to_categorical 22 | 23 | 24 | 25 | main_dir = os.path.dirname( 26 | os.path.dirname(os.path.realpath(__file__))) 27 | sys.path.extend([f"{main_dir}/Dataset", 28 | f"{main_dir}/ModelMaker", 29 | f"{main_dir}/CommonFunctions"]) 30 | 31 | from ProgramArguments import (makeArgParserCodeML, parseArguments) 32 | from Utilities import * 33 | from BagTokDataset import BagTokDataset 34 | from DsUtilities import DataRand 35 | from SeqModelMaker import SeqModelFactory 36 | 37 | def main(args): 38 | """ 39 | Main function of program for classifying source code 40 | Parameters: 41 | - args -- Parsed command line arguments 42 | as object returned by ArgumentParser 43 | """ 44 | resetSeeds() 45 | DataRand.setDsSeeds(args.seed_ds) 46 | 47 | if args.ckpt_dir: 48 | _latest_checkpoint = setupCheckpoint(args.ckpt_dir) 49 | _checkpoint_callback = makeCkptCallback(args.ckpt_dir) 50 | _callbacks=[_checkpoint_callback] 51 | else: 52 | _latest_checkpoint = None 53 | _callbacks = None 54 | 55 | _ds = BagTokDataset(args.dataset, 56 | min_n_solutions = max(args.min_solutions, 3), 57 | max_n_problems = args.problems, 58 | short_code_th = args.short_code, 59 | long_code_th = args.long_code, 60 | test_part = args.testpart, 61 | balanced_split = args.balanced_split) 62 | 63 | 64 | _ds_refactor = BagTokDataset('../', 65 | min_n_solutions = max(args.min_solutions, 3), 66 | max_n_problems = args.problems, 67 | short_code_th = args.short_code, 68 | long_code_th = args.long_code, 69 | test_part = args.testpart, 70 | balanced_split = args.balanced_split) 71 | 72 | 73 | 74 | print(f"Classification of source code among {_ds.n_labels} classes") 75 | print("Technique of fully connected neural network on bag of tokens\n") 76 | _model_factory = SeqModelFactory(_ds.n_token_types, _ds.n_labels) 77 | if _latest_checkpoint: 78 | print("Restoring DNN from", _latest_checkpoint) 79 | _dnn = tf.keras.models.load_model(_latest_checkpoint) 80 | else: 81 | print("Constructing DNN") 82 | _dnn = _model_factory.denseDNN(args.dense) 83 | 84 | _val_ds, _train_ds = _ds.trainValidDs(args.valpart, args.batch) 85 | 86 | a,b=mixup_data_refactor(_train_ds[0],to_categorical(_train_ds[1],num_classes=250), _train_ds_refactor[0],to_categorical(_train_ds_refactor[1],num_classes=250),alpha=0.6) 87 | 88 | _history = _dnn.fit(a, 89 | b, 90 | epochs = args.epochs, 91 | batch_size = args.batch, 92 | validation_data = (_val_ds[0], 93 | to_categorical(_val_ds[1], num_classes=250)), 94 | verbose = args.progress, callbacks = _callbacks) 95 | 96 | 97 | 98 | with open(args.history, 'wb') as _jar: 99 | pickle.dump(_history.history, _jar) 100 | 101 | ####################################################################### 102 | # Command line arguments of are described below 103 | ####################################################################### 104 | if __name__ == '__main__': 105 | print("\nCODE CLASSIFICATION WITH BAG OF TOKENS TECHNIQUE") 106 | 107 | #Command-line arguments 108 | parser = makeArgParserCodeML( 109 | "Bag of tokens program source code classifier", 110 | task = "classification") 111 | args = parseArguments(parser) 112 | 113 | main(args) 114 | -------------------------------------------------------------------------------- /model/GAT_model.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data,DataLoader 2 | import torch 3 | import torch.nn as nn 4 | from torch_geometric.nn import GATConv 5 | import random 6 | import copy 7 | import numpy as np 8 | from torch_geometric.nn import global_max_pool 9 | import torch.nn.functional as F 10 | 11 | def main(): 12 | data_buggy = [] 13 | data_fixed = [] 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | for i in range(100): 17 | edge_index = torch.tensor(Dataset_Buggy_Edge[i], dtype=torch.long) 18 | x = torch.tensor(Dataset_Buggy_Node[i]), dtype=torch.float) 19 | y = torch.tensor([0], dtype=torch.long) 20 | y = torch.nn.functional.one_hot(y, 2) 21 | data_buggy.append(Data(x=x, edge_index=edge_index.t().contiguous(), y=y)) 22 | 23 | for i in range(100): 24 | edge_index = torch.tensor(Dataset_Fixed_Edge[i], dtype=torch.long) 25 | x = torch.tensor(Dataset_Fixed_Node[i]), dtype=torch.float) 26 | y = torch.tensor([1], dtype=torch.long) 27 | y = torch.nn.functional.one_hot(y, 2) 28 | data_fixed.append(Data(x=x, edge_index=edge_index.t().contiguous(), y=y)) 29 | 30 | data = data_fixed + data_buggy 31 | random.shuffle(data) 32 | 33 | train_dataset = data[i:] 34 | test_dataset = data[:i] 35 | 36 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 37 | test_loader = DataLoader(test_dataset,batch_size=1, shuffle=True) 38 | 39 | 40 | class Net(torch.nn.Module): 41 | def __init__(self): 42 | super().__init__() 43 | self.conv1 = GATConv(Vertex_feature_dim, 160) 44 | self.conv2 = GATConv(160,160) 45 | self.linear = torch.nn.Linear(160, Num_class) 46 | pass 47 | 48 | def forward(self, x,batch,edge_index): 49 | out = self.conv1(x, edge_index) 50 | max_out = global_max_pool(out,batch) 51 | out = self.linear(max_out) 52 | out = nn.functional.dropout(out, p=0.5, training=self.training) 53 | return out 54 | pass 55 | pass 56 | 57 | 58 | model = Net().to(device) 59 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 60 | 61 | 62 | for i in range(100): 63 | for epoch in range(1): 64 | model.train() 65 | loss_per_epoch = 0 66 | for batch in train_loader: 67 | optimizer.zero_grad() 68 | batch = batch.to(device) 69 | out = model(batch.x, batch.batch, batch.edge_index) 70 | out_mixup, batch.y_mixup = mixup_data(out, batch.y) 71 | log_prob = torch.nn.functional.log_softmax(out_mixup, dim=1) 72 | loss = -torch.sum(log_prob * batch.y_mixup) / 32 73 | loss_per_epoch += loss.item() 74 | loss.backward() 75 | optimizer.step() 76 | print('Loss per epoch: {}'.format(str(loss_per_epoch))) 77 | 78 | model.eval() 79 | correct = 0 80 | total = 0 81 | for batch in test_loader: 82 | batch = batch.to(device) 83 | out = model(batch.x, batch.batch, batch.edge_index) 84 | pred = torch.argmax(out, dim=1) 85 | batch.y = torch.argmax(batch.y, dim=1) 86 | total += 1 87 | if pred.item() == batch.y.item(): 88 | correct += 1 89 | print('Eval Acc: {}'.format(str(correct / total))) 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | 95 | -------------------------------------------------------------------------------- /model/GCN_model.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data,DataLoader 2 | import torch 3 | import torch.nn as nn 4 | from torch_geometric.nn import GCNConv 5 | import random 6 | import copy 7 | import numpy as np 8 | 9 | from torch_geometric.nn import global_max_pool 10 | import torch.nn.functional as F 11 | 12 | 13 | def main(): 14 | data_buggy = [] 15 | data_fixed = [] 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | for i in range(100): 19 | edge_index = torch.tensor(Dataset_Buggy_Edge[i], dtype=torch.long) 20 | x = torch.tensor(Dataset_Buggy_Node[i]), dtype=torch.float) 21 | y = torch.tensor([0], dtype=torch.long) 22 | y = torch.nn.functional.one_hot(y, 2) 23 | data_buggy.append(Data(x=x, edge_index=edge_index.t().contiguous(), y=y)) 24 | 25 | for i in range(100): 26 | edge_index = torch.tensor(Dataset_Fixed_Edge[i], dtype=torch.long) 27 | x = torch.tensor(Dataset_Fixed_Node[i]), dtype=torch.float) 28 | y = torch.tensor([1], dtype=torch.long) 29 | y = torch.nn.functional.one_hot(y, 2) 30 | data_fixed.append(Data(x=x, edge_index=edge_index.t().contiguous(), y=y)) 31 | 32 | data = data_fixed + data_buggy 33 | random.shuffle(data) 34 | 35 | train_dataset = data[i:] 36 | test_dataset = data[:i] 37 | 38 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 39 | test_loader = DataLoader(test_dataset,batch_size=1, shuffle=True) 40 | 41 | class Net(torch.nn.Module): 42 | def __init__(self): 43 | super().__init__() 44 | self.conv1 = GCNConv(Vertex_feature_dim, 160) 45 | self.conv2 = GCNConv(160,160) 46 | self.linear = torch.nn.Linear(160, Num_class) 47 | pass 48 | 49 | def forward(self, x,batch,edge_index): 50 | out = self.conv1(x, edge_index) 51 | max_out = global_max_pool(out,batch) 52 | out = self.linear(max_out) 53 | out = nn.functional.dropout(out, p=0.5, training=self.training) 54 | return out 55 | pass 56 | pass 57 | 58 | model = Net().to(device) 59 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 60 | 61 | 62 | losslabel = [] 63 | xlabel = [] 64 | ACClabel =[] 65 | 66 | for i in range(100): 67 | for epoch in range(1): 68 | model.train() 69 | loss_per_epoch = 0 70 | for batch in train_loader: 71 | optimizer.zero_grad() 72 | batch = batch.to(device) 73 | out = model(batch.x, batch.batch, batch.edge_index) 74 | out_mixup, batch.y_mixup = mixup_data(out, batch.y) 75 | log_prob = torch.nn.functional.log_softmax(out_mixup, dim=1) 76 | loss = -torch.sum(log_prob * batch.y_mixup) / 32 77 | loss_per_epoch += loss.item() 78 | loss.backward() 79 | optimizer.step() 80 | print('Loss per epoch: {}'.format(str(loss_per_epoch))) 81 | 82 | model.eval() 83 | correct = 0 84 | total = 0 85 | for batch in test_loader: 86 | batch = batch.to(device) 87 | out = model(batch.x, batch.batch, batch.edge_index) 88 | pred = torch.argmax(out, dim=1) 89 | batch.y = torch.argmax(batch.y, dim=1) 90 | total += 1 91 | if pred.item() == batch.y.item(): 92 | correct += 1 93 | print('Eval Acc: {}'.format(str(correct / total))) 94 | 95 | if __name__ == "__main__": 96 | main() 97 | 98 | -------------------------------------------------------------------------------- /model/GGNN_model.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Data,DataLoader 2 | import torch 3 | import copy 4 | from torch_geometric.nn import GatedGraphConv 5 | import random 6 | import numpy as np 7 | from torch_geometric.nn import global_max_pool 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | 11 | def main(): 12 | data_buggy = [] 13 | data_fixed = [] 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | for i in range(100): 17 | edge_index = torch.tensor(Dataset_Buggy_Edge[i], dtype=torch.long) 18 | x = torch.tensor(Dataset_Buggy_Node[i]), dtype=torch.float) 19 | y = torch.tensor([0], dtype=torch.long) 20 | y = torch.nn.functional.one_hot(y, 2) 21 | data_buggy.append(Data(x=x, edge_index=edge_index.t().contiguous(), y=y)) 22 | 23 | for i in range(100): 24 | edge_index = torch.tensor(Dataset_Fixed_Edge[i], dtype=torch.long) 25 | x = torch.tensor(Dataset_Fixed_Node[i]), dtype=torch.float) 26 | y = torch.tensor([1], dtype=torch.long) 27 | y = torch.nn.functional.one_hot(y, 2) 28 | data_fixed.append(Data(x=x, edge_index=edge_index.t().contiguous(), y=y)) 29 | 30 | data = data_fixed + data_buggy 31 | random.shuffle(data) 32 | 33 | train_dataset = data[i:] 34 | test_dataset = data[:i] 35 | 36 | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) 37 | test_loader = DataLoader(test_dataset,batch_size=1, shuffle=True) 38 | 39 | class Net(torch.nn.Module): 40 | def __init__(self, out_channels, num_layers, Vertex_feature_dim): 41 | super().__init__() 42 | self.out_channels = out_channels 43 | self.num_layers = num_layers 44 | self.Vertex_feature_dim = Vertex_feature_dim 45 | self.conv1 = GatedGraphConv(out_channels, num_layers) 46 | self.linear = torch.nn.Linear(out_channels, 2) 47 | pass 48 | 49 | def forward(self, x,batch,edge_index): 50 | out = self.conv1(x, edge_index) 51 | max_out = global_max_pool(out,batch) 52 | out = self.linear(max_out) 53 | out = nn.functional.dropout(out,p=0.5, training=self.training) 54 | return out 55 | pass 56 | pass 57 | 58 | out_channels = out_channels 59 | num_layers = num_layers 60 | Vertex_feature_dim = Vertex_feature_dim 61 | 62 | model = Net(out_channels, num_layers, Vertex_feature_dim).to(device) 63 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 64 | 65 | for i in range(100): 66 | for epoch in range(1): 67 | model.train() 68 | loss_per_epoch = 0 69 | for batch in train_loader: 70 | optimizer.zero_grad() 71 | batch = batch.to(device) 72 | out = model(batch.x, batch.batch, batch.edge_index) 73 | out_mixup, batch.y_mixup = mixup_data(out, batch.y) 74 | log_prob = torch.nn.functional.log_softmax(out_mixup, dim=1) 75 | loss = -torch.sum(log_prob * batch.y_mixup) / 32 76 | loss_per_epoch += loss.item() 77 | loss.backward() 78 | optimizer.step() 79 | print('Loss per epoch: {}'.format(str(loss_per_epoch))) 80 | 81 | model.eval() 82 | correct = 0 83 | total = 0 84 | for batch in test_loader: 85 | batch = batch.to(device) 86 | out = model(batch.x, batch.batch, batch.edge_index) 87 | pred = torch.argmax(out, dim=1) 88 | batch.y = torch.argmax(batch.y,dim=1) 89 | total += 1 90 | if pred.item() == batch.y.item(): 91 | correct += 1 92 | print('Eval Acc: {}'.format(str(correct / total))) 93 | 94 | if __name__ == "__main__": 95 | main() 96 | 97 | -------------------------------------------------------------------------------- /model/SeqofToken.py: -------------------------------------------------------------------------------- 1 | """ 2 | Program for code classification by Sequence or Tokens technique 3 | using convolutional neural network in multi-GPU mode 4 | History of traing is pickled to be plotted with other application 5 | Implements convolutional neural network classifier of 6 | sorce code represented with sequeence of tokens 7 | Its number of convolutional and dense layers, 8 | and their dimensions are specified by program arguments. 9 | The programm processes all files in the given directory 10 | Each file there contains all samples of one class of samples, 11 | i.e. tokenised source code 12 | Number of classes (lables) is defined automatically 13 | Program arguments are defined below in definition of 14 | argparse Argmuments Parser object 15 | """ 16 | import sys 17 | import os 18 | import argparse 19 | import pickle 20 | import tensorflow as tf 21 | import numpy as np 22 | from keras.utils.np_utils import to_categorical 23 | 24 | 25 | main_dir = os.path.dirname( 26 | os.path.dirname(os.path.realpath(__file__))) 27 | sys.path.extend([f"{main_dir}/Dataset", 28 | f"{main_dir}/ModelMaker", 29 | f"{main_dir}/CommonFunctions"]) 30 | 31 | from ProgramArguments import * 32 | from Utilities import * 33 | from DsUtilities import DataRand 34 | from SeqTokDataset import SeqTokDataset 35 | from SeqModelMaker import SeqModelFactory 36 | from ExperimentalModel import ExperimentModelFactory 37 | from ModelUtils import UniqueSeed 38 | 39 | def makeDNN(n_tokens, n_labels, args): 40 | """ 41 | Make classification DNN according to program arguments 42 | Parameters: 43 | - n_tokens -- number of token types in the sequences 44 | - n_labels -- number of class lables 45 | - args -- parsed main program arguments 46 | """ 47 | _convolutions = \ 48 | list(zip(args.filters, args.kernels, args.strides) 49 | if args.strides 50 | else zip(args.filters, args.kernels)) 51 | if args.dnn == "basic": 52 | _model_factory = SeqModelFactory(n_tokens, n_labels) 53 | _dnn = _model_factory.cnnDNN(_convolutions, args.dense, 54 | pool = args.pool, 55 | conv_act = args.conv_act, 56 | regular = (args.l1, args.l2), 57 | regul_dense_only = args.regul_dense_only, 58 | input_type = args.coding, 59 | dropout_rate = args.dropout, 60 | optimizer = args.optimizer, 61 | embedding_dim = args.embed) 62 | print("Basic dnn for source code classification is constructed") 63 | else: 64 | _model_factory = ExperimentModelFactory(n_labels, 65 | regularizer = (args.l1, args.l2)) 66 | _dnn = _model_factory.doublePoolClassCNN( 67 | n_tokens, 68 | _convolutions, args.dense, 69 | conv_act = args.conv_act, 70 | input_type = args.coding, 71 | embedding_dim = args.embed, 72 | regul_dense_only = args.regul_dense_only, 73 | dropout_rate = args.dropout, 74 | optimizer = args.optimizer) 75 | print("Experimental dnn with both max and average pooling is constructed") 76 | return _dnn 77 | 78 | 79 | def main(args): 80 | """ 81 | Main function of program for classifying source code 82 | Parameters: 83 | - args -- Parsed command line arguments 84 | as object returned by ArgumentParser 85 | """ 86 | resetSeeds() 87 | DataRand.setDsSeeds(args.seed_ds) 88 | UniqueSeed.setSeed(args.seed_model) 89 | 90 | early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', 91 | patience=args.patience) 92 | callbacks = [early_stop] 93 | # if args.ckpt_dir: 94 | # latest_checkpoint = setupCheckpoint(args.ckpt_dir) 95 | # checkpoint_callback = makeCkptCallback(args.ckpt_dir) 96 | # callbacks.append(checkpoint_callback) 97 | # else: 98 | # latest_checkpoint = None 99 | 100 | _ds = SeqTokDataset(args.dataset, 101 | min_n_solutions = max(args.min_solutions, 3), 102 | max_n_problems = args.problems, 103 | short_code_th = args.short_code, 104 | long_code_th = args.long_code, 105 | max_seq_length = args.seq_len, 106 | test_part = args.testpart, 107 | balanced_split = args.balanced_split) 108 | 109 | _ds_refactor = SeqTokDataset('../', 110 | min_n_solutions = max(args.min_solutions, 3), 111 | max_n_problems = args.problems, 112 | short_code_th = args.short_code, 113 | long_code_th = args.long_code, 114 | max_seq_length = args.seq_len, 115 | test_part = args.testpart, 116 | balanced_split = args.balanced_split) 117 | 118 | print(f"Classification of source code among {_ds.n_labels} classes") 119 | print("Technique of convolutional neural network on sequence of tokens\n") 120 | 121 | _dnn = makeDNN(_ds.n_token_types, _ds.n_labels, args) 122 | 123 | _val_ds, _train_ds = _ds.trainValidDs(args.valpart, 153600) 124 | _val_ds_refactor, _train_ds_refactor = _ds_refactor.trainValidDs(args.valpart, 153600) 125 | 126 | _tds = _train_ds[0] 127 | _tds = _tds.shuffle(50, reshuffle_each_iteration=True, 128 | seed = UniqueSeed.getSeed()).prefetch(2) 129 | 130 | 131 | _tds_refactor = _train_ds_refactor[0] 132 | _tds_refactor = _tds_refactor.shuffle(50, reshuffle_each_iteration=True, 133 | seed = UniqueSeed.getSeed()).prefetch(2) 134 | for _ in _tds: 135 | train_x = _[0] 136 | train_y = _[1] 137 | pass 138 | 139 | ## Mixup-refactor 140 | for _ in _tds_refactor: 141 | train_x_refactor = _[0] 142 | train_y_refactor = _[1] 143 | pass 144 | 145 | x_val = np.load("../../../../x_val_python.npy") 146 | y_val = np.load("../../../../y_val_python.npy") 147 | y_val = to_categorical(y_val, 800) 148 | 149 | #################################### MixUP:End ######################################### 150 | alpha = args.epochs / 10 151 | val_acc = 0 152 | del _tds_refactor 153 | del _tds 154 | for i in range(10): 155 | Mixup_x, Mixup_y = mixup_data_refactor(np.array(train_x), np.array(train_y), np.array(train_x_refactor), 156 | # np.array(train_y_refactor), alpha=alpha) 157 | history = _dnn.fit(np.array(train_x), 158 | np.array(train_y), 159 | epochs = 1, 160 | verbose = args.progress, 161 | batch_size=200) 162 | _loss, _acc = _dnn.evaluate(x_val, y_val) 163 | print("val acc: ", _acc) 164 | if _acc > val_acc: 165 | _dnn.save(args.ckpt_dir) 166 | val_acc = _acc 167 | print("final acc: ", val_acc) 168 | 169 | 170 | ####################################################################### 171 | # Command line arguments of are described below 172 | ####################################################################### 173 | if __name__ == '__main__': 174 | print("\nCODE CLASSIFICATION WITH SEQUENCE OF TOKENS TECHNIQUE") 175 | 176 | #Command-line arguments 177 | parser = makeArgParserCodeML( 178 | "Sequence of tokens source code classifier", 179 | task = "classification") 180 | parser = addSeqTokensArgs(parser) 181 | parser = addRegularizationArgs(parser) 182 | args = parseArguments(parser) 183 | 184 | checkConvolution(args) 185 | 186 | main(args) 187 | --------------------------------------------------------------------------------