├── .gitignore ├── DocChecker ├── bleu.py ├── dataloader.py ├── diff_utils.py ├── model.py ├── requirements.txt ├── run.py └── utils.py ├── README.md ├── assets ├── logo.jpg └── overview.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | ./saved_model 2 | __pycache__ 3 | ./wandb 4 | ./pretrained_model 5 | ./env 6 | ./preprocess.py 7 | ./tmp.py 8 | ./tmp.json 9 | ./test.ipynb -------------------------------------------------------------------------------- /DocChecker/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | ''' 4 | This script was adapted from the original version by hieuhoang1972 which is part of MOSES. 5 | ''' 6 | 7 | # $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $ 8 | 9 | '''Provides: 10 | 11 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 12 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 13 | score_cooked(alltest, n=4): Score a list of cooked test sentences. 14 | 15 | score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids. 16 | 17 | The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible. 18 | ''' 19 | 20 | import sys, math, re, xml.sax.saxutils 21 | import subprocess 22 | import os 23 | 24 | # Added to bypass NIST-style pre-processing of hyp and ref files -- wade 25 | nonorm = 0 26 | 27 | preserve_case = False 28 | eff_ref_len = "shortest" 29 | 30 | normalize1 = [ 31 | ('', ''), # strip "skipped" tags 32 | (r'-\n', ''), # strip end-of-line hyphenation and join lines 33 | (r'\n', ' '), # join lines 34 | # (r'(\d)\s+(?=\d)', r'\1'), # join digits 35 | ] 36 | normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1] 37 | 38 | normalize2 = [ 39 | (r'([\{-\~\[-\` -\&\(-\+\:-\@\/])',r' \1 '), # tokenize punctuation. apostrophe is missing 40 | (r'([^0-9])([\.,])',r'\1 \2 '), # tokenize period and comma unless preceded by a digit 41 | (r'([\.,])([^0-9])',r' \1 \2'), # tokenize period and comma unless followed by a digit 42 | (r'([0-9])(-)',r'\1 \2 ') # tokenize dash when preceded by a digit 43 | ] 44 | normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2] 45 | 46 | # Normalize and tokenize text. 47 | def normalize(s): 48 | '''Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl.''' 49 | # Added to bypass NIST-style pre-processing of hyp and ref files -- wade 50 | if (nonorm): 51 | return s.split() 52 | if type(s) is not str: 53 | s = " ".join(s) 54 | # language-independent part: 55 | for (pattern, replace) in normalize1: 56 | s = re.sub(pattern, replace, s) 57 | s = xml.sax.saxutils.unescape(s, {'"':'"'}) 58 | # language-dependent part (assuming Western languages): 59 | s = " %s " % s 60 | if not preserve_case: 61 | s = s.lower() # this might not be identical to the original 62 | for (pattern, replace) in normalize2: 63 | s = re.sub(pattern, replace, s) 64 | return s.split() 65 | 66 | def count_ngrams(words, n=4): 67 | counts = {} 68 | for k in range(1,n+1): 69 | for i in range(len(words)-k+1): 70 | ngram = tuple(words[i:i+k]) 71 | counts[ngram] = counts.get(ngram, 0)+1 72 | return counts 73 | 74 | """ 75 | Takes a list of reference sentences for a single segment and returns an object that encapsulates everything that BLEU needs to know about them. 76 | 77 | @param refs - A list of reference sentences for a single segment. 78 | @param n - Number of samples to return. 79 | """ 80 | def cook_refs(refs, n=4): 81 | '''Takes a list of reference sentences for a single segment 82 | and returns an object that encapsulates everything that BLEU 83 | needs to know about them.''' 84 | 85 | refs = [normalize(ref) for ref in refs] 86 | maxcounts = {} 87 | for ref in refs: 88 | counts = count_ngrams(ref, n) 89 | for (ngram,count) in counts.items(): 90 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 91 | return ([len(ref) for ref in refs], maxcounts) 92 | 93 | def cook_test(test, item, n=4): 94 | '''Takes a test sentence and returns an object that 95 | encapsulates everything that BLEU needs to know about it.''' 96 | (reflens, refmaxcounts)=item 97 | test = normalize(test) 98 | result = {} 99 | result["testlen"] = len(test) 100 | 101 | # Calculate effective reference sentence length. 102 | 103 | if eff_ref_len == "shortest": 104 | result["reflen"] = min(reflens) 105 | elif eff_ref_len == "average": 106 | result["reflen"] = float(sum(reflens))/len(reflens) 107 | elif eff_ref_len == "closest": 108 | min_diff = None 109 | for reflen in reflens: 110 | if min_diff is None or abs(reflen-len(test)) < min_diff: 111 | min_diff = abs(reflen-len(test)) 112 | result['reflen'] = reflen 113 | 114 | result["guess"] = [max(len(test)-k+1,0) for k in range(1,n+1)] 115 | 116 | result['correct'] = [0]*n 117 | counts = count_ngrams(test, n) 118 | for (ngram, count) in counts.items(): 119 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 120 | 121 | return result 122 | 123 | def score_cooked(allcomps, n=4, ground=0, smooth=1): 124 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 125 | for comps in allcomps: 126 | for key in ['testlen','reflen']: 127 | totalcomps[key] += comps[key] 128 | for key in ['guess','correct']: 129 | for k in range(n): 130 | totalcomps[key][k] += comps[key][k] 131 | logbleu = 0.0 132 | all_bleus = [] 133 | for k in range(n): 134 | correct = totalcomps['correct'][k] 135 | guess = totalcomps['guess'][k] 136 | addsmooth = 0 137 | if smooth == 1 and k > 0: 138 | addsmooth = 1 139 | logbleu += math.log(correct + addsmooth + sys.float_info.min)-math.log(guess + addsmooth+ sys.float_info.min) 140 | if guess == 0: 141 | all_bleus.append(-10000000) 142 | else: 143 | all_bleus.append(math.log(correct + sys.float_info.min)-math.log( guess )) 144 | 145 | logbleu /= float(n) 146 | all_bleus.insert(0, logbleu) 147 | 148 | brevPenalty = min(0,1-float(totalcomps['reflen'] + 1)/(totalcomps['testlen'] + 1)) 149 | for i in range(len(all_bleus)): 150 | if i ==0: 151 | all_bleus[i] += brevPenalty 152 | all_bleus[i] = math.exp(all_bleus[i]) 153 | return all_bleus 154 | 155 | """ 156 | Run a test against a set of refs. 157 | 158 | @param refs - list of references to test against. 159 | @param candidate - candidate reference to test 160 | @param ground - the ground score. 161 | @param smooth - True if the test is smooth. 162 | """ 163 | def bleu(refs, candidate, ground=0, smooth=1): 164 | refs = cook_refs(refs) 165 | test = cook_test(candidate, refs) 166 | return score_cooked([test], ground=ground, smooth=smooth) 167 | 168 | """ 169 | Split a line into a single punctuation. 170 | 171 | @param line - The line to parse. 172 | """ 173 | def splitPuncts(line): 174 | return ' '.join(re.findall(r"[\w]+|[^\s\w]", line)) 175 | 176 | def computeMaps(predictions, goldfile): 177 | predictionMap = {} 178 | goldMap = {} 179 | gf = open(goldfile, 'r') 180 | 181 | for row in predictions: 182 | cols = row.strip().split('\t') 183 | if len(cols) == 1: 184 | (rid, pred) = (cols[0], '') 185 | else: 186 | (rid, pred) = (cols[0], cols[1]) 187 | predictionMap[rid] = [splitPuncts(pred.strip().lower())] 188 | 189 | for row in gf: 190 | (rid, pred) = row.split('\t') 191 | if rid in predictionMap: # Only insert if the id exists for the method 192 | if rid not in goldMap: 193 | goldMap[rid] = [] 194 | goldMap[rid].append(splitPuncts(pred.strip().lower())) 195 | 196 | sys.stderr.write('Total: ' + str(len(goldMap)) + '\n') 197 | return (goldMap, predictionMap) 198 | 199 | def computeMaps_label(predictions, goldfile): 200 | predictionMap = {} 201 | goldMap = {} 202 | gf = open(goldfile, 'r') 203 | for row in predictions: 204 | cols = row.strip().split('\t') 205 | if len(cols) == 2: 206 | (rid, pred) = (cols[0], '') 207 | else: 208 | (rid, pred) = (cols[0], cols[2]) 209 | label = cols[1] 210 | # if int(label) == 1: 211 | predictionMap[rid] = [splitPuncts(pred.strip().lower())] 212 | 213 | for row in gf: 214 | (rid, label, pred) = row.split('\t') 215 | if rid in predictionMap: # Only insert if the id exists for the method 216 | if rid not in goldMap: 217 | goldMap[rid] = [] 218 | goldMap[rid].append(splitPuncts(pred.strip().lower())) 219 | # print(rid) 220 | 221 | sys.stderr.write('Total: ' + str(len(goldMap)) + '\n') 222 | return (goldMap, predictionMap) 223 | 224 | def computeMaps_ensemble(predictions, goldfile): 225 | predictionMap = {} 226 | goldMap = {} 227 | gf = open(goldfile, 'r') 228 | x_match = 0 229 | for row in predictions: 230 | cols = row.strip().split('\t') 231 | if len(cols) == 2: 232 | (rid, pred) = (cols[0], '') 233 | else: 234 | (rid, pred) = (cols[0], cols[2]) 235 | label = cols[1] 236 | # if int(label) == 1: 237 | predictionMap[rid] = [splitPuncts(pred.strip().lower())] 238 | 239 | for row in gf: 240 | (rid, label, pred) = row.split('\t') 241 | if rid in predictionMap: # Only insert if the id exists for the method 242 | if rid not in goldMap: 243 | goldMap[rid] = [] 244 | goldMap[rid].append(splitPuncts(pred.strip().lower())) 245 | if goldMap[rid][0] == predictionMap[rid][0]: 246 | x_match += 1 247 | # print(rid) 248 | 249 | sys.stderr.write('Total: ' + str(len(goldMap)) + '\n') 250 | x_match = x_match/len(goldMap) 251 | return x_match, (goldMap, predictionMap) 252 | 253 | def computeMaps_coditT5(predfile, goldfile): 254 | predictionMap = {} 255 | goldMap = {} 256 | pf = open(predfile, 'r') 257 | gf = open(goldfile, 'r') 258 | for rid,row in enumerate(gf): 259 | pred = row 260 | goldMap[rid] = [splitPuncts(pred.strip().lower())] 261 | 262 | for rid, row in enumerate(pf): 263 | pred = row.strip().split() 264 | start_id = len(pred)-1-pred[::-1].index('') 265 | pred = pred[start_id+1:] 266 | pred = ' '.join(pred) 267 | # print(rid) 268 | # print(splitPuncts(pred.lower())) 269 | predictionMap[rid]=[splitPuncts(pred.lower())] 270 | 271 | # sys.stderr.write('Total: ' + str(len(goldMap)) + '\n') 272 | return (goldMap, predictionMap) 273 | 274 | #m1 is the reference map 275 | #m2 is the prediction map 276 | """ 277 | BLEU score between two maps. 278 | 279 | @param m1 - Map of key value pairs. 280 | @param m2 - Map of key value pairs. 281 | """ 282 | def bleuFromMaps(m1, m2): 283 | score = [0] * 5 284 | num = 0.0 285 | # print(m1) 286 | 287 | for key in m1: 288 | if key in m2: 289 | bl = bleu(m1[key], m2[key][0]) 290 | score = [ score[i] + bl[i] for i in range(0, len(bl))] 291 | num += 1 292 | if num == 0.0: 293 | return 0 294 | return [s * 100.0 / num for s in score] 295 | 296 | if __name__ == '__main__': 297 | reference_file = sys.argv[1] 298 | predictions = [] 299 | output_dir = '/cm/shared/anhdtv7/BLIP/saved_model/pretrained_CSN' 300 | # with open(output_dir+"/dev.output",'w') as f, open(args.output_dir+"/dev.gold",'w') as f1: 301 | # for ref,gold in zip(p,): 302 | # predictions.append(str(gold.idx)+'\t'+ref) 303 | # f.write(str(gold.idx)+'\t'+ref+'\n') 304 | # f1.write(str(gold.idx)+'\t'+gold.target+'\n') 305 | 306 | (goldMap, predictionMap) = computeMaps(os.path.join(output_dir, "dev.output"), os.path.join(output_dir, "dev.gold")) 307 | print(bleuFromMaps(goldMap, predictionMap)[0]) 308 | 309 | -------------------------------------------------------------------------------- /DocChecker/dataloader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ast 3 | import random 4 | import torch 5 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset 6 | from os import listdir 7 | from os.path import isfile, join 8 | from utils import get_tqdm 9 | import os 10 | 11 | 12 | class Example(object): 13 | """A single training/test example.""" 14 | """ 15 | Constructor for the index source target and label 16 | 17 | @param self - the index 18 | @param idx - Index of the target. 19 | @param source - Source of the target. 20 | @param target - Target of the node. 21 | @param label - Label of the target 22 | """ 23 | 24 | def __init__(self, 25 | idx, 26 | source, 27 | target, 28 | label, 29 | ): 30 | self.idx = idx 31 | self.source = source 32 | self.target = target 33 | self.label = label 34 | 35 | # ---------------------------------- 36 | """ 37 | Read examples from filename. 38 | 39 | @param filename - file to read examples from. 40 | @param stage - the stage of the example to read. 41 | """ 42 | 43 | 44 | def read_examples(filename, args, stage='train'): 45 | if args.task == 'just_in_time': 46 | return read_examples_justInTime(filename, stage, args, post_hoc=args.post_hoc) 47 | elif args.task == 'pretrain': 48 | return read_examples_CSN(filename, stage, args) 49 | elif args.task == 'cup': 50 | return read_examples_cup(filename, stage, args) 51 | 52 | 53 | def read_examples_cup(root_folder, stage, args): 54 | """Read examples from filename.""" 55 | examples = [] 56 | filename = root_folder + stage + '.jsonl' 57 | with open(filename, encoding="utf-8") as f: 58 | for idx, line in enumerate(f): 59 | line = line.strip() 60 | js = json.loads(line) 61 | if 'idx' not in js: 62 | js['idx'] = idx 63 | code = js['diff_code_change'].replace('\n', ' ') 64 | code = ' '.join(code.strip().split()) 65 | nl = js['dst_desc'].replace('\n', '') 66 | nl = ' '.join(nl.strip().split()) 67 | # if 'label' in js: 68 | # label = js['label'] 69 | # else: 70 | label = 0 71 | examples.append( 72 | Example( 73 | idx=idx, 74 | source=code, 75 | target=nl, 76 | label=label 77 | ) 78 | ) 79 | return examples 80 | 81 | 82 | def read_examples_CSN(root_folder, stage, args): 83 | """Read examples from filename.""" 84 | examples = [] 85 | for lan in args.language: 86 | filename = root_folder + lan + '/' + stage + '.jsonl' 87 | with open(filename, encoding="utf-8") as f: 88 | for idx, line in enumerate(f): 89 | line = line.strip() 90 | js = json.loads(line) 91 | if 'idx' not in js: 92 | js['idx'] = idx 93 | code = ' '.join(js['code_tokens']).replace('\n', ' ') 94 | code = ' '.join(code.strip().split()) 95 | nl = ' '.join(js['docstring_tokens']).replace('\n', '') 96 | nl = ' '.join(nl.strip().split()) 97 | if 'label' in js: 98 | label = js['label'] 99 | else: 100 | label = 0 101 | examples.append( 102 | Example( 103 | idx=idx, 104 | source=code, 105 | target=nl, 106 | label=label 107 | ) 108 | ) 109 | return examples 110 | 111 | def read_examples_justInTime(root_folder, stage, args, post_hoc=None): 112 | """Read examples from filename.""" 113 | examples = [] 114 | for lan in args.language: 115 | filename = root_folder + lan + '/' + stage + '.json' 116 | with open(filename, encoding="utf-8") as f: 117 | data = ast.literal_eval(f.read()) 118 | 119 | for idx, js in enumerate(data): 120 | if post_hoc: 121 | code = ' '.join(js['new_code_subtokens']).replace('\n', ' ') 122 | code = ' '.join(code.strip().split()) 123 | nl = ' '.join(js['old_comment_subtokens']).replace('\n', '') 124 | nl = ' '.join(nl.strip().split()) 125 | else: 126 | code = ' '.join(js['span_diff_code_subtokens']).replace('\n', ' ') 127 | code = ' '.join(code.strip().split()) 128 | nl = ' '.join(js['old_comment_subtokens']).replace('\n', '') 129 | nl = ' '.join(nl.strip().split()) 130 | label = (js['label']+1)%2 131 | 132 | examples.append( 133 | Example( 134 | idx=idx, 135 | source=code, 136 | target=nl, 137 | label=label, 138 | ) 139 | ) 140 | return examples 141 | 142 | def read_examples_infer(root_folder, stage, args): 143 | """Read examples from filename.""" 144 | examples = [] 145 | raw_examples = [] 146 | for lan in args.language: 147 | filename = root_folder + lan + '/' + stage + '.jsonl' 148 | with open(filename, encoding="utf-8") as f: 149 | for idx, line in enumerate(f): 150 | # if idx == 100: 151 | # break 152 | line = line.strip() 153 | js = json.loads(line) 154 | raw_examples.append(js) 155 | if 'idx' not in js: 156 | js['idx'] = idx 157 | code = ' '.join(js['code_tokens']).replace('\n', ' ') 158 | code = ' '.join(code.strip().split()) 159 | nl = ' '.join(js['docstring_tokens']).replace('\n', '') 160 | nl = ' '.join(nl.strip().split()) 161 | if 'label' in js: 162 | label = js['label'] 163 | else: 164 | label = 0 165 | examples.append( 166 | Example( 167 | idx=idx, 168 | source=code, 169 | target=nl, 170 | label=label 171 | ) 172 | ) 173 | return examples, raw_examples 174 | 175 | 176 | class InputFeatures(object): 177 | """A single training/test features for a example.""" 178 | """ 179 | A class that represents a single example. 180 | 181 | @param self - the example 182 | @param example_id - example id 183 | @param source_ids - Source IDs. 184 | @param target_ids - Target ids. 185 | @param label - Label of the target. 186 | """ 187 | 188 | def __init__(self, 189 | example_id, 190 | source_ids, 191 | target_ids, 192 | label, 193 | ): 194 | self.example_id = example_id 195 | self.source_ids = source_ids 196 | self.target_ids = target_ids 197 | self.label = label 198 | 199 | """ 200 | convert examples to token ids 201 | 202 | @param examples - list of examples to convert to token ids 203 | @param tokenizer - the tokenizer to use. 204 | @param args - the arguments to pass to the feature decoder. 205 | @param stage - the stage of the feature to convert to token ids 206 | """ 207 | 208 | 209 | def convert_examples_to_features(item): 210 | """convert examples to token ids""" 211 | example_index, example, tokenizer, args = item 212 | 213 | if example_index % 5000 == 0: 214 | print(example_index) 215 | if 'unixcoder' in args.model_name_or_path: 216 | source_tokens = tokenizer.tokenize(example.source)[ 217 | :args.max_source_length-5] 218 | source_tokens = [tokenizer.cls_token, "", 219 | tokenizer.sep_token, ""]+source_tokens+[tokenizer.sep_token] 220 | source_ids = tokenizer.convert_tokens_to_ids(source_tokens) 221 | 222 | padding_length = args.max_source_length - len(source_ids) 223 | source_ids += [tokenizer.pad_token_id]*padding_length 224 | 225 | target_tokens = tokenizer.tokenize(example.target)[ 226 | :args.max_target_length-2] 227 | target_tokens = [""] + target_tokens + [tokenizer.sep_token] 228 | target_ids = tokenizer.convert_tokens_to_ids(target_tokens) 229 | padding_length = args.max_target_length - len(target_ids) 230 | target_ids += [tokenizer.pad_token_id] * padding_length 231 | 232 | 233 | else: 234 | source_str = example.source.replace('', '') 235 | source_ids = tokenizer.encode(source_str, max_length=args.max_source_length, padding='max_length', truncation=True) 236 | assert source_ids.count(tokenizer.eos_token_id) == 1 237 | 238 | target_str = example.target 239 | target_str = target_str.replace('', '') 240 | target_ids = tokenizer.encode(target_str, max_length=args.max_target_length, padding='max_length', 241 | truncation=True) 242 | assert target_ids.count(tokenizer.eos_token_id) == 1 243 | 244 | 245 | label = example.label 246 | 247 | 248 | return InputFeatures( 249 | example_index, 250 | source_ids, 251 | target_ids, 252 | label, 253 | ) 254 | 255 | 256 | def get_dataloader(args, filename, tokenizer, pool, stage='train', label=False, sequential=False, num_sample=None, infer=False,lan=None): 257 | 258 | if infer: 259 | examples, raw_examples = read_examples_infer(filename, stage=stage, args=args) 260 | else: 261 | examples = read_examples(filename, args, stage=stage) 262 | 263 | if num_sample != None: 264 | examples = random.sample(examples, min(num_sample, len(examples))) 265 | 266 | print('Reading raw samples has done !!!') 267 | print(len(examples)) 268 | tuple_examples = [(idx, example, tokenizer, args) for idx, example in enumerate(examples)] 269 | 270 | features = pool.map(convert_examples_to_features, get_tqdm(tuple_examples)) 271 | all_source_ids = torch.tensor( 272 | [f.source_ids for f in features], dtype=torch.long) 273 | all_target_ids = torch.tensor( 274 | [f.target_ids for f in features], dtype=torch.long) 275 | 276 | print('Converting raw samples to ids has done !!!') 277 | 278 | if label: 279 | all_label = torch.tensor([f.label for f in features], dtype=torch.long) 280 | data = TensorDataset(all_source_ids, all_target_ids, all_label) 281 | 282 | data = TensorDataset(all_source_ids, all_target_ids, all_label) 283 | else: 284 | data = TensorDataset(all_source_ids, all_target_ids) 285 | 286 | if sequential: 287 | sampler = SequentialSampler(data) 288 | else: 289 | sampler = RandomSampler(data) 290 | 291 | if 'test' in stage: 292 | drop_last = False 293 | else: 294 | drop_last = True 295 | dataloader = DataLoader(data, sampler=sampler, batch_size=args.train_batch_size // 296 | args.gradient_accumulation_steps, drop_last=drop_last) 297 | 298 | print('Creating dataloader has done !!!') 299 | 300 | if infer: 301 | return examples, dataloader, raw_examples 302 | else: 303 | return examples, dataloader 304 | -------------------------------------------------------------------------------- /DocChecker/diff_utils.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | 3 | REPLACE = '' 4 | REPLACE_OLD = '' 5 | REPLACE_NEW = '' 6 | REPLACE_END = '' 7 | REPLACE_OLD_KEEP_BEFORE = '' 8 | REPLACE_NEW_KEEP_BEFORE = '' 9 | REPLACE_OLD_KEEP_AFTER = '' 10 | REPLACE_NEW_KEEP_AFTER = '' 11 | REPLACE_OLD_DELETE_KEEP_BEFORE = '' 12 | REPLACE_NEW_DELETE_KEEP_BEFORE = '' 13 | REPLACE_OLD_DELETE_KEEP_AFTER = '' 14 | REPLACE_NEW_DELETE_KEEP_AFTER = '' 15 | 16 | INSERT = '' 17 | INSERT_OLD = '' 18 | INSERT_NEW = '' 19 | INSERT_END = '' 20 | INSERT_OLD_KEEP_BEFORE = '' 21 | INSERT_NEW_KEEP_BEFORE = '' 22 | INSERT_OLD_KEEP_AFTER = '' 23 | INSERT_NEW_KEEP_AFTER = '' 24 | 25 | DELETE = '' 26 | DELETE_END = '' 27 | 28 | KEEP = '' 29 | KEEP_END = '' 30 | 31 | class EditNode: 32 | def __init__(self, edit_type, children, prev, next): 33 | self.edit_type = edit_type 34 | self.children = children 35 | self.prev = prev 36 | self.next = next 37 | 38 | def get_edit_keywords(): 39 | return [REPLACE, REPLACE_OLD, REPLACE_NEW, REPLACE_END, REPLACE_OLD_KEEP_BEFORE, REPLACE_NEW_KEEP_BEFORE, REPLACE_OLD_KEEP_AFTER, 40 | REPLACE_NEW_KEEP_AFTER, REPLACE_OLD_DELETE_KEEP_BEFORE, REPLACE_NEW_DELETE_KEEP_BEFORE, REPLACE_OLD_DELETE_KEEP_AFTER, 41 | REPLACE_NEW_DELETE_KEEP_AFTER, INSERT, INSERT_OLD, INSERT_NEW, INSERT_END, INSERT_OLD_KEEP_BEFORE, INSERT_NEW_KEEP_BEFORE, 42 | INSERT_OLD_KEEP_AFTER, INSERT_NEW_KEEP_AFTER, DELETE, DELETE_END, KEEP, KEEP_END] 43 | 44 | def get_index(search_tokens, full_tokens): 45 | if len(search_tokens) == 0: 46 | return 0 47 | 48 | possible_positions = [k for k in range(len(full_tokens)) if full_tokens[k] == search_tokens[0]] 49 | if len(possible_positions) == 0: 50 | return -1 51 | 52 | if len(possible_positions) == 1: 53 | return possible_positions[0] 54 | 55 | for p in possible_positions: 56 | s_pos = 1 57 | f_pos = p + 1 58 | invalid = False 59 | 60 | while s_pos < len(search_tokens) and f_pos < len(full_tokens): 61 | if search_tokens[s_pos] != full_tokens[f_pos]: 62 | invalid = True 63 | break 64 | 65 | s_pos += 1 66 | f_pos += 1 67 | 68 | if not invalid: 69 | return p 70 | return -1 71 | 72 | def get_valid_positions(search_str, full_str): 73 | search_sequence = search_str.split() 74 | full_sequence = full_str.split() 75 | 76 | if len(search_sequence) == 0: 77 | return 0 78 | 79 | possible_positions = [p for p in range(len(full_sequence)) if full_sequence[p] == search_sequence[0]] 80 | valid_positions = [] 81 | 82 | for p in possible_positions: 83 | valid = True 84 | for i in range(len(search_sequence)): 85 | if p+i >= len(full_sequence) or full_sequence[p+i] != search_sequence[i]: 86 | valid = False 87 | break 88 | if valid: 89 | valid_positions.append(p) 90 | 91 | return valid_positions 92 | 93 | def get_frequency(search_str, full_str): 94 | return len(get_valid_positions(search_str, full_str)) 95 | 96 | def get_coarse_diff_structure(old_tokens, new_tokens): 97 | nodes = [] 98 | last_node = None 99 | for edit_type, o_start, o_end, n_start, n_end in difflib.SequenceMatcher(None, old_tokens, new_tokens).get_opcodes(): 100 | if edit_type == 'equal': 101 | edit_node = EditNode(KEEP, old_tokens[o_start:o_end], last_node, None) 102 | elif edit_type == 'replace': 103 | edit_node = EditNode(REPLACE, old_tokens[o_start:o_end] + [REPLACE_NEW] + new_tokens[n_start:n_end], last_node, None) 104 | elif edit_type == 'insert': 105 | edit_node = EditNode(INSERT, new_tokens[n_start:n_end], last_node, None) 106 | else: 107 | edit_node = EditNode(DELETE, old_tokens[o_start:o_end], last_node, None) 108 | 109 | if last_node: 110 | last_node.next = edit_node 111 | last_node = edit_node 112 | nodes.append(edit_node) 113 | return nodes 114 | 115 | def merge_diff_actions(diff_structure): 116 | mega_nodes = [] 117 | curr_mega_node = [] 118 | for node in diff_structure: 119 | if len(node.children) == 1: 120 | curr_mega_node.append(node) 121 | else: 122 | if len(curr_mega_node) == 1: 123 | curr_mega_node.append(node) 124 | mega_nodes.append(curr_mega_node) 125 | curr_mega_node = [] 126 | else: 127 | if len(curr_mega_node) > 0: 128 | mega_nodes.append(curr_mega_node) 129 | curr_mega_node = [] 130 | mega_nodes.append([node]) 131 | 132 | if len(curr_mega_node) == 1: 133 | mega_nodes[-1].extend(curr_mega_node) 134 | elif len(curr_mega_node) > 0: 135 | mega_nodes.append(curr_mega_node) 136 | 137 | new_nodes = [] 138 | for m_node in mega_nodes: 139 | if len(m_node) == 1: 140 | new_nodes.append(m_node[0]) 141 | continue 142 | 143 | old_tokens = [] 144 | new_tokens = [] 145 | for sub in m_node: 146 | if sub.edit_type == KEEP: 147 | old_tokens.extend(sub.children) 148 | new_tokens.extend(sub.children) 149 | elif sub.edit_type == INSERT: 150 | new_tokens.extend(sub.children) 151 | elif sub.edit_type == DELETE: 152 | old_tokens.extend(sub.children) 153 | else: 154 | rep_idx = sub.children.index(REPLACE_NEW) 155 | old_tokens.extend(sub.children[:rep_idx]) 156 | new_tokens.extend(sub.children[rep_idx+1:]) 157 | 158 | replace_node = EditNode(REPLACE, old_tokens + [REPLACE_NEW] + new_tokens, None, None) 159 | new_nodes.append(replace_node) 160 | 161 | n = 0 162 | final_new_nodes = [] 163 | while n < len(new_nodes): 164 | while n < len(new_nodes) and new_nodes[n].edit_type not in [INSERT, REPLACE, DELETE]: 165 | final_new_nodes.append(new_nodes[n]) 166 | n += 1 167 | 168 | to_merge = [] 169 | while n < len(new_nodes) and new_nodes[n].edit_type in [INSERT, REPLACE, DELETE]: 170 | to_merge.append(new_nodes[n]) 171 | n += 1 172 | 173 | if len(to_merge) > 0: 174 | old_tokens = [] 175 | new_tokens = [] 176 | for node in to_merge: 177 | if node.edit_type == INSERT: 178 | new_tokens.extend(node.children) 179 | elif node.edit_type == DELETE: 180 | old_tokens.extend(node.children) 181 | elif node.edit_type == REPLACE: 182 | rep_idx = node.children.index(REPLACE_NEW) 183 | old_tokens.extend(node.children[:rep_idx]) 184 | new_tokens.extend(node.children[rep_idx+1:]) 185 | 186 | replace_node = EditNode(REPLACE, old_tokens + [REPLACE_NEW] + new_tokens, None, None) 187 | final_new_nodes.append(replace_node) 188 | 189 | if n < len(new_nodes): 190 | final_new_nodes.append(new_nodes[n]) 191 | 192 | n += 1 193 | 194 | new_nodes = final_new_nodes 195 | for n, node in enumerate(new_nodes): 196 | if n > 0: 197 | new_nodes[n].next = node 198 | node.prev = new_nodes[n] 199 | if n+1 < len(new_nodes): 200 | new_nodes[n+1].prev = node 201 | node.next = new_nodes[n+1] 202 | 203 | return new_nodes 204 | 205 | def compute_code_diffs(old_tokens, new_tokens): 206 | spans = [] 207 | tokens = [] 208 | commands = [] 209 | 210 | for edit_type, o_start, o_end, n_start, n_end in difflib.SequenceMatcher(None, old_tokens, new_tokens).get_opcodes(): 211 | if edit_type == 'equal': 212 | spans.extend([KEEP] + old_tokens[o_start:o_end] + [KEEP_END]) 213 | for i in range(o_start, o_end): 214 | tokens.extend([KEEP, old_tokens[i]]) 215 | commands.append(KEEP) 216 | elif edit_type == 'replace': 217 | spans.extend([REPLACE_OLD] + old_tokens[o_start:o_end] + [REPLACE_NEW] + new_tokens[n_start:n_end] + [REPLACE_END]) 218 | for i in range(o_start, o_end): 219 | tokens.extend([REPLACE_OLD, old_tokens[i]]) 220 | commands.append(REPLACE_OLD) 221 | for j in range(n_start, n_end): 222 | tokens.extend([REPLACE_NEW, new_tokens[j]]) 223 | commands.extend([REPLACE_NEW, new_tokens[j]]) 224 | elif edit_type == 'insert': 225 | spans.extend([INSERT] + new_tokens[n_start:n_end] + [INSERT_END]) 226 | for j in range(n_start, n_end): 227 | tokens.extend([INSERT, new_tokens[j]]) 228 | commands.extend([INSERT, new_tokens[j]]) 229 | else: 230 | spans.extend([DELETE] + old_tokens[o_start:o_end] + [DELETE_END]) 231 | for i in range(o_start, o_end): 232 | tokens.extend([DELETE, old_tokens[i]]) 233 | commands.append(DELETE) 234 | 235 | return spans, tokens, commands 236 | 237 | def compute_minimal_code_diffs(old_tokens, new_tokens): 238 | spans = [] 239 | tokens = [] 240 | commands = [] 241 | 242 | for edit_type, o_start, o_end, n_start, n_end in difflib.SequenceMatcher(None, old_tokens, new_tokens).get_opcodes(): 243 | if edit_type == 'equal': 244 | continue 245 | elif edit_type == 'replace': 246 | spans.extend([REPLACE_OLD] + old_tokens[o_start:o_end] + [REPLACE_NEW] + new_tokens[n_start:n_end] + [REPLACE_END]) 247 | for i in range(o_start, o_end): 248 | tokens.extend([REPLACE_OLD, old_tokens[i]]) 249 | commands.append(REPLACE_OLD) 250 | for j in range(n_start, n_end): 251 | tokens.extend([REPLACE_NEW, new_tokens[j]]) 252 | commands.extend([REPLACE_NEW, new_tokens[j]]) 253 | elif edit_type == 'insert': 254 | spans.extend([INSERT] + new_tokens[n_start:n_end] + [INSERT_END]) 255 | for j in range(n_start, n_end): 256 | tokens.extend([INSERT, new_tokens[j]]) 257 | commands.extend([INSERT, new_tokens[j]]) 258 | else: 259 | spans.extend([DELETE] + old_tokens[o_start:o_end] + [DELETE_END]) 260 | for i in range(o_start, o_end): 261 | tokens.extend([DELETE, old_tokens[i]]) 262 | commands.append(DELETE) 263 | 264 | return spans, tokens, commands 265 | 266 | def compute_comment_diffs(old_tokens, new_tokens): 267 | spans = [] 268 | tokens = [] 269 | commands = [] 270 | 271 | diff_nodes = get_coarse_diff_structure(old_tokens, new_tokens) 272 | diff_nodes = merge_diff_actions(diff_nodes) 273 | 274 | for node in diff_nodes: 275 | if node.edit_type == KEEP: 276 | spans.extend([KEEP] + node.children + [KEEP_END]) 277 | for i in range(len(node.children)): 278 | tokens.extend([KEEP, node.children[i]]) 279 | commands.append(KEEP) 280 | elif node.edit_type == REPLACE: 281 | o_end = node.children.index(REPLACE_NEW) 282 | n_start = o_end + 1 283 | n_end = len(node.children) 284 | spans.extend([REPLACE_OLD] + node.children + [REPLACE_END]) 285 | for i in range(o_end): 286 | tokens.extend([REPLACE_OLD, node.children[i]]) 287 | commands.append(REPLACE_OLD) 288 | for j in range(n_start, n_end): 289 | tokens.extend([REPLACE_NEW, node.children[j]]) 290 | commands.extend([REPLACE_NEW, node.children[j]]) 291 | elif node.edit_type == INSERT: 292 | spans.extend([INSERT] + node.children + [INSERT_END]) 293 | for j in range(len(node.children)): 294 | tokens.extend([INSERT, node.children[j]]) 295 | commands.extend([INSERT, node.children[j]]) 296 | else: 297 | spans.extend([DELETE] + node.children + [DELETE_END]) 298 | for i in range(len(node.children)): 299 | tokens.extend([DELETE, node.children[i]]) 300 | commands.append(DELETE) 301 | 302 | return spans, tokens, commands 303 | 304 | def compute_minimal_comment_diffs(old_tokens, new_tokens): 305 | spans = [] 306 | tokens = [] 307 | commands = [] 308 | 309 | old_str = ' '.join(old_tokens) 310 | diff_nodes = get_coarse_diff_structure(old_tokens, new_tokens) 311 | 312 | new_nodes = [] 313 | 314 | for n, node in enumerate(diff_nodes): 315 | if node.edit_type == KEEP: 316 | new_nodes.append(node) 317 | 318 | elif node.edit_type == DELETE: 319 | search_str = ' '.join(node.children) 320 | if get_frequency(search_str, old_str) == 1: 321 | node.children.insert(0, DELETE) 322 | new_nodes.append(node) 323 | continue 324 | 325 | if node.prev and node.prev.edit_type == KEEP: 326 | adopted_children = [] 327 | found_substring = False 328 | while not found_substring and len(node.prev.children) > 0: 329 | adopted_children.insert(0, node.prev.children.pop()) 330 | search_str = ' '.join(adopted_children + node.children) 331 | found_substring = get_frequency(search_str, old_str) == 1 332 | 333 | if found_substring: 334 | new_children = [REPLACE_OLD_DELETE_KEEP_BEFORE] + adopted_children + node.children + [REPLACE_NEW_DELETE_KEEP_BEFORE] + adopted_children 335 | new_node = EditNode(REPLACE, new_children, node.prev, node.next) 336 | node.prev.next = new_node 337 | if node.next: 338 | node.next.prev = new_node 339 | new_nodes.append(new_node) 340 | continue 341 | else: 342 | node.prev.children.extend(adopted_children) 343 | 344 | if node.next and node.next.edit_type == KEEP: 345 | adopted_children = [] 346 | found_substring = False 347 | while not found_substring and len(node.next.children) > 0: 348 | adopted_children.append(node.next.children.pop(0)) 349 | search_str = ' '.join(node.children + adopted_children) 350 | found_substring = get_frequency(search_str, old_str) == 1 351 | 352 | if found_substring: 353 | new_children = [REPLACE_OLD_DELETE_KEEP_AFTER] + node.children + adopted_children + [REPLACE_NEW_DELETE_KEEP_AFTER] + adopted_children 354 | new_node = EditNode(REPLACE, new_children, node.prev, node.next) 355 | 356 | if node.prev: 357 | node.prev.next = new_node 358 | 359 | node.next.prev = new_node 360 | new_nodes.append(new_node) 361 | continue 362 | else: 363 | node.next.children = adopted_children + node.next.children 364 | 365 | return get_full_replace_span(old_tokens, new_tokens), tokens, commands 366 | 367 | elif node.edit_type == REPLACE: 368 | rep_idx = node.children.index(REPLACE_NEW) 369 | rep_old_children = node.children[:rep_idx] 370 | rep_new_children = node.children[rep_idx+1:] 371 | search_str = ' '.join(rep_old_children) 372 | 373 | if get_frequency(search_str, old_str) == 1: 374 | node.children.insert(0, REPLACE_OLD) 375 | new_nodes.append(node) 376 | continue 377 | 378 | if node.prev and node.prev.edit_type == KEEP: 379 | adopted_children = [] 380 | found_substring = False 381 | while not found_substring and len(node.prev.children) > 0: 382 | adopted_children.insert(0, node.prev.children.pop()) 383 | search_str = ' '.join(adopted_children + rep_old_children) 384 | found_substring = get_frequency(search_str, old_str) == 1 385 | 386 | if found_substring: 387 | new_children = [REPLACE_OLD_KEEP_BEFORE] + adopted_children + rep_old_children + [REPLACE_NEW_KEEP_BEFORE] + adopted_children + rep_new_children 388 | new_node = EditNode(REPLACE, new_children, node.prev, node.next) 389 | node.prev.next = new_node 390 | if node.next: 391 | node.next.prev = new_node 392 | new_nodes.append(new_node) 393 | continue 394 | else: 395 | node.prev.children.extend(adopted_children) 396 | 397 | if node.next and node.next.edit_type == KEEP: 398 | adopted_children = [] 399 | found_substring = False 400 | while not found_substring and len(node.next.children) > 0: 401 | adopted_children.append(node.next.children.pop(0)) 402 | search_str = ' '.join(rep_old_children + adopted_children) 403 | found_substring = get_frequency(search_str, old_str) == 1 404 | 405 | if found_substring: 406 | new_children = [REPLACE_OLD_KEEP_AFTER] + rep_old_children + adopted_children + [REPLACE_NEW_KEEP_AFTER] + rep_new_children + adopted_children 407 | new_node = EditNode(REPLACE, new_children, node.prev, node.next) 408 | 409 | if node.prev: 410 | node.prev.next = new_node 411 | 412 | node.next.prev = new_node 413 | new_nodes.append(new_node) 414 | continue 415 | else: 416 | node.next.children = adopted_children + node.next.children 417 | 418 | return get_full_replace_span(old_tokens, new_tokens), tokens, commands 419 | 420 | elif node.edit_type == INSERT: 421 | if node.prev and node.prev.edit_type == KEEP: 422 | adopted_children = [] 423 | found_substring = False 424 | while not found_substring and len(node.prev.children) > 0: 425 | adopted_children.insert(0, node.prev.children.pop()) 426 | search_str = ' '.join(adopted_children) 427 | found_substring = get_frequency(search_str, old_str) == 1 428 | 429 | if found_substring: 430 | new_children = [INSERT_OLD_KEEP_BEFORE] + adopted_children + [INSERT_NEW_KEEP_BEFORE] + adopted_children + node.children 431 | new_node = EditNode(INSERT, new_children, node.prev, node.next) 432 | node.prev.next = new_node 433 | if node.next: 434 | node.next.prev = new_node 435 | new_nodes.append(new_node) 436 | continue 437 | else: 438 | node.prev.children.extend(adopted_children) 439 | 440 | if node.next and node.next.edit_type == KEEP: 441 | adopted_children = [] 442 | found_substring = False 443 | while not found_substring and len(node.next.children) > 0: 444 | adopted_children.append(node.next.children.pop(0)) 445 | search_str = ' '.join(adopted_children) 446 | found_substring = get_frequency(search_str, old_str) == 1 447 | 448 | if found_substring: 449 | new_children = [INSERT_OLD_KEEP_AFTER] + adopted_children + [INSERT_NEW_KEEP_AFTER] + node.children + adopted_children 450 | new_node = EditNode(INSERT, new_children, node.prev, node.next) 451 | 452 | if node.prev: 453 | node.prev.next = new_node 454 | 455 | node.next.prev = new_node 456 | new_nodes.append(new_node) 457 | continue 458 | else: 459 | node.next.children = adopted_children + node.next.children 460 | 461 | return get_full_replace_span(old_tokens, new_tokens), tokens, commands 462 | 463 | for node in new_nodes: 464 | if 'INSERT' in node.edit_type: 465 | spans.extend(node.children + [INSERT_END]) 466 | elif 'REPLACE' in node.edit_type: 467 | spans.extend(node.children + [REPLACE_END]) 468 | elif 'DELETE' in node.edit_type: 469 | spans.extend(node.children + [DELETE_END]) 470 | return spans, tokens, commands 471 | 472 | def get_full_replace_span(old_tokens, new_tokens): 473 | return [REPLACE_OLD] + old_tokens + [REPLACE_NEW] + new_tokens + [REPLACE_END] 474 | 475 | def is_insert(token): 476 | return 'INSERT' in token 477 | 478 | def is_keep(token): 479 | return 'KEEP' in token 480 | 481 | def is_replace(token): 482 | return 'REPLACE' in token 483 | 484 | def is_delete(token): 485 | return 'DELETE' in token 486 | 487 | def is_insert_end(token): 488 | return is_insert(token) and is_end(token) 489 | 490 | def is_insert_old(token): 491 | return is_insert(token) and 'OLD' in token 492 | 493 | def is_insert_new(token): 494 | return is_insert(token) and 'NEW' in token 495 | 496 | def is_keep_end(token): 497 | return is_keep(token) and is_end(token) 498 | 499 | def is_replace_end(token): 500 | return is_replace(token) and is_end(token) 501 | 502 | def is_replace_old(token): 503 | return is_replace(token) and 'OLD' in token 504 | 505 | def is_replace_new(token): 506 | return is_replace(token) and 'NEW' in token 507 | 508 | def is_delete_end(token): 509 | return is_delete(token) and is_end(token) 510 | 511 | def is_edit_keyword(token): 512 | return is_insert(token) or is_keep(token) or is_replace(token) or is_delete(token) 513 | 514 | def is_start(token): 515 | return is_edit_keyword(token) and 'NEW' not in token and not is_end(token) 516 | 517 | def is_end(token): 518 | return is_edit_keyword(token) and 'END' in token 519 | 520 | def is_new(token): 521 | return is_edit_keyword(token) and 'NEW' in token 522 | 523 | def get_location(search_tokens, reference_tokens): 524 | ref_str = ' '.join(reference_tokens) 525 | for i in range(len(search_tokens)): 526 | for j in range(len(search_tokens), i, -1): 527 | search_str = ' '.join(search_tokens[i:j]) 528 | valid_positions = get_valid_positions(search_str, ref_str) 529 | if len(valid_positions) > 0: 530 | return valid_positions[0], i, len(valid_positions) > 1 531 | return -1, -1, False 532 | 533 | def format_minimal_diff_spans(reference_tokens, diff_span_tokens): 534 | ptr = 0 535 | new_comment_tokens = [] 536 | 537 | post_delete = [] 538 | post_replace = [] 539 | 540 | i = 0 541 | while i < len(diff_span_tokens): 542 | token = diff_span_tokens[i] 543 | 544 | if not is_start(token): 545 | i += 1 546 | continue 547 | 548 | if is_delete(token): 549 | j = i + 1 550 | delete_tokens = [] 551 | multiple_delete = False 552 | 553 | while j < len(diff_span_tokens) and not is_delete_end(diff_span_tokens[j]): 554 | delete_tokens.append(diff_span_tokens[j]) 555 | j += 1 556 | 557 | idx, d_start, multiple_delete = get_location(delete_tokens, reference_tokens[ptr:]) 558 | 559 | if multiple_delete: 560 | post_delete.append(delete_tokens) 561 | 562 | if idx >= 0: 563 | before_match = delete_tokens[:d_start] 564 | for r in range(ptr, ptr+idx): 565 | if reference_tokens[r] in before_match: 566 | before_match.pop(before_match.index(reference_tokens[r])) 567 | else: 568 | new_comment_tokens.append(reference_tokens[r]) 569 | 570 | ptr += idx 571 | remaining_delete_tokens = delete_tokens[d_start:] 572 | for d in remaining_delete_tokens: 573 | if ptr < len(reference_tokens) and d in reference_tokens[ptr:]: 574 | idx = reference_tokens[ptr:].index(d) 575 | new_comment_tokens.extend(reference_tokens[ptr:ptr+idx]) 576 | ptr += idx + 1 577 | 578 | elif is_insert_old(token): 579 | j = i + 1 580 | delete_tokens = [] 581 | insert_tokens = [] 582 | multiple_insert = False 583 | 584 | while j < len(diff_span_tokens) and not is_insert_new(diff_span_tokens[j]): 585 | delete_tokens.append(diff_span_tokens[j]) 586 | j += 1 587 | 588 | can_add = False 589 | idx, d_start, multiple_insert = get_location(delete_tokens, reference_tokens[ptr:]) 590 | 591 | if idx >= 0: 592 | can_add = True 593 | before_match = delete_tokens[:d_start] 594 | for r in range(ptr, ptr+idx): 595 | if reference_tokens[r] in before_match: 596 | before_match.pop(before_match.index(reference_tokens[r])) 597 | else: 598 | new_comment_tokens.append(reference_tokens[r]) 599 | 600 | ptr += idx 601 | remaining_delete_tokens = delete_tokens[d_start:] 602 | for d in remaining_delete_tokens: 603 | if ptr < len(reference_tokens) and d in reference_tokens[ptr:]: 604 | idx = reference_tokens[ptr:].index(d) 605 | new_comment_tokens.extend(reference_tokens[ptr:ptr+idx]) 606 | ptr += idx + 1 607 | 608 | j += 1 609 | while j < len(diff_span_tokens) and not is_insert_end(diff_span_tokens[j]): 610 | insert_tokens.append(diff_span_tokens[j]) 611 | if can_add: 612 | new_comment_tokens.append(diff_span_tokens[j]) 613 | j += 1 614 | 615 | if multiple_insert: 616 | post_replace.append((delete_tokens, insert_tokens)) 617 | 618 | elif is_replace_old(token): 619 | j = i + 1 620 | delete_tokens = [] 621 | insert_tokens = [] 622 | multiple_replace = False 623 | 624 | while j < len(diff_span_tokens) and not is_replace_new(diff_span_tokens[j]): 625 | delete_tokens.append(diff_span_tokens[j]) 626 | j += 1 627 | 628 | can_add = False 629 | idx, d_start, multiple_replace = get_location(delete_tokens, reference_tokens[ptr:]) 630 | if idx >= 0: 631 | can_add = True 632 | before_match = delete_tokens[:d_start] 633 | for r in range(ptr, ptr+idx): 634 | if reference_tokens[r] in before_match: 635 | before_match.pop(before_match.index(reference_tokens[r])) 636 | else: 637 | new_comment_tokens.append(reference_tokens[r]) 638 | 639 | ptr += idx 640 | remaining_delete_tokens = delete_tokens[d_start:] 641 | for d in remaining_delete_tokens: 642 | if ptr < len(reference_tokens) and d in reference_tokens[ptr:]: 643 | idx = reference_tokens[ptr:].index(d) 644 | new_comment_tokens.extend(reference_tokens[ptr:ptr+idx]) 645 | ptr += idx + 1 646 | 647 | j += 1 648 | while j < len(diff_span_tokens) and not is_replace_end(diff_span_tokens[j]): 649 | insert_tokens.append(diff_span_tokens[j]) 650 | if can_add: 651 | new_comment_tokens.append(diff_span_tokens[j]) 652 | j += 1 653 | 654 | if multiple_replace: 655 | post_replace.append((delete_tokens, insert_tokens)) 656 | else: 657 | raise ValueError('Invalid: {}'.format(token)) 658 | i = j+1 659 | 660 | if ptr < len(reference_tokens): 661 | new_comment_tokens.extend(reference_tokens[ptr:]) 662 | 663 | if len(post_delete) > 0: 664 | delete_positions = [] 665 | for d in post_delete: 666 | start_positions = get_valid_positions(' '.join(d), ' '.join(new_comment_tokens)) 667 | for s in start_positions: 668 | delete_positions.extend(range(s, s+len(d))) 669 | 670 | cleaned_new_comment_tokens = [] 671 | for i, tok in enumerate(new_comment_tokens): 672 | if i not in delete_positions: 673 | cleaned_new_comment_tokens.append(tok) 674 | 675 | new_comment_tokens = cleaned_new_comment_tokens 676 | 677 | for d, i in post_replace: 678 | valid_positions = get_valid_positions(' '.join(d), ' '.join(new_comment_tokens)) 679 | for v in valid_positions: 680 | if v + len(i) >= len(new_comment_tokens) or new_comment_tokens[v:v+len(i)] != i: 681 | new_comment_tokens[v:v+len(d)] = i 682 | 683 | return ' '.join(new_comment_tokens) 684 | 685 | def format_diff_commands(reference_tokens, commands): 686 | i = 0 687 | ref_ptr = 0 688 | output = [] 689 | 690 | while i < len(commands): 691 | command = commands[i] 692 | if command in [DELETE, REPLACE_OLD]: 693 | ref_ptr += 1 694 | elif command == KEEP: 695 | if ref_ptr < len(reference_tokens): 696 | output.append(reference_tokens[ref_ptr]) 697 | ref_ptr += 1 698 | elif command not in [INSERT, REPLACE_NEW]: 699 | output.append(command) 700 | i += 1 701 | return ' '.join(output) 702 | 703 | def format_diff_tokens(diff_tokens): 704 | i = 0 705 | output = [] 706 | last_command = KEEP 707 | 708 | while i < len(diff_tokens): 709 | token = diff_tokens[i] 710 | if token in [INSERT, DELETE, REPLACE_OLD, REPLACE_NEW, KEEP]: 711 | last_command = token 712 | elif last_command in [INSERT, REPLACE_NEW, KEEP]: 713 | output.append(token) 714 | i += 1 715 | return ' '.join(output) 716 | 717 | def format_diff_spans(reference_tokens, diff_span_tokens): 718 | def get_next_keep_token(start_idx, sequence): 719 | while start_idx < len(sequence) and sequence[start_idx] != KEEP: 720 | start_idx += 1 721 | 722 | start_idx += 1 723 | if start_idx < len(sequence): 724 | return sequence[start_idx] 725 | return None 726 | 727 | ptr = 0 728 | output = reference_tokens.copy() 729 | 730 | i = 0 731 | while i < len(diff_span_tokens): 732 | token = diff_span_tokens[i] 733 | i += 1 734 | 735 | if token not in [INSERT, DELETE, REPLACE_OLD, KEEP]: 736 | continue 737 | 738 | if token == INSERT: 739 | j = i 740 | 741 | next_keep_token = get_next_keep_token(j, diff_span_tokens) 742 | if next_keep_token: 743 | copy_ptr = ptr 744 | while copy_ptr < len(output) and output[copy_ptr] != next_keep_token: 745 | copy_ptr += 1 746 | if copy_ptr < len(output): 747 | ptr = copy_ptr 748 | elif ptr < len(output): 749 | ptr = len(output) 750 | 751 | while j < len(diff_span_tokens) and diff_span_tokens[j] != INSERT_END: 752 | output.insert(ptr, diff_span_tokens[j]) 753 | ptr += 1 754 | j += 1 755 | 756 | i = j+1 757 | 758 | elif token == DELETE: 759 | j = i 760 | while j < len(diff_span_tokens) and diff_span_tokens[j] != DELETE_END: 761 | copy_ptr = max(0, ptr-1) 762 | while copy_ptr < len(output) and diff_span_tokens[j] != output[copy_ptr]: 763 | copy_ptr += 1 764 | if copy_ptr < len(output): 765 | output.pop(copy_ptr) 766 | ptr = copy_ptr 767 | else: 768 | ptr += 1 769 | j += 1 770 | i = j+1 771 | 772 | elif token == KEEP: 773 | j = i 774 | while j < len(diff_span_tokens) and diff_span_tokens[j] != KEEP_END: 775 | if ptr < len(output) and diff_span_tokens[j] == output[ptr]: 776 | ptr += 1 777 | j += 1 778 | i = j+1 779 | else: 780 | j = i 781 | while j < len(diff_span_tokens) and diff_span_tokens[j] != REPLACE_NEW: 782 | copy_ptr = max(0, ptr-1) 783 | while copy_ptr < len(output) and diff_span_tokens[j] != output[copy_ptr]: 784 | copy_ptr += 1 785 | if copy_ptr < len(output): 786 | output.pop(copy_ptr) 787 | ptr = copy_ptr 788 | else: 789 | ptr += 1 790 | j += 1 791 | 792 | j += 1 793 | next_keep_token = get_next_keep_token(j, diff_span_tokens) 794 | if next_keep_token: 795 | copy_ptr = ptr 796 | while copy_ptr < len(output) and output[copy_ptr] != next_keep_token: 797 | copy_ptr += 1 798 | if copy_ptr < len(output): 799 | ptr = copy_ptr 800 | elif ptr < len(output): 801 | ptr = len(output) 802 | 803 | while j < len(diff_span_tokens) and diff_span_tokens[j] != REPLACE_END: 804 | output.insert(ptr, diff_span_tokens[j]) 805 | ptr += 1 806 | j += 1 807 | i = j+1 808 | return ' '.join(output) 809 | 810 | -------------------------------------------------------------------------------- /DocChecker/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | class Seq2Seq(nn.Module): 11 | """ 12 | Build Seqence-to-Sequence. 13 | 14 | Parameters: 15 | * `encoder`- encoder of seq2seq model. e.g. roberta 16 | * `decoder`- decoder of seq2seq model. e.g. transformer 17 | * `config`- configuration of encoder model. 18 | * `beam_size`- beam size for beam search. 19 | * `max_length`- max length of target for beam search. 20 | * `sos_id`- start of symbol ids in target for beam search. 21 | * `eos_id`- end of symbol ids in target for beam search. 22 | """ 23 | def __init__(self, encoder,decoder, config, beam_size=4, max_length=32, sos_id=None, eos_id=None, queue_size=57600, 24 | momentum = 0.995, embed_dim=256, device='cuda'): 25 | super(Seq2Seq, self).__init__() 26 | 27 | self.device=device 28 | 29 | self.encoder = encoder 30 | self.decoder = decoder 31 | 32 | self.encoder_m = encoder 33 | self.decoder_m = decoder 34 | 35 | self.encoder_proj = nn.Linear(config.hidden_size, embed_dim) 36 | self.decoder_proj = nn.Linear(config.hidden_size, embed_dim) 37 | self.encoder_proj_m = nn.Linear(config.hidden_size, embed_dim) 38 | self.decoder_proj_m = nn.Linear(config.hidden_size, embed_dim) 39 | 40 | self.config=config 41 | self.register_buffer( 42 | "bias", torch.tril(torch.ones((1024, 1024), dtype=torch.uint8)).view(1,1024, 1024) 43 | ) 44 | 45 | self.itm_head = nn.Linear(config.hidden_size, 2) 46 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 47 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 48 | self.lm_head.weight = self.encoder.embeddings.word_embeddings.weight 49 | self.lsm = nn.LogSoftmax(dim=-1) 50 | 51 | self.model_pairs = [[self.encoder,self.encoder_m], 52 | [self.encoder_proj, self.encoder_proj_m], 53 | [self.decoder,self.decoder_m], 54 | [self.decoder_proj, self.decoder_proj_m] 55 | ] 56 | self.copy_params() 57 | self.register_buffer("code_queue", torch.randn(embed_dim, queue_size)) 58 | self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) 59 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 60 | self.code_queue = nn.functional.normalize(self.code_queue, dim=0) 61 | self.text_queue = nn.functional.normalize(self.text_queue, dim=0) 62 | 63 | self.queue_size = queue_size 64 | self.momentum = momentum 65 | self.temp = nn.Parameter(0.07*torch.ones([])) 66 | 67 | self.beam_size = beam_size 68 | self.max_length = max_length 69 | self.sos_id = sos_id 70 | self.eos_id = eos_id 71 | 72 | def forward(self, source_ids, target_ids=None, labels=None, stage=None, alpha=0.01, just_in_time=False, source_text_ids=None): 73 | 74 | # print(self.queue_ptr) 75 | with torch.no_grad(): 76 | self.temp.clamp_(0.001,0.5) 77 | 78 | if stage =='test_original': 79 | return self.generate(source_ids) 80 | elif stage == 'dev' or stage=='get_positive': 81 | gen_sentence = self.generate(source_ids) 82 | 83 | 84 | mask_source = source_ids.ne(1)[:,None,:]*source_ids.ne(1)[:,:,None] 85 | encoder_output = self.encoder(source_ids,attention_mask=mask_source,use_cache=True) 86 | 87 | # get mask for output of encoder 88 | mask_encoder = source_ids.ne(1) 89 | mask_encoder = torch.unsqueeze(mask_encoder,-1) 90 | mask_encoder = mask_encoder.expand(-1, -1, self.config.hidden_size) 91 | encoder_output_contrastive = encoder_output.last_hidden_state*mask_encoder 92 | 93 | code_embeds = torch.mean(encoder_output_contrastive, dim=1) 94 | code_feat = F.normalize(self.encoder_proj(code_embeds), dim=-1) 95 | 96 | if source_text_ids != None: 97 | TARGET = target_ids 98 | target_ids = source_text_ids 99 | else: 100 | TARGET = target_ids 101 | 102 | ids = torch.cat((source_ids,target_ids),-1) 103 | mask = self.bias[:,source_ids.size(-1):ids.size(-1),:ids.size(-1)].bool() 104 | mask = mask & ids[:,None,:].ne(1) 105 | out = self.decoder(target_ids,attention_mask=mask,past_key_values=encoder_output.past_key_values).last_hidden_state 106 | 107 | # get mask for output of decoder 108 | mask_decoder = target_ids.ne(1) 109 | mask_decoder = torch.unsqueeze(mask_decoder,-1) 110 | mask_decoder = mask_decoder.expand(-1, -1, self.config.hidden_size) 111 | decoder_output_contrastive = out*mask_decoder 112 | 113 | # text_embeds = out[:, 0, :] 114 | text_embeds = torch.mean(decoder_output_contrastive, dim=1) 115 | text_feat = F.normalize(self.decoder_proj(text_embeds), dim=-1) 116 | 117 | sim_pos = code_feat @ text_feat.t() / self.temp 118 | if stage == 'get_positive': 119 | pred_output = self.itm_head(text_embeds) 120 | return gen_sentence, sim_pos, pred_output 121 | 122 | elif stage=='inference': 123 | pred_output = self.itm_head(text_embeds) 124 | _, pred = pred_output.max(1) 125 | pred = torch.tensor(pred, dtype=torch.int64) 126 | return pred, self.generate(source_ids) 127 | 128 | elif stage == 'test': 129 | pred_output = self.itm_head(text_embeds) 130 | _, pred = pred_output.max(1) 131 | pred = torch.tensor(pred, dtype=torch.int64) 132 | hits = (pred == labels).float() 133 | return pred, hits#, self.generate(source_ids) 134 | 135 | # ============= loss lm ==================== 136 | 137 | ids_lm = torch.cat((source_ids,TARGET),-1) 138 | mask_lm = self.bias[:,source_ids.size(-1):ids_lm.size(-1),:ids_lm.size(-1)].bool() 139 | mask_lm = mask_lm & ids_lm[:,None,:].ne(1) 140 | out_lm = self.decoder(TARGET,attention_mask=mask_lm,past_key_values=encoder_output.past_key_values).last_hidden_state 141 | lm_logits = self.lm_head(out_lm) 142 | # Shift so that tokens < n predict n 143 | active_loss = TARGET[..., 1:].ne(1).view(-1) 144 | shift_logits = lm_logits[..., :-1, :].contiguous() 145 | shift_labels = TARGET[..., 1:].contiguous() 146 | 147 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 148 | loss_lm = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss], 149 | shift_labels.view(-1)[active_loss]) 150 | 151 | 152 | # ============= loss contrastive =================== 153 | with torch.no_grad(): 154 | self._momentum_update() 155 | encoder_output_m = self.encoder_m(source_ids,attention_mask=mask_source,use_cache=True) 156 | code_embeds_m = torch.mean(encoder_output_m.last_hidden_state*mask_encoder, dim=1) 157 | code_feat_m = F.normalize(self.encoder_proj_m(code_embeds_m),dim=-1) 158 | code_feat_all = torch.cat([code_feat_m.t(),self.code_queue.clone().detach()],dim=1) 159 | 160 | output_m = self.decoder_m(target_ids,attention_mask=mask,past_key_values=encoder_output_m.past_key_values).last_hidden_state 161 | text_output_m = torch.mean(output_m*mask_decoder, dim=1) 162 | text_feat_m = F.normalize(self.decoder_proj_m(text_output_m),dim=-1) 163 | text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) 164 | 165 | sim_i2t_m = code_feat_m @ text_feat_all / self.temp 166 | sim_t2i_m = text_feat_m @ code_feat_all / self.temp 167 | 168 | sim_targets = torch.zeros(sim_i2t_m.size()).to(self.device) 169 | sim_targets.fill_diagonal_(1) 170 | 171 | sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets 172 | sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets 173 | 174 | sim_i2t = code_feat @ text_feat_all / self.temp 175 | sim_t2i = text_feat @ code_feat_all / self.temp 176 | 177 | loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() 178 | loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() 179 | 180 | loss_ita = (loss_i2t+loss_t2i)/2 181 | 182 | self._dequeue_and_enqueue(code_feat_m, text_feat_m) 183 | 184 | #============== code-text Matching ===================### 185 | 186 | # forward the positve code-text pair 187 | bs = source_ids.size(0) 188 | output_pos = text_embeds 189 | with torch.no_grad(): 190 | weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4 191 | weights_t2i.fill_diagonal_(0) 192 | weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4 193 | weights_i2t.fill_diagonal_(0) 194 | 195 | # select a negative text for each code 196 | 197 | 198 | text_ids_neg = [] 199 | for b in range(bs): 200 | neg_idx = torch.multinomial(weights_i2t[b], 1).item() 201 | text_ids_neg.append(target_ids[neg_idx]) 202 | 203 | 204 | text_ids_neg = torch.stack(text_ids_neg,dim=0) 205 | text_ids_neg = text_ids_neg.to(torch.long) 206 | 207 | ids = torch.cat((source_ids,text_ids_neg),-1) 208 | mask = self.bias[:,source_ids.size(-1):ids.size(-1),:ids.size(-1)].bool() 209 | mask = mask & ids[:,None,:].ne(1) 210 | 211 | output_neg = self.decoder(text_ids_neg,attention_mask=mask,past_key_values=encoder_output.past_key_values).last_hidden_state 212 | 213 | mask_decoder_neg = text_ids_neg.ne(1) 214 | mask_decoder_neg = torch.unsqueeze(mask_decoder_neg,-1) 215 | mask_decoder_neg = mask_decoder_neg.expand(-1, -1, self.config.hidden_size) 216 | decoder_output_neg = output_neg*mask_decoder_neg 217 | text_embeds_neg = torch.mean(decoder_output_neg, dim=1) 218 | 219 | if just_in_time: 220 | vl_output = self.itm_head(output_pos) 221 | itm_labels = labels 222 | else: 223 | vl_embeddings = torch.cat([output_pos, text_embeds_neg],dim=0) 224 | vl_output = self.itm_head(vl_embeddings) 225 | 226 | itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(bs,dtype=torch.long)], 227 | dim=0).to(self.device) 228 | 229 | 230 | itm_labels = itm_labels.to(self.device) 231 | loss_itm = F.cross_entropy(vl_output, itm_labels) 232 | 233 | _, pred = vl_output.max(1) 234 | hits = pred == itm_labels 235 | 236 | if stage=='dev': 237 | return loss_lm, loss_ita, loss_itm, gen_sentence 238 | else: 239 | return loss_lm, loss_ita, loss_itm, hits 240 | 241 | def generate(self, source_ids): 242 | mask = source_ids.ne(1)[:,None,:]*source_ids.ne(1)[:,:,None] 243 | encoder_output = self.encoder(source_ids,attention_mask=mask,use_cache=True) 244 | preds = [] 245 | zero = torch.cuda.LongTensor(1).fill_(0) 246 | 247 | source_len = list(source_ids.ne(1).sum(-1).cpu().numpy()) 248 | for i in range(source_ids.shape[0]): 249 | context = [[x[i:i+1,:,:source_len[i]].repeat(self.beam_size,1,1,1) for x in y] 250 | for y in encoder_output.past_key_values] 251 | beam = Beam(self.beam_size,self.sos_id,self.eos_id) 252 | input_ids = beam.getCurrentState() 253 | context_ids = source_ids[i:i+1,:source_len[i]].repeat(self.beam_size,1) 254 | for _ in range(self.max_length): 255 | if beam.done(): 256 | break 257 | # input_ids = input_ids.to(self.device) 258 | ids = torch.cat((context_ids,input_ids),-1) 259 | mask = self.bias[:,context_ids.size(-1):ids.size(-1),:ids.size(-1)].bool() 260 | mask = mask & ids[:,None,:].ne(1) 261 | out = self.decoder(input_ids,attention_mask=mask,past_key_values=context).last_hidden_state 262 | hidden_states = out[:,-1,:] 263 | out = self.lsm(self.lm_head(hidden_states)).data 264 | beam.advance(out) 265 | input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin())) 266 | input_ids = torch.cat((input_ids,beam.getCurrentState()),-1) 267 | hyp = beam.getHyp(beam.getFinal()) 268 | pred = beam.buildTargetTokens(hyp)[:self.beam_size] 269 | pred = [torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred] 270 | preds.append(torch.cat(pred,0).unsqueeze(0)) 271 | 272 | preds = torch.cat(preds,0) 273 | 274 | return preds 275 | 276 | @torch.no_grad() 277 | def copy_params(self): 278 | for model_pair in self.model_pairs: 279 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 280 | param_m.data.copy_(param.data) # initialize 281 | param_m.requires_grad = False # not update by gradient 282 | param.requires_grad = True 283 | 284 | @torch.no_grad() 285 | def _momentum_update(self): 286 | for model_pair in self.model_pairs: 287 | for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): 288 | param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) 289 | 290 | @torch.no_grad() 291 | def _dequeue_and_enqueue(self, code_feat, text_feat): 292 | # gather keys before updating queue 293 | code_feats = concat_all_gather(code_feat) 294 | text_feats = concat_all_gather(text_feat) 295 | 296 | batch_size = code_feats.shape[0] 297 | 298 | ptr = int(self.queue_ptr) 299 | 300 | assert self.queue_size % code_feats.shape[0] == 0 # for simplicity 301 | self.code_queue[:, ptr:ptr + batch_size] = code_feats.T 302 | self.text_queue[:, ptr:ptr + batch_size] = text_feats.T 303 | ptr = (ptr + batch_size) % self.queue_size # move pointer 304 | self.queue_ptr[0] = ptr 305 | 306 | @torch.no_grad() 307 | def concat_all_gather(tensor): 308 | """ 309 | Performs all_gather operation on the provided tensors. 310 | *** Warning ***: torch.distributed.all_gather has no gradient. 311 | """ 312 | tensors_gather = [torch.ones_like(tensor) 313 | for _ in range(torch.distributed.get_world_size())] 314 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 315 | 316 | output = torch.cat(tensors_gather, dim=0) 317 | return output 318 | 319 | # class Beam(object): 320 | # def __init__(self, size,sos,eos, device='cuda'): 321 | # self.size = size 322 | # print(device) 323 | # if device == 'cuda': 324 | # self.tt = torch.cuda 325 | # else: 326 | # self.tt = torch 327 | # self.device=device 328 | # # The score for each translation on the beam. 329 | # self.scores = self.tt.FloatTensor(size).zero_() 330 | # # The backpointers at each time-step. 331 | # self.prevKs = [] 332 | # # The outputs at each time-step. 333 | # self.nextYs = [self.tt.LongTensor(size) 334 | # .fill_(0)] 335 | # self.nextYs[0][0] = sos 336 | # # Has EOS topped the beam yet. 337 | # self._eos = eos 338 | # self.eosTop = False 339 | # # Time and k pair for finished. 340 | # self.finished = [] 341 | 342 | # def getCurrentState(self): 343 | # "Get the outputs for the current timestep." 344 | # batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) 345 | # batch = batch.to(self.device) 346 | # return batch 347 | 348 | # def getCurrentOrigin(self): 349 | # "Get the backpointers for the current timestep." 350 | # return self.prevKs[-1] 351 | 352 | # def advance(self, wordLk): 353 | # """ 354 | # Given prob over words for every last beam `wordLk` and attention 355 | # `attnOut`: Compute and update the beam search. 356 | # Parameters: 357 | # * `wordLk`- probs of advancing from the last step (K x words) 358 | # * `attnOut`- attention at the last step 359 | # Returns: True if beam search is complete. 360 | # """ 361 | # numWords = wordLk.size(1) 362 | 363 | # # Sum the previous scores. 364 | # if len(self.prevKs) > 0: 365 | # beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 366 | 367 | # # Don't let EOS have children. 368 | # for i in range(self.nextYs[-1].size(0)): 369 | # if self.nextYs[-1][i] == self._eos: 370 | # beamLk[i] = -1e20 371 | # else: 372 | # beamLk = wordLk[0] 373 | # flatBeamLk = beamLk.view(-1) 374 | # bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 375 | 376 | # self.scores = bestScores 377 | 378 | # # bestScoresId is flattened beam x word array, so calculate which 379 | # # word and beam each score came from 380 | # prevK = bestScoresId // numWords 381 | # self.prevKs.append(prevK) 382 | # self.nextYs.append((bestScoresId - prevK * numWords)) 383 | 384 | 385 | # for i in range(self.nextYs[-1].size(0)): 386 | # if self.nextYs[-1][i] == self._eos: 387 | # s = self.scores[i] 388 | # self.finished.append((s, len(self.nextYs) - 1, i)) 389 | 390 | # # End condition is when top-of-beam is EOS and no global score. 391 | # if self.nextYs[-1][0] == self._eos: 392 | # self.eosTop = True 393 | 394 | # def done(self): 395 | # return self.eosTop and len(self.finished) >= self.size 396 | 397 | # def getFinal(self): 398 | # if len(self.finished) == 0: 399 | # self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) 400 | # self.finished.sort(key=lambda a: -a[0]) 401 | # if len(self.finished) != self.size: 402 | # unfinished=[] 403 | # for i in range(self.nextYs[-1].size(0)): 404 | # if self.nextYs[-1][i] != self._eos: 405 | # s = self.scores[i] 406 | # unfinished.append((s, len(self.nextYs) - 1, i)) 407 | # unfinished.sort(key=lambda a: -a[0]) 408 | # self.finished+=unfinished[:self.size-len(self.finished)] 409 | # return self.finished[:self.size] 410 | 411 | # def getHyp(self, beam_res): 412 | # """ 413 | # Walk back to construct the full hypothesis. 414 | # """ 415 | # hyps=[] 416 | # for _,timestep, k in beam_res: 417 | # hyp = [] 418 | # for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): 419 | # hyp.append(self.nextYs[j+1][k]) 420 | # k = self.prevKs[j][k] 421 | # hyps.append(hyp[::-1]) 422 | # return hyps 423 | 424 | # def buildTargetTokens(self, preds): 425 | # sentence=[] 426 | # for pred in preds: 427 | # tokens = [] 428 | # for tok in pred: 429 | # if tok==self._eos: 430 | # break 431 | # tokens.append(tok) 432 | # sentence.append(tokens) 433 | # return sentence 434 | 435 | class Beam(object): 436 | def __init__(self, size,sos,eos): 437 | self.size = size 438 | self.tt = torch.cuda 439 | # The score for each translation on the beam. 440 | self.scores = self.tt.FloatTensor(size).zero_() 441 | # The backpointers at each time-step. 442 | self.prevKs = [] 443 | # The outputs at each time-step. 444 | self.nextYs = [self.tt.LongTensor(size) 445 | .fill_(0)] 446 | self.nextYs[0][0] = sos 447 | # Has EOS topped the beam yet. 448 | self._eos = eos 449 | self.eosTop = False 450 | # Time and k pair for finished. 451 | self.finished = [] 452 | 453 | def getCurrentState(self): 454 | "Get the outputs for the current timestep." 455 | batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) 456 | return batch 457 | 458 | def getCurrentOrigin(self): 459 | "Get the backpointers for the current timestep." 460 | return self.prevKs[-1] 461 | 462 | def advance(self, wordLk): 463 | """ 464 | Given prob over words for every last beam `wordLk` and attention 465 | `attnOut`: Compute and update the beam search. 466 | 467 | Parameters: 468 | 469 | * `wordLk`- probs of advancing from the last step (K x words) 470 | * `attnOut`- attention at the last step 471 | 472 | Returns: True if beam search is complete. 473 | """ 474 | numWords = wordLk.size(1) 475 | 476 | # Sum the previous scores. 477 | if len(self.prevKs) > 0: 478 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 479 | 480 | # Don't let EOS have children. 481 | for i in range(self.nextYs[-1].size(0)): 482 | if self.nextYs[-1][i] == self._eos: 483 | beamLk[i] = -1e20 484 | else: 485 | beamLk = wordLk[0] 486 | flatBeamLk = beamLk.view(-1) 487 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 488 | 489 | self.scores = bestScores 490 | 491 | # bestScoresId is flattened beam x word array, so calculate which 492 | # word and beam each score came from 493 | prevK = bestScoresId // numWords 494 | self.prevKs.append(prevK) 495 | self.nextYs.append((bestScoresId - prevK * numWords)) 496 | 497 | 498 | for i in range(self.nextYs[-1].size(0)): 499 | if self.nextYs[-1][i] == self._eos: 500 | s = self.scores[i] 501 | self.finished.append((s, len(self.nextYs) - 1, i)) 502 | 503 | # End condition is when top-of-beam is EOS and no global score. 504 | if self.nextYs[-1][0] == self._eos: 505 | self.eosTop = True 506 | 507 | def done(self): 508 | return self.eosTop and len(self.finished) >= self.size 509 | 510 | def getFinal(self): 511 | if len(self.finished) == 0: 512 | self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) 513 | self.finished.sort(key=lambda a: -a[0]) 514 | if len(self.finished) != self.size: 515 | unfinished=[] 516 | for i in range(self.nextYs[-1].size(0)): 517 | if self.nextYs[-1][i] != self._eos: 518 | s = self.scores[i] 519 | unfinished.append((s, len(self.nextYs) - 1, i)) 520 | unfinished.sort(key=lambda a: -a[0]) 521 | self.finished+=unfinished[:self.size-len(self.finished)] 522 | return self.finished[:self.size] 523 | 524 | def getHyp(self, beam_res): 525 | """ 526 | Walk back to construct the full hypothesis. 527 | """ 528 | hyps=[] 529 | for _,timestep, k in beam_res: 530 | hyp = [] 531 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): 532 | hyp.append(self.nextYs[j+1][k]) 533 | k = self.prevKs[j][k] 534 | hyps.append(hyp[::-1]) 535 | return hyps 536 | 537 | def buildTargetTokens(self, preds): 538 | sentence=[] 539 | for pred in preds: 540 | tokens = [] 541 | for tok in pred: 542 | if tok==self._eos: 543 | break 544 | tokens.append(tok) 545 | sentence.append(tokens) 546 | return sentence 547 | -------------------------------------------------------------------------------- /DocChecker/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.7.0 2 | numpy==1.24.2 3 | Requests==2.30.0 4 | scikit_learn==1.2.1 5 | torch==1.13.1+cu116 6 | tqdm==4.64.1 7 | transformers==4.26.1 8 | wandb==0.13.10 9 | -------------------------------------------------------------------------------- /DocChecker/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). 18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned 19 | using a masked language modeling (MLM) loss. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | import os 24 | import bleu as bleu 25 | import torch 26 | import logging 27 | import argparse 28 | import numpy as np 29 | from io import open 30 | from tqdm import tqdm 31 | import utils 32 | import wandb 33 | from dataloader import * 34 | import multiprocessing 35 | 36 | from transformers import (AdamW, get_linear_schedule_with_warmup) 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | 44 | ## Required parameters 45 | parser.add_argument("--model_name_or_path", default='microsoft/unixcoder-base', type=str, 46 | help="Path to pre-trained model: e.g. roberta-base" ) 47 | parser.add_argument("--output_dir", default='../saved_model/', type=str, 48 | help="The output directory where the model predictions and checkpoints will be written.") 49 | parser.add_argument("--load_model_dir", default='../saved_model/pretrained_CSN', type=str, 50 | help="The output directory where the pretrained model was saved") 51 | 52 | ## Other parameters 53 | parser.add_argument("--output_clean_dir", type=str, 54 | help="The output directory where the model predictions and checkpoints will be written.") 55 | 56 | parser.add_argument("--task", default='just_in_time', type=str, 57 | choices=['pretrain', 'just_in_time', 'cup'],) 58 | 59 | parser.add_argument("--data_folder", default=None, type=str, required=True, 60 | help="The folder that contains dataset") 61 | 62 | parser.add_argument("--run_name", default='', type=str, 63 | help="name for each running in wandb") 64 | parser.add_argument("--max_source_length", default=200, type=int, 65 | help="The maximum total source sequence length after tokenization. Sequences longer " 66 | "than this will be truncated, sequences shorter will be padded.") 67 | parser.add_argument("--max_target_length", default=32, type=int, 68 | help="The maximum total target sequence length after tokenization. Sequences longer " 69 | "than this will be truncated, sequences shorter will be padded.") 70 | parser.add_argument("--wandb", action='store_true', 71 | help="whether to visualize training phase by wandb") 72 | parser.add_argument("--post_hoc", action='store_true', 73 | help="whether to run the setting of post hoc (for Just_in_time task)") 74 | parser.add_argument("--do_train", action='store_true', 75 | help="Whether to run training.") 76 | parser.add_argument("--do_eval", action='store_true', 77 | help="Whether to run eval on the dev set.") 78 | parser.add_argument("--do_test", action='store_true', 79 | help="Whether to run eval on the dev set.") 80 | parser.add_argument("--no_cuda", action='store_true', 81 | help="Avoid using CUDA when available") 82 | parser.add_argument("--load_model", action='store_true', 83 | help="Whether to load the pretrained checkpoint.") 84 | 85 | parser.add_argument("--train_batch_size", default=100, type=int, 86 | help="Batch size per GPU/CPU for training.") 87 | parser.add_argument("--eval_batch_size", default=128, type=int, 88 | help="Batch size per GPU/CPU for evaluation.") 89 | parser.add_argument("--alpha", default=1/3, type=float, 90 | help="hyperparam in loss function for language model loss.") 91 | parser.add_argument("--beta", default=1/3, type=float, 92 | help="hyperpapram in loss function for contrastive learning loss.") 93 | parser.add_argument("--queue_size", default=57600, type=int, 94 | help="size for the queue in the model.") 95 | parser.add_argument('--gradient_accumulation_steps', type=int, default=2, 96 | help="Number of updates steps to accumulate before performing a backward/update pass.") 97 | parser.add_argument("--learning_rate", default=0.00005, type=float, 98 | help="The initial learning rate for Adam.") 99 | parser.add_argument("--beam_size", default=4, type=int, 100 | help="beam size for beam search") 101 | parser.add_argument("--weight_decay", default=0.0, type=float, 102 | help="Weight deay if we apply some.") 103 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 104 | help="Epsilon for Adam optimizer.") 105 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 106 | help="Max gradient norm.") 107 | parser.add_argument("--num_train_epochs", default=10, type=int, 108 | help="Total number of training epochs to perform.") 109 | parser.add_argument('--seed', type=int, default=42, 110 | help="random seed for initialization") 111 | parser.add_argument('--distributed', action='store_true') 112 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 113 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 114 | 115 | # print arguments 116 | args = parser.parse_args() 117 | 118 | #set dataset 119 | if args.task == 'just_in_time': 120 | args.language = ['Summary','Return', 'Param'] 121 | if args.post_hoc: 122 | args.output_dir += 'post_hoc/' 123 | else: 124 | args.output_dir += 'just_in_time/' 125 | args.num_train_epochs = 30 126 | elif args.task == 'pretrain': 127 | args.language = [ 'python', 'go','java','javascript','php','ruby'] 128 | args.output_dir += 'pretrained_CSN/' 129 | elif args.task == "cup": 130 | args.output_dir += 'cup/' 131 | if args.wandb: 132 | wandb.init(project="DocCheckerNet", name = args.run_name) 133 | 134 | # set log 135 | 136 | if os.path.exists(args.output_dir) is False: 137 | os.makedirs(args.output_dir) 138 | 139 | logging.basicConfig(filename=args.output_dir + '/run.log', 140 | filemode='a', 141 | format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 142 | datefmt = '%m/%d/%Y %H:%M:%S', 143 | level = logging.INFO) 144 | # set device 145 | utils.init_distributed_mode(args) 146 | 147 | if args.wandb: 148 | wandb.config = { 149 | "learning_rate": args.learning_rate, 150 | "epochs": args.num_train_epochs, 151 | "batch_size": args.train_batch_size, 152 | "beam_size": args.beam_size 153 | } 154 | 155 | # Set seed 156 | utils.set_seed(args.seed) 157 | 158 | pool = multiprocessing.Pool(args.cpu_cont) 159 | scaler = torch.cuda.amp.GradScaler() 160 | 161 | config, tokenizer, model = utils.build_or_load_gen(args) 162 | 163 | logger.info("Training/evaluation parameters %s", args) 164 | 165 | map_location = {"cuda:0": "cuda:%d" % args.rank} if args.distributed else None 166 | if args.distributed: 167 | model.cuda() 168 | if args.load_model: 169 | checkpoint_prefix = 'checkpoint-best-bleu/pytorch_model.bin' 170 | output_dir = os.path.join(args.load_model_dir, checkpoint_prefix) 171 | model.load_state_dict(torch.load(output_dir,map_location='cuda:0')) 172 | else: 173 | model.to(args.device) 174 | if args.load_model: 175 | checkpoint_prefix = 'checkpoint-epoch-3/pytorch_model.bin' 176 | output_dir = os.path.join(args.load_model_dir, checkpoint_prefix) 177 | model.load_state_dict(torch.load(output_dir, map_location='cuda:0')) 178 | model.queue_ptr[0] = 0 179 | if args.do_train: 180 | # Prepare training data loader 181 | if args.task == 'just_in_time': 182 | train_examples, train_dataloader = get_dataloader(args, args.data_folder, tokenizer=tokenizer, pool=pool, stage='train', label=True) 183 | elif args.task == 'pretrain': 184 | train_examples, train_dataloader = get_dataloader(args, args.data_folder, tokenizer=tokenizer, pool=pool,stage='train') 185 | elif args.task == 'cup': 186 | train_examples, train_dataloader = get_dataloader(args, args.data_folder, tokenizer=tokenizer, pool=pool,stage='train') 187 | # Prepare optimizer and schedule (linear warmup and decay) 188 | no_decay = ['bias', 'LayerNorm.weight'] 189 | optimizer_grouped_parameters = [ 190 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 191 | 'weight_decay': args.weight_decay}, 192 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 193 | ] 194 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 195 | scheduler = get_linear_schedule_with_warmup(optimizer, 196 | num_warmup_steps=int(len(train_dataloader)*args.num_train_epochs*0.1), 197 | num_training_steps=len(train_dataloader)*args.num_train_epochs) 198 | 199 | #Start training 200 | logger.info("***** Running training *****") 201 | logger.info(" Num examples = %d", len(train_examples)) 202 | logger.info(" Batch size = %d", args.train_batch_size * args.gradient_accumulation_steps) 203 | logger.info(" Num epoch = %d", args.num_train_epochs) 204 | 205 | 206 | model.train() 207 | patience, best_bleu, best_loss, best_acc, best_f1, losses, dev_dataset = 0, 0, 10000, 0, 0, [], {} 208 | losses_lm = [] 209 | losses_ita = [] 210 | losses_itm = [] 211 | 212 | for epoch in range(0, args.num_train_epochs): 213 | total_num = 0 214 | total_acc = 0 215 | for batch in tqdm(train_dataloader): 216 | 217 | if args.task == 'just_in_time': 218 | batch = tuple(t.to(args.device) for t in batch) 219 | source_ids, target_ids, label= batch 220 | 221 | total_num += source_ids.size(0) 222 | loss_lm, loss_ita, loss_itm, hits = model(source_ids=source_ids,target_ids=target_ids, labels=label, just_in_time=True) 223 | 224 | total_acc += hits.sum().data.cpu().numpy() 225 | elif args.task == 'pretrain' or args.task == 'cup': 226 | batch = tuple(t.to(args.device) for t in batch) 227 | source_ids,target_ids = batch[0], batch[1] 228 | 229 | total_num += 2*source_ids.size(0) 230 | loss_lm, loss_ita, loss_itm, hits = model(source_ids=source_ids,target_ids=target_ids) 231 | 232 | total_acc += hits.sum().data.cpu().numpy() 233 | 234 | 235 | loss = args.alpha*loss_lm + args.beta*loss_ita + (1-args.alpha-args.beta)*loss_itm 236 | 237 | if args.wandb: 238 | wandb.log({ 239 | "tota loss": loss, 240 | "loss lm": loss_lm, 241 | "loss contrastive": loss_ita, 242 | "loss binary classification": loss_itm, 243 | "train acc": total_acc/total_num*100}) 244 | 245 | # breaknum_train_epochs 246 | if args.n_gpu > 1: 247 | loss = loss.mean() # mean() to average on multi-gpu. 248 | if args.gradient_accumulation_steps > 1: 249 | loss = loss / args.gradient_accumulation_steps 250 | losses.append(loss.item()) 251 | losses_ita.append(loss_ita.item()) 252 | losses_itm.append(loss_itm.item()) 253 | losses_lm.append(loss_lm.item()) 254 | loss.backward() 255 | if len(losses) % args.gradient_accumulation_steps == 0: 256 | #Update parameters 257 | optimizer.step() 258 | optimizer.zero_grad() 259 | scheduler.step() 260 | if len(losses) // args.gradient_accumulation_steps % 100 == 0: 261 | logger.info("epoch {} step {} total loss {} loss_lm {} loss_contrastive {} loss_binary {} acc {:.2f}".format(epoch, 262 | len(losses)//args.gradient_accumulation_steps, 263 | round(np.mean(losses[-100*args.gradient_accumulation_steps:]),4), 264 | round(np.mean(losses_lm[-100*args.gradient_accumulation_steps:]),4), 265 | round(np.mean(losses_ita[-100*args.gradient_accumulation_steps:]),4), 266 | round(np.mean(losses_itm[-100*args.gradient_accumulation_steps:]),4), 267 | total_acc/total_num*100)) 268 | 269 | acc = total_acc/total_num*100 270 | 271 | if (len(losses) // args.gradient_accumulation_steps % 5000 == 0 and args.do_eval and args.task == 'pretrain'): 272 | #Eval model with dev dataset 273 | if 'dev' in dev_dataset: 274 | eval_examples, eval_dataloader = dev_dataset['dev'] 275 | else: 276 | eval_examples, eval_dataloader = get_dataloader(args, args.data_folder, pool=pool,tokenizer=tokenizer, stage='valid', label=False, sequential=True, num_sample=1000) 277 | dev_dataset['dev']= eval_examples, eval_dataloader 278 | 279 | logger.info("\n***** Running evaluation *****") 280 | logger.info(" Num examples = %d", len(eval_examples)) 281 | logger.info(" Batch size = %d", args.eval_batch_size) 282 | losses_eval = [] 283 | 284 | model.eval() 285 | p=[] 286 | # pred_ids = [] 287 | for batch in eval_dataloader: 288 | batch = tuple(t.to(args.device) for t in batch) 289 | source_ids = batch[0] 290 | target_ids = batch[1] 291 | with torch.no_grad(): 292 | loss_lm, loss_ita,loss_itm, pred_sentence = model(source_ids=source_ids,target_ids=target_ids, stage='dev') 293 | for pred in pred_sentence: 294 | t = pred[0].cpu().numpy() 295 | t = list(t) 296 | if 0 in t: 297 | t = t[:t.index(0)] 298 | text = tokenizer.decode(t,clean_up_tokenization_spaces=False) 299 | p.append(text) 300 | 301 | loss = args.alpha*loss_lm + args.beta*loss_ita + (1-args.alpha-args.beta)*loss_itm 302 | losses_eval.append(loss.item()) 303 | 304 | model.train() 305 | predictions = [] 306 | with open(args.output_dir+"/dev.output",'w') as f, open(args.output_dir+"/dev.gold",'w') as f1: 307 | for ref,gold in zip(p,eval_examples): 308 | predictions.append(str(gold.idx)+'\t'+ref) 309 | f.write(str(gold.idx)+'\t'+ref+'\n') 310 | f1.write(str(gold.idx)+'\t'+gold.target+'\n') 311 | 312 | (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "dev.gold")) 313 | dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2) 314 | logger.info(" %s = %s "%("bleu-4",str(dev_bleu))) 315 | logger.info(" "+"*"*20) 316 | if dev_bleu > best_bleu: 317 | logger.info(" Best bleu:%s",dev_bleu) 318 | logger.info(" "+"*"*20) 319 | best_bleu = dev_bleu 320 | # Save best checkpoint for best bleu 321 | output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu') 322 | if not os.path.exists(output_dir): 323 | os.makedirs(output_dir) 324 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 325 | output_model_file = os.path.join(output_dir, "pytorch_model.bin") 326 | torch.save(model_to_save.state_dict(), output_model_file) 327 | patience =0 328 | else: 329 | patience +=1 330 | if patience == 5: 331 | break 332 | 333 | if np.mean(losses_eval) < best_loss: 334 | logger.info(" Best loss:%s", np.mean(losses_eval)) 335 | logger.info(" "+"*"*20) 336 | best_loss = np.mean(losses_eval) 337 | # Save best checkpoint for best bleu 338 | output_dir = os.path.join(args.output_dir, 'checkpoint-best-loss') 339 | if not os.path.exists(output_dir): 340 | os.makedirs(output_dir) 341 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 342 | output_model_file = os.path.join(output_dir, "pytorch_model.bin") 343 | torch.save(model_to_save.state_dict(), output_model_file) 344 | patience =0 345 | 346 | 347 | 348 | 349 | if args.task == 'pretrain' : 350 | output_dir = os.path.join(args.output_dir, 'checkpoint-epoch-{}'.format(epoch)) 351 | if not os.path.exists(output_dir): 352 | os.makedirs(output_dir) 353 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 354 | output_model_file = os.path.join(output_dir, "pytorch_model.bin") 355 | torch.save(model_to_save.state_dict(), output_model_file) 356 | elif args.task == 'just_in_time': 357 | if 'dev' in dev_dataset: 358 | eval_examples, eval_dataloader = dev_dataset['dev'] 359 | else: 360 | eval_examples, eval_dataloader = get_dataloader(args, args.data_folder, pool=pool,tokenizer=tokenizer, stage='valid', label=True) 361 | dev_dataset['dev']= eval_examples, eval_dataloader 362 | 363 | logger.info("\n***** Running evaluation *****") 364 | logger.info(" Num examples = %d", len(eval_examples)) 365 | logger.info(" Batch size = %d", args.eval_batch_size) 366 | model.eval() 367 | p=[] 368 | 369 | total_num = 0 370 | total_acc = 0 371 | target_labels = [] 372 | pred_labels = [] 373 | for batch in eval_dataloader: 374 | batch = tuple(t.to(args.device) for t in batch) 375 | source_ids, target_ids, labels = batch 376 | 377 | bs = source_ids.size(0) 378 | with torch.no_grad(): 379 | pred, hits = model(source_ids, target_ids, labels=labels, stage='test') 380 | total_num += bs 381 | total_acc += hits.sum().data.cpu().numpy() 382 | target_labels.extend(labels.tolist()) 383 | pred_labels.extend(pred.tolist()) 384 | acc = total_acc/total_num*100 385 | precision, recall, F1_score = utils.compute_score(pred_labels, target_labels) 386 | model.train() 387 | if acc > best_acc: 388 | logger.info(" Best acc:%s", acc) 389 | logger.info(" "+"*"*20) 390 | best_acc = acc 391 | # Save best checkpoint for best acc 392 | output_dir = os.path.join(args.output_dir, 'checkpoint-best-acc') 393 | if not os.path.exists(output_dir): 394 | os.makedirs(output_dir) 395 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 396 | output_model_file = os.path.join(output_dir, "pytorch_model.bin") 397 | torch.save(model_to_save.state_dict(), output_model_file) 398 | patience = 0 399 | else: 400 | patience += 1 401 | if patience == 8: 402 | break 403 | 404 | if F1_score > best_f1: 405 | logger.info(" Best F1:%s", F1_score) 406 | logger.info(" "+"*"*20) 407 | best_f1 = F1_score 408 | # Save best checkpoint for best F1 score 409 | output_dir = os.path.join(args.output_dir, 'checkpoint-best-F1') 410 | if not os.path.exists(output_dir): 411 | os.makedirs(output_dir) 412 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 413 | output_model_file = os.path.join(output_dir, "pytorch_model.bin") 414 | torch.save(model_to_save.state_dict(), output_model_file) 415 | patience = 0 416 | elif args.task == "cup": 417 | if 'dev' in dev_dataset: 418 | eval_examples, eval_dataloader = dev_dataset['dev'] 419 | else: 420 | eval_examples, eval_dataloader = get_dataloader(args, args.data_folder, pool=pool,tokenizer=tokenizer, stage='valid', label=False, sequential=True, num_sample=1000) 421 | dev_dataset['dev']= eval_examples, eval_dataloader 422 | 423 | logger.info("\n***** Running evaluation *****") 424 | logger.info(" Num examples = %d", len(eval_examples)) 425 | logger.info(" Batch size = %d", args.eval_batch_size) 426 | losses_eval = [] 427 | 428 | model.eval() 429 | p=[] 430 | # pred_ids = [] 431 | for batch in eval_dataloader: 432 | batch = tuple(t.to(args.device) for t in batch) 433 | source_ids = batch[0] 434 | target_ids = batch[1] 435 | with torch.no_grad(): 436 | loss_lm, loss_ita,loss_itm, pred_sentence = model(source_ids=source_ids,target_ids=target_ids, stage='dev') 437 | for pred in pred_sentence: 438 | t = pred[0].cpu().numpy() 439 | t = list(t) 440 | if 0 in t: 441 | t = t[:t.index(0)] 442 | text = tokenizer.decode(t,clean_up_tokenization_spaces=False) 443 | p.append(text) 444 | 445 | loss = args.alpha*loss_lm + args.beta*loss_ita + (1-args.alpha-args.beta)*loss_itm 446 | losses_eval.append(loss.item()) 447 | 448 | model.train() 449 | predictions = [] 450 | with open(args.output_dir+"/dev.output",'w') as f, open(args.output_dir+"/dev.gold",'w') as f1: 451 | for ref,gold in zip(p,eval_examples): 452 | predictions.append(str(gold.idx)+'\t'+ref) 453 | f.write(str(gold.idx)+'\t'+ref+'\n') 454 | f1.write(str(gold.idx)+'\t'+gold.target+'\n') 455 | 456 | (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "dev.gold")) 457 | dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2) 458 | logger.info(" %s = %s "%("bleu-4",str(dev_bleu))) 459 | logger.info(" "+"*"*20) 460 | if dev_bleu > best_bleu: 461 | logger.info(" Best bleu:%s",dev_bleu) 462 | logger.info(" "+"*"*20) 463 | best_bleu = dev_bleu 464 | # Save best checkpoint for best bleu 465 | output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu') 466 | if not os.path.exists(output_dir): 467 | os.makedirs(output_dir) 468 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 469 | output_model_file = os.path.join(output_dir, "pytorch_model.bin") 470 | torch.save(model_to_save.state_dict(), output_model_file) 471 | patience =0 472 | else: 473 | patience +=1 474 | if patience == 5: 475 | break 476 | 477 | if np.mean(losses_eval) < best_loss: 478 | logger.info(" Best loss:%s", np.mean(losses_eval)) 479 | logger.info(" "+"*"*20) 480 | best_loss = np.mean(losses_eval) 481 | # Save best checkpoint for best bleu 482 | output_dir = os.path.join(args.output_dir, 'checkpoint-best-loss') 483 | if not os.path.exists(output_dir): 484 | os.makedirs(output_dir) 485 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 486 | output_model_file = os.path.join(output_dir, "pytorch_model.bin") 487 | torch.save(model_to_save.state_dict(), output_model_file) 488 | patience =0 489 | 490 | if args.task == 'just_in_time' and args.do_test: 491 | if 'test' in dev_dataset: 492 | eval_examples, eval_dataloader = dev_dataset['test'] 493 | else: 494 | eval_examples, eval_dataloader = get_dataloader(args, args.data_folder, tokenizer=tokenizer, pool=pool,stage='test', label=True, sequential=True) 495 | dev_dataset['test'] = eval_examples, eval_dataloader 496 | 497 | logger.info("\n***** Running evaluation on the synthetic *****") 498 | logger.info(" Num examples = %d", len(eval_examples)) 499 | logger.info(" Batch size = %d", args.eval_batch_size) 500 | model.eval() 501 | p=[] 502 | total_num = 0 503 | total_acc = 0 504 | target_labels = [] 505 | pred_labels = [] 506 | for batch in eval_dataloader: 507 | batch = tuple(t.to(args.device) for t in batch) 508 | source_ids, target_ids, labels = batch 509 | bs = source_ids.size(0) 510 | with torch.no_grad(): 511 | pred_label, hits = model(source_ids, target_ids, labels=labels, stage='test') 512 | 513 | total_num += bs 514 | total_acc += hits.sum().data.cpu().numpy() 515 | target_labels.extend(labels.tolist()) 516 | pred_labels.extend(pred_label.tolist()) 517 | acc = total_acc/total_num*100 518 | precision, recall, f1 = utils.compute_score(pred_labels, target_labels) 519 | logger.info(' Testing ACC = {:.2f}'.format(acc)) 520 | logger.info(' Recall score = {:.3f}'.format(recall)) 521 | logger.info(' Precision score = {:.3f}'.format(precision)) 522 | logger.info(' F1 score= {:.3f}'.format(f1)) 523 | 524 | 525 | if args.do_test: 526 | if args.task == 'just_in_time': 527 | checkpoint_prefix = 'checkpoint-best-acc/pytorch_model.bin' 528 | eval_examples, eval_dataloader = get_dataloader(args, args.data_folder, tokenizer=tokenizer, pool=pool,stage='test', label=True, sequential=True) 529 | else: 530 | checkpoint_prefix = 'checkpoint-best-bleu/pytorch_model.bin' 531 | eval_examples, eval_dataloader, raw_examples = get_dataloader(args, args.data_folder, tokenizer=tokenizer, pool=pool,stage='test', label=False, sequential=True, infer=True) 532 | 533 | # output_dir = os.path.join(args.load_model_dir, checkpoint_prefix) 534 | # model.load_state_dict(torch.load(output_dir,map_location='cuda:0')) 535 | 536 | # output_dir = os.path.join(args.output_dir, checkpoint_prefix) 537 | # model_to_load = model.module if hasattr(model, 'module') else model 538 | # model.load_state_dict(torch.load(output_dir,map_location='cuda:0')) 539 | model.eval() 540 | 541 | logger.info("\n***** Running evaluation on the test set *****") 542 | logger.info(" Num examples = %d", len(eval_examples)) 543 | logger.info(" Batch size = %d", args.eval_batch_size) 544 | 545 | if args.task == 'just_in_time': 546 | model.eval() 547 | p=[] 548 | total_num = 0 549 | total_acc = 0 550 | target_labels = [] 551 | pred_labels = [] 552 | for batch in eval_dataloader: 553 | batch = tuple(t.to(args.device) for t in batch) 554 | source_ids, target_ids, labels = batch 555 | 556 | bs = source_ids.size(0) 557 | with torch.no_grad(): 558 | pred, hits = model(source_ids, target_ids, labels=labels, stage='test') 559 | 560 | total_num += bs 561 | total_acc += hits.sum().data.cpu().numpy() 562 | target_labels.extend(labels.tolist()) 563 | pred_labels.extend(pred.tolist()) 564 | acc = total_acc/total_num*100 565 | precision, recall, f1 = utils.compute_score(pred_labels, target_labels) 566 | utils.confusion_matrix(pred_labels, target_labels, args.output_dir) 567 | logger.info(' Testing ACC = {:.2f}'.format(acc)) 568 | logger.info(' Recall score = {:.3f}'.format(recall)) 569 | logger.info(' Precision score = {:.3f}'.format(precision)) 570 | logger.info(' F1 score= {:.3f}'.format(f1)) 571 | with open(args.output_dir+"/test.output",'w') as f: 572 | for label,target,gold in zip(pred_labels,target_labels,eval_examples): 573 | f.write(str(gold.idx)+'\t'+str(label)+'\t'+str(target)+'\n') 574 | elif args.task == "pretrain" or args.task=="cup": 575 | p = [] 576 | labels=[] 577 | for batch in tqdm(eval_dataloader,total=len(eval_dataloader)): 578 | batch = tuple(t.to(args.device) for t in batch) 579 | source_ids = batch[0] 580 | target_ids = batch[1] 581 | 582 | 583 | with torch.no_grad(): 584 | pred_labels, preds = model(source_ids,target_ids=target_ids, stage='inference') 585 | for label, pred in zip(pred_labels,preds): 586 | t = pred[0].cpu().numpy() 587 | t = list(t) 588 | if 0 in t: 589 | t = t[:t.index(0)] 590 | text = tokenizer.decode(t,clean_up_tokenization_spaces=False) 591 | p.append(text) 592 | labels.append(int(label)) 593 | 594 | 595 | model.train() 596 | predictions=[] 597 | with open(args.output_dir+"/test.output",'w') as f, open(args.output_dir+"/test.gold",'w') as f1: 598 | # test original 599 | for ref,label,gold in zip(p,labels,eval_examples): 600 | predictions.append(str(gold.idx)+'\t'+ref) 601 | f.write(str(gold.idx)+'\t'+ref+'\n') 602 | f1.write(str(gold.idx)+'\t'+gold.target+'\n') 603 | 604 | (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "test.gold")) 605 | dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2) 606 | logger.info(" %s = %s "%("bleu-4",str(dev_bleu))) 607 | logger.info(" "+"*"*20) 608 | print(dev_bleu) 609 | 610 | if __name__ == "__main__": 611 | main() 612 | 613 | -------------------------------------------------------------------------------- /DocChecker/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ast 3 | from collections import deque 4 | import random 5 | import numpy as np 6 | import torch 7 | import multiprocessing 8 | import json 9 | from transformers import RobertaTokenizer, RobertaConfig, RobertaModel 10 | from model import Seq2Seq 11 | from diff_utils import * 12 | from statistics import mean, median 13 | import matplotlib.pyplot as plt 14 | from sklearn import metrics 15 | import requests 16 | 17 | def download_file_from_google_drive(id, destination): 18 | URL = "https://docs.google.com/uc?export=download" 19 | 20 | session = requests.Session() 21 | 22 | response = session.get(URL, params = { 'id' : id }, stream = True) 23 | token = get_confirm_token(response) 24 | 25 | if token: 26 | params = { 'id' : id, 'confirm' : token } 27 | response = session.get(URL, params = params, stream = True) 28 | 29 | save_response_content(response, destination) 30 | 31 | def get_confirm_token(response): 32 | for key, value in response.cookies.items(): 33 | if key.startswith('download_warning'): 34 | return value 35 | 36 | return None 37 | 38 | def save_response_content(response, destination): 39 | CHUNK_SIZE = 32768 40 | 41 | with open(destination, "wb") as f: 42 | for chunk in response.iter_content(CHUNK_SIZE): 43 | if chunk: # filter out keep-alive new chunks 44 | f.write(chunk) 45 | 46 | def setup_for_distributed(is_master): 47 | """ 48 | This function disables printing when not in master process 49 | """ 50 | import builtins as __builtin__ 51 | builtin_print = __builtin__.print 52 | 53 | def print(*args, **kwargs): 54 | force = kwargs.pop('force', False) 55 | if is_master or force: 56 | builtin_print(*args, **kwargs) 57 | 58 | __builtin__.print = print 59 | 60 | 61 | def init_distributed_mode(args): 62 | cpu_cont = multiprocessing.cpu_count() 63 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 64 | args.rank = int(os.environ["RANK"]) 65 | args.world_size = int(os.environ['WORLD_SIZE']) 66 | args.gpu = int(os.environ['LOCAL_RANK']) 67 | elif 'SLURM_PROCID' in os.environ: 68 | args.rank = int(os.environ['SLURM_PROCID']) 69 | args.gpu = args.rank % torch.cuda.device_count() 70 | else: 71 | print('Not using distributed mode') 72 | args.distributed = False 73 | return 74 | 75 | args.distributed = True 76 | 77 | torch.cuda.set_device(args.gpu) 78 | args.dist_backend = 'nccl' 79 | print('| distributed init (rank {}, word {}): {}'.format( 80 | args.rank, args.world_size, args.dist_url), flush=True) 81 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 82 | world_size=args.world_size, rank=args.rank) 83 | torch.distributed.barrier() 84 | device = torch.device("cuda", args.gpu) 85 | args.n_gpu = torch.cuda.device_count() 86 | args.device = device 87 | args.cpu_cont = cpu_cont 88 | setup_for_distributed(args.rank == 0) 89 | 90 | def compute_score(predicted_labels, gold_labels): 91 | true_positives = 0.0 92 | true_negatives = 0.0 93 | false_positives = 0.0 94 | false_negatives = 0.0 95 | 96 | assert(len(predicted_labels) == len(gold_labels)) 97 | 98 | for i in range(len(gold_labels)): 99 | if gold_labels[i]: 100 | if predicted_labels[i]: 101 | true_positives += 1 102 | else: 103 | false_negatives += 1 104 | else: 105 | if predicted_labels[i]: 106 | false_positives += 1 107 | else: 108 | true_negatives += 1 109 | 110 | if verbose: 111 | print('True positives: {}'.format(true_positives)) 112 | print('False positives: {}'.format(false_positives)) 113 | print('True negatives: {}'.format(true_negatives)) 114 | print('False negatives: {}'.format(false_negatives)) 115 | 116 | try: 117 | precision = true_positives/(true_positives + false_positives) 118 | except: 119 | precision = 0.0 120 | try: 121 | recall = true_positives/(true_positives + false_negatives) 122 | except: 123 | recall = 0.0 124 | try: 125 | f1 = 2*((precision * recall)/(precision + recall)) 126 | except: 127 | f1 = 0.0 128 | 129 | return precision, recall, f1 130 | 131 | def confusion_matrix(predict, label, path): 132 | 133 | confusion_matrix = metrics.confusion_matrix(label, predict) 134 | 135 | cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix, display_labels = [0, 1]) 136 | 137 | cm_display.plot() 138 | plt.savefig(path+'/matrix.png') 139 | plt.show() 140 | 141 | def build_or_load_gen(args): 142 | # build model 143 | if 'unixcoder' in args.model_name_or_path: 144 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path) 145 | config = RobertaConfig.from_pretrained(args.model_name_or_path) 146 | # import!!!you must set is_decoder as True for generation 147 | 148 | config.is_decoder = True 149 | encoder = RobertaModel.from_pretrained(args.model_name_or_path,config=config) 150 | 151 | tokenizer.add_tokens(["", '', "", "","", "", "", 152 | "" ,"","", ""],special_tokens=True) 153 | config.vocab_size = len(tokenizer) 154 | encoder.resize_token_embeddings(len(tokenizer)) 155 | 156 | # print(encoder) 157 | model = Seq2Seq(encoder=encoder,decoder=encoder,config=config, 158 | beam_size=args.beam_size,max_length=args.max_target_length, 159 | sos_id=tokenizer.convert_tokens_to_ids([""])[0],eos_id=tokenizer.sep_token_id, 160 | queue_size=args.queue_size, device= args.device) 161 | 162 | return config, tokenizer, model 163 | 164 | def set_seed(seed=42): 165 | random.seed(seed) 166 | os.environ['PYHTONHASHSEED'] = str(seed) 167 | np.random.seed(seed) 168 | torch.manual_seed(seed) 169 | torch.cuda.manual_seed(seed) 170 | torch.backends.cudnn.deterministic = True 171 | 172 | 173 | def get_tqdm(iterator, desc=""): 174 | return iterator 175 | 176 | 177 | def compute_score(predicted_labels, gold_labels, verbose=False): 178 | true_positives = 0.0 179 | true_negatives = 0.0 180 | false_positives = 0.0 181 | false_negatives = 0.0 182 | 183 | assert(len(predicted_labels) == len(gold_labels)) 184 | 185 | for i in range(len(gold_labels)): 186 | if gold_labels[i]==0: 187 | if predicted_labels[i]==0: 188 | true_positives += 1 189 | else: 190 | false_negatives += 1 191 | else: 192 | if predicted_labels[i]==0: 193 | false_positives += 1 194 | else: 195 | true_negatives += 1 196 | 197 | if verbose: 198 | print('True positives: {}'.format(true_positives)) 199 | print('False positives: {}'.format(false_positives)) 200 | print('True negatives: {}'.format(true_negatives)) 201 | print('False negatives: {}'.format(false_negatives)) 202 | 203 | try: 204 | precision = true_positives/(true_positives + false_positives) 205 | except: 206 | precision = 0.0 207 | try: 208 | recall = true_positives/(true_positives + false_negatives) 209 | except: 210 | recall = 0.0 211 | try: 212 | f1 = 2*((precision * recall)/(precision + recall)) 213 | except: 214 | f1 = 0.0 215 | return precision, recall, f1 216 | 217 | 218 | def count_probability_justInTime(): 219 | examples = [] 220 | language = ['Summary','Return', 'Param'] 221 | root_folder = './dataset/just_in_time/' 222 | stages = ['train', 'test', 'valid', 'test_clean'] 223 | count_code = [] 224 | count_nl = [] 225 | for stage in stages: 226 | for lan in language: 227 | filename = root_folder + lan + '/' + stage + '.json' 228 | with open(filename, encoding="utf-8") as f: 229 | data = ast.literal_eval(f.read()) 230 | 231 | for idx, js in enumerate(data): 232 | code = ' '.join(js['span_diff_code_subtokens']).replace('\n', ' ') 233 | code = ' '.join(code.strip().split()) 234 | nl = ' '.join(js['new_comment_subtokens']).replace('\n', '') 235 | nl = ' '.join(nl.strip().split()) 236 | count_code.append(len(js['span_diff_code_subtokens'])) 237 | count_nl.append(len(js['new_comment_subtokens'])) 238 | print('code max: ', max(count_code)) 239 | print('code min: ', min(count_code)) 240 | print('code mean: ', mean(count_code)) 241 | print('code median: ', median(count_code)) 242 | print('nl max: ', max(count_nl)) 243 | print('nl min: ', min(count_nl)) 244 | print('nl mean: ', mean(count_nl)) 245 | print('nl median: ', median(count_nl)) 246 | 247 | def generate_diff_comment_justInTime(): 248 | examples = [] 249 | language = ['Summary','Return', 'Param'] 250 | root_folder = './dataset/just_in_time/' 251 | stages = [ 'test','train','valid'] 252 | count_nl = [] 253 | count_pl = [] 254 | for stage in stages: 255 | for lan in language: 256 | filename = root_folder + lan + '/' + stage + '_old.json' 257 | with open(filename, encoding="utf-8") as f: 258 | data = ast.literal_eval(f.read()) 259 | D = [] 260 | for idx, js in enumerate(data): 261 | try: 262 | nl_new = js["new_comment_subtokens"] 263 | nl_old = js["old_comment_subtokens"] 264 | diff_comment,_,_ = compute_comment_diffs(nl_old, nl_new) 265 | count_nl.append(len(diff_comment)) 266 | count_pl.append(len(js['span_diff_code_subtokens'])) 267 | js["span_diff_comment_subtokens"] = diff_comment 268 | print(' '.join(diff_comment)) 269 | D.append(js) 270 | except: 271 | None 272 | with open(root_folder + lan + '/' + stage + '.json', 'w') as fo: 273 | json.dump(D, fo) 274 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

4 | logo 5 |

6 | 7 | 8 | license 9 | 10 | 11 | python 12 | 13 | 14 | downloads 15 | 16 | 17 | # DocChecker: Bootstrapping Code-Text Pretrained Language Model to Detect Inconsistency Between Code and Comment 18 | 19 |
20 | 21 | # Table of content 22 | - [Introduction](#introduction) 23 | - [Installation](#installation-guide) 24 | - [Getting Started](#getting-started) 25 | - [Inferencing Pipeline](#inferencing-pipeline) 26 | - [Pre-training Pipeline](#pre-training-pipeline) 27 | - [Installation for Pre-training](#installation-for-pre-training) 28 | - [Dataset for Pre-training](#dataset-for-pre-training) 29 | - [Fine-tuning Pipeline](#fine-tuning-pipeline) 30 | - [Dataset for Fine-tuning](#dataset-for-fine-tuning) 31 | - [Playground](#playground) 32 | - [Citing Us](#citing-us) 33 | - [Contact Us](#contact-us) 34 | - [License](#license) 35 | 36 | ___________ 37 | # Introduction 38 | Comments on source code serve as critical documentation for enabling developers to understand the code's functionality and use it properly. However, it is challenging to ensure that comments accurately reflect the corresponding code, particularly as the software evolves over time. Although increasing interest has been taken in developing automated methods for identifying and fixing inconsistencies between code and comments, the existing methods have primarily relied on heuristic rules. 39 | 40 | DocChecker is trained on top of encoder-decoder model to learn from code-text pairs. It is jointly pre-trained with three objectives: code-text contrastive learning, binary classification, and text generation. DocChecker is a tool that be used to detect noisy code-comment pairs and generate synthetic comments, enabling it to determine comments that do not match their associated code snippets and correct them. 41 | Its effectiveness is demonstrated on the Just-In-Time dataset compared with other state-of-the-art methods. 42 | 43 |

44 | overview 45 |

46 | 47 | # Installation Guide 48 | 49 | 1. (Optional) Creating conda environment 50 | 51 | ```bash 52 | conda create -n docchecker python=3.8 53 | conda activate docchecker 54 | ``` 55 | 56 | 2. Install from [PyPI](https://pypi.org/project/docchecker/): 57 | ```bash 58 | pip install docchecker 59 | ``` 60 | 61 | 3. Alternatively, build DocChecker from source: 62 | 63 | ```bash 64 | git clone https://github.com/FSoft-AI4Code/DocChecker.git 65 | cd DocChecker 66 | pip install -r requirements.txt . 67 | ``` 68 | 69 | # Getting Started 70 | ## Inferencing pipeline 71 | 72 | Getting started with DocChecker is simple and quick with our tool by using ``inference()`` function. 73 | 74 | ```python 75 | from DocChecker.utils import inference 76 | ``` 77 | There are a few notable arguments that need to be considered: 78 | 79 | Parameters: 80 | 81 | - ``input_file_path`` (str): the file path that contains source code, if you want to check all the functions in there. 82 | - ``raw_code`` (str): a sequence of source code if `input_file_path` is not given. 83 | - ``language`` (str, required): the programming language that corresponds your raw_code. We support 10 popular programming languages, including Java, JavaScript, Python, Ruby, Rust, Golang, C#, C++, C, and PHP. 84 | - ``output_file_path`` (str): if `output_file_path` is given, the results from our tool will be written in `output_file_path`; otherwise, they will be printed on the screen. 85 | 86 | Returns: 87 | 88 | - list of dictionaries, including: 89 | - ``function_name``: the name of each function in the raw code 90 | - ``code``: code snippet 91 | - ``docstring``: the docstring corresponding code snippet 92 | - ``predict``: the prediction of DocChecker. It returns “Inconsistent!” or “Consistent!”, corresponding the docstring is inconsistent/consistent with the code in a code-text pair 93 | - ``recommend_docstring``: If a code-text pair is considered as “Inconsistent!”, DocChecker will replace its docstring by giving comprehensive ones; otherwise, it will keep the original version. 94 | 95 | Here's an example showing how to load docchecker model and perform inference on inconsistent detection task: 96 | 97 | ```python 98 | from DocChecker.utils import inference 99 | 100 | code = """ 101 | def inject_func_as_unbound_method(class_, func, method_name=None): 102 | # This is actually quite simple 103 | if method_name is None: 104 | method_name = get_funcname(func) 105 | setattr(class_, method_name, func) 106 | 107 | def e(message, exit_code=None): 108 | # Print an error log message. 109 | print_log(message, YELLOW, BOLD) 110 | if exit_code is not None: 111 | sys.exit(exit_code) 112 | """ 113 | 114 | inference(raw_code=code,language='python') 115 | 116 | >>[ 117 | { 118 | "function_name": "inject_func_as_unbound_method", 119 | "code": "def inject_func_as_unbound_method(class_, func, method_name=None):\n \n if method_name is None:\n method_name = get_funcname(func)\n setattr(class_, method_name, func)", 120 | "docstring": " This is actually quite simple", 121 | "predict": "Inconsistent!", 122 | "recommended_docstring": "Inject a function as an unbound method." 123 | }, 124 | { 125 | "function_name": "e", 126 | "code": "def e(message, exit_code=None):\n \n print_log(message, YELLOW, BOLD)\n if exit_code is not None:\n sys.exit(exit_code)", 127 | "docstring": "Print an error log message.", 128 | "predict": "Consistent!", 129 | "recommended_docstring": "Print an error log message." 130 | } 131 | ] 132 | ``` 133 | 134 | ## Pre-training Pipeline 135 | We also provide our source code for you to re-pretraining DocChecker. 136 | 137 | ### Installation for Pre-training 138 | Setup environment and install dependencies for pre-training. 139 | ```bash 140 | cd ./DocChecker 141 | pip -r install requirements.txt 142 | ``` 143 | 144 | ### Dataset for Pre-training 145 | The dataset we used comes from [CodeXGLUE](https://github.com/microsoft/CodeXGLUE/tree/main/Code-Text/code-to-text). 146 | It can be downloaded by following the command line: 147 | 148 | ```bash 149 | wget https://github.com/microsoft/CodeXGLUE/raw/main/Code-Text/code-to-text/dataset.zip 150 | unzip dataset.zip 151 | rm dataset.zip 152 | cd dataset 153 | wget https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/python.zip 154 | wget https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip 155 | wget https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/ruby.zip 156 | wget https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/javascript.zip 157 | wget https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/go.zip 158 | wget https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/php.zip 159 | 160 | unzip python.zip 161 | unzip java.zip 162 | unzip ruby.zip 163 | unzip javascript.zip 164 | unzip go.zip 165 | unzip php.zip 166 | rm *.zip 167 | rm *.pkl 168 | 169 | python preprocess.py 170 | rm -r */final 171 | cd .. 172 | ``` 173 | 174 | To re-pretrain, follow the below command line: 175 | ```shell 176 | python -m torch.distributed.run --nproc_per_node=2 run.py \ 177 | --do_train \ 178 | --do_eval \ 179 | --task pretrain \ 180 | --data_folder dataset/pretrain_dataset \ 181 | --num_train_epochs 10 182 | ``` 183 | 184 | ## Fine-tuning Pipeline 185 | To demonstrate the performance of our approach, we fine-tune DocChecker on the Just-In-Time task. The purpose of this task is to determine whether the comment is semantically out of sync with the corresponding code function. 186 | 187 | ### Dataset for Fine-tuning 188 | 189 | Download data for the [Just-In-Time](https://github.com/panthap2/deep-jit-inconsistency-detection) task from [here](https://drive.google.com/drive/folders/1heqEQGZHgO6gZzCjuQD1EyYertN4SAYZ?usp=sharing). 190 | 191 | We also provide fine-tune settings for DocChecker, whose results are reported in the paper. 192 | 193 | ```shell 194 | 195 | # Training 196 | python -m torch.distributed.run --nproc_per_node=2 run.py \ 197 | --do_train \ 198 | --do_eval \ 199 | --post_hoc \ 200 | --task just_in_time \ 201 | --load_model \ 202 | --data_folder dataset/just_in_time \ 203 | --num_train_epochs 30 204 | 205 | # Testing 206 | python -m torch.distributed.run --nproc_per_node=2 run.py \ 207 | --do_test \ 208 | --post_hoc \ 209 | --task just_in_time \ 210 | --data_folder dataset/just_in_time \ 211 | ``` 212 | 213 | # Playground 214 | We provide an interface for DocChecker at the [link](http://4.193.50.237:5000/). 215 | The demonstration can be found at [Youtube](https://youtu.be/KFbyaSf2I3c). 216 | 217 | # Citing Us 218 | More details can be found in our [paper](https://arxiv.org/abs/2306.06347). 219 | If you use this code or our package, please consider citing us: 220 | 221 | ```bibtex 222 | @article{DocChecker, 223 | title={Bootstrapping Code-Text Pretrained Language Model to Detect Inconsistency Between Code and Comment}, 224 | author={Anh T. V. Dau, Jin L.C. Guo, Nghi D. Q. Bui}, 225 | journal={EACL 2024 - Demonstration track}, 226 | pages={}, 227 | year={2024} 228 | } 229 | ``` 230 | 231 | # Contact us 232 | If you have any questions, comments or suggestions, please do not hesitate to contact us. 233 | - Website: [fpt-aicenter](https://www.fpt-aicenter.com/ai-residency/) 234 | - Email: support.ailab@fpt.com 235 | 236 | # License 237 | [Apache License Version 2.0](LICENSE.txt) 238 | -------------------------------------------------------------------------------- /assets/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/DocChecker/ca92c1117de077ffe8af879737ac3b12a2e3dd45/assets/logo.jpg -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/DocChecker/ca92c1117de077ffe8af879737ac3b12a2e3dd45/assets/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | codetext==0.0.6 2 | gdown==4.7.1 3 | torch==1.13.1 4 | transformers==4.26.1 5 | docchecker==0.1.1 --------------------------------------------------------------------------------