├── CodeBERT ├── model.py └── run.py ├── GraphCodeBERT ├── model.py ├── parser │ ├── DFG.py │ ├── __init__.py │ ├── build.py │ ├── build.sh │ └── utils.py └── run.py ├── Mixup.py ├── README.md ├── Tool ├── Java_refactor │ ├── generate_refactoring.py │ ├── processing_source_code.py │ ├── refactoring_methods.py │ └── util.py └── Python_refactor │ ├── generate_refactoring.py │ ├── processing_source_code.py │ ├── refactoring_methods.py │ └── util.py ├── img └── overview.png └── model ├── BagOfToken.py ├── GAT_model.py ├── GCN_model.py ├── GGNN_model.py └── SeqofToken.py /CodeBERT/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import torch 4 | import torch.nn as nn 5 | import torch 6 | from torch.autograd import Variable 7 | import copy 8 | import torch.nn.functional as F 9 | from torch.nn import CrossEntropyLoss, MSELoss 10 | import numpy as np 11 | 12 | def mixup_data(x, y, alpha=0.1, runs, use_cuda=True): 13 | for i in range(runs): 14 | output_x = torch.Tensor(0) 15 | output_x= output_x.numpy().tolist() 16 | output_y = torch.Tensor(0) 17 | output_y = output_y.numpy().tolist() 18 | batch_size = x.size()[0] 19 | if alpha > 0.: 20 | lam = np.random.beta(alpha, alpha) 21 | else: 22 | lam = 1. 23 | 24 | if use_cuda: 25 | index = torch.randperm(batch_size).cuda() 26 | else: 27 | index = torch.randperm(batch_size) 28 | mixed_x = lam * x + (1 - lam) * x[index, :] 29 | mixed_y = lam * y + (1 - lam) * y[index, :] 30 | output_x.append(mixed_x) 31 | output_y.append(mixed_y) 32 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 33 | 34 | 35 | def mixup_data_refactor( x, y, x_refactor, y_refactor, alpha, runs, use_cuda=True): 36 | for i in range(runs): 37 | output_x = torch.Tensor(0) 38 | output_x= output_x.numpy().tolist() 39 | output_y = torch.Tensor(0) 40 | output_y = output_y.numpy().tolist() 41 | batch_size = x.size()[0] 42 | if alpha > 0.: 43 | lam = np.random.beta(alpha, alpha) 44 | else: 45 | lam = 1. 46 | if use_cuda: 47 | index = torch.randperm(batch_size).cuda() 48 | else: 49 | index = torch.randperm(batch_size) 50 | mixed_x = lam * x + (1 - lam) * x_refactor[index, :] 51 | mixed_y = lam * y + (1 - lam) * y_refactor[index, :] 52 | output_x.append(mixed_x) 53 | output_y.append(mixed_y) 54 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 55 | 56 | 57 | class Model(nn.Module): 58 | def __init__(self, encoder,config,tokenizer,args): 59 | super(Model, self).__init__() 60 | self.encoder = encoder 61 | self.config=config 62 | self.tokenizer=tokenizer 63 | self.args=args 64 | 65 | 66 | def forward(self, input_ids=None,labels=None): 67 | logits=self.encoder(input_ids,attention_mask=input_ids.ne(1))[0] 68 | logits, labels = mixup_data(logits,labels) # Mixup Data 69 | prob=torch.nn.functional.log_softmax(logits,-1) 70 | if labels is not None: 71 | loss = -torch.sum(prob*labels) 72 | return loss,prob 73 | else: 74 | return prob 75 | 76 | 77 | -------------------------------------------------------------------------------- /CodeBERT/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). 18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned 19 | using a masked language modeling (MLM) loss. 20 | """ 21 | 22 | from __future__ import absolute_import, division, print_function 23 | 24 | import argparse 25 | import glob 26 | import logging 27 | import os 28 | import pickle 29 | import random 30 | import re 31 | import shutil 32 | 33 | import numpy as np 34 | import torch 35 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset 36 | from torch.utils.data.distributed import DistributedSampler 37 | import json 38 | 39 | 40 | from tqdm import tqdm, trange 41 | import multiprocessing 42 | from model import Model 43 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 44 | RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer) 45 | 46 | logger = logging.getLogger(__name__) 47 | 48 | class InputFeatures(object): 49 | """A single training/test features for a example.""" 50 | def __init__(self, 51 | input_tokens, 52 | input_ids, 53 | label, 54 | 55 | ): 56 | self.input_tokens = input_tokens 57 | self.input_ids = input_ids 58 | self.label=label 59 | 60 | 61 | def convert_examples_to_features(js,tokenizer,args): 62 | #source 63 | code=' '.join(js['code'].split()) 64 | code_tokens=tokenizer.tokenize(code)[:args.block_size-2] 65 | source_tokens =[tokenizer.cls_token]+code_tokens+[tokenizer.sep_token] 66 | source_ids = tokenizer.convert_tokens_to_ids(source_tokens) 67 | padding_length = args.block_size - len(source_ids) 68 | source_ids+=[tokenizer.pad_token_id]*padding_length 69 | return InputFeatures(source_tokens,source_ids,js['label']) 70 | 71 | class TextDataset(Dataset): 72 | def __init__(self, tokenizer, args, file_path=None): 73 | self.examples = [] 74 | with open(file_path) as f: 75 | for line in f: 76 | js=json.loads(line.strip()) 77 | self.examples.append(convert_examples_to_features(js,tokenizer,args)) 78 | if 'train' in file_path: 79 | for idx, example in enumerate(self.examples[:3]): 80 | logger.info("*** Example ***") 81 | logger.info("label: {}".format(example.label)) 82 | logger.info("input_tokens: {}".format([x.replace('\u0120','_') for x in example.input_tokens])) 83 | logger.info("input_ids: {}".format(' '.join(map(str, example.input_ids)))) 84 | 85 | def __len__(self): 86 | return len(self.examples) 87 | 88 | def __getitem__(self, i): 89 | return torch.tensor(self.examples[i].input_ids),torch.tensor(self.examples[i].label) 90 | 91 | 92 | def set_seed(seed=42): 93 | random.seed(seed) 94 | os.environ['PYHTONHASHSEED'] = str(seed) 95 | np.random.seed(seed) 96 | torch.manual_seed(seed) 97 | torch.cuda.manual_seed(seed) 98 | torch.backends.cudnn.deterministic = True 99 | 100 | def train(args, train_dataset, model, tokenizer): 101 | """ Train the model """ 102 | train_sampler = RandomSampler(train_dataset) 103 | 104 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, 105 | batch_size=args.train_batch_size,num_workers=4,pin_memory=True) 106 | 107 | 108 | 109 | # Prepare optimizer and schedule (linear warmup and decay) 110 | no_decay = ['bias', 'LayerNorm.weight'] 111 | optimizer_grouped_parameters = [ 112 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 113 | 'weight_decay': args.weight_decay}, 114 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 115 | ] 116 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 117 | max_steps = len(train_dataloader) * args.num_train_epochs 118 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=max_steps*0.1, 119 | num_training_steps=max_steps) 120 | 121 | # Train! 122 | logger.info("***** Running training *****") 123 | logger.info(" Num examples = %d", len(train_dataset)) 124 | logger.info(" Num Epochs = %d", args.num_train_epochs) 125 | logger.info(" batch size = %d", args.train_batch_size) 126 | logger.info(" Total optimization steps = %d", max_steps) 127 | best_acc=0.0 128 | model.zero_grad() 129 | 130 | for idx in range(args.num_train_epochs): 131 | bar = tqdm(train_dataloader,total=len(train_dataloader)) 132 | losses=[] 133 | for step, batch in enumerate(bar): 134 | inputs = batch[0].to(args.device) 135 | labels=batch[1].to(args.device) 136 | #print(inputs.size()) 137 | #print(labels.size()) 138 | labels = torch.nn.functional.one_hot(labels, args.num_labels) 139 | # print(labels.size()) 140 | model.train() 141 | loss,logits = model(inputs,labels) 142 | 143 | if args.n_gpu > 1: 144 | loss = loss.mean() # mean() to average on multi-gpu parallel training 145 | 146 | loss.backward() 147 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 148 | 149 | losses.append(loss.item()) 150 | bar.set_description("epoch {} loss {}".format(idx,round(np.mean(losses),3))) 151 | optimizer.step() 152 | optimizer.zero_grad() 153 | scheduler.step() 154 | 155 | results = evaluate(args, model, tokenizer) 156 | for key, value in results.items(): 157 | logger.info(" %s = %s", key, round(value,4)) 158 | 159 | # Save model checkpoint 160 | if results['eval_acc']>best_acc: 161 | best_acc=results['eval_acc'] 162 | logger.info(" "+"*"*20) 163 | logger.info(" Best acc:%s",round(best_acc,4)) 164 | logger.info(" "+"*"*20) 165 | 166 | checkpoint_prefix = 'checkpoint-best-acc' 167 | output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) 168 | if not os.path.exists(output_dir): 169 | os.makedirs(output_dir) 170 | model_to_save = model.module if hasattr(model,'module') else model 171 | output_dir = os.path.join(output_dir, '{}'.format('model.bin')) 172 | torch.save(model_to_save.state_dict(), output_dir) 173 | logger.info("Saving model checkpoint to %s", output_dir) 174 | 175 | def evaluate(args, model, tokenizer): 176 | eval_output_dir = args.output_dir 177 | 178 | eval_dataset = TextDataset(tokenizer, args,args.eval_data_file) 179 | 180 | if not os.path.exists(eval_output_dir): 181 | os.makedirs(eval_output_dir) 182 | 183 | 184 | eval_sampler = SequentialSampler(eval_dataset) 185 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size,num_workers=4,pin_memory=True) 186 | 187 | # Eval! 188 | logger.info("***** Running evaluation *****") 189 | logger.info(" Num examples = %d", len(eval_dataset)) 190 | logger.info(" Batch size = %d", args.eval_batch_size) 191 | eval_loss = 0.0 192 | nb_eval_steps = 0 193 | model.eval() 194 | logits=[] 195 | labels=[] 196 | for batch in eval_dataloader: 197 | inputs = batch[0].to(args.device) 198 | label=batch[1].to(args.device) 199 | label = torch.nn.functional.one_hot(label, args.num_labels) 200 | 201 | with torch.no_grad(): 202 | lm_loss,logit = model(inputs,label) 203 | eval_loss += lm_loss.mean().item() 204 | logits.append(logit.cpu().numpy()) 205 | labels.append(label.cpu().numpy()) 206 | nb_eval_steps += 1 207 | logits=np.concatenate(logits,0) 208 | labels=np.concatenate(labels,0) 209 | #print('logits:',logits) 210 | preds=logits.argmax(-1) 211 | #print('preds:',preds) 212 | labels=labels.argmax(-1) 213 | eval_acc=np.mean(labels==preds) 214 | eval_loss = eval_loss / nb_eval_steps 215 | perplexity = torch.tensor(eval_loss) 216 | 217 | result = { 218 | "eval_loss": float(perplexity), 219 | "eval_acc":round(eval_acc,4), 220 | } 221 | return result 222 | 223 | def test(args, model, tokenizer): 224 | # Note that DistributedSampler samples randomly 225 | eval_dataset = TextDataset(tokenizer, args,args.test_data_file) 226 | eval_sampler = SequentialSampler(eval_dataset) 227 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 228 | 229 | # Eval! 230 | logger.info("***** Running Test *****") 231 | logger.info(" Num examples = %d", len(eval_dataset)) 232 | logger.info(" Batch size = %d", args.eval_batch_size) 233 | eval_loss = 0.0 234 | nb_eval_steps = 0 235 | model.eval() 236 | logits=[] 237 | labels=[] 238 | for batch in tqdm(eval_dataloader,total=len(eval_dataloader)): 239 | inputs = batch[0].to(args.device) 240 | label=batch[1].to(args.device) 241 | label = torch.nn.functional.one_hot(label, args.num_labels) 242 | with torch.no_grad(): 243 | _, logit = model(inputs,label) 244 | logits.append(logit.cpu().numpy()) 245 | labels.append(label.cpu().numpy()) 246 | nb_eval_steps += 1 247 | 248 | logits=np.concatenate(logits,0) 249 | labels=np.concatenate(labels,0) 250 | preds=logits.argmax(-1) 251 | labels= logits.argmax(-1) 252 | test_acc = np.mean(labels==preds) 253 | 254 | result = { 255 | "test_acc":round(test_acc,4), 256 | } 257 | return result 258 | 259 | 260 | 261 | #with open(os.path.join(args.output_dir,"predictions.txt"),'w') as f: 262 | # for example,pred in zip(eval_dataset.examples,preds): 263 | # f.write(str(pred)+'\n') 264 | 265 | def main(): 266 | parser = argparse.ArgumentParser() 267 | 268 | ## Required parameters 269 | parser.add_argument("--train_data_file", default=None, type=str, required=True, 270 | help="The input training data file (a text file).") 271 | parser.add_argument("--output_dir", default=None, type=str, required=True, 272 | help="The output directory where the model predictions and checkpoints will be written.") 273 | 274 | ## Other parameters 275 | parser.add_argument("--eval_data_file", default=None, type=str, 276 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).") 277 | parser.add_argument("--test_data_file", default=None, type=str, 278 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).") 279 | parser.add_argument("--model_name_or_path", default=None, type=str, 280 | help="The model checkpoint for weights initialization.") 281 | parser.add_argument("--tokenizer_name", default="", type=str, 282 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") 283 | parser.add_argument("--block_size", default=-1, type=int, 284 | help="Optional input sequence length after tokenization.") 285 | parser.add_argument("--do_train", action='store_true', 286 | help="Whether to run training.") 287 | parser.add_argument("--do_eval", action='store_true', 288 | help="Whether to run eval on the dev set.") 289 | parser.add_argument("--do_test", action='store_true', 290 | help="Whether to run eval on the dev set.") 291 | parser.add_argument("--train_batch_size", default=4, type=int, 292 | help="Batch size per GPU/CPU for training.") 293 | parser.add_argument("--eval_batch_size", default=4, type=int, 294 | help="Batch size per GPU/CPU for evaluation.") 295 | parser.add_argument("--learning_rate", default=5e-5, type=float, 296 | help="The initial learning rate for Adam.") 297 | parser.add_argument("--weight_decay", default=0.0, type=float, 298 | help="Weight deay if we apply some.") 299 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 300 | help="Epsilon for Adam optimizer.") 301 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 302 | help="Max gradient norm.") 303 | parser.add_argument("--warmup_steps", default=0, type=int, 304 | help="Linear warmup over warmup_steps.") 305 | parser.add_argument('--seed', type=int, default=42, 306 | help="random seed for initialization") 307 | parser.add_argument('--num_train_epochs', type=int, default=42, 308 | help="num_train_epochs") 309 | parser.add_argument('--num_labels', type=int, default=None, 310 | help = 'num_labels') 311 | 312 | args = parser.parse_args() 313 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 314 | args.n_gpu = torch.cuda.device_count() 315 | 316 | args.device = device 317 | # Setup logging 318 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 319 | datefmt='%m/%d/%Y %H:%M:%S', 320 | level=logging.INFO) 321 | logger.warning("device: %s, n_gpu: %s", device, args.n_gpu) 322 | 323 | # Set seed 324 | set_seed(args.seed) 325 | 326 | config = RobertaConfig.from_pretrained(args.model_name_or_path) 327 | config.num_labels=args.num_labels 328 | tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) 329 | model = RobertaForSequenceClassification.from_pretrained(args.model_name_or_path,config=config) 330 | 331 | model=Model(model,config,tokenizer,args) 332 | 333 | # multi-gpu training (should be after apex fp16 initialization) 334 | model.to(args.device) 335 | if args.n_gpu > 1: 336 | model = torch.nn.DataParallel(model) 337 | 338 | logger.info("Training/evaluation parameters %s", args) 339 | 340 | # Training 341 | if args.do_train: 342 | train_dataset = TextDataset(tokenizer, args,args.train_data_file) 343 | train(args, train_dataset, model, tokenizer) 344 | 345 | # Evaluation 346 | results = {} 347 | if args.do_eval: 348 | checkpoint_prefix = 'checkpoint-best-acc/model.bin' 349 | output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) 350 | model.load_state_dict(torch.load(output_dir)) 351 | model.to(args.device) 352 | result=evaluate(args, model, tokenizer) 353 | logger.info("***** Final Eval results *****") 354 | for key in sorted(result.keys()): 355 | logger.info(" %s = %s", key, str(round(result[key],4))) 356 | 357 | if args.do_test: 358 | checkpoint_prefix = 'checkpoint-best-acc/model.bin' 359 | output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) 360 | model.load_state_dict(torch.load(output_dir)) 361 | model.to(args.device) 362 | result = test(args, model, tokenizer) 363 | logger.info("***** Final Test results *****") 364 | for key in sorted(result.keys()): 365 | logger.info(" %s = %s", key, str(round(result[key],4))) 366 | 367 | 368 | return results 369 | 370 | 371 | if __name__ == "__main__": 372 | main() 373 | -------------------------------------------------------------------------------- /GraphCodeBERT/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | import torch 4 | import torch.nn as nn 5 | import torch 6 | from torch.autograd import Variable 7 | import copy 8 | import numpy as np 9 | import torch.nn.functional as F 10 | from torch.nn import CrossEntropyLoss, MSELoss 11 | from torch.utils.data import SequentialSampler, DataLoader 12 | import pdb 13 | 14 | 15 | def mixup_data(x, y, alpha=0.1, runs, use_cuda=True): 16 | for i in range(runs): 17 | output_x = torch.Tensor(0) 18 | output_x= output_x.numpy().tolist() 19 | output_y = torch.Tensor(0) 20 | output_y = output_y.numpy().tolist() 21 | batch_size = x.size()[0] 22 | if alpha > 0.: 23 | lam = np.random.beta(alpha, alpha) 24 | else: 25 | lam = 1. 26 | 27 | if use_cuda: 28 | index = torch.randperm(batch_size).cuda() 29 | else: 30 | index = torch.randperm(batch_size) 31 | mixed_x = lam * x + (1 - lam) * x[index, :] 32 | mixed_y = lam * y + (1 - lam) * y[index, :] 33 | output_x.append(mixed_x) 34 | output_y.append(mixed_y) 35 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 36 | 37 | 38 | def mixup_data_refactor( x, y, x_refactor, y_refactor, alpha, runs, use_cuda=True): 39 | for i in range(runs): 40 | output_x = torch.Tensor(0) 41 | output_x= output_x.numpy().tolist() 42 | output_y = torch.Tensor(0) 43 | output_y = output_y.numpy().tolist() 44 | batch_size = x.size()[0] 45 | if alpha > 0.: 46 | lam = np.random.beta(alpha, alpha) 47 | else: 48 | lam = 1. 49 | if use_cuda: 50 | index = torch.randperm(batch_size).cuda() 51 | else: 52 | index = torch.randperm(batch_size) 53 | mixed_x = lam * x + (1 - lam) * x_refactor[index, :] 54 | mixed_y = lam * y + (1 - lam) * y_refactor[index, :] 55 | output_x.append(mixed_x) 56 | output_y.append(mixed_y) 57 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 58 | 59 | class RobertaClassificationHead(nn.Module): 60 | """Head for sentence-level classification tasks.""" 61 | 62 | def __init__(self, config): 63 | super().__init__() 64 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 65 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 66 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 67 | 68 | def forward(self, features, **kwargs): 69 | x = features[:, 0, :] 70 | x = self.dropout(x) 71 | x = self.dense(x) 72 | x = torch.tanh(x) 73 | x = self.dropout(x) 74 | x = self.out_proj(x) 75 | return x 76 | 77 | class Model(nn.Module): 78 | def __init__(self, encoder,config,tokenizer,args): 79 | super(Model, self).__init__() 80 | self.encoder = encoder 81 | self.config=config 82 | self.tokenizer=tokenizer 83 | self.classifier=RobertaClassificationHead(config) 84 | self.args=args 85 | self.query = 0 86 | 87 | 88 | def forward(self, inputs_ids=None, attn_mask=None, position_idx=None, labels=None): 89 | 90 | nodes_mask=position_idx.eq(0) 91 | token_mask=position_idx.ge(2) 92 | 93 | inputs_embeddings=self.encoder.roberta.embeddings.word_embeddings(inputs_ids) 94 | nodes_to_token_mask=nodes_mask[:,:,None]&token_mask[:,None,:]&attn_mask 95 | nodes_to_token_mask=nodes_to_token_mask/(nodes_to_token_mask.sum(-1)+1e-10)[:,:,None] 96 | avg_embeddings=torch.einsum("abc,acd->abd",nodes_to_token_mask,inputs_embeddings) 97 | inputs_embeddings=inputs_embeddings*(~nodes_mask)[:,:,None]+avg_embeddings*nodes_mask[:,:,None] 98 | outputs = self.encoder.roberta(inputs_embeds=inputs_embeddings, attention_mask=attn_mask, position_ids=position_idx, token_type_ids=position_idx.eq(-1).long())[0] 99 | 100 | logits=self.classifier(outputs) 101 | logits, labels = mixup_data(logits,labels) # Mixup Data 102 | prob=torch.nn.functional.log_softmax(logits,-1) 103 | 104 | if labels is not None: 105 | loss = -torch.sum(prob*labels) 106 | return loss,prob 107 | else: 108 | return prob 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /GraphCodeBERT/parser/DFG.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from tree_sitter import Language, Parser 5 | from .utils import (remove_comments_and_docstrings, 6 | tree_to_token_index, 7 | index_to_code_token, 8 | tree_to_variable_index) 9 | 10 | 11 | def DFG_python(root_node,index_to_code,states): 12 | assignment=['assignment','augmented_assignment','for_in_clause'] 13 | if_statement=['if_statement'] 14 | for_statement=['for_statement'] 15 | while_statement=['while_statement'] 16 | do_first_statement=['for_in_clause'] 17 | def_statement=['default_parameter'] 18 | states=states.copy() 19 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 20 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 21 | if root_node.type==code: 22 | return [],states 23 | elif code in states: 24 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 25 | else: 26 | if root_node.type=='identifier': 27 | states[code]=[idx] 28 | return [(code,idx,'comesFrom',[],[])],states 29 | elif root_node.type in def_statement: 30 | name=root_node.child_by_field_name('name') 31 | value=root_node.child_by_field_name('value') 32 | DFG=[] 33 | if value is None: 34 | indexs=tree_to_variable_index(name,index_to_code) 35 | for index in indexs: 36 | idx,code=index_to_code[index] 37 | DFG.append((code,idx,'comesFrom',[],[])) 38 | states[code]=[idx] 39 | return sorted(DFG,key=lambda x:x[1]),states 40 | else: 41 | name_indexs=tree_to_variable_index(name,index_to_code) 42 | value_indexs=tree_to_variable_index(value,index_to_code) 43 | temp,states=DFG_python(value,index_to_code,states) 44 | DFG+=temp 45 | for index1 in name_indexs: 46 | idx1,code1=index_to_code[index1] 47 | for index2 in value_indexs: 48 | idx2,code2=index_to_code[index2] 49 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 50 | states[code1]=[idx1] 51 | return sorted(DFG,key=lambda x:x[1]),states 52 | elif root_node.type in assignment: 53 | if root_node.type=='for_in_clause': 54 | right_nodes=[root_node.children[-1]] 55 | left_nodes=[root_node.child_by_field_name('left')] 56 | else: 57 | if root_node.child_by_field_name('right') is None: 58 | return [],states 59 | left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=','] 60 | right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=','] 61 | if len(right_nodes)!=len(left_nodes): 62 | left_nodes=[root_node.child_by_field_name('left')] 63 | right_nodes=[root_node.child_by_field_name('right')] 64 | if len(left_nodes)==0: 65 | left_nodes=[root_node.child_by_field_name('left')] 66 | if len(right_nodes)==0: 67 | right_nodes=[root_node.child_by_field_name('right')] 68 | DFG=[] 69 | for node in right_nodes: 70 | temp,states=DFG_python(node,index_to_code,states) 71 | DFG+=temp 72 | 73 | for left_node,right_node in zip(left_nodes,right_nodes): 74 | left_tokens_index=tree_to_variable_index(left_node,index_to_code) 75 | right_tokens_index=tree_to_variable_index(right_node,index_to_code) 76 | temp=[] 77 | for token1_index in left_tokens_index: 78 | idx1,code1=index_to_code[token1_index] 79 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], 80 | [index_to_code[x][0] for x in right_tokens_index])) 81 | states[code1]=[idx1] 82 | DFG+=temp 83 | return sorted(DFG,key=lambda x:x[1]),states 84 | elif root_node.type in if_statement: 85 | DFG=[] 86 | current_states=states.copy() 87 | others_states=[] 88 | tag=False 89 | if 'else' in root_node.type: 90 | tag=True 91 | for child in root_node.children: 92 | if 'else' in child.type: 93 | tag=True 94 | if child.type not in ['elif_clause','else_clause']: 95 | temp,current_states=DFG_python(child,index_to_code,current_states) 96 | DFG+=temp 97 | else: 98 | temp,new_states=DFG_python(child,index_to_code,states) 99 | DFG+=temp 100 | others_states.append(new_states) 101 | others_states.append(current_states) 102 | if tag is False: 103 | others_states.append(states) 104 | new_states={} 105 | for dic in others_states: 106 | for key in dic: 107 | if key not in new_states: 108 | new_states[key]=dic[key].copy() 109 | else: 110 | new_states[key]+=dic[key] 111 | for key in new_states: 112 | new_states[key]=sorted(list(set(new_states[key]))) 113 | return sorted(DFG,key=lambda x:x[1]),new_states 114 | elif root_node.type in for_statement: 115 | DFG=[] 116 | for i in range(2): 117 | right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=','] 118 | left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=','] 119 | if len(right_nodes)!=len(left_nodes): 120 | left_nodes=[root_node.child_by_field_name('left')] 121 | right_nodes=[root_node.child_by_field_name('right')] 122 | if len(left_nodes)==0: 123 | left_nodes=[root_node.child_by_field_name('left')] 124 | if len(right_nodes)==0: 125 | right_nodes=[root_node.child_by_field_name('right')] 126 | for node in right_nodes: 127 | temp,states=DFG_python(node,index_to_code,states) 128 | DFG+=temp 129 | for left_node,right_node in zip(left_nodes,right_nodes): 130 | left_tokens_index=tree_to_variable_index(left_node,index_to_code) 131 | right_tokens_index=tree_to_variable_index(right_node,index_to_code) 132 | temp=[] 133 | for token1_index in left_tokens_index: 134 | idx1,code1=index_to_code[token1_index] 135 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], 136 | [index_to_code[x][0] for x in right_tokens_index])) 137 | states[code1]=[idx1] 138 | DFG+=temp 139 | if root_node.children[-1].type=="block": 140 | temp,states=DFG_python(root_node.children[-1],index_to_code,states) 141 | DFG+=temp 142 | dic={} 143 | for x in DFG: 144 | if (x[0],x[1],x[2]) not in dic: 145 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 146 | else: 147 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 148 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 149 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 150 | return sorted(DFG,key=lambda x:x[1]),states 151 | elif root_node.type in while_statement: 152 | DFG=[] 153 | for i in range(2): 154 | for child in root_node.children: 155 | temp,states=DFG_python(child,index_to_code,states) 156 | DFG+=temp 157 | dic={} 158 | for x in DFG: 159 | if (x[0],x[1],x[2]) not in dic: 160 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 161 | else: 162 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 163 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 164 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 165 | return sorted(DFG,key=lambda x:x[1]),states 166 | else: 167 | DFG=[] 168 | for child in root_node.children: 169 | if child.type in do_first_statement: 170 | temp,states=DFG_python(child,index_to_code,states) 171 | DFG+=temp 172 | for child in root_node.children: 173 | if child.type not in do_first_statement: 174 | temp,states=DFG_python(child,index_to_code,states) 175 | DFG+=temp 176 | 177 | return sorted(DFG,key=lambda x:x[1]),states 178 | 179 | 180 | def DFG_java(root_node,index_to_code,states): 181 | assignment=['assignment_expression'] 182 | def_statement=['variable_declarator'] 183 | increment_statement=['update_expression'] 184 | if_statement=['if_statement','else'] 185 | for_statement=['for_statement'] 186 | enhanced_for_statement=['enhanced_for_statement'] 187 | while_statement=['while_statement'] 188 | do_first_statement=[] 189 | states=states.copy() 190 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 191 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 192 | if root_node.type==code: 193 | return [],states 194 | elif code in states: 195 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 196 | else: 197 | if root_node.type=='identifier': 198 | states[code]=[idx] 199 | return [(code,idx,'comesFrom',[],[])],states 200 | elif root_node.type in def_statement: 201 | name=root_node.child_by_field_name('name') 202 | value=root_node.child_by_field_name('value') 203 | DFG=[] 204 | if value is None: 205 | indexs=tree_to_variable_index(name,index_to_code) 206 | for index in indexs: 207 | idx,code=index_to_code[index] 208 | DFG.append((code,idx,'comesFrom',[],[])) 209 | states[code]=[idx] 210 | return sorted(DFG,key=lambda x:x[1]),states 211 | else: 212 | name_indexs=tree_to_variable_index(name,index_to_code) 213 | value_indexs=tree_to_variable_index(value,index_to_code) 214 | temp,states=DFG_java(value,index_to_code,states) 215 | DFG+=temp 216 | for index1 in name_indexs: 217 | idx1,code1=index_to_code[index1] 218 | for index2 in value_indexs: 219 | idx2,code2=index_to_code[index2] 220 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 221 | states[code1]=[idx1] 222 | return sorted(DFG,key=lambda x:x[1]),states 223 | elif root_node.type in assignment: 224 | left_nodes=root_node.child_by_field_name('left') 225 | right_nodes=root_node.child_by_field_name('right') 226 | DFG=[] 227 | temp,states=DFG_java(right_nodes,index_to_code,states) 228 | DFG+=temp 229 | name_indexs=tree_to_variable_index(left_nodes,index_to_code) 230 | value_indexs=tree_to_variable_index(right_nodes,index_to_code) 231 | for index1 in name_indexs: 232 | idx1,code1=index_to_code[index1] 233 | for index2 in value_indexs: 234 | idx2,code2=index_to_code[index2] 235 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 236 | states[code1]=[idx1] 237 | return sorted(DFG,key=lambda x:x[1]),states 238 | elif root_node.type in increment_statement: 239 | DFG=[] 240 | indexs=tree_to_variable_index(root_node,index_to_code) 241 | for index1 in indexs: 242 | idx1,code1=index_to_code[index1] 243 | for index2 in indexs: 244 | idx2,code2=index_to_code[index2] 245 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 246 | states[code1]=[idx1] 247 | return sorted(DFG,key=lambda x:x[1]),states 248 | elif root_node.type in if_statement: 249 | DFG=[] 250 | current_states=states.copy() 251 | others_states=[] 252 | flag=False 253 | tag=False 254 | if 'else' in root_node.type: 255 | tag=True 256 | for child in root_node.children: 257 | if 'else' in child.type: 258 | tag=True 259 | if child.type not in if_statement and flag is False: 260 | temp,current_states=DFG_java(child,index_to_code,current_states) 261 | DFG+=temp 262 | else: 263 | flag=True 264 | temp,new_states=DFG_java(child,index_to_code,states) 265 | DFG+=temp 266 | others_states.append(new_states) 267 | others_states.append(current_states) 268 | if tag is False: 269 | others_states.append(states) 270 | new_states={} 271 | for dic in others_states: 272 | for key in dic: 273 | if key not in new_states: 274 | new_states[key]=dic[key].copy() 275 | else: 276 | new_states[key]+=dic[key] 277 | for key in new_states: 278 | new_states[key]=sorted(list(set(new_states[key]))) 279 | return sorted(DFG,key=lambda x:x[1]),new_states 280 | elif root_node.type in for_statement: 281 | DFG=[] 282 | for child in root_node.children: 283 | temp,states=DFG_java(child,index_to_code,states) 284 | DFG+=temp 285 | flag=False 286 | for child in root_node.children: 287 | if flag: 288 | temp,states=DFG_java(child,index_to_code,states) 289 | DFG+=temp 290 | elif child.type=="local_variable_declaration": 291 | flag=True 292 | dic={} 293 | for x in DFG: 294 | if (x[0],x[1],x[2]) not in dic: 295 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 296 | else: 297 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 298 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 299 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 300 | return sorted(DFG,key=lambda x:x[1]),states 301 | elif root_node.type in enhanced_for_statement: 302 | name=root_node.child_by_field_name('name') 303 | value=root_node.child_by_field_name('value') 304 | body=root_node.child_by_field_name('body') 305 | DFG=[] 306 | for i in range(2): 307 | temp,states=DFG_java(value,index_to_code,states) 308 | DFG+=temp 309 | name_indexs=tree_to_variable_index(name,index_to_code) 310 | value_indexs=tree_to_variable_index(value,index_to_code) 311 | for index1 in name_indexs: 312 | idx1,code1=index_to_code[index1] 313 | for index2 in value_indexs: 314 | idx2,code2=index_to_code[index2] 315 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 316 | states[code1]=[idx1] 317 | temp,states=DFG_java(body,index_to_code,states) 318 | DFG+=temp 319 | dic={} 320 | for x in DFG: 321 | if (x[0],x[1],x[2]) not in dic: 322 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 323 | else: 324 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 325 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 326 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 327 | return sorted(DFG,key=lambda x:x[1]),states 328 | elif root_node.type in while_statement: 329 | DFG=[] 330 | for i in range(2): 331 | for child in root_node.children: 332 | temp,states=DFG_java(child,index_to_code,states) 333 | DFG+=temp 334 | dic={} 335 | for x in DFG: 336 | if (x[0],x[1],x[2]) not in dic: 337 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 338 | else: 339 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 340 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 341 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 342 | return sorted(DFG,key=lambda x:x[1]),states 343 | else: 344 | DFG=[] 345 | for child in root_node.children: 346 | if child.type in do_first_statement: 347 | temp,states=DFG_java(child,index_to_code,states) 348 | DFG+=temp 349 | for child in root_node.children: 350 | if child.type not in do_first_statement: 351 | temp,states=DFG_java(child,index_to_code,states) 352 | DFG+=temp 353 | 354 | return sorted(DFG,key=lambda x:x[1]),states 355 | 356 | def DFG_csharp(root_node,index_to_code,states): 357 | assignment=['assignment_expression'] 358 | def_statement=['variable_declarator'] 359 | increment_statement=['postfix_unary_expression'] 360 | if_statement=['if_statement','else'] 361 | for_statement=['for_statement'] 362 | enhanced_for_statement=['for_each_statement'] 363 | while_statement=['while_statement'] 364 | do_first_statement=[] 365 | states=states.copy() 366 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 367 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 368 | if root_node.type==code: 369 | return [],states 370 | elif code in states: 371 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 372 | else: 373 | if root_node.type=='identifier': 374 | states[code]=[idx] 375 | return [(code,idx,'comesFrom',[],[])],states 376 | elif root_node.type in def_statement: 377 | if len(root_node.children)==2: 378 | name=root_node.children[0] 379 | value=root_node.children[1] 380 | else: 381 | name=root_node.children[0] 382 | value=None 383 | DFG=[] 384 | if value is None: 385 | indexs=tree_to_variable_index(name,index_to_code) 386 | for index in indexs: 387 | idx,code=index_to_code[index] 388 | DFG.append((code,idx,'comesFrom',[],[])) 389 | states[code]=[idx] 390 | return sorted(DFG,key=lambda x:x[1]),states 391 | else: 392 | name_indexs=tree_to_variable_index(name,index_to_code) 393 | value_indexs=tree_to_variable_index(value,index_to_code) 394 | temp,states=DFG_csharp(value,index_to_code,states) 395 | DFG+=temp 396 | for index1 in name_indexs: 397 | idx1,code1=index_to_code[index1] 398 | for index2 in value_indexs: 399 | idx2,code2=index_to_code[index2] 400 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 401 | states[code1]=[idx1] 402 | return sorted(DFG,key=lambda x:x[1]),states 403 | elif root_node.type in assignment: 404 | left_nodes=root_node.child_by_field_name('left') 405 | right_nodes=root_node.child_by_field_name('right') 406 | DFG=[] 407 | temp,states=DFG_csharp(right_nodes,index_to_code,states) 408 | DFG+=temp 409 | name_indexs=tree_to_variable_index(left_nodes,index_to_code) 410 | value_indexs=tree_to_variable_index(right_nodes,index_to_code) 411 | for index1 in name_indexs: 412 | idx1,code1=index_to_code[index1] 413 | for index2 in value_indexs: 414 | idx2,code2=index_to_code[index2] 415 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 416 | states[code1]=[idx1] 417 | return sorted(DFG,key=lambda x:x[1]),states 418 | elif root_node.type in increment_statement: 419 | DFG=[] 420 | indexs=tree_to_variable_index(root_node,index_to_code) 421 | for index1 in indexs: 422 | idx1,code1=index_to_code[index1] 423 | for index2 in indexs: 424 | idx2,code2=index_to_code[index2] 425 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 426 | states[code1]=[idx1] 427 | return sorted(DFG,key=lambda x:x[1]),states 428 | elif root_node.type in if_statement: 429 | DFG=[] 430 | current_states=states.copy() 431 | others_states=[] 432 | flag=False 433 | tag=False 434 | if 'else' in root_node.type: 435 | tag=True 436 | for child in root_node.children: 437 | if 'else' in child.type: 438 | tag=True 439 | if child.type not in if_statement and flag is False: 440 | temp,current_states=DFG_csharp(child,index_to_code,current_states) 441 | DFG+=temp 442 | else: 443 | flag=True 444 | temp,new_states=DFG_csharp(child,index_to_code,states) 445 | DFG+=temp 446 | others_states.append(new_states) 447 | others_states.append(current_states) 448 | if tag is False: 449 | others_states.append(states) 450 | new_states={} 451 | for dic in others_states: 452 | for key in dic: 453 | if key not in new_states: 454 | new_states[key]=dic[key].copy() 455 | else: 456 | new_states[key]+=dic[key] 457 | for key in new_states: 458 | new_states[key]=sorted(list(set(new_states[key]))) 459 | return sorted(DFG,key=lambda x:x[1]),new_states 460 | elif root_node.type in for_statement: 461 | DFG=[] 462 | for child in root_node.children: 463 | temp,states=DFG_csharp(child,index_to_code,states) 464 | DFG+=temp 465 | flag=False 466 | for child in root_node.children: 467 | if flag: 468 | temp,states=DFG_csharp(child,index_to_code,states) 469 | DFG+=temp 470 | elif child.type=="local_variable_declaration": 471 | flag=True 472 | dic={} 473 | for x in DFG: 474 | if (x[0],x[1],x[2]) not in dic: 475 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 476 | else: 477 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 478 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 479 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 480 | return sorted(DFG,key=lambda x:x[1]),states 481 | elif root_node.type in enhanced_for_statement: 482 | name=root_node.child_by_field_name('left') 483 | value=root_node.child_by_field_name('right') 484 | body=root_node.child_by_field_name('body') 485 | DFG=[] 486 | for i in range(2): 487 | temp,states=DFG_csharp(value,index_to_code,states) 488 | DFG+=temp 489 | name_indexs=tree_to_variable_index(name,index_to_code) 490 | value_indexs=tree_to_variable_index(value,index_to_code) 491 | for index1 in name_indexs: 492 | idx1,code1=index_to_code[index1] 493 | for index2 in value_indexs: 494 | idx2,code2=index_to_code[index2] 495 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 496 | states[code1]=[idx1] 497 | temp,states=DFG_csharp(body,index_to_code,states) 498 | DFG+=temp 499 | dic={} 500 | for x in DFG: 501 | if (x[0],x[1],x[2]) not in dic: 502 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 503 | else: 504 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 505 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 506 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 507 | return sorted(DFG,key=lambda x:x[1]),states 508 | elif root_node.type in while_statement: 509 | DFG=[] 510 | for i in range(2): 511 | for child in root_node.children: 512 | temp,states=DFG_csharp(child,index_to_code,states) 513 | DFG+=temp 514 | dic={} 515 | for x in DFG: 516 | if (x[0],x[1],x[2]) not in dic: 517 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 518 | else: 519 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 520 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 521 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 522 | return sorted(DFG,key=lambda x:x[1]),states 523 | else: 524 | DFG=[] 525 | for child in root_node.children: 526 | if child.type in do_first_statement: 527 | temp,states=DFG_csharp(child,index_to_code,states) 528 | DFG+=temp 529 | for child in root_node.children: 530 | if child.type not in do_first_statement: 531 | temp,states=DFG_csharp(child,index_to_code,states) 532 | DFG+=temp 533 | 534 | return sorted(DFG,key=lambda x:x[1]),states 535 | 536 | 537 | 538 | 539 | def DFG_ruby(root_node,index_to_code,states): 540 | assignment=['assignment','operator_assignment'] 541 | if_statement=['if','elsif','else','unless','when'] 542 | for_statement=['for'] 543 | while_statement=['while_modifier','until'] 544 | do_first_statement=[] 545 | def_statement=['keyword_parameter'] 546 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 547 | states=states.copy() 548 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 549 | if root_node.type==code: 550 | return [],states 551 | elif code in states: 552 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 553 | else: 554 | if root_node.type=='identifier': 555 | states[code]=[idx] 556 | return [(code,idx,'comesFrom',[],[])],states 557 | elif root_node.type in def_statement: 558 | name=root_node.child_by_field_name('name') 559 | value=root_node.child_by_field_name('value') 560 | DFG=[] 561 | if value is None: 562 | indexs=tree_to_variable_index(name,index_to_code) 563 | for index in indexs: 564 | idx,code=index_to_code[index] 565 | DFG.append((code,idx,'comesFrom',[],[])) 566 | states[code]=[idx] 567 | return sorted(DFG,key=lambda x:x[1]),states 568 | else: 569 | name_indexs=tree_to_variable_index(name,index_to_code) 570 | value_indexs=tree_to_variable_index(value,index_to_code) 571 | temp,states=DFG_ruby(value,index_to_code,states) 572 | DFG+=temp 573 | for index1 in name_indexs: 574 | idx1,code1=index_to_code[index1] 575 | for index2 in value_indexs: 576 | idx2,code2=index_to_code[index2] 577 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 578 | states[code1]=[idx1] 579 | return sorted(DFG,key=lambda x:x[1]),states 580 | elif root_node.type in assignment: 581 | left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=','] 582 | right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=','] 583 | if len(right_nodes)!=len(left_nodes): 584 | left_nodes=[root_node.child_by_field_name('left')] 585 | right_nodes=[root_node.child_by_field_name('right')] 586 | if len(left_nodes)==0: 587 | left_nodes=[root_node.child_by_field_name('left')] 588 | if len(right_nodes)==0: 589 | right_nodes=[root_node.child_by_field_name('right')] 590 | if root_node.type=="operator_assignment": 591 | left_nodes=[root_node.children[0]] 592 | right_nodes=[root_node.children[-1]] 593 | 594 | DFG=[] 595 | for node in right_nodes: 596 | temp,states=DFG_ruby(node,index_to_code,states) 597 | DFG+=temp 598 | 599 | for left_node,right_node in zip(left_nodes,right_nodes): 600 | left_tokens_index=tree_to_variable_index(left_node,index_to_code) 601 | right_tokens_index=tree_to_variable_index(right_node,index_to_code) 602 | temp=[] 603 | for token1_index in left_tokens_index: 604 | idx1,code1=index_to_code[token1_index] 605 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], 606 | [index_to_code[x][0] for x in right_tokens_index])) 607 | states[code1]=[idx1] 608 | DFG+=temp 609 | return sorted(DFG,key=lambda x:x[1]),states 610 | elif root_node.type in if_statement: 611 | DFG=[] 612 | current_states=states.copy() 613 | others_states=[] 614 | tag=False 615 | if 'else' in root_node.type: 616 | tag=True 617 | for child in root_node.children: 618 | if 'else' in child.type: 619 | tag=True 620 | if child.type not in if_statement: 621 | temp,current_states=DFG_ruby(child,index_to_code,current_states) 622 | DFG+=temp 623 | else: 624 | temp,new_states=DFG_ruby(child,index_to_code,states) 625 | DFG+=temp 626 | others_states.append(new_states) 627 | others_states.append(current_states) 628 | if tag is False: 629 | others_states.append(states) 630 | new_states={} 631 | for dic in others_states: 632 | for key in dic: 633 | if key not in new_states: 634 | new_states[key]=dic[key].copy() 635 | else: 636 | new_states[key]+=dic[key] 637 | for key in new_states: 638 | new_states[key]=sorted(list(set(new_states[key]))) 639 | return sorted(DFG,key=lambda x:x[1]),new_states 640 | elif root_node.type in for_statement: 641 | DFG=[] 642 | for i in range(2): 643 | left_nodes=[root_node.child_by_field_name('pattern')] 644 | right_nodes=[root_node.child_by_field_name('value')] 645 | assert len(right_nodes)==len(left_nodes) 646 | for node in right_nodes: 647 | temp,states=DFG_ruby(node,index_to_code,states) 648 | DFG+=temp 649 | for left_node,right_node in zip(left_nodes,right_nodes): 650 | left_tokens_index=tree_to_variable_index(left_node,index_to_code) 651 | right_tokens_index=tree_to_variable_index(right_node,index_to_code) 652 | temp=[] 653 | for token1_index in left_tokens_index: 654 | idx1,code1=index_to_code[token1_index] 655 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index], 656 | [index_to_code[x][0] for x in right_tokens_index])) 657 | states[code1]=[idx1] 658 | DFG+=temp 659 | temp,states=DFG_ruby(root_node.child_by_field_name('body'),index_to_code,states) 660 | DFG+=temp 661 | dic={} 662 | for x in DFG: 663 | if (x[0],x[1],x[2]) not in dic: 664 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 665 | else: 666 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 667 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 668 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 669 | return sorted(DFG,key=lambda x:x[1]),states 670 | elif root_node.type in while_statement: 671 | DFG=[] 672 | for i in range(2): 673 | for child in root_node.children: 674 | temp,states=DFG_ruby(child,index_to_code,states) 675 | DFG+=temp 676 | dic={} 677 | for x in DFG: 678 | if (x[0],x[1],x[2]) not in dic: 679 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 680 | else: 681 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 682 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 683 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 684 | return sorted(DFG,key=lambda x:x[1]),states 685 | else: 686 | DFG=[] 687 | for child in root_node.children: 688 | if child.type in do_first_statement: 689 | temp,states=DFG_ruby(child,index_to_code,states) 690 | DFG+=temp 691 | for child in root_node.children: 692 | if child.type not in do_first_statement: 693 | temp,states=DFG_ruby(child,index_to_code,states) 694 | DFG+=temp 695 | 696 | return sorted(DFG,key=lambda x:x[1]),states 697 | 698 | def DFG_go(root_node,index_to_code,states): 699 | assignment=['assignment_statement',] 700 | def_statement=['var_spec'] 701 | increment_statement=['inc_statement'] 702 | if_statement=['if_statement','else'] 703 | for_statement=['for_statement'] 704 | enhanced_for_statement=[] 705 | while_statement=[] 706 | do_first_statement=[] 707 | states=states.copy() 708 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 709 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 710 | if root_node.type==code: 711 | return [],states 712 | elif code in states: 713 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 714 | else: 715 | if root_node.type=='identifier': 716 | states[code]=[idx] 717 | return [(code,idx,'comesFrom',[],[])],states 718 | elif root_node.type in def_statement: 719 | name=root_node.child_by_field_name('name') 720 | value=root_node.child_by_field_name('value') 721 | DFG=[] 722 | if value is None: 723 | indexs=tree_to_variable_index(name,index_to_code) 724 | for index in indexs: 725 | idx,code=index_to_code[index] 726 | DFG.append((code,idx,'comesFrom',[],[])) 727 | states[code]=[idx] 728 | return sorted(DFG,key=lambda x:x[1]),states 729 | else: 730 | name_indexs=tree_to_variable_index(name,index_to_code) 731 | value_indexs=tree_to_variable_index(value,index_to_code) 732 | temp,states=DFG_go(value,index_to_code,states) 733 | DFG+=temp 734 | for index1 in name_indexs: 735 | idx1,code1=index_to_code[index1] 736 | for index2 in value_indexs: 737 | idx2,code2=index_to_code[index2] 738 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 739 | states[code1]=[idx1] 740 | return sorted(DFG,key=lambda x:x[1]),states 741 | elif root_node.type in assignment: 742 | left_nodes=root_node.child_by_field_name('left') 743 | right_nodes=root_node.child_by_field_name('right') 744 | DFG=[] 745 | temp,states=DFG_go(right_nodes,index_to_code,states) 746 | DFG+=temp 747 | name_indexs=tree_to_variable_index(left_nodes,index_to_code) 748 | value_indexs=tree_to_variable_index(right_nodes,index_to_code) 749 | for index1 in name_indexs: 750 | idx1,code1=index_to_code[index1] 751 | for index2 in value_indexs: 752 | idx2,code2=index_to_code[index2] 753 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 754 | states[code1]=[idx1] 755 | return sorted(DFG,key=lambda x:x[1]),states 756 | elif root_node.type in increment_statement: 757 | DFG=[] 758 | indexs=tree_to_variable_index(root_node,index_to_code) 759 | for index1 in indexs: 760 | idx1,code1=index_to_code[index1] 761 | for index2 in indexs: 762 | idx2,code2=index_to_code[index2] 763 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 764 | states[code1]=[idx1] 765 | return sorted(DFG,key=lambda x:x[1]),states 766 | elif root_node.type in if_statement: 767 | DFG=[] 768 | current_states=states.copy() 769 | others_states=[] 770 | flag=False 771 | tag=False 772 | if 'else' in root_node.type: 773 | tag=True 774 | for child in root_node.children: 775 | if 'else' in child.type: 776 | tag=True 777 | if child.type not in if_statement and flag is False: 778 | temp,current_states=DFG_go(child,index_to_code,current_states) 779 | DFG+=temp 780 | else: 781 | flag=True 782 | temp,new_states=DFG_go(child,index_to_code,states) 783 | DFG+=temp 784 | others_states.append(new_states) 785 | others_states.append(current_states) 786 | if tag is False: 787 | others_states.append(states) 788 | new_states={} 789 | for dic in others_states: 790 | for key in dic: 791 | if key not in new_states: 792 | new_states[key]=dic[key].copy() 793 | else: 794 | new_states[key]+=dic[key] 795 | for key in states: 796 | if key not in new_states: 797 | new_states[key]=states[key] 798 | else: 799 | new_states[key]+=states[key] 800 | for key in new_states: 801 | new_states[key]=sorted(list(set(new_states[key]))) 802 | return sorted(DFG,key=lambda x:x[1]),new_states 803 | elif root_node.type in for_statement: 804 | DFG=[] 805 | for child in root_node.children: 806 | temp,states=DFG_go(child,index_to_code,states) 807 | DFG+=temp 808 | flag=False 809 | for child in root_node.children: 810 | if flag: 811 | temp,states=DFG_go(child,index_to_code,states) 812 | DFG+=temp 813 | elif child.type=="for_clause": 814 | if child.child_by_field_name('update') is not None: 815 | temp,states=DFG_go(child.child_by_field_name('update'),index_to_code,states) 816 | DFG+=temp 817 | flag=True 818 | dic={} 819 | for x in DFG: 820 | if (x[0],x[1],x[2]) not in dic: 821 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 822 | else: 823 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 824 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 825 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 826 | return sorted(DFG,key=lambda x:x[1]),states 827 | else: 828 | DFG=[] 829 | for child in root_node.children: 830 | if child.type in do_first_statement: 831 | temp,states=DFG_go(child,index_to_code,states) 832 | DFG+=temp 833 | for child in root_node.children: 834 | if child.type not in do_first_statement: 835 | temp,states=DFG_go(child,index_to_code,states) 836 | DFG+=temp 837 | 838 | return sorted(DFG,key=lambda x:x[1]),states 839 | 840 | 841 | 842 | 843 | def DFG_php(root_node,index_to_code,states): 844 | assignment=['assignment_expression','augmented_assignment_expression'] 845 | def_statement=['simple_parameter'] 846 | increment_statement=['update_expression'] 847 | if_statement=['if_statement','else_clause'] 848 | for_statement=['for_statement'] 849 | enhanced_for_statement=['foreach_statement'] 850 | while_statement=['while_statement'] 851 | do_first_statement=[] 852 | states=states.copy() 853 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 854 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 855 | if root_node.type==code: 856 | return [],states 857 | elif code in states: 858 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 859 | else: 860 | if root_node.type=='identifier': 861 | states[code]=[idx] 862 | return [(code,idx,'comesFrom',[],[])],states 863 | elif root_node.type in def_statement: 864 | name=root_node.child_by_field_name('name') 865 | value=root_node.child_by_field_name('default_value') 866 | DFG=[] 867 | if value is None: 868 | indexs=tree_to_variable_index(name,index_to_code) 869 | for index in indexs: 870 | idx,code=index_to_code[index] 871 | DFG.append((code,idx,'comesFrom',[],[])) 872 | states[code]=[idx] 873 | return sorted(DFG,key=lambda x:x[1]),states 874 | else: 875 | name_indexs=tree_to_variable_index(name,index_to_code) 876 | value_indexs=tree_to_variable_index(value,index_to_code) 877 | temp,states=DFG_php(value,index_to_code,states) 878 | DFG+=temp 879 | for index1 in name_indexs: 880 | idx1,code1=index_to_code[index1] 881 | for index2 in value_indexs: 882 | idx2,code2=index_to_code[index2] 883 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 884 | states[code1]=[idx1] 885 | return sorted(DFG,key=lambda x:x[1]),states 886 | elif root_node.type in assignment: 887 | left_nodes=root_node.child_by_field_name('left') 888 | right_nodes=root_node.child_by_field_name('right') 889 | DFG=[] 890 | temp,states=DFG_php(right_nodes,index_to_code,states) 891 | DFG+=temp 892 | name_indexs=tree_to_variable_index(left_nodes,index_to_code) 893 | value_indexs=tree_to_variable_index(right_nodes,index_to_code) 894 | for index1 in name_indexs: 895 | idx1,code1=index_to_code[index1] 896 | for index2 in value_indexs: 897 | idx2,code2=index_to_code[index2] 898 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 899 | states[code1]=[idx1] 900 | return sorted(DFG,key=lambda x:x[1]),states 901 | elif root_node.type in increment_statement: 902 | DFG=[] 903 | indexs=tree_to_variable_index(root_node,index_to_code) 904 | for index1 in indexs: 905 | idx1,code1=index_to_code[index1] 906 | for index2 in indexs: 907 | idx2,code2=index_to_code[index2] 908 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 909 | states[code1]=[idx1] 910 | return sorted(DFG,key=lambda x:x[1]),states 911 | elif root_node.type in if_statement: 912 | DFG=[] 913 | current_states=states.copy() 914 | others_states=[] 915 | flag=False 916 | tag=False 917 | if 'else' in root_node.type: 918 | tag=True 919 | for child in root_node.children: 920 | if 'else' in child.type: 921 | tag=True 922 | if child.type not in if_statement and flag is False: 923 | temp,current_states=DFG_php(child,index_to_code,current_states) 924 | DFG+=temp 925 | else: 926 | flag=True 927 | temp,new_states=DFG_php(child,index_to_code,states) 928 | DFG+=temp 929 | others_states.append(new_states) 930 | others_states.append(current_states) 931 | new_states={} 932 | for dic in others_states: 933 | for key in dic: 934 | if key not in new_states: 935 | new_states[key]=dic[key].copy() 936 | else: 937 | new_states[key]+=dic[key] 938 | for key in states: 939 | if key not in new_states: 940 | new_states[key]=states[key] 941 | else: 942 | new_states[key]+=states[key] 943 | for key in new_states: 944 | new_states[key]=sorted(list(set(new_states[key]))) 945 | return sorted(DFG,key=lambda x:x[1]),new_states 946 | elif root_node.type in for_statement: 947 | DFG=[] 948 | for child in root_node.children: 949 | temp,states=DFG_php(child,index_to_code,states) 950 | DFG+=temp 951 | flag=False 952 | for child in root_node.children: 953 | if flag: 954 | temp,states=DFG_php(child,index_to_code,states) 955 | DFG+=temp 956 | elif child.type=="assignment_expression": 957 | flag=True 958 | dic={} 959 | for x in DFG: 960 | if (x[0],x[1],x[2]) not in dic: 961 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 962 | else: 963 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 964 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 965 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 966 | return sorted(DFG,key=lambda x:x[1]),states 967 | elif root_node.type in enhanced_for_statement: 968 | name=None 969 | value=None 970 | for child in root_node.children: 971 | if child.type=='variable_name' and value is None: 972 | value=child 973 | elif child.type=='variable_name' and name is None: 974 | name=child 975 | break 976 | body=root_node.child_by_field_name('body') 977 | DFG=[] 978 | for i in range(2): 979 | temp,states=DFG_php(value,index_to_code,states) 980 | DFG+=temp 981 | name_indexs=tree_to_variable_index(name,index_to_code) 982 | value_indexs=tree_to_variable_index(value,index_to_code) 983 | for index1 in name_indexs: 984 | idx1,code1=index_to_code[index1] 985 | for index2 in value_indexs: 986 | idx2,code2=index_to_code[index2] 987 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 988 | states[code1]=[idx1] 989 | temp,states=DFG_php(body,index_to_code,states) 990 | DFG+=temp 991 | dic={} 992 | for x in DFG: 993 | if (x[0],x[1],x[2]) not in dic: 994 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 995 | else: 996 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 997 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 998 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 999 | return sorted(DFG,key=lambda x:x[1]),states 1000 | elif root_node.type in while_statement: 1001 | DFG=[] 1002 | for i in range(2): 1003 | for child in root_node.children: 1004 | temp,states=DFG_php(child,index_to_code,states) 1005 | DFG+=temp 1006 | dic={} 1007 | for x in DFG: 1008 | if (x[0],x[1],x[2]) not in dic: 1009 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 1010 | else: 1011 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 1012 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 1013 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 1014 | return sorted(DFG,key=lambda x:x[1]),states 1015 | else: 1016 | DFG=[] 1017 | for child in root_node.children: 1018 | if child.type in do_first_statement: 1019 | temp,states=DFG_php(child,index_to_code,states) 1020 | DFG+=temp 1021 | for child in root_node.children: 1022 | if child.type not in do_first_statement: 1023 | temp,states=DFG_php(child,index_to_code,states) 1024 | DFG+=temp 1025 | 1026 | return sorted(DFG,key=lambda x:x[1]),states 1027 | 1028 | 1029 | def DFG_javascript(root_node,index_to_code,states): 1030 | assignment=['assignment_pattern','augmented_assignment_expression'] 1031 | def_statement=['variable_declarator'] 1032 | increment_statement=['update_expression'] 1033 | if_statement=['if_statement','else'] 1034 | for_statement=['for_statement'] 1035 | enhanced_for_statement=[] 1036 | while_statement=['while_statement'] 1037 | do_first_statement=[] 1038 | states=states.copy() 1039 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 1040 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)] 1041 | if root_node.type==code: 1042 | return [],states 1043 | elif code in states: 1044 | return [(code,idx,'comesFrom',[code],states[code].copy())],states 1045 | else: 1046 | if root_node.type=='identifier': 1047 | states[code]=[idx] 1048 | return [(code,idx,'comesFrom',[],[])],states 1049 | elif root_node.type in def_statement: 1050 | name=root_node.child_by_field_name('name') 1051 | value=root_node.child_by_field_name('value') 1052 | DFG=[] 1053 | if value is None: 1054 | indexs=tree_to_variable_index(name,index_to_code) 1055 | for index in indexs: 1056 | idx,code=index_to_code[index] 1057 | DFG.append((code,idx,'comesFrom',[],[])) 1058 | states[code]=[idx] 1059 | return sorted(DFG,key=lambda x:x[1]),states 1060 | else: 1061 | name_indexs=tree_to_variable_index(name,index_to_code) 1062 | value_indexs=tree_to_variable_index(value,index_to_code) 1063 | temp,states=DFG_javascript(value,index_to_code,states) 1064 | DFG+=temp 1065 | for index1 in name_indexs: 1066 | idx1,code1=index_to_code[index1] 1067 | for index2 in value_indexs: 1068 | idx2,code2=index_to_code[index2] 1069 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2])) 1070 | states[code1]=[idx1] 1071 | return sorted(DFG,key=lambda x:x[1]),states 1072 | elif root_node.type in assignment: 1073 | left_nodes=root_node.child_by_field_name('left') 1074 | right_nodes=root_node.child_by_field_name('right') 1075 | DFG=[] 1076 | temp,states=DFG_javascript(right_nodes,index_to_code,states) 1077 | DFG+=temp 1078 | name_indexs=tree_to_variable_index(left_nodes,index_to_code) 1079 | value_indexs=tree_to_variable_index(right_nodes,index_to_code) 1080 | for index1 in name_indexs: 1081 | idx1,code1=index_to_code[index1] 1082 | for index2 in value_indexs: 1083 | idx2,code2=index_to_code[index2] 1084 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 1085 | states[code1]=[idx1] 1086 | return sorted(DFG,key=lambda x:x[1]),states 1087 | elif root_node.type in increment_statement: 1088 | DFG=[] 1089 | indexs=tree_to_variable_index(root_node,index_to_code) 1090 | for index1 in indexs: 1091 | idx1,code1=index_to_code[index1] 1092 | for index2 in indexs: 1093 | idx2,code2=index_to_code[index2] 1094 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2])) 1095 | states[code1]=[idx1] 1096 | return sorted(DFG,key=lambda x:x[1]),states 1097 | elif root_node.type in if_statement: 1098 | DFG=[] 1099 | current_states=states.copy() 1100 | others_states=[] 1101 | flag=False 1102 | tag=False 1103 | if 'else' in root_node.type: 1104 | tag=True 1105 | for child in root_node.children: 1106 | if 'else' in child.type: 1107 | tag=True 1108 | if child.type not in if_statement and flag is False: 1109 | temp,current_states=DFG_javascript(child,index_to_code,current_states) 1110 | DFG+=temp 1111 | else: 1112 | flag=True 1113 | temp,new_states=DFG_javascript(child,index_to_code,states) 1114 | DFG+=temp 1115 | others_states.append(new_states) 1116 | others_states.append(current_states) 1117 | if tag is False: 1118 | others_states.append(states) 1119 | new_states={} 1120 | for dic in others_states: 1121 | for key in dic: 1122 | if key not in new_states: 1123 | new_states[key]=dic[key].copy() 1124 | else: 1125 | new_states[key]+=dic[key] 1126 | for key in states: 1127 | if key not in new_states: 1128 | new_states[key]=states[key] 1129 | else: 1130 | new_states[key]+=states[key] 1131 | for key in new_states: 1132 | new_states[key]=sorted(list(set(new_states[key]))) 1133 | return sorted(DFG,key=lambda x:x[1]),new_states 1134 | elif root_node.type in for_statement: 1135 | DFG=[] 1136 | for child in root_node.children: 1137 | temp,states=DFG_javascript(child,index_to_code,states) 1138 | DFG+=temp 1139 | flag=False 1140 | for child in root_node.children: 1141 | if flag: 1142 | temp,states=DFG_javascript(child,index_to_code,states) 1143 | DFG+=temp 1144 | elif child.type=="variable_declaration": 1145 | flag=True 1146 | dic={} 1147 | for x in DFG: 1148 | if (x[0],x[1],x[2]) not in dic: 1149 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 1150 | else: 1151 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 1152 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 1153 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 1154 | return sorted(DFG,key=lambda x:x[1]),states 1155 | elif root_node.type in while_statement: 1156 | DFG=[] 1157 | for i in range(2): 1158 | for child in root_node.children: 1159 | temp,states=DFG_javascript(child,index_to_code,states) 1160 | DFG+=temp 1161 | dic={} 1162 | for x in DFG: 1163 | if (x[0],x[1],x[2]) not in dic: 1164 | dic[(x[0],x[1],x[2])]=[x[3],x[4]] 1165 | else: 1166 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3])) 1167 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4]))) 1168 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])] 1169 | return sorted(DFG,key=lambda x:x[1]),states 1170 | else: 1171 | DFG=[] 1172 | for child in root_node.children: 1173 | if child.type in do_first_statement: 1174 | temp,states=DFG_javascript(child,index_to_code,states) 1175 | DFG+=temp 1176 | for child in root_node.children: 1177 | if child.type not in do_first_statement: 1178 | temp,states=DFG_javascript(child,index_to_code,states) 1179 | DFG+=temp 1180 | 1181 | return sorted(DFG,key=lambda x:x[1]),states 1182 | 1183 | 1184 | 1185 | -------------------------------------------------------------------------------- /GraphCodeBERT/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import (remove_comments_and_docstrings, 2 | tree_to_token_index, 3 | index_to_code_token, 4 | tree_to_variable_index) 5 | from .DFG import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp -------------------------------------------------------------------------------- /GraphCodeBERT/parser/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from tree_sitter import Language, Parser 5 | 6 | Language.build_library( 7 | # Store the library in the `build` directory 8 | 'my-languages.so', 9 | 10 | # Include one or more languages 11 | [ 12 | 'tree-sitter-go', 13 | 'tree-sitter-javascript', 14 | 'tree-sitter-python', 15 | 'tree-sitter-php', 16 | 'tree-sitter-java', 17 | 'tree-sitter-ruby', 18 | 'tree-sitter-c-sharp', 19 | ] 20 | ) 21 | 22 | -------------------------------------------------------------------------------- /GraphCodeBERT/parser/build.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/tree-sitter/tree-sitter-go 2 | git clone https://github.com/tree-sitter/tree-sitter-javascript 3 | git clone https://github.com/tree-sitter/tree-sitter-python 4 | git clone https://github.com/tree-sitter/tree-sitter-ruby 5 | git clone https://github.com/tree-sitter/tree-sitter-php 6 | git clone https://github.com/tree-sitter/tree-sitter-java 7 | git clone https://github.com/tree-sitter/tree-sitter-c-sharp 8 | python build.py 9 | -------------------------------------------------------------------------------- /GraphCodeBERT/parser/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from io import StringIO 3 | import tokenize 4 | def remove_comments_and_docstrings(source,lang): 5 | if lang in ['python']: 6 | """ 7 | Returns 'source' minus comments and docstrings. 8 | """ 9 | io_obj = StringIO(source) 10 | out = "" 11 | prev_toktype = tokenize.INDENT 12 | last_lineno = -1 13 | last_col = 0 14 | for tok in tokenize.generate_tokens(io_obj.readline): 15 | token_type = tok[0] 16 | token_string = tok[1] 17 | start_line, start_col = tok[2] 18 | end_line, end_col = tok[3] 19 | ltext = tok[4] 20 | if start_line > last_lineno: 21 | last_col = 0 22 | if start_col > last_col: 23 | out += (" " * (start_col - last_col)) 24 | # Remove comments: 25 | if token_type == tokenize.COMMENT: 26 | pass 27 | # This series of conditionals removes docstrings: 28 | elif token_type == tokenize.STRING: 29 | if prev_toktype != tokenize.INDENT: 30 | # This is likely a docstring; double-check we're not inside an operator: 31 | if prev_toktype != tokenize.NEWLINE: 32 | if start_col > 0: 33 | out += token_string 34 | else: 35 | out += token_string 36 | prev_toktype = token_type 37 | last_col = end_col 38 | last_lineno = end_line 39 | temp=[] 40 | for x in out.split('\n'): 41 | if x.strip()!="": 42 | temp.append(x) 43 | return '\n'.join(temp) 44 | elif lang in ['ruby']: 45 | return source 46 | else: 47 | def replacer(match): 48 | s = match.group(0) 49 | if s.startswith('/'): 50 | return " " # note: a space and not an empty string 51 | else: 52 | return s 53 | pattern = re.compile( 54 | r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', 55 | re.DOTALL | re.MULTILINE 56 | ) 57 | temp=[] 58 | for x in re.sub(pattern, replacer, source).split('\n'): 59 | if x.strip()!="": 60 | temp.append(x) 61 | return '\n'.join(temp) 62 | 63 | def tree_to_token_index(root_node): 64 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 65 | return [(root_node.start_point,root_node.end_point)] 66 | else: 67 | code_tokens=[] 68 | for child in root_node.children: 69 | code_tokens+=tree_to_token_index(child) 70 | return code_tokens 71 | 72 | def tree_to_variable_index(root_node,index_to_code): 73 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment': 74 | index=(root_node.start_point,root_node.end_point) 75 | _,code=index_to_code[index] 76 | if root_node.type!=code: 77 | return [(root_node.start_point,root_node.end_point)] 78 | else: 79 | return [] 80 | else: 81 | code_tokens=[] 82 | for child in root_node.children: 83 | code_tokens+=tree_to_variable_index(child,index_to_code) 84 | return code_tokens 85 | 86 | def index_to_code_token(index,code): 87 | start_point=index[0] 88 | end_point=index[1] 89 | if start_point[0]==end_point[0]: 90 | s=code[start_point[0]][start_point[1]:end_point[1]] 91 | else: 92 | s="" 93 | s+=code[start_point[0]][start_point[1]:] 94 | for i in range(start_point[0]+1,end_point[0]): 95 | s+=code[i] 96 | s+=code[end_point[0]][:end_point[1]] 97 | return s 98 | -------------------------------------------------------------------------------- /GraphCodeBERT/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). 18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned 19 | using a masked language modeling (MLM) loss. 20 | """ 21 | 22 | from __future__ import absolute_import, division, print_function 23 | 24 | import argparse 25 | import glob 26 | import logging 27 | import os 28 | import pickle 29 | import random 30 | import re 31 | import shutil 32 | import json 33 | import numpy as np 34 | import torch 35 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset 36 | from torch.utils.data.distributed import DistributedSampler 37 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 38 | RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer) 39 | from tqdm import tqdm, trange 40 | import multiprocessing 41 | from model import Model 42 | import pdb 43 | cpu_cont = 16 44 | logger = logging.getLogger(__name__) 45 | 46 | from parser import DFG_python, DFG_java, DFG_ruby, DFG_go, DFG_php, DFG_javascript 47 | from parser import (remove_comments_and_docstrings, 48 | tree_to_token_index, 49 | index_to_code_token, 50 | tree_to_variable_index) 51 | from tree_sitter import Language, Parser 52 | 53 | root_path = "tg" 54 | 55 | dfg_function = { 56 | 'python': DFG_python, 57 | 'java': DFG_java, 58 | 'ruby': DFG_ruby, 59 | 'go': DFG_go, 60 | 'php': DFG_php, 61 | 'javascript': DFG_javascript 62 | } 63 | 64 | # load parsers 65 | parsers = {} 66 | for lang in dfg_function: 67 | LANGUAGE = Language('parser/my-languages.so', lang) 68 | parser = Parser() 69 | parser.set_language(LANGUAGE) 70 | parser = [parser, dfg_function[lang]] 71 | parsers[lang] = parser 72 | 73 | 74 | # remove comments, tokenize code and extract dataflow 75 | def extract_dataflow(code, parser, lang): 76 | # remove comments 77 | code = code.replace("\\n", "\n") 78 | try: 79 | code = remove_comments_and_docstrings(code, lang) 80 | except: 81 | pass 82 | # obtain dataflow 83 | if lang == "php": 84 | code = "" 85 | try: 86 | tree = parser[0].parse(bytes(code, 'utf8')) 87 | root_node = tree.root_node 88 | tokens_index = tree_to_token_index(root_node) 89 | code = code.split('\n') 90 | code_tokens = [index_to_code_token(x, code) for x in tokens_index] 91 | index_to_code = {} 92 | for idx, (index, code) in enumerate(zip(tokens_index, code_tokens)): 93 | index_to_code[index] = (idx, code) 94 | try: 95 | DFG, _ = parser[1](root_node, index_to_code, {}) 96 | except: 97 | DFG = [] 98 | DFG = sorted(DFG, key=lambda x: x[1]) 99 | indexs = set() 100 | for d in DFG: 101 | if len(d[-1]) != 0: 102 | indexs.add(d[1]) 103 | for x in d[-1]: 104 | indexs.add(x) 105 | new_DFG = [] 106 | for d in DFG: 107 | if d[1] in indexs: 108 | new_DFG.append(d) 109 | dfg = new_DFG 110 | except: 111 | dfg = [] 112 | return code_tokens, dfg 113 | 114 | 115 | class InputFeatures(object): 116 | """A single training/test features for a example.""" 117 | 118 | def __init__(self, 119 | input_tokens, 120 | input_ids, 121 | position_idx, 122 | dfg_to_code, 123 | dfg_to_dfg, 124 | label 125 | 126 | ): 127 | self.input_tokens = input_tokens 128 | self.input_ids = input_ids 129 | self.position_idx = position_idx 130 | self.dfg_to_code = dfg_to_code 131 | self.dfg_to_dfg = dfg_to_dfg 132 | self.label = label 133 | 134 | 135 | def convert_examples_to_features(code, label, tokenizer, args): 136 | # source 137 | # code=' '.join(js['func'].split()) 138 | parser = parsers["java"] 139 | code_tokens, dfg = extract_dataflow(code, parser, "java") 140 | 141 | code_tokens = [tokenizer.tokenize('@ ' + x)[1:] if idx != 0 else tokenizer.tokenize(x) for idx, x in enumerate(code_tokens)] 142 | ori2cur_pos = {} 143 | ori2cur_pos[-1] = (0, 0) 144 | for i in range(len(code_tokens)): 145 | ori2cur_pos[i] = (ori2cur_pos[i - 1][1], ori2cur_pos[i - 1][1] + len(code_tokens[i])) 146 | code_tokens = [y for x in code_tokens for y in x] 147 | 148 | code_tokens = code_tokens[:args.code_length + args.data_flow_length - 2 - min(len(dfg), args.data_flow_length)] 149 | source_tokens = [tokenizer.cls_token] + code_tokens + [tokenizer.sep_token] 150 | source_ids = tokenizer.convert_tokens_to_ids(source_tokens) 151 | position_idx = [i + tokenizer.pad_token_id + 1 for i in range(len(source_tokens))] 152 | dfg = dfg[:args.code_length + args.data_flow_length - len(source_tokens)] 153 | source_tokens += [x[0] for x in dfg] 154 | position_idx += [0 for x in dfg] 155 | source_ids += [tokenizer.unk_token_id for x in dfg] 156 | padding_length = args.code_length + args.data_flow_length - len(source_ids) 157 | position_idx += [tokenizer.pad_token_id] * padding_length 158 | source_ids += [tokenizer.pad_token_id] * padding_length 159 | 160 | reverse_index = {} 161 | for idx, x in enumerate(dfg): 162 | reverse_index[x[1]] = idx 163 | for idx, x in enumerate(dfg): 164 | dfg[idx] = x[:-1] + ([reverse_index[i] for i in x[-1] if i in reverse_index],) 165 | dfg_to_dfg = [x[-1] for x in dfg] 166 | dfg_to_code = [ori2cur_pos[x[1]] for x in dfg] 167 | length = len([tokenizer.cls_token]) 168 | dfg_to_code = [(x[0] + length, x[1] + length) for x in dfg_to_code] 169 | 170 | return InputFeatures(source_tokens, source_ids, position_idx, dfg_to_code, dfg_to_dfg, label) 171 | 172 | 173 | class TextDataset(Dataset): 174 | def __init__(self, tokenizer, args, file_path=None): 175 | self.examples = [] 176 | self.args = args 177 | file_type = file_path.split('/')[-1].split('.')[0] 178 | folder = '/'.join(file_path.split('/')[:-1]) 179 | 180 | cache_file_path = os.path.join(folder, 'cached_{}'.format(file_type)) 181 | code_pairs_file_path = os.path.join(folder, 'cached_{}.pkl'.format(file_type)) 182 | 183 | print('\n cached_features_file: ', cache_file_path) 184 | try: 185 | self.examples = torch.load(cache_file_path) 186 | with open(code_pairs_file_path, 'rb') as f: 187 | code_files = pickle.load(f) 188 | logger.info("Loading features from cached file %s", cache_file_path) 189 | 190 | except: 191 | logger.info("Creating features from dataset file at %s", file_path) 192 | code_files = [] 193 | with open(file_path) as f: 194 | for line in f: 195 | js = json.loads(line.strip()) 196 | code = ' '.join(js['func'].split()) 197 | label = js['target'] 198 | #label = torch.nn.functional.one_hot(label, args.num_labels) 199 | self.examples.append(convert_examples_to_features(code, label, tokenizer, args)) 200 | code_files.append(code) 201 | assert (len(self.examples) == len(code_files)) 202 | with open(code_pairs_file_path, 'wb') as f: 203 | pickle.dump(code_files, f) 204 | logger.info("Saving features into cached file %s", cache_file_path) 205 | torch.save(self.examples, cache_file_path) 206 | 207 | def __len__(self): 208 | return len(self.examples) 209 | 210 | def __getitem__(self, item): 211 | # calculate graph-guided masked function 212 | attn_mask = np.zeros((self.args.code_length + self.args.data_flow_length, 213 | self.args.code_length + self.args.data_flow_length), dtype=np.bool) 214 | # calculate begin index of node and max length of input 215 | 216 | node_index = sum([i > 1 for i in self.examples[item].position_idx]) 217 | max_length = sum([i != 1 for i in self.examples[item].position_idx]) 218 | # sequence can attend to sequence 219 | attn_mask[:node_index, :node_index] = True 220 | # special tokens attend to all tokens 221 | 222 | for idx, i in enumerate(self.examples[item].input_ids): 223 | if i in [0, 2]: 224 | attn_mask[idx, :max_length] = True 225 | # nodes attend to code tokens that are identified from 226 | for idx, (a, b) in enumerate(self.examples[item].dfg_to_code): 227 | if a < node_index and b < node_index: 228 | attn_mask[idx + node_index, a:b] = True 229 | attn_mask[a:b, idx + node_index] = True 230 | # nodes attend to adjacent nodes 231 | for idx, nodes in enumerate(self.examples[item].dfg_to_dfg): 232 | for a in nodes: 233 | if a + node_index < len(self.examples[item].position_idx): 234 | attn_mask[idx + node_index, a + node_index] = True 235 | 236 | return (torch.tensor(self.examples[item].input_ids), 237 | torch.tensor(attn_mask), 238 | torch.tensor(self.examples[item].position_idx), 239 | torch.tensor(self.examples[item].label)) 240 | 241 | 242 | def set_seed(args): 243 | random.seed(args.seed) 244 | np.random.seed(args.seed) 245 | torch.manual_seed(args.seed) 246 | if args.n_gpu > 0: 247 | torch.cuda.manual_seed_all(args.seed) 248 | 249 | 250 | def train(args, train_dataset, model, tokenizer): 251 | """ Train the model """ 252 | 253 | # build dataloader 254 | train_sampler = RandomSampler(train_dataset) 255 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=4) 256 | 257 | args.max_steps = args.epochs * len(train_dataloader) 258 | args.save_steps = len(train_dataloader) 259 | args.warmup_steps = args.max_steps // 5 260 | model.to(args.device) 261 | 262 | # Prepare optimizer and schedule (linear warmup and decay) 263 | no_decay = ['bias', 'LayerNorm.weight'] 264 | optimizer_grouped_parameters = [ 265 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 266 | 'weight_decay': args.weight_decay}, 267 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 268 | ] 269 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 270 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, 271 | num_training_steps=args.max_steps) 272 | 273 | # multi-gpu training 274 | if args.n_gpu > 1: 275 | model = torch.nn.DataParallel(model) 276 | 277 | # Train! 278 | logger.info("***** Running training *****") 279 | logger.info(" Num examples = %d", len(train_dataset)) 280 | logger.info(" Num Epochs = %d", args.epochs) 281 | logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size // max(args.n_gpu, 1)) 282 | logger.info(" Total train batch size = %d", args.train_batch_size * args.gradient_accumulation_steps) 283 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 284 | logger.info(" Total optimization steps = %d", args.max_steps) 285 | 286 | global_step = 0 287 | tr_loss, logging_loss, avg_loss, tr_nb, tr_num, train_loss = 0.0, 0.0, 0.0, 0, 0, 0 288 | best_f1 = 0 289 | 290 | model.zero_grad() 291 | 292 | for idx in range(args.epochs): 293 | bar = tqdm(train_dataloader, total=len(train_dataloader)) 294 | tr_num = 0 295 | train_loss = 0 296 | for step, batch in enumerate(bar): 297 | inputs_ids = batch[0].to(args.device) 298 | attn_mask = batch[1].to(args.device) 299 | position_idx = batch[2].to(args.device) 300 | labels = batch[3].to(args.device) 301 | labels = torch.nn.functional.one_hot(labels, args.num_labels) 302 | model.train() 303 | loss, logits = model(inputs_ids, attn_mask, position_idx, labels) 304 | 305 | if args.n_gpu > 1: 306 | loss = loss.mean() 307 | 308 | if args.gradient_accumulation_steps > 1: 309 | loss = loss / args.gradient_accumulation_steps 310 | 311 | loss.backward() 312 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 313 | 314 | tr_loss += loss.item() 315 | tr_num += 1 316 | train_loss += loss.item() 317 | if avg_loss == 0: 318 | avg_loss = tr_loss 319 | 320 | avg_loss = round(train_loss / tr_num, 5) 321 | bar.set_description("epoch {} loss {}".format(idx, avg_loss)) 322 | 323 | if (step + 1) % args.gradient_accumulation_steps == 0: 324 | optimizer.step() 325 | optimizer.zero_grad() 326 | scheduler.step() 327 | global_step += 1 328 | output_flag = True 329 | avg_loss = round(np.exp((tr_loss - logging_loss) / (global_step - tr_nb)), 4) 330 | 331 | if global_step % args.save_steps == 0: 332 | results = evaluate(args, model, tokenizer, eval_when_training=True) 333 | 334 | # Save model checkpoint 335 | if results['eval_precision'] > best_f1: 336 | best_f1 = results['eval_precision'] 337 | logger.info(" " + "*" * 20) 338 | logger.info(" Best precision:%s", round(best_f1, 4)) 339 | logger.info(" " + "*" * 20) 340 | 341 | output_dir1 = f"{root_path}/{args.output_dir}" 342 | if not os.path.exists(output_dir1): 343 | os.makedirs(output_dir1) 344 | model_to_save = model.module if hasattr(model, 'module') else model 345 | output_dir = os.path.join(output_dir1, '{}'.format('model.bin')) 346 | torch.save(model_to_save.state_dict(), output_dir) 347 | logger.info("Saving model checkpoint to %s", output_dir) 348 | 349 | 350 | def evaluate(args, model, tokenizer, eval_when_training=False): 351 | # build dataloader 352 | eval_dataset = TextDataset(tokenizer, args, file_path=args.eval_data_file) 353 | eval_sampler = SequentialSampler(eval_dataset) 354 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=4) 355 | 356 | # multi-gpu evaluate 357 | if args.n_gpu > 1 and eval_when_training is False: 358 | model = torch.nn.DataParallel(model) 359 | 360 | # Eval! 361 | logger.info("***** Running evaluation *****") 362 | logger.info(" Num examples = %d", len(eval_dataset)) 363 | logger.info(" Batch size = %d", args.eval_batch_size) 364 | 365 | eval_loss = 0.0 366 | nb_eval_steps = 0 367 | model.eval() 368 | logits = [] 369 | y_trues = [] 370 | for batch in eval_dataloader: 371 | inputs_ids = batch[0].to(args.device) 372 | attn_mask = batch[1].to(args.device) 373 | position_idx = batch[2].to(args.device) 374 | label = batch[3].to(args.device) 375 | labels = torch.nn.functional.one_hot(labels, args.num_labels) 376 | with torch.no_grad(): 377 | lm_loss, logit = model(inputs_ids, attn_mask, position_idx, label) 378 | eval_loss += lm_loss.mean().item() 379 | logits.append(logit.cpu().numpy()) 380 | y_trues.append(label.cpu().numpy()) 381 | nb_eval_steps += 1 382 | logits = np.concatenate(logits, 0) 383 | y_trues = np.concatenate(y_trues, 0) 384 | 385 | y_preds = [] 386 | for logit in logits: 387 | y_preds.append(np.argmax(logit)) 388 | 389 | from sklearn.metrics import recall_score 390 | recall = recall_score(y_trues, y_preds, average='macro') 391 | from sklearn.metrics import precision_score 392 | precision = precision_score(y_trues, y_preds, average='macro') 393 | from sklearn.metrics import f1_score 394 | f1 = f1_score(y_trues, y_preds, average='macro') 395 | 396 | eval_acc = np.mean(y_trues == np.array(y_preds)) 397 | result = { 398 | "eval_precision": float(precision), 399 | "eval_acc": eval_acc, 400 | } 401 | 402 | # logger.info("***** Eval results {} *****".format(prefix)) 403 | for key in sorted(result.keys()): 404 | logger.info(" %s = %s", key, str(round(result[key], 4))) 405 | 406 | return result 407 | 408 | 409 | def main(): 410 | parser = argparse.ArgumentParser() 411 | 412 | ## Required parameters 413 | parser.add_argument("--data_dir", default="datasets/GraphCodeBERT/code_classification", type=str, 414 | help="The output directory where the model predictions and checkpoints will be written.") 415 | parser.add_argument("--output_dir", default="models/GraphCodeBERT/code_classification/model", type=str, 416 | help="The output directory where the model predictions and checkpoints will be written.") 417 | 418 | ## Other parameters 419 | parser.add_argument("--model_name_or_path", default="microsoft/graphcodebert-base", type=str, 420 | help="The model checkpoint for weights initialization.") 421 | parser.add_argument("--number_labels", type=int, default=250, 422 | help="The model checkpoint for weights initialization.") 423 | 424 | parser.add_argument("--config_name", default="microsoft/graphcodebert-base", type=str, 425 | help="Optional pretrained config name or path if not the same as model_name_or_path") 426 | parser.add_argument("--tokenizer_name", default="microsoft/graphcodebert-base", type=str, 427 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") 428 | 429 | parser.add_argument("--code_length", default=384, type=int, 430 | help="Optional Code input sequence length after tokenization.") 431 | parser.add_argument("--data_flow_length", default=128, type=int, 432 | help="Optional Data Flow input sequence length after tokenization.") 433 | parser.add_argument("--do_train", action='store_true', 434 | help="Whether to run training.") 435 | parser.add_argument("--do_eval", action='store_true', 436 | help="Whether to run eval on the dev set.") 437 | parser.add_argument("--do_test", action='store_true', 438 | help="Whether to run eval on the dev set.") 439 | parser.add_argument("--language_type", type=str, 440 | help="The programming language type of dataset") 441 | parser.add_argument("--evaluate_during_training", action='store_true', 442 | help="Run evaluation during training at each logging step.") 443 | 444 | parser.add_argument("--train_batch_size", default=32, type=int, 445 | help="Batch size per GPU/CPU for training.") 446 | parser.add_argument("--eval_batch_size", default=64, type=int, 447 | help="Batch size per GPU/CPU for evaluation.") 448 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 449 | help="Number of updates steps to accumulate before performing a backward/update pass.") 450 | parser.add_argument("--learning_rate", default=2e-5, type=float, 451 | help="The initial learning rate for Adam.") 452 | parser.add_argument("--weight_decay", default=0.0, type=float, 453 | help="Weight deay if we apply some.") 454 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 455 | help="Epsilon for Adam optimizer.") 456 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 457 | help="Max gradient norm.") 458 | parser.add_argument("--epochs", default=20, type=int, 459 | help="Total number of training epochs to perform.") 460 | parser.add_argument("--max_steps", default=-1, type=int, 461 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 462 | parser.add_argument("--warmup_steps", default=0, type=int, 463 | help="Linear warmup over warmup_steps.") 464 | parser.add_argument('--num_labels', type=int, default=None, 465 | help = 'num_labels') 466 | 467 | parser.add_argument('--seed', type=int, default=123456, 468 | help="random seed for initialization") 469 | 470 | args = parser.parse_args() 471 | 472 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 473 | args.n_gpu = torch.cuda.device_count() 474 | 475 | args.device = device 476 | args.train_data_file = f"{root_path}/{args.data_dir}/train.jsonl" 477 | args.eval_data_file = f"{root_path}/{args.data_dir}/test.jsonl" 478 | args.test_data_file = f"{root_path}/{args.data_dir}/test.jsonl" 479 | 480 | # Setup logging 481 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) 482 | logger.warning("device: %s, n_gpu: %s", device, args.n_gpu) 483 | 484 | # Set seed 485 | set_seed(args) 486 | config = RobertaConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) 487 | config.num_labels = args.number_labels 488 | tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) 489 | model = RobertaForSequenceClassification.from_pretrained(args.model_name_or_path, config=config) 490 | model = Model(model, config, tokenizer, args) 491 | 492 | logger.info("Training/evaluation parameters %s", args) 493 | 494 | # Training 495 | if args.do_train: 496 | train_dataset = TextDataset(tokenizer, args, args.train_data_file) 497 | train(args, train_dataset, model, tokenizer) 498 | 499 | # Evaluation 500 | if args.do_eval: 501 | checkpoint_prefix = f"{root_path}/{args.output_dir}/model.bin" 502 | model.load_state_dict(torch.load(checkpoint_prefix)) 503 | model.to(args.device) 504 | result = evaluate(args, model, tokenizer) 505 | logger.info("***** Eval results *****") 506 | for key in sorted(result.keys()): 507 | logger.info(" %s = %s", key, str(round(result[key], 4))) 508 | 509 | 510 | 511 | if __name__ == "__main__": 512 | main() 513 | 514 | -------------------------------------------------------------------------------- /Mixup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | ##### Mixup-Tensorflow ################ 6 | def mixup_data( x, y, alpha, runs): 7 | if runs is None: 8 | runs = 1 9 | output_x = [] 10 | output_y = [] 11 | batch_size = x.shape[0] 12 | for i in range(runs): 13 | lam_vector = np.random.beta(alpha, alpha, batch_size) 14 | index = np.random.permutation(batch_size) 15 | mixed_x = (x.T * lam_vector).T + (x[index, :].T * (1.0 - lam_vector)).T 16 | output_x.append(mixed_x) 17 | if y is None: 18 | return np.concatenate(output_x, axis=0) 19 | mixed_y = (y.T * lam_vector).T + (y[index].T * (1.0 - lam_vector)).T 20 | output_y.append(mixed_y) 21 | return np.concatenate(output_x, axis=0), np.concatenate(output_y, axis=0) 22 | 23 | 24 | 25 | def mixup_data_refactor( x, y, x_refactor, y_refactor, alpha, runs): 26 | if runs is None: 27 | runs = 1 28 | output_x = [] 29 | output_y = [] 30 | batch_size = x.shape[0] 31 | for i in range(runs): 32 | lam_vector = np.random.beta(alpha, alpha, batch_size) 33 | index = np.random.permutation(batch_size) 34 | 35 | mixed_x = (x.T * lam_vector).T + (x_refactor[index, :].T * (1.0 - lam_vector)).T 36 | output_x.append(mixed_x) 37 | if y is None: 38 | return np.concatenate(output_x, axis=0) 39 | mixed_y = (y.T * lam_vector).T + (y_refactor[index].T * (1.0 - lam_vector)).T 40 | output_y.append(mixed_y) 41 | return np.concatenate(output_x, axis=0), np.concatenate(output_y, axis=0) 42 | 43 | 44 | 45 | 46 | ##### Mixup-Pytorch ################ 47 | def mixup_data(x, y, alpha=0.1, runs, use_cuda=True): 48 | for i in range(runs): 49 | output_x = torch.Tensor(0) 50 | output_x= output_x.numpy().tolist() 51 | output_y = torch.Tensor(0) 52 | output_y = output_y.numpy().tolist() 53 | batch_size = x.size()[0] 54 | if alpha > 0.: 55 | lam = np.random.beta(alpha, alpha) 56 | else: 57 | lam = 1. 58 | 59 | if use_cuda: 60 | index = torch.randperm(batch_size).cuda() 61 | else: 62 | index = torch.randperm(batch_size) 63 | mixed_x = lam * x + (1 - lam) * x[index, :] 64 | mixed_y = lam * y + (1 - lam) * y[index, :] 65 | output_x.append(mixed_x) 66 | output_y.append(mixed_y) 67 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 68 | 69 | 70 | def mixup_data_refactor( x, y, x_refactor, y_refactor, alpha, runs, use_cuda=True): 71 | for i in range(runs): 72 | output_x = torch.Tensor(0) 73 | output_x= output_x.numpy().tolist() 74 | output_y = torch.Tensor(0) 75 | output_y = output_y.numpy().tolist() 76 | batch_size = x.size()[0] 77 | if alpha > 0.: 78 | lam = np.random.beta(alpha, alpha) 79 | else: 80 | lam = 1. 81 | if use_cuda: 82 | index = torch.randperm(batch_size).cuda() 83 | else: 84 | index = torch.randperm(batch_size) 85 | mixed_x = lam * x + (1 - lam) * x_refactor[index, :] 86 | mixed_y = lam * y + (1 - lam) * y_refactor[index, :] 87 | output_x.append(mixed_x) 88 | output_y.append(mixed_y) 89 | return torch.cat(output_x,dim=0), torch.cat(output_y,dim=0) 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MixCode: Enhancing Code Classification by Mixup-Based Data Augmentation 2 | Implementation of SANER2023 paper MixCode: Enhancing Code Classification by Mixup-Based Data Augmentation [[arxiv]](https://arxiv.org/abs/2210.03003). 3 | 4 | We build this project on the top of [Project_CodeNet](https://github.com/IBM/Project_CodeNet). Please refer to this project for more details. 5 | 6 | ## Introduction 7 | MIXCODE aims to effectively supplement valid training data without manually collecting or labeling new code, inspired by the recent advance named Mixup in computer vision. Specifically, 1) first utilize multiple code refactoring methods to generate transformed code that holds consistent labels with the original data; 2) adapt the Mixup technique to linearly mix the original code with the transformed code to augment the training data. 8 | 9 |