├── .gitignore ├── README.md ├── appendix └── Appendix.pdf ├── evaluate.py ├── images └── race.png ├── metric ├── cider │ ├── cider.py │ └── cider_scorer.py ├── meteor │ ├── data │ │ └── paraphrase-en.gz │ ├── meteor-1.5.jar │ └── meteor.py ├── rouge │ └── rouge.py └── smooth_bleu.py ├── model.py ├── results ├── cpp │ ├── test.gold │ └── test.pred ├── csharp │ ├── test.gold │ └── test.pred ├── java │ ├── test.gold │ └── test.pred ├── javascript │ ├── test.gold │ └── test.pred └── python │ ├── test.gold │ └── test.pred ├── run.py ├── run.sh ├── run_small.sh └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | *.bin 9 | events.out.* 10 | saved_model/* 11 | *.tar.gz 12 | dataset/* 13 | # -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RACE: Retrieval-Augmented Commit Message Generation 2 | 3 | 4 | ## The architecture Model 5 | We propose a novel model RACE, which retrieves a similar commit message as an exemplar, guides the neural model to learn the content of the code diff and the intent behind the code diff, and generates the readable and informative commit message. 6 | Specifically, our model includes two modules: retrieval module and generation module. Specifically, RACE firstly retrieves the most semantically similar code diff paired with the commit message from the large parallel training corpus. The semantic similarity between two code diffs is measured by the cosine similarity of vectors obtained by a code diff encoder. Next, RACE treats the retrieved commit message as an example and uses it to guide the neural network to generate an understandable and concise commit message. 7 | 8 | 9 | ![1](./images/race.png) 10 | 11 | ## Environment 12 | 13 | ``` 14 | conda create -n RACE python=3.6 -y 15 | conda activate RACE 16 | pip install torch==1.10 transformers==4.12.5 tqdm==4.64.1 prettytable==2.5.0 gdown==4.5.1 more-itertools==8.14.0 tensorboardX==2.5.1 setuptools==59.5.0 tensorboard== 2.10.1 17 | ``` 18 | ## Dataset 19 | 20 | 21 | 22 | The dataset MCMD has five programming languages (PL): Java, C#, Cpp, Python and JavaScript. The dataset can be downloaded [here](https://zenodo.org/record/7196966#.Y0juJHZBxmM). More info about MCMD can be found [here](https://github.com/DeepSoftwareAnalytics/CommitMsgEmpirical/tree/main/dataset). We use the filtered dataset in our work. 23 | 24 | **Statistics of dataset** 25 | 26 | | language | Training | Valid | Test | 27 | | :--------- | :------: | :----: | :----: | 28 | | Java | 160,018 |19,825|20,159| 29 | | C#| 149,907 |18,688 |18,702| 30 | | Cpp | 160,948 |20,000 |20,141| 31 | | Python | 206,777 |25,912 |25,837 | 32 | | JavaScript | 3197,529 |24,899 |24,773| 33 | 34 | Use the following commands to download and unzip the downloaded dataset. 35 | ``` 36 | wget https://zenodo.org/record/7196966/files/dataset.tar.gz 37 | tar zxvf dataset.tar.gz 38 | ``` 39 | It will take about 1 min. 40 | 41 | * The orginal data is saved in `dataset/java/`. 42 | * The processed data is saved in `dataset/java/contextual_medits/`. 43 | * The retrievae data is saved in `dataset/java/contextual_medits/codet5_retrieval_result`. 44 | 45 | 46 | ## Training 47 | 48 | ``` 49 | language=java 50 | bash run.sh $language 51 | ``` 52 | 53 | ## Evaluation 54 | ``` 55 | python evaluate.py --refs_filename [The path of the reference file] --preds_filename [The path of the predicted file] 56 | ``` 57 | For example, 58 | ``` 59 | lang=javascript 60 | python evaluate.py --refs_filename results/${lang}/test.gold --preds_filename results/${lang}/test.pred 61 | 62 | ``` 63 | Output 64 | ``` 65 | BLEU: 25.66 66 | Meteor: 15.46 67 | Rouge-L: 32.02 68 | Cider: 1.76 69 | ``` 70 | 71 | ## Results 72 | 73 | | Language | Result Dir | 74 | | :--------- | :------: | 75 | | Java | `results/java/test.pred`| 76 | | C#| `results/csharp/test.pred`| 77 | | Cpp |`results/cpp/test.pred` | 78 | | Python | `results/python/test.pred` | 79 | | JavaScript | `results/javascript/test.pred` | 80 | 81 | 82 | -------------------------------------------------------------------------------- /appendix/Appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSoftwareAnalytics/RACE/cab39487ff94d6ebbfbfec3bb8821435767b38f4/appendix/Appendix.pdf -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # !-*-coding:utf-8 -*- 3 | import json 4 | import numpy as np 5 | import sys 6 | 7 | sys.path.append("metric") 8 | from metric.smooth_bleu import codenn_smooth_bleu 9 | from metric.meteor.meteor import Meteor 10 | from metric.rouge.rouge import Rouge 11 | from metric.cider.cider import Cider 12 | import warnings 13 | import argparse 14 | import logging 15 | import prettytable as pt 16 | 17 | warnings.filterwarnings('ignore') 18 | logging.basicConfig(format='[%(asctime)s - %(levelname)s - %(name)s] %(message)s', 19 | datefmt='%m/%d/%Y %H:%M:%S', 20 | level=logging.INFO) 21 | 22 | 23 | def Commitbleus(refs, preds): 24 | 25 | r_str_list = [] 26 | p_str_list = [] 27 | for r, p in zip(refs, preds): 28 | if len(r[0]) == 0 or len(p) == 0: 29 | continue 30 | r_str_list.append([" ".join([str(token_id) for token_id in r[0]])]) 31 | p_str_list.append(" ".join([str(token_id) for token_id in p])) 32 | try: 33 | bleu_list = codenn_smooth_bleu(r_str_list, p_str_list) 34 | except: 35 | bleu_list = [0, 0, 0, 0] 36 | codenn_bleu = bleu_list[0] 37 | 38 | B_Norm = round(codenn_bleu, 4) 39 | 40 | return B_Norm 41 | 42 | 43 | def read_to_list(filename): 44 | f = open(filename, 'r',encoding="utf-8") 45 | res = [] 46 | for row in f: 47 | # (rid, text) = row.split('\t') 48 | res.append(row.lower().split()) 49 | return res 50 | 51 | def metetor_rouge_cider(refs, preds): 52 | 53 | refs_dict = {} 54 | preds_dict = {} 55 | for i in range(len(preds)): 56 | preds_dict[i] = [" ".join(preds[i])] 57 | refs_dict[i] = [" ".join(refs[i][0])] 58 | 59 | score_Meteor, scores_Meteor = Meteor().compute_score(refs_dict, preds_dict) 60 | print("Meteor: ", round(score_Meteor*100,2)) 61 | 62 | score_Rouge, scores_Rouge = Rouge().compute_score(refs_dict, preds_dict) 63 | print("Rouge-L: ", round(score_Rouge*100,2)) 64 | 65 | score_Cider, scores_Cider = Cider().compute_score(refs_dict, preds_dict) 66 | print("Cider: ",round(score_Cider,2) ) 67 | 68 | 69 | 70 | def main(): 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--refs_filename', type=str, default="../saved_model/tlcodesum/UNLC/ref.txt", required=False) 73 | parser.add_argument('--preds_filename', type=str, default="../saved_model/tlcodesum/UNLC/dlen500-clen30-dvoc30000-cvoc30000-bs-ddim64-cdim-rhs64-lr0_Medit_pred.txt", required=False) 74 | args = parser.parse_args() 75 | refs = read_to_list(args.refs_filename) 76 | refs = [[t] for t in refs] 77 | preds = read_to_list(args.preds_filename) 78 | bleus_score = Commitbleus(refs, preds) 79 | print("BLEU: %.2f"%bleus_score) 80 | metetor_rouge_cider(refs, preds) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /images/race.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSoftwareAnalytics/RACE/cab39487ff94d6ebbfbfec3bb8821435767b38f4/images/race.png -------------------------------------------------------------------------------- /metric/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | # import sys 10 | # sys.path.append("./") 11 | 12 | from cider.cider_scorer import CiderScorer 13 | 14 | 15 | import pdb 16 | 17 | class Cider: 18 | """ 19 | Main Class to compute the CIDEr metric 20 | 21 | """ 22 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 23 | # set cider to sum over 1 to 4-grams 24 | self._n = n 25 | # set the standard deviation parameter for gaussian penalty 26 | self._sigma = sigma 27 | 28 | def compute_score(self, gts, res): 29 | """ 30 | Main function to compute CIDEr score 31 | :param hypo_for_image (dict) : dictionary with key and value 32 | ref_for_image (dict) : dictionary with key and value 33 | :return: cider (float) : computed CIDEr score for the corpus 34 | """ 35 | 36 | assert(gts.keys() == res.keys()) 37 | imgIds = gts.keys() 38 | 39 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 40 | 41 | for id in imgIds: 42 | hypo = res[id] 43 | ref = gts[id] 44 | 45 | # Sanity check. 46 | assert(type(hypo) is list) 47 | assert(len(hypo) == 1) 48 | assert(type(ref) is list) 49 | assert(len(ref) > 0) 50 | 51 | cider_scorer += (hypo[0], ref) 52 | 53 | (score, scores) = cider_scorer.compute_score() 54 | 55 | return score, scores 56 | 57 | def method(self): 58 | return "CIDEr" 59 | 60 | 61 | if __name__ == '__main__': 62 | predict, ground_truth = {}, {} 63 | predict[1] = ["I love you "] 64 | ground_truth[1] = ["I love you ", "I like you "] 65 | predict[2] = ["he love you "] 66 | ground_truth[2] = ["he love you "] 67 | predict[3] = ["she love you "] 68 | ground_truth[3] = ["she love you"] 69 | score_Cider, scores_Cider = Cider().compute_score(ground_truth, predict) 70 | print(score_Cider, scores_Cider) -------------------------------------------------------------------------------- /metric/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import pdb 9 | import math 10 | 11 | def precook(s, n=4, out=False): 12 | """ 13 | Takes a string as input and returns an object that can be given to 14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 15 | can take string arguments as well. 16 | :param s: string : sentence to be converted into ngrams 17 | :param n: int : number of ngrams for which representation is calculated 18 | :return: term frequency vector for occuring ngrams 19 | """ 20 | words = s.split() 21 | counts = defaultdict(int) 22 | for k in range(1,n+1): 23 | for i in range(len(words)-k+1): 24 | ngram = tuple(words[i:i+k]) 25 | counts[ngram] += 1 26 | return counts 27 | 28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 29 | '''Takes a list of reference sentences for a single segment 30 | and returns an object that encapsulates everything that BLEU 31 | needs to know about them. 32 | :param refs: list of string : reference sentences for some image 33 | :param n: int : number of ngrams for which (ngram) representation is calculated 34 | :return: result (list of dict) 35 | ''' 36 | return [precook(ref, n) for ref in refs] 37 | 38 | def cook_test(test, n=4): 39 | '''Takes a test sentence and returns an object that 40 | encapsulates everything that BLEU needs to know about it. 41 | :param test: list of string : hypothesis sentence for some image 42 | :param n: int : number of ngrams for which (ngram) representation is calculated 43 | :return: result (dict) 44 | ''' 45 | return precook(test, n, True) 46 | 47 | class CiderScorer(object): 48 | """CIDEr scorer. 49 | """ 50 | 51 | def copy(self): 52 | ''' copy the refs.''' 53 | new = CiderScorer(n=self.n) 54 | new.ctest = copy.copy(self.ctest) 55 | new.crefs = copy.copy(self.crefs) 56 | return new 57 | 58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 59 | ''' singular instance ''' 60 | self.n = n 61 | self.sigma = sigma 62 | self.crefs = [] 63 | self.ctest = [] 64 | self.document_frequency = defaultdict(float) 65 | self.cook_append(test, refs) 66 | self.ref_len = None 67 | 68 | def cook_append(self, test, refs): 69 | '''called by constructor and __iadd__ to avoid creating new instances.''' 70 | 71 | if refs is not None: 72 | self.crefs.append(cook_refs(refs)) 73 | if test is not None: 74 | self.ctest.append(cook_test(test)) ## N.B.: -1 75 | else: 76 | self.ctest.append(None) # lens of crefs and ctest have to match 77 | 78 | def size(self): 79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 80 | return len(self.crefs) 81 | 82 | def __iadd__(self, other): 83 | '''add an instance (e.g., from another sentence).''' 84 | 85 | if type(other) is tuple: 86 | ## avoid creating new CiderScorer instances 87 | self.cook_append(other[0], other[1]) 88 | else: 89 | self.ctest.extend(other.ctest) 90 | self.crefs.extend(other.crefs) 91 | 92 | return self 93 | def compute_doc_freq(self): 94 | ''' 95 | Compute term frequency for reference data. 96 | This will be used to compute idf (inverse document frequency later) 97 | The term frequency is stored in the object 98 | :return: None 99 | ''' 100 | for refs in self.crefs: 101 | # refs, k ref captions of one image 102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 103 | self.document_frequency[ngram] += 1 104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 105 | 106 | def compute_cider(self): 107 | def counts2vec(cnts): 108 | """ 109 | Function maps counts of ngram to vector of tfidf weights. 110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 111 | The n-th entry of array denotes length of n-grams. 112 | :param cnts: 113 | :return: vec (array of dict), norm (array of float), length (int) 114 | """ 115 | vec = [defaultdict(float) for _ in range(self.n)] 116 | length = 0 117 | norm = [0.0 for _ in range(self.n)] 118 | for (ngram,term_freq) in cnts.items(): 119 | # give word count 1 if it doesn't appear in reference corpus 120 | df = np.log(max(1.0, self.document_frequency[ngram])) 121 | # ngram index 122 | n = len(ngram)-1 123 | # tf (term_freq) * idf (precomputed idf) for n-grams 124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 125 | # compute norm for the vector. the norm will be used for computing similarity 126 | norm[n] += pow(vec[n][ngram], 2) 127 | 128 | if n == 1: 129 | length += term_freq 130 | norm = [np.sqrt(n) for n in norm] 131 | return vec, norm, length 132 | 133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 134 | ''' 135 | Compute the cosine similarity of two vectors. 136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 137 | :param vec_ref: array of dictionary for vector corresponding to reference 138 | :param norm_hyp: array of float for vector corresponding to hypothesis 139 | :param norm_ref: array of float for vector corresponding to reference 140 | :param length_hyp: int containing length of hypothesis 141 | :param length_ref: int containing length of reference 142 | :return: array of score for each n-grams cosine similarity 143 | ''' 144 | delta = float(length_hyp - length_ref) 145 | # measure consine similarity 146 | val = np.array([0.0 for _ in range(self.n)]) 147 | for n in range(self.n): 148 | # ngram 149 | for (ngram,count) in vec_hyp[n].items(): 150 | # vrama91 : added clipping 151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 152 | 153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 154 | val[n] /= (norm_hyp[n]*norm_ref[n]) 155 | 156 | assert(not math.isnan(val[n])) 157 | # vrama91: added a length based gaussian penalty 158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 159 | return val 160 | 161 | # compute log reference length 162 | self.ref_len = np.log(float(len(self.crefs))) 163 | 164 | scores = [] 165 | for test, refs in zip(self.ctest, self.crefs): 166 | # compute vector for test captions 167 | vec, norm, length = counts2vec(test) 168 | # compute vector for ref captions 169 | score = np.array([0.0 for _ in range(self.n)]) 170 | for ref in refs: 171 | vec_ref, norm_ref, length_ref = counts2vec(ref) 172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 173 | # change by vrama91 - mean of ngram scores, instead of sum 174 | score_avg = np.mean(score) 175 | # divide by number of references 176 | score_avg /= len(refs) 177 | # multiply score by 10 178 | score_avg *= 10.0 179 | # append score of an image to the score list 180 | scores.append(score_avg) 181 | return scores 182 | 183 | def compute_score(self, option=None, verbose=0): 184 | # compute idf 185 | self.compute_doc_freq() 186 | # assert to check document frequency 187 | assert(len(self.ctest) >= max(self.document_frequency.values())) 188 | # compute cider score 189 | score = self.compute_cider() 190 | # debug 191 | # print score 192 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /metric/meteor/data/paraphrase-en.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSoftwareAnalytics/RACE/cab39487ff94d6ebbfbfec3bb8821435767b38f4/metric/meteor/data/paraphrase-en.gz -------------------------------------------------------------------------------- /metric/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSoftwareAnalytics/RACE/cab39487ff94d6ebbfbfec3bb8821435767b38f4/metric/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /metric/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import threading 10 | 11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 12 | METEOR_JAR = 'meteor-1.5.jar' 13 | # print METEOR_JAR 14 | 15 | class Meteor: 16 | 17 | def __init__(self): 18 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 19 | '-', '-', '-stdio', '-l', 'en', '-norm'] 20 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 21 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 22 | stdin=subprocess.PIPE, \ 23 | stdout=subprocess.PIPE, \ 24 | stderr=subprocess.PIPE) 25 | # Used to guarantee thread safety 26 | self.lock = threading.Lock() 27 | 28 | def compute_score(self, gts, res): 29 | assert(gts.keys() == res.keys()) 30 | imgIds = gts.keys() 31 | scores = [] 32 | 33 | eval_line = 'EVAL' 34 | self.lock.acquire() 35 | for i in imgIds: 36 | assert(len(res[i]) == 1) 37 | stat = self._stat(res[i][0], gts[i]) 38 | eval_line += ' ||| {}'.format(stat) 39 | 40 | self.meteor_p.stdin.write('{}\n'.format(eval_line).encode()) 41 | self.meteor_p.stdin.flush() 42 | for i in range(0,len(imgIds)): 43 | scores.append(float(self.meteor_p.stdout.readline().strip())) 44 | score = float(self.meteor_p.stdout.readline().strip()) 45 | self.lock.release() 46 | 47 | return score, scores 48 | 49 | def method(self): 50 | return "METEOR" 51 | 52 | def _stat(self, hypothesis_str, reference_list): 53 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 54 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 55 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 56 | self.meteor_p.stdin.write('{}\n'.format(score_line).encode()) 57 | self.meteor_p.stdin.flush() 58 | return self.meteor_p.stdout.readline().decode().strip() 59 | 60 | def _score(self, hypothesis_str, reference_list): 61 | self.lock.acquire() 62 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 63 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 64 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 65 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 66 | stats = self.meteor_p.stdout.readline().strip() 67 | eval_line = 'EVAL ||| {}'.format(stats) 68 | # EVAL ||| stats 69 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 70 | score = float(self.meteor_p.stdout.readline().strip()) 71 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 72 | # thanks for Andrej for pointing this out 73 | score = float(self.meteor_p.stdout.readline().strip()) 74 | self.lock.release() 75 | return score 76 | 77 | def __del__(self): 78 | self.lock.acquire() 79 | self.meteor_p.stdin.close() 80 | self.meteor_p.kill() 81 | self.meteor_p.wait() 82 | self.lock.release() 83 | 84 | if __name__ == '__main__': 85 | predict, ground_truth = {}, {} 86 | predict[1] = ["I am enshi"] 87 | # ground_truth[1] = ["I am enshi", "my name is enshi"] 88 | ground_truth[1] = ["I am enshi"] 89 | predict[2] = ["I am enshi"] 90 | ground_truth[2] = ["I was enshi"] 91 | predict[3] = ["I am enshi"] 92 | ground_truth[3] = ["I am Tom"] 93 | 94 | 95 | score_Meteor, scores_Meteor = Meteor().compute_score(ground_truth, predict) 96 | print("Meteor: ", score_Meteor) 97 | -------------------------------------------------------------------------------- /metric/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate)==1) 53 | assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) == 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | return average_score, np.array(score) 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /metric/smooth_bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | ''' 4 | This script was adapted from the original version by hieuhoang1972 which is part of MOSES. 5 | ''' 6 | 7 | # $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $ 8 | 9 | '''Provides: 10 | 11 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 12 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 13 | score_cooked(alltest, n=4): Score a list of cooked test sentences. 14 | 15 | score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids. 16 | 17 | The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible. 18 | ''' 19 | 20 | import sys, math, re, xml.sax.saxutils 21 | import subprocess 22 | import os 23 | 24 | # Added to bypass NIST-style pre-processing of hyp and ref files -- wade 25 | nonorm = 0 26 | 27 | preserve_case = False 28 | eff_ref_len = "shortest" 29 | 30 | normalize1 = [ 31 | ('', ''), # strip "skipped" tags 32 | (r'-\n', ''), # strip end-of-line hyphenation and join lines 33 | (r'\n', ' '), # join lines 34 | # (r'(\d)\s+(?=\d)', r'\1'), # join digits 35 | ] 36 | normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1] 37 | 38 | normalize2 = [ 39 | (r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', r' \1 '), # tokenize punctuation. apostrophe is missing 40 | (r'([^0-9])([\.,])', r'\1 \2 '), # tokenize period and comma unless preceded by a digit 41 | (r'([\.,])([^0-9])', r' \1 \2'), # tokenize period and comma unless followed by a digit 42 | (r'([0-9])(-)', r'\1 \2 ') # tokenize dash when preceded by a digit 43 | ] 44 | normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2] 45 | 46 | 47 | def normalize(s): 48 | '''Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl.''' 49 | # Added to bypass NIST-style pre-processing of hyp and ref files -- wade 50 | if (nonorm): 51 | return s.split() 52 | if type(s) is not str: 53 | s = " ".join(s) 54 | # language-independent part: 55 | for (pattern, replace) in normalize1: 56 | s = re.sub(pattern, replace, s) 57 | s = xml.sax.saxutils.unescape(s, {'"': '"'}) 58 | # language-dependent part (assuming Western languages): 59 | s = " %s " % s 60 | if not preserve_case: 61 | s = s.lower() # this might not be identical to the original 62 | for (pattern, replace) in normalize2: 63 | s = re.sub(pattern, replace, s) 64 | return s.split() 65 | 66 | 67 | def count_ngrams(words, n=4): 68 | counts = {} 69 | for k in range(1, n + 1): 70 | for i in range(len(words) - k + 1): 71 | ngram = tuple(words[i:i + k]) 72 | counts[ngram] = counts.get(ngram, 0) + 1 73 | return counts 74 | 75 | 76 | def cook_refs(refs, n=4): 77 | '''Takes a list of reference sentences for a single segment 78 | and returns an object that encapsulates everything that BLEU 79 | needs to know about them.''' 80 | 81 | refs = [normalize(ref) for ref in refs] 82 | maxcounts = {} 83 | for ref in refs: 84 | counts = count_ngrams(ref, n) 85 | for (ngram, count) in counts.items(): 86 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) 87 | return ([len(ref) for ref in refs], maxcounts) 88 | 89 | 90 | def cook_test(test, item, n=4): 91 | '''Takes a test sentence and returns an object that 92 | encapsulates everything that BLEU needs to know about it.''' 93 | (reflens, refmaxcounts) = item 94 | test = normalize(test) 95 | result = {} 96 | result["testlen"] = len(test) 97 | 98 | # Calculate effective reference sentence length. 99 | 100 | if eff_ref_len == "shortest": 101 | result["reflen"] = min(reflens) 102 | elif eff_ref_len == "average": 103 | result["reflen"] = float(sum(reflens)) / len(reflens) 104 | elif eff_ref_len == "closest": 105 | min_diff = None 106 | for reflen in reflens: 107 | if min_diff is None or abs(reflen - len(test)) < min_diff: 108 | min_diff = abs(reflen - len(test)) 109 | result['reflen'] = reflen 110 | 111 | result["guess"] = [max(len(test) - k + 1, 0) for k in range(1, n + 1)] 112 | 113 | result['correct'] = [0] * n 114 | counts = count_ngrams(test, n) 115 | for (ngram, count) in counts.items(): 116 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) 117 | 118 | return result 119 | 120 | 121 | def score_cooked(allcomps, n=4, ground=0, smooth=1): 122 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n} 123 | for comps in allcomps: 124 | for key in ['testlen', 'reflen']: 125 | totalcomps[key] += comps[key] 126 | for key in ['guess', 'correct']: 127 | for k in range(n): 128 | totalcomps[key][k] += comps[key][k] 129 | logbleu = 0.0 130 | all_bleus = [] 131 | for k in range(n): 132 | correct = totalcomps['correct'][k] 133 | guess = totalcomps['guess'][k] 134 | addsmooth = 0 135 | if smooth == 1 and k > 0: 136 | addsmooth = 1 137 | logbleu += math.log(correct + addsmooth + sys.float_info.min) - math.log(guess + addsmooth + sys.float_info.min) 138 | if guess == 0: 139 | all_bleus.append(-10000000) 140 | else: 141 | all_bleus.append(math.log(correct + sys.float_info.min) - math.log(guess)) 142 | 143 | logbleu /= float(n) 144 | all_bleus.insert(0, logbleu) 145 | 146 | brevPenalty = min(0, 1 - float(totalcomps['reflen'] + 1) / (totalcomps['testlen'] + 1)) 147 | for i in range(len(all_bleus)): 148 | if i == 0: 149 | all_bleus[i] += brevPenalty 150 | all_bleus[i] = math.exp(all_bleus[i]) 151 | return all_bleus 152 | 153 | 154 | def bleu(refs, candidate, ground=0, smooth=1): 155 | refs = cook_refs(refs) 156 | test = cook_test(candidate, refs) 157 | return score_cooked([test], ground=ground, smooth=smooth) 158 | 159 | 160 | def splitPuncts(line): 161 | return ' '.join(re.findall(r"[\w]+|[^\s\w]", line)) 162 | 163 | 164 | def computeMaps(predictions, goldfile): 165 | predictionMap = {} 166 | goldMap = {} 167 | gf = open(goldfile, 'r') 168 | 169 | for row in predictions: 170 | cols = row.strip().split('\t') 171 | if len(cols) == 1: 172 | (rid, pred) = (cols[0], '') 173 | else: 174 | (rid, pred) = (cols[0], cols[1]) 175 | predictionMap[rid] = [splitPuncts(pred.strip().lower())] 176 | 177 | for row in gf: 178 | (rid, pred) = row.split('\t') 179 | if rid in predictionMap: # Only insert if the id exists for the method 180 | if rid not in goldMap: 181 | goldMap[rid] = [] 182 | goldMap[rid].append(splitPuncts(pred.strip().lower())) 183 | 184 | sys.stderr.write('Total: ' + str(len(goldMap)) + '\n') 185 | return (goldMap, predictionMap) 186 | 187 | 188 | # m1 is the reference map 189 | # m2 is the prediction map 190 | def bleuFromMaps(m1, m2): 191 | score = [0] * 5 192 | num = 0.0 193 | 194 | for key in m1: 195 | if key in m2: 196 | bl = bleu(m1[key], m2[key][0]) 197 | score = [score[i] + bl[i] for i in range(0, len(bl))] 198 | num += 1 199 | return [s * 100.0 / num for s in score] 200 | 201 | 202 | if __name__ == '__main__': 203 | reference_file = sys.argv[1] 204 | predictions = [] 205 | for row in sys.stdin: 206 | predictions.append(row) 207 | (goldMap, predictionMap) = computeMaps(predictions, reference_file) 208 | print(bleuFromMaps(goldMap, predictionMap)[0]) 209 | 210 | def codenn_smooth_bleu(groundtruth, prediction): 211 | """ 212 | 213 | :param groundtruth: list of list containing one sentence 214 | :param prediction: list of sentences 215 | :return: 216 | """ 217 | score = [0] * 5 218 | num = 0.0 219 | 220 | for g, p in zip(groundtruth, prediction): 221 | # g is a list of str 222 | # p is a str 223 | bl = bleu(g, p) 224 | score = [score[i] + bl[i] for i in range(0, len(bl))] 225 | num += 1 226 | return [s * 100.0 / num for s in score] -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | import copy 11 | import sys 12 | # from models.RvNNRvNNASTCodeAttn import BatchASTEncoder 13 | import logging 14 | from util import REPLACE, REPLACE_OLD, REPLACE_NEW,REPLACE_END,INSERT,INSERT_OLD,INSERT_NEW ,INSERT_END,DELETE,DELETE_END,KEEP,KEEP_END 15 | from transformers import T5ForConditionalGeneration 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | 20 | 21 | class ECMGModel(T5ForConditionalGeneration): 22 | def __init__(self, base_model,config,args=None,sos_id=None, eos_id=None): 23 | super().__init__(config) 24 | # self.base_model = base_model 25 | self.encoder = base_model.encoder 26 | self.decoder = base_model.decoder 27 | self.lm_head = base_model.lm_head 28 | self.pooler = nn.Sequential(nn.Linear(config.d_model,config.d_model ), nn.Tanh(), nn.Dropout(0.5) ) 29 | self.W_sim = nn.Linear(2 * config.d_model, 1) 30 | self.W_c = nn.Linear(config.d_model, config.d_model) 31 | self.args=args 32 | self.register_buffer("bias", torch.tril(torch.ones(2048, 2048))) 33 | 34 | self.beam_size=args.beam_size 35 | self.max_length=args.max_target_length 36 | self.sos_id=sos_id 37 | self.eos_id=eos_id 38 | 39 | self.lsm = nn.LogSoftmax(dim=-1) 40 | 41 | def forward(self, source_ids=None, source_mask=None, target_ids=None, target_mask=None, args=None, 42 | retrieved_source_ids=None, retrieved_source_mask=None, retrieved_target_ids=None, retrieved_target_mask=None,use_cache=None,return_dict=None 43 | ): 44 | use_cache = use_cache if use_cache is not None else self.config.use_cache 45 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 46 | encoder_outputs_of_retrieved_target_ids = self.encoder(input_ids=retrieved_target_ids ,attention_mask=retrieved_target_mask).last_hidden_state 47 | 48 | bs = source_ids.shape[0] 49 | inputs = torch.cat((source_ids , retrieved_source_ids, ), 0) 50 | inputs_mask = torch.cat((source_mask , retrieved_source_mask), 0) 51 | # outputs 是个二元组 [0]--> [bs, sequence_len, dim]; [1] --> [bs, dim] 52 | encoder_outputs = self.encoder(input_ids=inputs, attention_mask=inputs_mask) 53 | outputs = encoder_outputs.last_hidden_state # [bs*2,seq_len, dim] 54 | encoder_outputs_of_input_source_ids = outputs[:bs] # [bs,seq_len,dim] 55 | encoder_outputs_of_retrieved_source_ids = outputs[bs:] # [bs,seq_len,dim] 56 | 57 | input_code_representation = encoder_outputs_of_input_source_ids.mean(1) # [bs, model_dim] 58 | similar_code_representation = encoder_outputs_of_retrieved_source_ids.mean(1) # [bs, model_dim] 59 | cat_two_code_outputs = torch.cat((input_code_representation, similar_code_representation), dim=-1) 60 | sim = F.sigmoid(self.W_sim(cat_two_code_outputs)) # [batch_size, 1, 1] 61 | sim = sim.reshape(-1, 1, 1) # [batch_size, 1, 1] 62 | 63 | # combine the input and retrieved result 64 | combined_encoder_output = torch.cat((self.W_c(encoder_outputs_of_input_source_ids) * (1 - sim) , encoder_outputs_of_retrieved_target_ids* sim),dim=1) 65 | combined_encoder_mask = torch.cat((source_mask,retrieved_target_mask),dim=1) 66 | 67 | if target_ids is not None: 68 | decoder_outputs = self.decoder( 69 | input_ids=target_ids , # [bs, length] 70 | attention_mask=target_mask, # [bs, length] 71 | inputs_embeds=None, 72 | past_key_values=None, 73 | encoder_hidden_states=combined_encoder_output, 74 | encoder_attention_mask=combined_encoder_mask, 75 | head_mask=None, 76 | cross_attn_head_mask=None, 77 | use_cache=use_cache, 78 | output_attentions=None, 79 | output_hidden_states=None, 80 | return_dict=return_dict, 81 | ) 82 | sequence_output = decoder_outputs[0] # [bs, length, dim] 83 | lm_logits = self.lm_head(sequence_output) 84 | 85 | active_loss = target_mask[..., 1:].ne(0).view(-1) == 1 #[bs * (seq-1)] 86 | shift_logits = lm_logits[..., :-1, :].contiguous() #[bs, seq_length-1,vocab_size] 87 | shift_labels = target_ids[..., 1:].contiguous() #[bs * (seq-1)] 88 | 89 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 90 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss], 91 | shift_labels.view(-1)[active_loss]) 92 | 93 | outputs = loss,loss*active_loss.sum(),active_loss.sum() 94 | return outputs 95 | # return outputs 96 | else: 97 | preds=[] 98 | zero=torch.cuda.LongTensor(1).fill_(0) 99 | for i in range(source_ids.shape[0]): 100 | context=combined_encoder_output[i:i+1] # [1,seq_len, dim ] 101 | context_mask=combined_encoder_mask[i:i+1,:] 102 | beam = Beam(self.beam_size,self.sos_id,self.eos_id) 103 | input_ids=beam.getCurrentState() # [bs , 1 ] 104 | context=context.repeat(self.beam_size, 1, 1) # [beam_size,seq_len, dim ] 105 | context_mask=context_mask.repeat(self.beam_size,1) # [beam_size, seq_len] 106 | for _ in range(self.max_length): 107 | if beam.done(): 108 | break 109 | decoder_outputs = self.decoder( 110 | input_ids=input_ids, # [bs, length] 111 | attention_mask=None, # [bs, length] 112 | inputs_embeds=None, 113 | past_key_values=None, 114 | encoder_hidden_states=context, 115 | encoder_attention_mask=context_mask, 116 | head_mask=None, 117 | cross_attn_head_mask=None, 118 | use_cache=use_cache, 119 | output_attentions=None, 120 | output_hidden_states=None, 121 | return_dict=return_dict, 122 | ) 123 | hidden_states=decoder_outputs[0][:,-1,:] #[beam_size, dim] 124 | out = self.lsm(self.lm_head(hidden_states)).data #[beam_size, vocab_size] 125 | beam.advance(out) 126 | #https://blog.csdn.net/kdongyi/article/details/103099589 127 | # copy the choose beam and expand it 128 | input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin())) 129 | input_ids=torch.cat((input_ids,beam.getCurrentState()),-1) 130 | hyp= beam.getHyp(beam.getFinal()) 131 | pred=beam.buildTargetTokens(hyp)[:self.beam_size] 132 | pred=[torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred] 133 | preds.append(torch.cat(pred,0).unsqueeze(0)) 134 | 135 | preds=torch.cat(preds,0) 136 | return preds 137 | 138 | 139 | 140 | 141 | class Beam(object): 142 | def __init__(self, size,sos,eos): 143 | self.size = size 144 | self.tt = torch.cuda 145 | # The score for each translation on the beam. 146 | self.scores = self.tt.FloatTensor(size).zero_() 147 | # The backpointers at each time-step. 148 | self.prevKs = [] 149 | # The outputs at each time-step. 150 | self.nextYs = [self.tt.LongTensor(size) 151 | .fill_(0)] 152 | self.nextYs[0][0] = sos # sos 1, eos 2 153 | # Has EOS topped the beam yet. 154 | self._eos = eos 155 | self.eosTop = False 156 | # Time and k pair for finished. 157 | self.finished = [] 158 | 159 | def getCurrentState(self): 160 | "Get the outputs for the current timestep." 161 | batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) 162 | return batch 163 | 164 | def getCurrentOrigin(self): 165 | "Get the backpointers for the current timestep." 166 | return self.prevKs[-1] 167 | 168 | def advance(self, wordLk): 169 | """ 170 | Given prob over words for every last beam `wordLk` and attention 171 | `attnOut`: Compute and update the beam search. 172 | 173 | Parameters: 174 | 175 | * `wordLk`- probs of advancing from the last step (K x words) 176 | * `attnOut`- attention at the last step 177 | 178 | Returns: True if beam search is complete. 179 | """ 180 | numWords = wordLk.size(1) # wordLk [beam_size. vocab_size] 181 | 182 | # Sum the previous scores. 183 | if len(self.prevKs) > 0: 184 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 185 | 186 | # Don't let EOS have children. 187 | for i in range(self.nextYs[-1].size(0)): 188 | if self.nextYs[-1][i] == self._eos: 189 | beamLk[i] = -1e20 190 | else: 191 | beamLk = wordLk[0] 192 | flatBeamLk = beamLk.view(-1) 193 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) # beam size 194 | 195 | self.scores = bestScores 196 | 197 | # bestScoresId is flattened beam x word array, so calculate which 198 | # word and beam each score came from 199 | prevK = bestScoresId // numWords # divide and floor 200 | self.prevKs.append(prevK) # which beam 201 | self.nextYs.append((bestScoresId - prevK * numWords)) # which word 202 | 203 | 204 | for i in range(self.nextYs[-1].size(0)): 205 | if self.nextYs[-1][i] == self._eos: 206 | s = self.scores[i] 207 | self.finished.append((s, len(self.nextYs) - 1, i)) 208 | 209 | # End condition is when top-of-beam is EOS and no global score. 210 | if self.nextYs[-1][0] == self._eos: 211 | self.eosTop = True 212 | 213 | def done(self): 214 | return self.eosTop and len(self.finished) >=self.size 215 | 216 | def getFinal(self): 217 | if len(self.finished) == 0: 218 | self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) 219 | self.finished.sort(key=lambda a: -a[0]) 220 | if len(self.finished) != self.size: 221 | unfinished=[] 222 | for i in range(self.nextYs[-1].size(0)): 223 | if self.nextYs[-1][i] != self._eos: 224 | s = self.scores[i] 225 | unfinished.append((s, len(self.nextYs) - 1, i)) 226 | unfinished.sort(key=lambda a: -a[0]) 227 | self.finished+=unfinished[:self.size-len(self.finished)] 228 | return self.finished[:self.size] 229 | 230 | def getHyp(self, beam_res): 231 | """ 232 | Walk back to construct the full hypothesis. 233 | """ 234 | hyps=[] 235 | for _,timestep, k in beam_res: 236 | hyp = [] 237 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): 238 | hyp.append(self.nextYs[j+1][k]) 239 | k = self.prevKs[j][k] 240 | hyps.append(hyp[::-1]) 241 | return hyps 242 | 243 | def buildTargetTokens(self, preds): 244 | sentence=[] 245 | for pred in preds: 246 | tokens = [] 247 | for tok in pred: 248 | if tok==self._eos: 249 | break 250 | tokens.append(tok) 251 | sentence.append(tokens) 252 | return sentence 253 | 254 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). 18 | GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned 19 | using a masked language modeling (MLM) loss. 20 | """ 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | import os 24 | import logging 25 | import argparse 26 | import math 27 | import numpy as np 28 | # from tqdm import tqdm 29 | import sys 30 | import multiprocessing 31 | import time 32 | if sys.stderr.isatty(): 33 | from tqdm import tqdm 34 | else: 35 | def tqdm(iterable, **kwargs): 36 | return iterable 37 | import torch 38 | from torch.utils.tensorboard import SummaryWriter 39 | from torch.utils.data import DataLoader, SequentialSampler, RandomSampler 40 | from torch.utils.data.distributed import DistributedSampler 41 | from transformers import AdamW, get_linear_schedule_with_warmup 42 | from model import ECMGModel 43 | from metric import smooth_bleu 44 | import random 45 | from util import get_elapse_time, load_and_cache_commit_data, load_and_commit_data_with_retrieved_result, save_json_data 46 | 47 | 48 | from transformers import (RobertaConfig, RobertaModel, RobertaTokenizer, 49 | BartConfig, BartForConditionalGeneration, BartTokenizer, 50 | T5Config, T5ForConditionalGeneration, T5Tokenizer) 51 | MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer), 52 | 't5': (T5Config, T5ForConditionalGeneration, T5Tokenizer), 53 | 'codet5': (T5Config, T5ForConditionalGeneration, RobertaTokenizer), 54 | 'bart': (BartConfig, BartForConditionalGeneration, BartTokenizer)} 55 | 56 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 57 | datefmt='%m/%d/%Y %H:%M:%S', 58 | level=logging.INFO) 59 | logger = logging.getLogger(__name__) 60 | 61 | 62 | sys.path.append("data/commit_msg") 63 | from util import REPLACE, REPLACE_OLD, REPLACE_NEW,REPLACE_END,INSERT,INSERT_OLD,INSERT_NEW ,INSERT_END,DELETE,DELETE_END,KEEP,KEEP_END 64 | 65 | 66 | 67 | def eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, split_tag,beam_size=1): 68 | logger.info(" ***** Running bleu evaluation on {} data*****".format(split_tag)) 69 | logger.info(" Num examples = %d", len(eval_examples)) 70 | logger.info(" Batch size = %d", args.eval_batch_size) 71 | eval_sampler = SequentialSampler(eval_data) 72 | if args.data_num == -1: 73 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size, 74 | num_workers=4, pin_memory=True) 75 | else: 76 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 77 | 78 | model.eval() 79 | pred_ids = [] 80 | bleu, codebleu = 0.0, 0.0 81 | for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval bleu for {} set".format(split_tag)): 82 | source_ids = batch[0].to(args.device) 83 | source_mask = source_ids.ne(tokenizer.pad_token_id) 84 | 85 | with torch.no_grad(): 86 | # if args.model_type == 'roberta': 87 | # preds = model(source_ids=source_ids, source_mask=source_mask) 88 | 89 | # top_preds = [pred[0].cpu().numpy() for pred in preds] 90 | if hasattr(model, 'module'): 91 | preds = model.module.generate(source_ids, 92 | attention_mask=source_mask, 93 | use_cache=True, 94 | num_beams=beam_size, 95 | early_stopping=args.task == 'summarize', 96 | max_length=args.max_target_length) 97 | 98 | else: 99 | preds = model.generate(source_ids, 100 | attention_mask=source_mask, 101 | use_cache=True, 102 | num_beams=beam_size, 103 | early_stopping=args.task == 'summarize', 104 | max_length=args.max_target_length) 105 | top_preds = list(preds.cpu().numpy()) 106 | pred_ids.extend(top_preds) 107 | 108 | pred_nls = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in pred_ids] 109 | 110 | 111 | return pred_nls 112 | 113 | def parse_args(): 114 | parser = argparse.ArgumentParser() 115 | # debug 116 | parser.add_argument('--is_cosine_space', action='store_true', help='is_cosine_space', required=False) 117 | parser.add_argument('--debug', action='store_true', help='debug mode', required=False) 118 | parser.add_argument('--n_debug_samples', type=int, default=100, required=False) 119 | parser.add_argument("--eval_frequency", default=1, type=int, required=False) 120 | 121 | parser.add_argument("--ECMG_type", default="shared_encoders", choices=["shared_encoders"], type=str, required=False, help="the type of ECMG model") 122 | parser.add_argument("--base_model_type", default="codet5", choices=["codet5","codet5Siamese","ECMG"], type=str, required=False, help="the type of base model, like codet5, siamese network") 123 | parser.add_argument("--model_type", default="codet5", type=str, choices=['roberta', 'bart', 'codet5'], help="the type of pretrain model") 124 | 125 | parser.add_argument("--diff_type", type=str,default="contextual-medit", 126 | choices=['plain_diff',"old-plain-diff", "medit", "old-medit", "contextual-palin_diff", "contextual-medit"], 127 | help="plain_diff: only plain diff text; old-plain-diff: before and after three lines + plain diff \ 128 | medit :using medit to represent plain diff; \ 129 | old-medit: old verison + medit, oversion is the before and after three lines + old diff , \ 130 | contextual-palin_diff: .diff; contextual-medit: " ) 131 | 132 | parser.add_argument("--task", type=str,default="summarize", 133 | choices=['summarize', 'concode', 'translate', 'refine', 'defect', 'clone']) 134 | # parser.add_argument("--sub_task", type=str, default='') 135 | parser.add_argument("--lang", type=str, default='java') 136 | parser.add_argument("--eval_task", type=str, default='') 137 | parser.add_argument("--add_lang_ids", action='store_true') 138 | parser.add_argument("--data_num", default=-1, type=int, help="DATA_NUM == -1 means all data") 139 | parser.add_argument("--start_epoch", default=0, type=int) 140 | parser.add_argument("--num_train_epochs", default=10, type=int) 141 | parser.add_argument("--patience", default=5, type=int) 142 | # parser.add_argument("--tokenizer_path", default="tokenizer/salesforce", type=str) 143 | parser.add_argument("--cache_path", type=str, default="cache/codesum/java",required=False) 144 | parser.add_argument("--summary_dir", type=str, default="saved_model/codesum/tmp") 145 | # parser.add_argument("--data_dir", type=str, default="data/summarize/java") 146 | # parser.add_argument("--res_dir", type=str,default="saved_model/codesum/tmp") 147 | # parser.add_argument("--res_fn", type=str, default='') 148 | parser.add_argument("--add_task_prefix", action='store_true', help="Whether to add task prefix for t5 and codet5") 149 | parser.add_argument("--save_last_checkpoints", action='store_true') 150 | parser.add_argument("--always_save_model", action='store_true') 151 | parser.add_argument("--do_eval_bleu", action='store_true', help="Whether to evaluate bleu on dev set.") 152 | 153 | 154 | parser.add_argument("--model_name_or_path", default="Salesforce/codet5-base", type=str, 155 | help="Path to pre-trained model: e.g. roberta-base,Salesforce/codet5-small,Salesforce/codet5-base") 156 | parser.add_argument("--output_dir", type=str, default="../saved_model/commit_msg_generation/java/tmp", 157 | help="The output directory where the model predictions and checkpoints will be written.") 158 | parser.add_argument("--load_model_path", default=None, type=str, 159 | help="Path to trained model: Should contain the .bin files") 160 | parser.add_argument("--load_finetuned_model_path", default=None, type=str, 161 | help="Path to fine tuned trained model: Should contain the .bin files") 162 | ## Other parameters 163 | parser.add_argument("--train_filename", default="../data/commit_msg/java/contextual_medits/train.jsonl", type=str, 164 | help="The train filename. Should contain the .jsonl files for this task.") 165 | parser.add_argument("--dev_filename", default="../data/commit_msg/java/contextual_medits/valid.jsonl", type=str, 166 | help="The dev filename. Should contain the .jsonl files for this task.") 167 | parser.add_argument("--test_filename", default="../data/commit_msg/java/contextual_medits/test.jsonl", type=str, 168 | help="The test filename. Should contain the .jsonl files for this task.") 169 | parser.add_argument("--train_retireved_filename", default="../data/commit_msg/java/contextual_medits/train.jsonl", type=str, 170 | help="The train filename. Should contain the .jsonl files for this task.") 171 | parser.add_argument("--dev_retireved_filename", default="../data/commit_msg/java/contextual_medits/valid.jsonl", type=str, 172 | help="The dev filename. Should contain the .jsonl files for this task.") 173 | parser.add_argument("--test_retireved_filename", default="../data/commit_msg/java/contextual_medits/test.jsonl", type=str, 174 | help="The test filename. Should contain the .jsonl files for this task.") 175 | 176 | parser.add_argument("--config_name", default="", type=str, 177 | help="Pretrained config name or path if not the same as model_name") 178 | parser.add_argument("--tokenizer_name", default="Salesforce/codet5-small", type=str, 179 | help="Pretrained tokenizer name or path if not the same as model_name") 180 | parser.add_argument("--max_source_length", default=64, type=int, 181 | help="The maximum total source sequence length after tokenization. Sequences longer " 182 | "than this will be truncated, sequences shorter will be padded.") 183 | parser.add_argument("--max_target_length", default=32, type=int, 184 | help="The maximum total target sequence length after tokenization. Sequences longer " 185 | "than this will be truncated, sequences shorter will be padded.") 186 | 187 | parser.add_argument("--do_train", action='store_true', 188 | help="Whether to run eval on the train set.") 189 | parser.add_argument("--do_eval", action='store_true', 190 | help="Whether to run eval on the dev set.") 191 | parser.add_argument("--do_test", action='store_true', 192 | help="Whether to run eval on the dev set.") 193 | parser.add_argument("--do_lower_case", action='store_true', 194 | help="Set this flag if you are using an uncased model.") 195 | parser.add_argument("--no_cuda", action='store_true', 196 | help="Avoid using CUDA when available") 197 | 198 | parser.add_argument('--do_retrieval', action='store_true', help='retrieval mode', required=False) 199 | parser.add_argument('--run_codet5', action='store_true', help='run codet5', required=False) 200 | parser.add_argument("--retrieval_filename", default="../data/commit_msg/java/contextual_medits/train.jsonl", type=str, 201 | help="The test filename. Should contain the .jsonl files for this task.") 202 | parser.add_argument("--retrieval_result_dir", default="../data/commit_msg/java/tmp/codet5_retrieval_result", type=str, 203 | help="The test filename. Should contain the .jsonl files for this task.") 204 | parser.add_argument("--retrieval_result_filename", default="train.jsonl", type=str, 205 | help="The test filename. Should contain the .jsonl files for this task.") 206 | 207 | 208 | parser.add_argument("--train_batch_size", default=8, type=int, 209 | help="Batch size per GPU/CPU for training.") 210 | parser.add_argument("--eval_batch_size", default=8, type=int, 211 | help="Batch size per GPU/CPU for evaluation.") 212 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 213 | help="Number of updates steps to accumulate before performing a backward/update pass.") 214 | parser.add_argument("--learning_rate", default=5e-5, type=float, 215 | help="The initial learning rate for Adam.") 216 | parser.add_argument("--beam_size", default=10, type=int, 217 | help="beam size for beam search") 218 | parser.add_argument("--weight_decay", default=0.0, type=float, 219 | help="Weight deay if we apply some.") 220 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 221 | help="Epsilon for Adam optimizer.") 222 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 223 | help="Max gradient norm.") 224 | 225 | parser.add_argument("--save_steps", default=-1, type=int, ) 226 | parser.add_argument("--log_steps", default=-1, type=int, ) 227 | parser.add_argument("--max_steps", default=-1, type=int, 228 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 229 | parser.add_argument("--eval_steps", default=-1, type=int, 230 | help="") 231 | parser.add_argument("--train_steps", default=-1, type=int, 232 | help="") 233 | parser.add_argument("--warmup_steps", default=100, type=int, 234 | help="Linear warmup over warmup_steps.") 235 | parser.add_argument("--local_rank", type=int, default=-1, 236 | help="For distributed training: local_rank") 237 | parser.add_argument('--seed', type=int, default=3407, 238 | help="random seed for initialization") 239 | args = parser.parse_args() 240 | return args 241 | 242 | def build_or_load_gen_model(args): 243 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 244 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) 245 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name) 246 | model = model_class.from_pretrained(args.model_name_or_path) 247 | 248 | special_tokens_dict = {'additional_special_tokens': [REPLACE, REPLACE_OLD, REPLACE_NEW,REPLACE_END,INSERT,INSERT_OLD,INSERT_NEW ,INSERT_END,DELETE,DELETE_END,KEEP,KEEP_END]} 249 | logger.info("adding new token %s"%str(special_tokens_dict)) 250 | num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 251 | model.resize_token_embeddings(len(tokenizer)) 252 | if args.load_finetuned_model_path is not None: 253 | logger.info("Reload fine tuned model from {}".format(args.load_finetuned_model_path)) 254 | model.load_state_dict(torch.load(args.load_finetuned_model_path)) 255 | if args.base_model_type == "codet5": 256 | pass 257 | elif args.base_model_type == "ECMG": 258 | model = ECMGModel(model, config, args,sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id) 259 | else: 260 | raise RuntimeError 261 | 262 | if args.load_model_path is not None: 263 | logger.info("Reload model from {}".format(args.load_model_path)) 264 | model.load_state_dict(torch.load(args.load_model_path)) 265 | 266 | return config, model, tokenizer 267 | 268 | def set_dist(args): 269 | # Setup CUDA, GPU & distributed training 270 | if args.local_rank == -1 or args.no_cuda: 271 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 272 | args.n_gpu = torch.cuda.device_count() 273 | else: 274 | # Setup for distributed data parallel 275 | torch.cuda.set_device(args.local_rank) 276 | device = torch.device("cuda", args.local_rank) 277 | torch.distributed.init_process_group(backend='nccl') 278 | args.n_gpu = 1 279 | cpu_cont = multiprocessing.cpu_count() 280 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, cpu count: %d", 281 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), cpu_cont) 282 | args.device = device 283 | args.cpu_cont = cpu_cont 284 | 285 | 286 | def set_seed(args): 287 | """set random seed.""" 288 | random.seed(args.seed) 289 | np.random.seed(args.seed) 290 | torch.manual_seed(args.seed) 291 | if args.n_gpu > 0: 292 | torch.cuda.manual_seed_all(args.seed) 293 | 294 | def main(args): 295 | 296 | t0 = time.time() 297 | 298 | set_dist(args) 299 | set_seed(args) 300 | config, model, tokenizer = build_or_load_gen_model(args) 301 | 302 | model.to(args.device) 303 | if args.n_gpu > 1: 304 | # for DataParallel 305 | model = torch.nn.DataParallel(model) 306 | pool = multiprocessing.Pool(args.cpu_cont) 307 | # args.train_filename, args.dev_filename, args.test_filename = get_filenames(args.data_dir, args.task, args.sub_task) 308 | os.makedirs(args.output_dir, exist_ok=True) 309 | os.makedirs(args.cache_path, exist_ok=True) 310 | fa = open(os.path.join(args.output_dir, 'summary.log'), 'a+') 311 | 312 | if args.do_train: 313 | if args.local_rank in [-1, 0] and args.data_num == -1: 314 | summary_fn = '{}/{}'.format(args.summary_dir, '/'.join(args.output_dir.split('/')[1:])) 315 | tb_writer = SummaryWriter(summary_fn) 316 | 317 | # Prepare training data loader 318 | train_examples, train_data = load_and_cache_commit_data(args, args.train_filename, pool, tokenizer, 'train', is_sample=args.debug) 319 | train_sampler = RandomSampler(train_data) if args.local_rank == -1 else DistributedSampler(train_data) 320 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size, 321 | num_workers=4, pin_memory=True) 322 | 323 | # Prepare optimizer and schedule (linear warmup and decay) 324 | no_decay = ['bias', 'LayerNorm.weight'] 325 | optimizer_grouped_parameters = [ 326 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 327 | 'weight_decay': args.weight_decay}, 328 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 329 | ] 330 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 331 | num_train_optimization_steps = args.num_train_epochs * len(train_dataloader) 332 | scheduler = get_linear_schedule_with_warmup(optimizer, 333 | num_warmup_steps=args.warmup_steps, 334 | num_training_steps=num_train_optimization_steps) 335 | 336 | # Start training 337 | train_example_num = len(train_data) 338 | logger.info("***** Running training *****") 339 | logger.info(" Num examples = %d", train_example_num) 340 | logger.info(" Batch size = %d", args.train_batch_size) 341 | logger.info(" Batch num = %d", math.ceil(train_example_num / args.train_batch_size)) 342 | logger.info(" Num epoch = %d", args.num_train_epochs) 343 | 344 | dev_dataset = {} 345 | global_step, best_bleu, best_ppl = 0, -1, 1e6 346 | not_loss_dec_cnt, not_bleu_em_inc_cnt = 0, 0 if args.do_eval_bleu else 1e6 347 | 348 | for cur_epoch in range(args.start_epoch, int(args.num_train_epochs)): 349 | bar = tqdm(train_dataloader, total=len(train_dataloader), desc="Training") 350 | nb_tr_examples, nb_tr_steps, tr_loss = 0, 0, 0 351 | model.train() 352 | for step, batch in enumerate(bar): 353 | # print(step) 354 | batch = tuple(t.to(args.device) for t in batch) 355 | source_ids, target_ids = batch 356 | source_mask = source_ids.ne(tokenizer.pad_token_id) 357 | target_mask = target_ids.ne(tokenizer.pad_token_id) 358 | 359 | if args.base_model_type == "codet5" and args.model_type == 'codet5': 360 | outputs = model(input_ids=source_ids, attention_mask=source_mask, 361 | labels=target_ids, decoder_attention_mask=target_mask) 362 | loss = outputs.loss 363 | 364 | if args.n_gpu > 1: 365 | loss = loss.mean() # mean() to average on multi-gpu. 366 | if args.gradient_accumulation_steps > 1: 367 | loss = loss / args.gradient_accumulation_steps 368 | tr_loss += loss.item() 369 | 370 | nb_tr_examples += source_ids.size(0) 371 | nb_tr_steps += 1 372 | loss.backward() 373 | 374 | if nb_tr_steps % args.gradient_accumulation_steps == 0: 375 | # Update parameters 376 | optimizer.step() 377 | optimizer.zero_grad() 378 | scheduler.step() 379 | global_step += 1 380 | train_loss = round(tr_loss * args.gradient_accumulation_steps / (nb_tr_steps + 1), 4) 381 | if sys.stderr.isatty(): 382 | bar.set_description("[{}] Train loss {}".format(cur_epoch, round(train_loss, 3))) 383 | if (step+1)% args.eval_frequency ==0 and (not sys.stderr.isatty()): 384 | logger.info("epoch {} loss {}".format(cur_epoch,train_loss)) 385 | if args.do_eval: 386 | eval_examples, eval_data = load_and_cache_commit_data(args, args.dev_filename, pool, tokenizer, 'dev', 387 | only_src=True, is_sample=args.debug) 388 | 389 | pred_nls = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'dev') 390 | output_fn = os.path.join(args.output_dir, "dev.output") 391 | gold_fn = os.path.join(args.output_dir, "dev.gold") 392 | dev_accs, predictions = [], [] 393 | with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1: 394 | for pred_nl, gold in zip(pred_nls, eval_examples): 395 | dev_accs.append(pred_nl.strip() == gold.target.strip()) 396 | predictions.append(str(gold.idx) + '\t' + pred_nl) 397 | f.write(str(gold.idx) + '\t' + pred_nl.strip() + '\n') 398 | f1.write(str(gold.idx) + '\t' + gold.target.strip() + '\n') 399 | 400 | (goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn) 401 | dev_bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2) 402 | logger.info(" %s = %s "%("codenn_bleu",str(dev_bleu))) 403 | logger.info(" "+"*"*20) 404 | 405 | #save last checkpoint 406 | last_output_dir = os.path.join(args.output_dir, 'checkpoint-last') 407 | if not os.path.exists(last_output_dir): 408 | os.makedirs(last_output_dir) 409 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 410 | output_model_file = os.path.join(last_output_dir, "pytorch_model.bin") 411 | torch.save(model_to_save.state_dict(), output_model_file) 412 | 413 | if dev_bleu>best_bleu: 414 | logger.info(" Best bleu:%s",dev_bleu) 415 | logger.info(" "+"*"*20) 416 | best_bleu=dev_bleu 417 | # Save best checkpoint for best bleu 418 | output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu') 419 | if not os.path.exists(output_dir): 420 | os.makedirs(output_dir) 421 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 422 | output_model_file = os.path.join(output_dir, "pytorch_model.bin") 423 | torch.save(model_to_save.state_dict(), output_model_file) 424 | 425 | # logger.info("***** CUDA.empty_cache() *****") 426 | torch.cuda.empty_cache() 427 | if args.local_rank in [-1, 0] and args.data_num == -1: 428 | tb_writer.close() 429 | logger.info("Finish training and take %s", get_elapse_time(t0)) 430 | 431 | if args.do_test: 432 | logger.info(" " + "***** Testing *****") 433 | logger.info(" Batch size = %d", args.eval_batch_size) 434 | model = model.module if hasattr(model, 'module') else model 435 | if args.load_model_path is not None: 436 | logger.info("reload model from {}".format(args.load_model_path)) 437 | model.load_state_dict(torch.load(args.load_model_path)) 438 | eval_examples, eval_data = load_and_cache_commit_data(args, args.test_filename, pool, tokenizer, 'test', 439 | only_src=True, is_sample=args.debug) 440 | 441 | pred_nls = eval_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'test',beam_size=args.beam_size) 442 | output_fn = os.path.join(args.output_dir, "test.output") 443 | gold_fn = os.path.join(args.output_dir, "test.gold") 444 | dev_accs, predictions = [], [] 445 | with open(output_fn, 'w') as f, open(gold_fn, 'w') as f1: 446 | for pred_nl, gold in zip(pred_nls, eval_examples): 447 | dev_accs.append(pred_nl.strip() == gold.target.strip()) 448 | predictions.append(str(gold.idx) + '\t' + pred_nl) 449 | f.write(str(gold.idx) + '\t' + pred_nl.strip() + '\n') 450 | f1.write(str(gold.idx) + '\t' + gold.target.strip() + '\n') 451 | 452 | (goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn) 453 | dev_bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2) 454 | logger.info(" %s = %s "%("codenn_bleu",str(dev_bleu))) 455 | 456 | if args.do_retrieval: 457 | logger.info(" " + "***** retrievaling *****") 458 | logger.info(" Batch size = %d", args.eval_batch_size) 459 | model = model.module if hasattr(model, 'module') else model 460 | train_examples, train_data = load_and_cache_commit_data(args, args.train_filename, pool, tokenizer, 'train', is_sample=args.debug) 461 | train_sampler = SequentialSampler(train_data) 462 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size= args.eval_batch_size, 463 | num_workers=4, pin_memory=True) 464 | 465 | eval_examples, eval_data = load_and_cache_commit_data(args, args.retrieval_filename, pool, tokenizer, 'train', 466 | only_src=True, is_sample=args.debug) 467 | eval_sampler = SequentialSampler(eval_data) 468 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 469 | 470 | model.eval() 471 | 472 | model = model.module.encoder if hasattr(model, "moddule") else model.encoder 473 | 474 | train_code_vecs=[] 475 | eval_code_vecs=[] 476 | logger.info(" Num examples of Corpus = %d", len(train_data)) 477 | for batch in train_dataloader: 478 | with torch.no_grad(): 479 | source_ids = batch[0].to(args.device) 480 | source_mask = source_ids.ne(tokenizer.pad_token_id) 481 | # [bs, hid_dim] 482 | train_code_vec = model(input_ids=source_ids, attention_mask=source_mask).last_hidden_state #[bs, sequence_length, dim] 483 | train_code_vec = torch.mean(train_code_vec, dim=1) #[bs, dim] 484 | if args.is_cosine_space: 485 | train_code_vec = F.normalize(train_code_vec, p=2, dim=1) 486 | 487 | train_code_vecs.append( train_code_vec.cpu().numpy()) 488 | 489 | logger.info(" Num examples to retrieve = %d", len(eval_data)) 490 | for batch in eval_dataloader: 491 | with torch.no_grad(): 492 | # batch = tuple(t.to(args.device) for t in batch) 493 | source_ids = batch[0].to(args.device) 494 | source_mask = source_ids.ne(tokenizer.pad_token_id) 495 | # [bs, 1, hid_dim] 496 | eval_code_vec = model(input_ids=source_ids, attention_mask=source_mask).last_hidden_state 497 | eval_code_vec = torch.mean(eval_code_vec, dim=1) 498 | if args.is_cosine_space: 499 | eval_code_vec = F.normalize(eval_code_vec, p=2, dim=1) 500 | eval_code_vecs.append( eval_code_vec.cpu().numpy()) 501 | 502 | train_code_vecs = np.concatenate(train_code_vecs,0) # [num_of_train_samples, hid_dim] 503 | eval_code_vecs = np.concatenate(eval_code_vecs,0) # [num_of_eval_samples, hid_dim] 504 | 505 | scores=np.matmul(eval_code_vecs, train_code_vecs.T) 506 | sort_ids=np.argsort(scores, axis=-1, kind='quicksort', order=None)[:,::-1] # [num_of_eval_samples,num_of_train_samples] 507 | if "train" in args.retrieval_result_filename: 508 | logger.info("return 2nd ranked result") 509 | rank1_result = sort_ids[:,1] # [num_of_eval_samples] 510 | else: 511 | rank1_result = sort_ids[:,0] # [num_of_eval_samples] 512 | logger.info("ranked list %s"%str(rank1_result[:30])) 513 | retrieval_results = [] 514 | pred_nls = [] 515 | for idx in rank1_result: 516 | retrieval_result = train_examples[idx] 517 | retrieval_results.append({ "diff":retrieval_result.source.split(), 518 | "msg_token":retrieval_result.target.split()}) 519 | pred_nls.append(retrieval_result.target) 520 | 521 | predictions =[] 522 | output_fn = os.path.join(args.retrieval_result_dir, "%s.retireval.output"%args.retrieval_result_filename) 523 | gold_fn = os.path.join(args.retrieval_result_dir, "%s.gold"%args.retrieval_result_filename) 524 | with open(output_fn, 'w', encoding="utf-8") as f, open(gold_fn, 'w', encoding="utf-8") as f1: 525 | for pred_nl, gold in zip(pred_nls, eval_examples): 526 | predictions.append(str(gold.idx) + '\t' + pred_nl) 527 | f.write(str(gold.idx) + '\t' + pred_nl.strip() + '\n') 528 | f1.write(str(gold.idx) + '\t' + gold.target.strip() + '\n') 529 | 530 | (goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn) 531 | dev_bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2) 532 | logger.info(" %s = %s "%("codenn_bleu",str(dev_bleu))) 533 | logger.info(" save predict result in %s"%(output_fn )) 534 | 535 | save_json_data(args.retrieval_result_dir, args.retrieval_result_filename, retrieval_results) 536 | 537 | logger.info("Finish and take {}".format(get_elapse_time(t0))) 538 | fa.write("Finish and take {}".format(get_elapse_time(t0))) 539 | fa.close() 540 | 541 | def eval_ecmg_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, split_tag,beam_size=1): 542 | logger.info(" ***** Running bleu evaluation on {} data*****".format(split_tag)) 543 | logger.info(" Num examples = %d", len(eval_examples)) 544 | logger.info(" Batch size = %d", args.eval_batch_size) 545 | eval_sampler = SequentialSampler(eval_data) 546 | if args.data_num == -1: 547 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size, 548 | num_workers=4, pin_memory=True) 549 | else: 550 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 551 | 552 | model.eval() 553 | pred_ids = [] 554 | bleu, codebleu = 0.0, 0.0 555 | for batch in tqdm(eval_dataloader, total=len(eval_dataloader), desc="Eval bleu for {} set".format(split_tag)): 556 | # source_ids = batch[0].to(args.device) 557 | # source_mask = source_ids.ne(tokenizer.pad_token_id) 558 | 559 | batch = tuple(t.to(args.device) for t in batch) 560 | input_source_ids, retrieved_source_ids, retrieved_target_ids = batch 561 | 562 | source_mask = input_source_ids.ne(tokenizer.pad_token_id) 563 | 564 | retrieved_source_mask = retrieved_source_ids.ne(tokenizer.pad_token_id) 565 | retrieved_target_mask = retrieved_target_ids.ne(tokenizer.pad_token_id) 566 | 567 | with torch.no_grad(): 568 | # if args.model_type == 'roberta': 569 | # preds = model(source_ids=source_ids, source_mask=source_mask) 570 | 571 | # top_preds = [pred[0].cpu().numpy() for pred in preds] 572 | # if hasattr(model, 'module'): 573 | # preds = model.module.generate(source_ids, 574 | # attention_mask=source_mask, 575 | # use_cache=True, 576 | # num_beams=beam_size, 577 | # early_stopping=args.task == 'summarize', 578 | # max_length=args.max_target_length) 579 | 580 | # else: 581 | # preds = model.generate(source_ids, 582 | # attention_mask=source_mask, 583 | # use_cache=True, 584 | # num_beams=beam_size, 585 | # early_stopping=args.task == 'summarize', 586 | # max_length=args.max_target_length) 587 | # top_preds = list(preds.cpu().numpy()) 588 | # pred_ids.extend(top_preds) 589 | # preds = model(source_ids=source_ids, source_mask=source_mask) 590 | preds = model(source_ids=input_source_ids, source_mask=source_mask, 591 | retrieved_source_ids=retrieved_source_ids, retrieved_source_mask=retrieved_source_mask, 592 | retrieved_target_ids=retrieved_target_ids, retrieved_target_mask=retrieved_target_mask 593 | ) 594 | top_preds = [pred[0].cpu().numpy() for pred in preds] 595 | pred_ids.extend(top_preds) 596 | pred_nls = [tokenizer.decode(id, skip_special_tokens=True, clean_up_tokenization_spaces=False) for id in pred_ids] 597 | 598 | 599 | return pred_nls 600 | 601 | 602 | def ECMG(args): 603 | t0 = time.time() 604 | set_dist(args) 605 | set_seed(args) 606 | # config, model, tokenizer = build_or_load_gen_model(args) 607 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 608 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) 609 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name) 610 | model = model_class.from_pretrained(args.model_name_or_path) 611 | 612 | special_tokens_dict = {'additional_special_tokens': [REPLACE, REPLACE_OLD, REPLACE_NEW,REPLACE_END,INSERT,INSERT_OLD,INSERT_NEW ,INSERT_END,DELETE,DELETE_END,KEEP,KEEP_END]} 613 | logger.info("adding new token %s"%str(special_tokens_dict)) 614 | num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 615 | model.resize_token_embeddings(len(tokenizer)) 616 | if args.load_finetuned_model_path is not None: 617 | logger.info("Reload fine tuned model from {}".format(args.load_finetuned_model_path)) 618 | model.load_state_dict(torch.load(args.load_finetuned_model_path)) 619 | 620 | decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads) 621 | decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) 622 | # model = ECMGModel(model,decoder, config, args,sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id) 623 | model = ECMGModel(model, config, args,sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id) 624 | 625 | if args.load_model_path is not None: 626 | logger.info("Reload model from {}".format(args.load_model_path)) 627 | model.load_state_dict(torch.load(args.load_model_path)) 628 | 629 | 630 | model.to(args.device) 631 | if args.n_gpu > 1: 632 | # for DataParallel 633 | model = torch.nn.DataParallel(model) 634 | pool = multiprocessing.Pool(args.cpu_cont) 635 | # args.train_filename, args.dev_filename, args.test_filename = get_filenames(args.data_dir, args.task, args.sub_task) 636 | os.makedirs(args.output_dir, exist_ok=True) 637 | os.makedirs(args.cache_path, exist_ok=True) 638 | fa = open(os.path.join(args.output_dir, 'summary.log'), 'a+') 639 | 640 | if args.do_train: 641 | if args.local_rank in [-1, 0] and args.data_num == -1: 642 | summary_fn = '{}/{}'.format(args.summary_dir, '/'.join(args.output_dir.split('/')[1:])) 643 | tb_writer = SummaryWriter(summary_fn) 644 | 645 | # Prepare training data loader 646 | train_input_examples, train_retrieved_examples, train_data = load_and_commit_data_with_retrieved_result(args, args.train_filename, args.train_retireved_filename, pool, tokenizer, 'train', is_sample=args.debug) 647 | train_sampler = RandomSampler(train_data) if args.local_rank == -1 else DistributedSampler(train_data) 648 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size, 649 | num_workers=4, pin_memory=True) 650 | 651 | # Prepare optimizer and schedule (linear warmup and decay) 652 | no_decay = ['bias', 'LayerNorm.weight'] 653 | optimizer_grouped_parameters = [ 654 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 655 | 'weight_decay': args.weight_decay}, 656 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 657 | ] 658 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 659 | num_train_optimization_steps = args.num_train_epochs * len(train_dataloader) 660 | scheduler = get_linear_schedule_with_warmup(optimizer, 661 | num_warmup_steps=args.warmup_steps, 662 | num_training_steps=num_train_optimization_steps) 663 | 664 | # Start training 665 | train_example_num = len(train_data) 666 | logger.info("***** Running training *****") 667 | logger.info(" Num examples = %d", train_example_num) 668 | logger.info(" Batch size = %d", args.train_batch_size) 669 | logger.info(" Batch num = %d", math.ceil(train_example_num / args.train_batch_size)) 670 | logger.info(" Num epoch = %d", args.num_train_epochs) 671 | 672 | dev_dataset = {} 673 | global_step, best_bleu, best_ppl = 0, -1, 1e6 674 | not_loss_dec_cnt, not_bleu_em_inc_cnt = 0, 0 if args.do_eval_bleu else 1e6 675 | 676 | for cur_epoch in range(args.start_epoch, int(args.num_train_epochs)): 677 | bar = tqdm(train_dataloader, total=len(train_dataloader), desc="Training") 678 | nb_tr_examples, nb_tr_steps, tr_loss = 0, 0, 0 679 | model.train() 680 | for step, batch in enumerate(bar): 681 | # print(step) 682 | batch = tuple(t.to(args.device) for t in batch) 683 | input_source_ids, input_target_ids, retrieved_source_ids, retrieved_target_ids = batch 684 | 685 | source_mask = input_source_ids.ne(tokenizer.pad_token_id) 686 | target_mask = input_target_ids.ne(tokenizer.pad_token_id) 687 | 688 | retrieved_source_mask = retrieved_source_ids.ne(tokenizer.pad_token_id) 689 | retrieved_target_mask = retrieved_target_ids.ne(tokenizer.pad_token_id) 690 | 691 | # if args.model_type == 'roberta': 692 | # loss, _, _ = model(source_ids=input_source_ids, source_mask=source_mask, 693 | # target_ids=input_target_ids, target_mask=target_mask) 694 | # else: 695 | # outputs = model(input_ids=input_source_ids, attention_mask=source_mask, 696 | # labels=input_target_ids, decoder_attention_mask=target_mask) 697 | # loss = outputs.loss 698 | 699 | loss, _, _ = model(source_ids=input_source_ids, source_mask=source_mask, 700 | target_ids=input_target_ids, target_mask=target_mask, 701 | retrieved_source_ids=retrieved_source_ids, retrieved_source_mask=retrieved_source_mask, 702 | retrieved_target_ids=retrieved_target_ids, retrieved_target_mask=retrieved_target_mask 703 | ) 704 | # loss = outputs.loss 705 | 706 | if args.n_gpu > 1: 707 | loss = loss.mean() # mean() to average on multi-gpu. 708 | if args.gradient_accumulation_steps > 1: 709 | loss = loss / args.gradient_accumulation_steps 710 | tr_loss += loss.item() 711 | 712 | nb_tr_examples += input_source_ids.size(0) 713 | nb_tr_steps += 1 714 | loss.backward() 715 | 716 | if nb_tr_steps % args.gradient_accumulation_steps == 0: 717 | # Update parameters 718 | optimizer.step() 719 | optimizer.zero_grad() 720 | scheduler.step() 721 | global_step += 1 722 | train_loss = round(tr_loss * args.gradient_accumulation_steps / (nb_tr_steps + 1), 4) 723 | if sys.stderr.isatty(): 724 | bar.set_description("[{}] Train loss {}".format(cur_epoch, round(train_loss, 3))) 725 | if (step+1)% args.eval_frequency ==0 and (not sys.stderr.isatty()): 726 | logger.info("epoch {} loss {}".format(cur_epoch,train_loss)) 727 | if args.do_eval: 728 | # eval_examples, eval_data = load_and_cache_commit_data(args, args.dev_filename, pool, tokenizer, 'dev', 729 | # only_src=True, is_sample=args.debug) 730 | eval_examples, eval_retrieved_examples, eval_data = load_and_commit_data_with_retrieved_result(args, args.dev_filename, args.dev_retireved_filename, pool, tokenizer, 'dev', only_src=True, is_sample=args.debug) 731 | 732 | 733 | pred_nls = eval_ecmg_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'dev') 734 | output_fn = os.path.join(args.output_dir, "dev.output") 735 | gold_fn = os.path.join(args.output_dir, "dev.gold") 736 | dev_accs, predictions = [], [] 737 | with open(output_fn, 'w', encoding="utf-8") as f, open(gold_fn, 'w', encoding="utf-8") as f1: 738 | for pred_nl, gold in zip(pred_nls, eval_examples): 739 | dev_accs.append(pred_nl.strip() == gold.target.strip()) 740 | predictions.append(str(gold.idx) + '\t' + pred_nl) 741 | f.write(str(gold.idx) + '\t' + pred_nl.strip() + '\n') 742 | f1.write(str(gold.idx) + '\t' + gold.target.strip() + '\n') 743 | 744 | (goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn) 745 | dev_bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2) 746 | logger.info(" %s = %s "%("codenn_bleu",str(dev_bleu))) 747 | logger.info(" "+"*"*20) 748 | logger.info(" save predict result in %s"%(output_fn )) 749 | #save last checkpoint 750 | last_output_dir = os.path.join(args.output_dir, 'checkpoint-last') 751 | if not os.path.exists(last_output_dir): 752 | os.makedirs(last_output_dir) 753 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 754 | output_model_file = os.path.join(last_output_dir, "pytorch_model.bin") 755 | torch.save(model_to_save.state_dict(), output_model_file) 756 | 757 | if dev_bleu>best_bleu: 758 | logger.info(" Best bleu:%s",dev_bleu) 759 | logger.info(" "+"*"*20) 760 | best_bleu=dev_bleu 761 | # Save best checkpoint for best bleu 762 | output_dir = os.path.join(args.output_dir, 'checkpoint-best-bleu') 763 | if not os.path.exists(output_dir): 764 | os.makedirs(output_dir) 765 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 766 | output_model_file = os.path.join(output_dir, "pytorch_model.bin") 767 | torch.save(model_to_save.state_dict(), output_model_file) 768 | 769 | # logger.info("***** CUDA.empty_cache() *****") 770 | torch.cuda.empty_cache() 771 | if args.local_rank in [-1, 0] and args.data_num == -1: 772 | tb_writer.close() 773 | logger.info("Finish training and take %s", get_elapse_time(t0)) 774 | 775 | if args.do_test: 776 | logger.info(" " + "***** Testing *****") 777 | logger.info(" Batch size = %d", args.eval_batch_size) 778 | model = model.module if hasattr(model, 'module') else model 779 | if args.load_model_path is not None: 780 | logger.info("reload model from {}".format(args.load_model_path)) 781 | model.load_state_dict(torch.load(args.load_model_path)) 782 | # eval_examples, eval_data = load_and_cache_commit_data(args, args.test_filename, pool, tokenizer, 'test', 783 | # only_src=True, is_sample=args.debug) 784 | eval_examples, eval_retrieved_examples, eval_data = load_and_commit_data_with_retrieved_result(args, args.test_filename, args.test_retireved_filename, pool, tokenizer, 'test', only_src=True, is_sample=args.debug) 785 | 786 | pred_nls = eval_ecmg_bleu_epoch(args, eval_data, eval_examples, model, tokenizer, 'test',beam_size=args.beam_size) 787 | output_fn = os.path.join(args.output_dir, "test.output") 788 | gold_fn = os.path.join(args.output_dir, "test.gold") 789 | dev_accs, predictions = [], [] 790 | with open(output_fn, 'w', encoding="utf-8") as f, open(gold_fn, 'w', encoding="utf-8") as f1: 791 | for pred_nl, gold in zip(pred_nls, eval_examples): 792 | dev_accs.append(pred_nl.strip() == gold.target.strip()) 793 | predictions.append(str(gold.idx) + '\t' + pred_nl) 794 | f.write(str(gold.idx) + '\t' + pred_nl.strip() + '\n') 795 | f1.write(str(gold.idx) + '\t' + gold.target.strip() + '\n') 796 | 797 | (goldMap, predictionMap) = smooth_bleu.computeMaps(predictions, gold_fn) 798 | dev_bleu = round(smooth_bleu.bleuFromMaps(goldMap, predictionMap)[0], 2) 799 | logger.info(" %s = %s "%("codenn_bleu",str(dev_bleu))) 800 | logger.info(" save predict result in %s"%(output_fn )) 801 | 802 | logger.info("Finish and take {}".format(get_elapse_time(t0))) 803 | fa.write("Finish and take {}".format(get_elapse_time(t0))) 804 | fa.close() 805 | 806 | if __name__ == "__main__": 807 | args = parse_args() 808 | logger.info(args) 809 | if args.run_codet5: 810 | main(args) 811 | else: 812 | ECMG(args) 813 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | lang=$1 2 | 3 | # optimizer 4 | lr=5e-5 5 | batch_size=32 6 | beam_size=10 7 | epochs=10 8 | 9 | # model 10 | source_length=200 11 | target_length=30 12 | 13 | # data 14 | data_dir=dataset/$lang/contextual_medits 15 | train_file=$data_dir/train.jsonl 16 | dev_file=$data_dir/valid.jsonl 17 | test_file=$data_dir/test.jsonl 18 | 19 | 20 | pretrained_model=Salesforce/codet5-base 21 | 22 | # ============ Step 1 Training ============== 23 | 24 | function train_codet5 () { 25 | 26 | output_dir=saved_model/codet5/${lang}/ 27 | mkdir -p $output_dir 28 | echo $output_dir 29 | echo "============TRAINING============" 30 | CUDA_VISIBLE_DEVICES=0 python run.py --do_train --do_eval --do_test --eval_frequency 100 \ 31 | --run_codet5 \ 32 | --model_name_or_path $pretrained_model \ 33 | --train_filename $train_file \ 34 | --dev_filename $dev_file \ 35 | --test_filename ${test_file} \ 36 | --output_dir $output_dir \ 37 | --max_source_length $source_length \ 38 | --max_target_length $target_length \ 39 | --do_lower_case \ 40 | --beam_size $beam_size --train_batch_size $batch_size \ 41 | --eval_batch_size $batch_size --learning_rate $lr \ 42 | --num_train_epochs $epochs --seed 0 2>&1|tee $output_dir/train.log 43 | } 44 | 45 | 46 | # 47 | train_codet5 48 | 49 | 50 | # ============ Step 2 Retrieval ============== 51 | 52 | retrieval_result_dir=${data_dir}/codet5_retrieval_result 53 | mkdir -p ${retrieval_result_dir} 54 | 55 | function retrieval () { 56 | echo "============retrieval ============" 57 | retrieval_filename=$1 58 | load_model_path=saved_model/codet5/${lang}/checkpoint-best-bleu/pytorch_model.bin 59 | CUDA_VISIBLE_DEVICES=0 python run.py --do_retrieval \ 60 | --run_codet5 \ 61 | --is_cosine_space \ 62 | --train_filename ${train_file} \ 63 | --max_source_length $source_length \ 64 | --max_target_length $target_length \ 65 | --train_batch_size $batch_size \ 66 | --eval_batch_size $batch_size \ 67 | --retrieval_filename ${data_dir}/${retrieval_filename}.jsonl \ 68 | --retrieval_result_dir ${retrieval_result_dir} \ 69 | --retrieval_result_filename ${retrieval_filename}.jsonl \ 70 | --load_model_path ${load_model_path} 2>&1 |tee ${retrieval_result_dir}/${retrieval_filename}.log.txt 71 | } 72 | 73 | 74 | 75 | 76 | retrieval "train" 77 | retrieval "valid" 78 | retrieval "test" 79 | 80 | # ============ Step 3 Refine =============== 81 | 82 | train_retireved_file=${retrieval_result_dir}/train.jsonl 83 | dev_retireved_file=${retrieval_result_dir}/valid.jsonl 84 | test_retireved_file=${retrieval_result_dir}/test.jsonl 85 | 86 | function refine () { 87 | # --debug 88 | load_model_path=saved_model/codet5/${lang}/checkpoint-best-bleu/pytorch_model.bin 89 | output_dir=saved_model/ECMG/${lang}/ 90 | mkdir -p $output_dir 91 | echo $output_dir 92 | 93 | echo "============Refining============" 94 | 95 | CUDA_VISIBLE_DEVICES=0 python run.py --do_train --do_eval --do_test --eval_frequency 100 \ 96 | --load_finetuned_model_path ${load_model_path} \ 97 | --model_name_or_path $pretrained_model \ 98 | --train_filename $train_file \ 99 | --dev_filename $dev_file \ 100 | --test_filename ${test_file} \ 101 | --train_retireved_filename $train_retireved_file \ 102 | --dev_retireved_filename $dev_retireved_file \ 103 | --test_retireved_filename ${test_retireved_file} \ 104 | --output_dir $output_dir \ 105 | --max_source_length $source_length \ 106 | --max_target_length $target_length \ 107 | --do_lower_case \ 108 | --beam_size $beam_size --train_batch_size $batch_size \ 109 | --eval_batch_size $batch_size --learning_rate $lr \ 110 | --num_train_epochs $epochs --seed 0 2>&1|tee $output_dir/refine.log 111 | } 112 | 113 | 114 | refine -------------------------------------------------------------------------------- /run_small.sh: -------------------------------------------------------------------------------- 1 | lang=$1 2 | 3 | # optimizer 4 | lr=5e-5 5 | batch_size=32 6 | beam_size=10 7 | epochs=10 8 | 9 | # model 10 | source_length=200 11 | target_length=30 12 | 13 | # data 14 | data_dir=dataset/$lang/contextual_medits 15 | train_file=$data_dir/train.jsonl 16 | dev_file=$data_dir/valid.jsonl 17 | test_file=$data_dir/test.jsonl 18 | 19 | 20 | pretrained_model=Salesforce/codet5-base 21 | 22 | # ============ Step 1 Training ============== 23 | 24 | 25 | function train_codet5_debug () { 26 | output_dir=saved_model/tmp/${lang} 27 | mkdir -p $output_dir 28 | echo $output_dir 29 | echo "============TRAINING Debugging============" 30 | 31 | CUDA_VISIBLE_DEVICES=0 python run.py --debug --n_debug_samples 100 --do_train --do_eval --do_test --eval_frequency 1 \ 32 | --run_codet5 \ 33 | --model_name_or_path $pretrained_model \ 34 | --train_filename $train_file \ 35 | --dev_filename $dev_file \ 36 | --test_filename ${test_file} \ 37 | --output_dir $output_dir \ 38 | --max_source_length $source_length \ 39 | --max_target_length $target_length \ 40 | --do_lower_case \ 41 | --beam_size $beam_size --train_batch_size $batch_size \ 42 | --eval_batch_size 8 --learning_rate $lr \ 43 | --num_train_epochs 3 --seed 0 2>&1|tee $output_dir/train.log 44 | } 45 | 46 | train_codet5_debug 47 | 48 | # ============ Step 2 Retrieval ============== 49 | 50 | retrieval_result_dir=${data_dir}/codet5_retrieval_result 51 | mkdir -p ${retrieval_result_dir} 52 | 53 | function retrieval_debug(){ 54 | echo "============retrieval Debugging============" 55 | retrieval_filename=$1 56 | load_model_path=saved_model/tmp/${lang}/checkpoint-best-bleu/pytorch_model.bin 57 | CUDA_VISIBLE_DEVICES=0 python run.py --debug --do_retrieval \ 58 | --run_codet5 \ 59 | --is_cosine_space \ 60 | --train_filename ${train_file} \ 61 | --max_source_length $source_length \ 62 | --max_target_length $target_length \ 63 | --train_batch_size $batch_size \ 64 | --eval_batch_size $batch_size \ 65 | --retrieval_filename ${data_dir}/${retrieval_filename}.jsonl \ 66 | --retrieval_result_dir ${retrieval_result_dir} \ 67 | --retrieval_result_filename ${retrieval_filename}.jsonl \ 68 | --load_model_path ${load_model_path} 2>&1 |tee ${retrieval_result_dir}/${retrieval_filename}.log.txt 69 | } 70 | 71 | 72 | retrieval_debug "train" 73 | retrieval_debug "valid" 74 | retrieval_debug "test" 75 | 76 | # ============ Step 3 Refine =============== 77 | 78 | train_retireved_file=${retrieval_result_dir}/train.jsonl 79 | dev_retireved_file=${retrieval_result_dir}/valid.jsonl 80 | test_retireved_file=${retrieval_result_dir}/test.jsonl 81 | 82 | function refine_debug () { 83 | # --debug 84 | load_model_path=saved_model/tmp/${lang}/checkpoint-best-bleu/pytorch_model.bin 85 | output_dir=saved_model/debug/ECMG/${lang}/ 86 | mkdir -p $output_dir 87 | echo $output_dir 88 | 89 | echo "============Refining Debug============" 90 | 91 | CUDA_VISIBLE_DEVICES=0 python run.py --debug --do_train --do_eval --do_test --eval_frequency 100 \ 92 | --load_finetuned_model_path ${load_model_path} \ 93 | --model_name_or_path $pretrained_model \ 94 | --train_filename $train_file \ 95 | --dev_filename $dev_file \ 96 | --test_filename ${test_file} \ 97 | --train_retireved_filename $train_retireved_file \ 98 | --dev_retireved_filename $dev_retireved_file \ 99 | --test_retireved_filename ${test_retireved_file} \ 100 | --output_dir $output_dir \ 101 | --max_source_length $source_length \ 102 | --max_target_length $target_length \ 103 | --do_lower_case \ 104 | --beam_size $beam_size --train_batch_size $batch_size \ 105 | --eval_batch_size $batch_size --learning_rate $lr \ 106 | --num_train_epochs 3 --seed 0 2>&1|tee $output_dir/refine.log 107 | } 108 | 109 | refine_debug -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import json 4 | import prettytable as pt 5 | import math 6 | from torch.utils.data import TensorDataset 7 | import numpy as np 8 | import os 9 | import random 10 | import time 11 | from tqdm import tqdm 12 | import torch 13 | import logging 14 | logger = logging.getLogger(__name__) 15 | 16 | REPLACE = '' 17 | REPLACE_OLD = '' 18 | REPLACE_NEW = '' 19 | REPLACE_END = '' 20 | 21 | INSERT = '' 22 | INSERT_OLD = '' 23 | INSERT_NEW = '' 24 | INSERT_END = '' 25 | 26 | DELETE = '' 27 | DELETE_END = '' 28 | 29 | KEEP = '' 30 | KEEP_END = '' 31 | 32 | def add_lang_by_task(target_str, task, sub_task): 33 | if task == 'summarize': 34 | target_str = ' ' + target_str 35 | elif task == 'refine': 36 | target_str = ' ' + target_str 37 | elif task == 'translate': 38 | if sub_task == 'java-cs': 39 | target_str = ' ' + target_str 40 | else: 41 | target_str = ' ' + target_str 42 | elif task == 'concode': 43 | target_str = ' ' + target_str 44 | elif task == 'defect': 45 | target_str = target_str 46 | return target_str 47 | 48 | 49 | def convert_examples_to_features(item): 50 | example, example_index, tokenizer, args, stage = item 51 | 52 | if args.model_type in ['t5', 'codet5'] and args.add_task_prefix: 53 | if args.sub_task != 'none': 54 | source_str = "{} {}: {}".format(args.task, args.sub_task, example.source) 55 | else: 56 | source_str = "{}: {}".format(args.task, example.source) 57 | else: 58 | source_str = example.source 59 | # https://blog.csdn.net/qq_33293040/article/details/105439750 60 | source_str = source_str.replace('', '') 61 | source_ids = tokenizer.encode(source_str, max_length=args.max_source_length, padding='max_length', truncation=True) 62 | assert source_ids.count(tokenizer.eos_token_id) == 1 63 | if stage == 'test': 64 | target_ids = [] 65 | else: 66 | target_str = example.target 67 | if args.add_lang_ids: 68 | target_str = add_lang_by_task(example.target, args.task, args.sub_task) 69 | if args.task in ['defect', 'clone']: 70 | if target_str == 0: 71 | target_str = 'false' 72 | elif target_str == 1: 73 | target_str = 'true' 74 | else: 75 | raise NameError 76 | target_str = target_str.replace('', '') 77 | target_ids = tokenizer.encode(target_str, max_length=args.max_target_length, padding='max_length', 78 | truncation=True) 79 | assert target_ids.count(tokenizer.eos_token_id) == 1 80 | 81 | return InputFeatures( 82 | example_index, 83 | source_ids, 84 | target_ids, 85 | url=example.url 86 | ) 87 | 88 | 89 | def convert_clone_examples_to_features(item): 90 | example, example_index, tokenizer, args = item 91 | if args.model_type in ['t5', 'codet5'] and args.add_task_prefix: 92 | source_str = "{}: {}".format(args.task, example.source) 93 | target_str = "{}: {}".format(args.task, example.target) 94 | else: 95 | source_str = example.source 96 | target_str = example.target 97 | code1 = tokenizer.encode(source_str, max_length=args.max_source_length, padding='max_length', truncation=True) 98 | code2 = tokenizer.encode(target_str, max_length=args.max_source_length, padding='max_length', truncation=True) 99 | source_ids = code1 + code2 100 | return CloneInputFeatures(example_index, source_ids, example.label, example.url1, example.url2) 101 | 102 | 103 | def convert_defect_examples_to_features(item): 104 | example, example_index, tokenizer, args = item 105 | if args.model_type in ['t5', 'codet5'] and args.add_task_prefix: 106 | source_str = "{}: {}".format(args.task, example.source) 107 | else: 108 | source_str = example.source 109 | code = tokenizer.encode(source_str, max_length=args.max_source_length, padding='max_length', truncation=True) 110 | return DefectInputFeatures(example_index, code, example.target) 111 | 112 | 113 | class CloneInputFeatures(object): 114 | """A single training/test features for a example.""" 115 | 116 | def __init__(self, 117 | example_id, 118 | source_ids, 119 | label, 120 | url1, 121 | url2 122 | ): 123 | self.example_id = example_id 124 | self.source_ids = source_ids 125 | self.label = label 126 | self.url1 = url1 127 | self.url2 = url2 128 | 129 | 130 | class DefectInputFeatures(object): 131 | """A single training/test features for a example.""" 132 | 133 | def __init__(self, 134 | example_id, 135 | source_ids, 136 | label 137 | ): 138 | self.example_id = example_id 139 | self.source_ids = source_ids 140 | self.label = label 141 | 142 | 143 | class InputFeatures(object): 144 | """A single training/test features for a example.""" 145 | 146 | def __init__(self, 147 | example_id, 148 | source_ids, 149 | target_ids, 150 | url=None 151 | ): 152 | self.example_id = example_id 153 | self.source_ids = source_ids 154 | self.target_ids = target_ids 155 | self.url = url 156 | 157 | 158 | class Example(object): 159 | """A single training/test example.""" 160 | 161 | def __init__(self, 162 | idx, 163 | source, 164 | target, 165 | url=None, 166 | task='', 167 | sub_task='' 168 | ): 169 | self.idx = idx 170 | self.source = source 171 | self.target = target 172 | self.url = url 173 | self.task = task 174 | self.sub_task = sub_task 175 | 176 | 177 | class CloneExample(object): 178 | """A single training/test example.""" 179 | 180 | def __init__(self, 181 | code1, 182 | code2, 183 | label, 184 | url1, 185 | url2 186 | ): 187 | self.source = code1 188 | self.target = code2 189 | self.label = label 190 | self.url1 = url1 191 | self.url2 = url2 192 | 193 | 194 | def read_translate_examples(filename, data_num): 195 | """Read examples from filename.""" 196 | examples = [] 197 | assert len(filename.split(',')) == 2 198 | src_filename = filename.split(',')[0] 199 | trg_filename = filename.split(',')[1] 200 | idx = 0 201 | with open(src_filename) as f1, open(trg_filename) as f2: 202 | for line1, line2 in zip(f1, f2): 203 | src = line1.strip() 204 | trg = line2.strip() 205 | examples.append( 206 | Example( 207 | idx=idx, 208 | source=src, 209 | target=trg, 210 | ) 211 | ) 212 | idx += 1 213 | if idx == data_num: 214 | break 215 | return examples 216 | 217 | 218 | def read_refine_examples(filename, data_num): 219 | """Read examples from filename.""" 220 | examples = [] 221 | assert len(filename.split(',')) == 2 222 | src_filename = filename.split(',')[0] 223 | trg_filename = filename.split(',')[1] 224 | idx = 0 225 | 226 | with open(src_filename) as f1, open(trg_filename) as f2: 227 | for line1, line2 in zip(f1, f2): 228 | examples.append( 229 | Example( 230 | idx=idx, 231 | source=line1.strip(), 232 | target=line2.strip(), 233 | ) 234 | ) 235 | idx += 1 236 | if idx == data_num: 237 | break 238 | return examples 239 | 240 | 241 | def read_concode_examples(filename, data_num): 242 | """Read examples from filename.""" 243 | examples = [] 244 | 245 | with open(filename) as f: 246 | for idx, line in enumerate(f): 247 | x = json.loads(line) 248 | examples.append( 249 | Example( 250 | idx=idx, 251 | source=x["nl"].strip(), 252 | target=x["code"].strip() 253 | ) 254 | ) 255 | idx += 1 256 | if idx == data_num: 257 | break 258 | return examples 259 | 260 | 261 | def read_summarize_examples(filename, data_num): 262 | """Read examples from filename.""" 263 | examples = [] 264 | with open(filename, encoding="utf-8") as f: 265 | for idx, line in enumerate(f): 266 | line = line.strip() 267 | js = json.loads(line) 268 | if 'idx' not in js: 269 | js['idx'] = idx 270 | code = ' '.join(js['code_tokens']).replace('\n', ' ') 271 | code = ' '.join(code.strip().split()) 272 | nl = ' '.join(js['docstring_tokens']).replace('\n', '') 273 | nl = ' '.join(nl.strip().split()) 274 | examples.append( 275 | Example( 276 | idx=idx, 277 | source=code, 278 | target=nl, 279 | ) 280 | ) 281 | if idx + 1 == data_num: 282 | break 283 | return examples 284 | 285 | 286 | def read_defect_examples(filename, data_num): 287 | """Read examples from filename.""" 288 | examples = [] 289 | with open(filename, encoding="utf-8") as f: 290 | for idx, line in enumerate(f): 291 | line = line.strip() 292 | js = json.loads(line) 293 | 294 | code = ' '.join(js['func'].split()) 295 | examples.append( 296 | Example( 297 | idx=js['idx'], 298 | source=code, 299 | target=js['target'] 300 | ) 301 | ) 302 | if idx + 1 == data_num: 303 | break 304 | return examples 305 | 306 | 307 | def read_clone_examples(filename, data_num): 308 | """Read examples from filename.""" 309 | index_filename = filename 310 | url_to_code = {} 311 | with open('/'.join(index_filename.split('/')[:-1]) + '/data.jsonl') as f: 312 | for line in f: 313 | line = line.strip() 314 | js = json.loads(line) 315 | code = ' '.join(js['func'].split()) 316 | url_to_code[js['idx']] = code 317 | 318 | data = [] 319 | with open(index_filename) as f: 320 | idx = 0 321 | for line in f: 322 | line = line.strip() 323 | url1, url2, label = line.split('\t') 324 | if url1 not in url_to_code or url2 not in url_to_code: 325 | continue 326 | if label == '0': 327 | label = 0 328 | else: 329 | label = 1 330 | data.append(CloneExample(url_to_code[url1], url_to_code[url2], label, url1, url2)) 331 | idx += 1 332 | if idx == data_num: 333 | break 334 | return data 335 | 336 | 337 | def load_and_cache_gen_data(args, filename, pool, tokenizer, split_tag, only_src=False, is_sample=False): 338 | # cache the data into args.cache_path except it is sampled 339 | # only_src: control whether to return only source ids for bleu evaluating (dev/test) 340 | # return: examples (Example object), data (TensorDataset) 341 | data_tag = '_all' if args.data_num == -1 else '_%d' % args.data_num 342 | cache_fn = '{}/{}.pt'.format(args.cache_path, split_tag + ('_src' if only_src else '') + data_tag) 343 | 344 | examples = read_examples(filename, args.data_num, args.task) 345 | 346 | if is_sample: 347 | examples = random.sample(examples, min(args.n_debug_samples, len(examples))) 348 | if split_tag == 'train': 349 | calc_stats(examples, tokenizer, is_tokenize=True) 350 | else: 351 | calc_stats(examples) 352 | if os.path.exists(cache_fn) and not is_sample: 353 | logger.info("Load cache data from %s", cache_fn) 354 | data = torch.load(cache_fn) 355 | else: 356 | if is_sample: 357 | logger.info("Sample 5k data for computing bleu from %s", filename) 358 | else: 359 | logger.info("Create cache data into %s", cache_fn) 360 | tuple_examples = [(example, idx, tokenizer, args, split_tag) for idx, example in enumerate(examples)] 361 | features = pool.map(convert_examples_to_features, tqdm(tuple_examples, total=len(tuple_examples))) 362 | all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) 363 | if split_tag == 'test' or only_src: 364 | data = TensorDataset(all_source_ids) 365 | else: 366 | all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long) 367 | data = TensorDataset(all_source_ids, all_target_ids) 368 | # if args.local_rank in [-1, 0] and not is_sample: 369 | # torch.save(data, cache_fn) 370 | return examples, data 371 | 372 | 373 | def read_commit_examples(filename, data_num): 374 | """Read examples from filename.""" 375 | examples = [] 376 | with open(filename, encoding="utf-8") as f: 377 | for idx, line in enumerate(f): 378 | line = line.strip() 379 | js = json.loads(line) 380 | if 'idx' not in js: 381 | js['idx'] = idx 382 | code = ' '.join(js['diff_tokens']).replace('\n', ' ') 383 | code = ' '.join(code.strip().split()) 384 | nl = ' '.join(js['msg_tokens']).replace('\n', '') 385 | nl = ' '.join(nl.strip().split()) 386 | examples.append( 387 | Example( 388 | idx=idx, 389 | source=code, 390 | target=nl, 391 | ) 392 | ) 393 | if idx + 1 == data_num: 394 | break 395 | return examples 396 | 397 | def read_plain_diff_examples(filename, data_num): 398 | """Read examples from filename.""" 399 | examples = [] 400 | with open(filename, encoding="utf-8") as f: 401 | for idx, line in enumerate(f): 402 | line = line.strip() 403 | js = json.loads(line) 404 | if 'idx' not in js: 405 | js['idx'] = idx 406 | # code = ' '.join(js['diff_tokens']).replace('\n', ' ') 407 | # code = ' '.join(code.strip().split()) 408 | chunks = js["chunks"] 409 | plain_diff="" 410 | for chunk in chunks: 411 | plain_diff += " - " + " - ".join(chunk["old"]) + " + " + " + ".join(chunk["old"]) 412 | 413 | nl = ' '.join(js['msg_token']).replace('\n', '') 414 | nl = ' '.join(nl.strip().split()) 415 | examples.append( 416 | Example( 417 | idx=idx, 418 | source= plain_diff, 419 | target=nl, 420 | ) 421 | ) 422 | if idx + 1 == data_num: 423 | break 424 | return examples 425 | 426 | def read_old_context_plain_diff_examples(filename, data_num, sep_token): 427 | """Read examples from filename.""" 428 | examples = [] 429 | with open(filename, encoding="utf-8") as f: 430 | for idx, line in enumerate(f): 431 | line = line.strip() 432 | js = json.loads(line) 433 | if 'idx' not in js: 434 | js['idx'] = idx 435 | # code = ' '.join(js['diff_tokens']).replace('\n', ' ') 436 | # code = ' '.join(code.strip().split()) 437 | chunks = js["chunks"] 438 | plain_diff="" 439 | for chunk in chunks: 440 | plain_diff += " - " + " - ".join(chunk["old"]) + " + " + " + ".join(chunk["old"]) 441 | old_verison = " ".join(js["old"]) 442 | code_diff = old_verison + " " + sep_token + " " + plain_diff 443 | nl = ' '.join(js['msg_token']).replace('\n', '') 444 | nl = ' '.join(nl.strip().split()) 445 | examples.append( 446 | Example( 447 | idx=idx, 448 | source= code_diff, 449 | target=nl, 450 | ) 451 | ) 452 | if idx + 1 == data_num: 453 | break 454 | return examples 455 | 456 | def read_medit_examples(filename, data_num): 457 | """Read examples from filename.""" 458 | examples = [] 459 | with open(filename, encoding="utf-8") as f: 460 | for idx, line in enumerate(f): 461 | line = line.strip() 462 | js = json.loads(line) 463 | if 'idx' not in js: 464 | js['idx'] = idx 465 | # code = ' '.join(js['diff_tokens']).replace('\n', ' ') 466 | # code = ' '.join(code.strip().split()) 467 | chunks = js["chunks_diff"] 468 | medit= " " 469 | for chunk in chunks: 470 | medit += " ".join(chunk) 471 | 472 | nl = ' '.join(js['msg_token']).replace('\n', '') 473 | nl = ' '.join(nl.strip().split()) 474 | examples.append( 475 | Example( 476 | idx=idx, 477 | source= medit, 478 | target=nl, 479 | ) 480 | ) 481 | if idx + 1 == data_num: 482 | break 483 | return examples 484 | 485 | 486 | def read_contextual_medit_examples(filename, data_num): 487 | """Read examples from filename.""" 488 | examples = [] 489 | with open(filename, encoding="utf-8") as f: 490 | for idx, line in enumerate(f): 491 | line = line.strip() 492 | js = json.loads(line) 493 | if 'idx' not in js: 494 | js['idx'] = idx 495 | # code = ' '.join(js['diff_tokens']).replace('\n', ' ') 496 | # code = ' '.join(code.strip().split()) 497 | code_diff = ' '.join( js["diff"]) 498 | # chunks = js["chunks_diff"] 499 | # medit= " " 500 | # for chunk in chunks: 501 | # medit += " ".join(chunk) 502 | 503 | nl = ' '.join(js['msg_token']).replace('\n', '') 504 | nl = ' '.join(nl.strip().split()) 505 | examples.append( 506 | Example( 507 | idx=idx, 508 | source= code_diff, 509 | target=nl, 510 | ) 511 | ) 512 | if idx + 1 == data_num: 513 | break 514 | return examples 515 | 516 | 517 | def read_old_context_medit_examples(filename, data_num, sep_token): 518 | """Read examples from filename.""" 519 | examples = [] 520 | with open(filename, encoding="utf-8") as f: 521 | for idx, line in enumerate(f): 522 | line = line.strip() 523 | js = json.loads(line) 524 | if 'idx' not in js: 525 | js['idx'] = idx 526 | # code = ' '.join(js['diff_tokens']).replace('\n', ' ') 527 | # code = ' '.join(code.strip().split()) 528 | chunks = js["chunks_diff"] 529 | medit= " " 530 | for chunk in chunks: 531 | medit += " ".join(chunk) 532 | old_verison = " ".join(js["old"]) 533 | code_diff = old_verison + " " + sep_token + " " + medit 534 | 535 | nl = ' '.join(js['msg_token']).replace('\n', '') 536 | nl = ' '.join(nl.strip().split()) 537 | examples.append( 538 | Example( 539 | idx=idx, 540 | source= code_diff, 541 | target=nl, 542 | ) 543 | ) 544 | if idx + 1 == data_num: 545 | break 546 | return examples 547 | 548 | def load_and_cache_commit_data(args, filename, pool, tokenizer, split_tag, only_src=False, is_sample=False): 549 | # cache the data into args.cache_path except it is sampled 550 | # only_src: control whether to return only source ids for bleu evaluating (dev/test) 551 | # return: examples (Example object), data (TensorDataset) 552 | data_tag = '_all' if args.data_num == -1 else '_%d' % args.data_num 553 | cache_fn = '{}/{}.pt'.format(args.cache_path, split_tag + ('_src' if only_src else '') + data_tag) 554 | if is_sample: 555 | data_num = args.n_debug_samples 556 | else: 557 | data_num = args.data_num 558 | 559 | if args.diff_type in ["plain_diff_context"]: 560 | examples = read_commit_examples(filename, data_num) 561 | elif args.diff_type in ["plain_diff"]: 562 | examples =read_plain_diff_examples(filename, data_num) 563 | elif args.diff_type in ["old-plain-diff"]: 564 | examples =read_old_context_plain_diff_examples(filename, data_num, tokenizer.sep_token) 565 | elif args.diff_type in ["medit"]: 566 | examples =read_medit_examples(filename, data_num) 567 | elif args.diff_type in ["old-medit"]: 568 | examples =read_old_context_medit_examples(filename, data_num, tokenizer.sep_token) 569 | elif args.diff_type in ["contextual-medit"]: 570 | examples =read_contextual_medit_examples(filename, data_num) 571 | else: 572 | raise RuntimeError("no such diff type") 573 | 574 | # if is_sample: 575 | # # examples = random.sample(examples, min(args.n_debug_samples, len(examples))) 576 | # examples = examples[:args.n_debug_samples] 577 | if split_tag == 'train': 578 | calc_stats(examples, tokenizer, is_tokenize=True) 579 | else: 580 | calc_stats(examples) 581 | # if os.path.exists(cache_fn) and not is_sample: 582 | # logger.info("Load cache data from %s", cache_fn) 583 | # data = torch.load(cache_fn) 584 | # else: 585 | if is_sample: 586 | logger.info("Sample some data for computing bleu from %s", filename) 587 | else: 588 | logger.info("Create cache data into %s", cache_fn) 589 | tuple_examples = [(example, idx, tokenizer, args, split_tag) for idx, example in enumerate(examples)] 590 | features = pool.map(convert_examples_to_features, tqdm(tuple_examples, total=len(tuple_examples))) 591 | all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) 592 | if split_tag == 'test' or only_src: 593 | data = TensorDataset(all_source_ids) 594 | else: 595 | all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long) 596 | data = TensorDataset(all_source_ids, all_target_ids) 597 | # if args.local_rank in [-1, 0] and not is_sample: 598 | # torch.save(data, cache_fn) 599 | if args.debug: 600 | logger.info("*** Example ***") 601 | logger.info("idx: {}".format(examples[0].idx)) 602 | 603 | logger.info("source_tokens: {}".format( examples[0].source )) 604 | logger.info("source_ids: {}".format(' '.join(map(str, features[0].source_ids)))) 605 | 606 | 607 | logger.info("target_tokens: {}".format(examples[0].target)) 608 | logger.info("target_ids: {}".format(' '.join(map(str, features[0].target_ids)))) 609 | 610 | return examples, data 611 | 612 | def load_and_commit_data_with_retrieved_result(args, input_filename, retireved_filename, pool, tokenizer, split_tag, only_src=False, is_sample=False): 613 | # cache the data into args.cache_path except it is sampled 614 | # only_src: control whether to return only source ids for bleu evaluating (dev/test) 615 | # return: examples (Example object), data (TensorDataset) 616 | # data_tag = '_all' if args.data_num == -1 else '_%d' % args.data_num 617 | # cache_fn = '{}/{}.pt'.format(args.cache_path, split_tag + ('_src' if only_src else '') + data_tag) 618 | # if args.diff_type in ["plain_diff_context"]: 619 | # examples = read_commit_examples(filename, args.data_num) 620 | # elif args.diff_type in ["plain_diff"]: 621 | # examples =read_plain_diff_examples(filename, args.data_num) 622 | # elif args.diff_type in ["old-plain-diff"]: 623 | # examples =read_old_context_plain_diff_examples(filename, args.data_num, tokenizer.sep_token) 624 | # elif args.diff_type in ["medit"]: 625 | # examples =read_medit_examples(filename, args.data_num) 626 | # elif args.diff_type in ["old-medit"]: 627 | # examples =read_old_context_medit_examples(filename, args.data_num, tokenizer.sep_token) 628 | if args.diff_type in ["contextual-medit"]: 629 | if is_sample: 630 | input_examples =read_contextual_medit_examples(input_filename, args.n_debug_samples) 631 | retrieved_examples =read_contextual_medit_examples(retireved_filename, args.n_debug_samples) 632 | else: 633 | if "dev" in split_tag: 634 | data_num = 2000 635 | else: 636 | data_num = args.data_num 637 | input_examples =read_contextual_medit_examples(input_filename, data_num) 638 | retrieved_examples =read_contextual_medit_examples(retireved_filename, data_num) 639 | else: 640 | raise RuntimeError("no such diff type") 641 | 642 | if is_sample: 643 | input_examples = random.sample(input_examples, min(args.n_debug_samples, len(input_examples))) 644 | retrieved_examples = random.sample(retrieved_examples, min(args.n_debug_samples, len(input_examples))) 645 | # if split_tag == 'train': 646 | # calc_stats(examples, tokenizer, is_tokenize=True) 647 | # else: 648 | # calc_stats(examples) 649 | # if os.path.exists(cache_fn) and not is_sample: 650 | # logger.info("Load cache data from %s", cache_fn) 651 | # data = torch.load(cache_fn) 652 | # else: 653 | if is_sample: 654 | logger.info("Sample some data for computing bleu from %s", input_filename) 655 | logger.info("Sample some data for computing bleu from %s", retireved_filename) 656 | 657 | tuple_input_examples = [(example, idx, tokenizer, args, split_tag) for idx, example in enumerate(input_examples)] 658 | tuple_retrieved_examples = [(example, idx, tokenizer, args, "train") for idx, example in enumerate(retrieved_examples)] 659 | 660 | input_features = pool.map(convert_examples_to_features, tqdm(tuple_input_examples, total=len(tuple_input_examples))) 661 | retrieved_features = pool.map(convert_examples_to_features, tqdm(tuple_retrieved_examples, total=len(tuple_retrieved_examples))) 662 | 663 | all_input_source_ids = torch.tensor([f.source_ids for f in input_features], dtype=torch.long) 664 | 665 | all_retrieved_source_ids = torch.tensor([f.source_ids for f in retrieved_features], dtype=torch.long) 666 | all_retrieved_target_ids = torch.tensor([f.target_ids for f in retrieved_features], dtype=torch.long) 667 | 668 | if split_tag == 'test' or only_src: 669 | data = TensorDataset(all_input_source_ids, all_retrieved_source_ids,all_retrieved_target_ids) 670 | else: 671 | all_target_ids = torch.tensor([f.target_ids for f in input_features], dtype=torch.long) 672 | data = TensorDataset(all_input_source_ids, all_target_ids, all_retrieved_source_ids,all_retrieved_target_ids) 673 | # if args.local_rank in [-1, 0] and not is_sample: 674 | # torch.save(data, cache_fn) 675 | if args.debug: 676 | logger.info("*** Example ***") 677 | logger.info("idx: {}".format(input_examples[0].idx)) 678 | 679 | logger.info("source_tokens: {}".format( input_examples[0].source )) 680 | logger.info("source_ids: {}".format(' '.join(map(str, input_features[0].source_ids)))) 681 | 682 | logger.info("target_tokens: {}".format(input_examples[0].target)) 683 | logger.info("target_ids: {}".format(' '.join(map(str, input_features[0].target_ids)))) 684 | 685 | logger.info("retrieved_source_tokens: {}".format( retrieved_examples[0].source )) 686 | logger.info("retrieved_source_ids: {}".format(' '.join(map(str, retrieved_features[0].source_ids)))) 687 | 688 | logger.info("retrieved_target_tokens: {}".format(retrieved_examples[0].target)) 689 | logger.info("retrieved_target_ids: {}".format(' '.join(map(str, retrieved_features[0].target_ids)))) 690 | 691 | return input_examples, retrieved_examples, data 692 | 693 | def load_and_cache_clone_data(args, filename, pool, tokenizer, split_tag, is_sample=False): 694 | cache_fn = '{}/{}.pt'.format(args.cache_path, split_tag + '_all' if args.data_num == -1 else '_%d' % args.data_num) 695 | examples = read_examples(filename, args.data_num, args.task) 696 | if is_sample: 697 | examples = random.sample(examples, int(len(examples) * 0.1)) 698 | 699 | calc_stats(examples, tokenizer, is_tokenize=True) 700 | if os.path.exists(cache_fn): 701 | logger.info("Load cache data from %s", cache_fn) 702 | data = torch.load(cache_fn) 703 | else: 704 | if is_sample: 705 | logger.info("Sample 10 percent of data from %s", filename) 706 | elif args.data_num == -1: 707 | logger.info("Create cache data into %s", cache_fn) 708 | tuple_examples = [(example, idx, tokenizer, args) for idx, example in enumerate(examples)] 709 | features = pool.map(convert_clone_examples_to_features, tqdm(tuple_examples, total=len(tuple_examples))) 710 | all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) 711 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 712 | data = TensorDataset(all_source_ids, all_labels) 713 | 714 | if args.local_rank in [-1, 0] and args.data_num == -1: 715 | torch.save(data, cache_fn) 716 | return examples, data 717 | 718 | 719 | def load_and_cache_defect_data(args, filename, pool, tokenizer, split_tag, is_sample=False): 720 | cache_fn = os.path.join(args.cache_path, split_tag) 721 | examples = read_examples(filename, args.data_num, args.task) 722 | if is_sample: 723 | examples = random.sample(examples, int(len(examples) * 0.1)) 724 | 725 | calc_stats(examples, tokenizer, is_tokenize=True) 726 | if os.path.exists(cache_fn): 727 | logger.info("Load cache data from %s", cache_fn) 728 | data = torch.load(cache_fn) 729 | else: 730 | if is_sample: 731 | logger.info("Sample 10 percent of data from %s", filename) 732 | elif args.data_num == -1: 733 | logger.info("Create cache data into %s", cache_fn) 734 | tuple_examples = [(example, idx, tokenizer, args) for idx, example in enumerate(examples)] 735 | features = pool.map(convert_defect_examples_to_features, tqdm(tuple_examples, total=len(tuple_examples))) 736 | # features = [convert_clone_examples_to_features(x) for x in tuple_examples] 737 | all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long) 738 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 739 | data = TensorDataset(all_source_ids, all_labels) 740 | 741 | if args.local_rank in [-1, 0] and args.data_num == -1: 742 | torch.save(data, cache_fn) 743 | return examples, data 744 | 745 | 746 | def get_filenames(data_root, task, sub_task, split=''): 747 | if task == 'concode': 748 | data_dir = '{}/{}'.format(data_root, task) 749 | train_fn = '{}/train.json'.format(data_dir) 750 | dev_fn = '{}/dev.json'.format(data_dir) 751 | test_fn = '{}/test.json'.format(data_dir) 752 | elif task == 'summarize': 753 | data_dir = '{}/{}/{}'.format(data_root, task, sub_task) 754 | train_fn = '{}/train.jsonl'.format(data_dir) 755 | dev_fn = '{}/valid.jsonl'.format(data_dir) 756 | test_fn = '{}/test.jsonl'.format(data_dir) 757 | elif task == 'refine': 758 | data_dir = '{}/{}/{}'.format(data_root, task, sub_task) 759 | train_fn = '{}/train.buggy-fixed.buggy,{}/train.buggy-fixed.fixed'.format(data_dir, data_dir) 760 | dev_fn = '{}/valid.buggy-fixed.buggy,{}/valid.buggy-fixed.fixed'.format(data_dir, data_dir) 761 | test_fn = '{}/test.buggy-fixed.buggy,{}/test.buggy-fixed.fixed'.format(data_dir, data_dir) 762 | elif task == 'translate': 763 | data_dir = '{}/{}'.format(data_root, task) 764 | if sub_task == 'cs-java': 765 | train_fn = '{}/train.java-cs.txt.cs,{}/train.java-cs.txt.java'.format(data_dir, data_dir) 766 | dev_fn = '{}/valid.java-cs.txt.cs,{}/valid.java-cs.txt.java'.format(data_dir, data_dir) 767 | test_fn = '{}/test.java-cs.txt.cs,{}/test.java-cs.txt.java'.format(data_dir, data_dir) 768 | else: 769 | train_fn = '{}/train.java-cs.txt.java,{}/train.java-cs.txt.cs'.format(data_dir, data_dir) 770 | dev_fn = '{}/valid.java-cs.txt.java,{}/valid.java-cs.txt.cs'.format(data_dir, data_dir) 771 | test_fn = '{}/test.java-cs.txt.java,{}/test.java-cs.txt.cs'.format(data_dir, data_dir) 772 | elif task == 'clone': 773 | data_dir = '{}/{}'.format(data_root, task) 774 | train_fn = '{}/train.txt'.format(data_dir) 775 | dev_fn = '{}/valid.txt'.format(data_dir) 776 | test_fn = '{}/test.txt'.format(data_dir) 777 | elif task == 'defect': 778 | data_dir = '{}/{}'.format(data_root, task) 779 | train_fn = '{}/train.jsonl'.format(data_dir) 780 | dev_fn = '{}/valid.jsonl'.format(data_dir) 781 | test_fn = '{}/test.jsonl'.format(data_dir) 782 | if split == 'train': 783 | return train_fn 784 | elif split == 'dev': 785 | return dev_fn 786 | elif split == 'test': 787 | return test_fn 788 | else: 789 | return train_fn, dev_fn, test_fn 790 | 791 | 792 | def read_examples(filename, data_num, task): 793 | read_example_dict = { 794 | 'summarize': read_summarize_examples, 795 | 'refine': read_refine_examples, 796 | 'translate': read_translate_examples, 797 | 'concode': read_concode_examples, 798 | 'clone': read_clone_examples, 799 | 'defect': read_defect_examples, 800 | } 801 | return read_example_dict[task](filename, data_num) 802 | 803 | 804 | def calc_stats(examples, tokenizer=None, is_tokenize=False): 805 | avg_src_len = [] 806 | avg_trg_len = [] 807 | avg_src_len_tokenize = [] 808 | avg_trg_len_tokenize = [] 809 | for ex in examples: 810 | if is_tokenize: 811 | avg_src_len.append(len(ex.source.split())) 812 | avg_trg_len.append(len(str(ex.target).split())) 813 | avg_src_len_tokenize.append(len(tokenizer.tokenize(ex.source))) 814 | avg_trg_len_tokenize.append(len(tokenizer.tokenize(str(ex.target)))) 815 | else: 816 | avg_src_len.append(len(ex.source.split())) 817 | avg_trg_len.append(len(str(ex.target).split())) 818 | if is_tokenize: 819 | logger.info("Read %d examples, avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d", 820 | len(examples), np.mean(avg_src_len), np.mean(avg_trg_len), max(avg_src_len), max(avg_trg_len)) 821 | logger.info("[TOKENIZE] avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d", 822 | np.mean(avg_src_len_tokenize), np.mean(avg_trg_len_tokenize), max(avg_src_len_tokenize), 823 | max(avg_trg_len_tokenize)) 824 | else: 825 | logger.info("Read %d examples, avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d", 826 | len(examples), np.mean(avg_src_len), np.mean(avg_trg_len), max(avg_src_len), max(avg_trg_len)) 827 | 828 | 829 | def get_elapse_time(t0): 830 | elapse_time = time.time() - t0 831 | if elapse_time > 3600: 832 | hour = int(elapse_time // 3600) 833 | minute = int((elapse_time % 3600) // 60) 834 | return "{}h{}m".format(hour, minute) 835 | else: 836 | minute = int((elapse_time % 3600) // 60) 837 | return "{}m".format(minute) 838 | 839 | 840 | def save_pickle_data(path_dir, filename, data): 841 | full_path = path_dir + '/' + filename 842 | print("Save dataset to: %s" % full_path) 843 | if not os.path.exists(path_dir): 844 | os.makedirs(path_dir) 845 | 846 | with open(full_path, 'wb') as output: 847 | pickle.dump(data, output,protocol=4) 848 | 849 | 850 | def read_json_file(filename): 851 | with open(filename, 'r') as fp: 852 | data = fp.readlines() 853 | if len(data) == 1: 854 | data = json.loads(data[0]) 855 | else: 856 | data = [json.loads(line) for line in data] 857 | return data 858 | 859 | def save_json_data(data_dir, filename, data): 860 | os.makedirs(data_dir, exist_ok=True) 861 | file_name = os.path.join(data_dir, filename) 862 | with open(file_name, 'w') as output: 863 | if type(data) == list: 864 | if type(data[0]) in [str, list,dict]: 865 | for item in data: 866 | output.write(json.dumps(item)) 867 | output.write('\n') 868 | 869 | else: 870 | json.dump(data, output) 871 | elif type(data) == dict: 872 | json.dump(data, output) 873 | else: 874 | raise RuntimeError('Unsupported type: %s' % type(data)) 875 | print("saved dataset in " + file_name) 876 | 877 | def percent_len(all_len,percentiles=None): 878 | if percentiles is None: 879 | percentiles = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] 880 | ptiles_vers = list(np.percentile(all_len, np.array(percentiles))) 881 | ptiles_vers =[str(round(item,4)) for item in ptiles_vers] 882 | tb = pt.PrettyTable() 883 | tb.field_names = ['mean'] + percentiles 884 | mean_value = round(np.mean(all_len), 1) 885 | tb.add_row([mean_value] + ptiles_vers) 886 | print(tb) 887 | latex_output = "& %.2f &"% float(mean_value) + " &".join(ptiles_vers) 888 | print(latex_output) 889 | 890 | def cal_r1_r5_r10(ranks): 891 | r1,r5,r10= 0,0,0 892 | data_len= len(ranks) 893 | for item in ranks: 894 | if item >=1: 895 | r1 +=1 896 | r5 += 1 897 | r10 += 1 898 | elif item >=0.2: 899 | r5+= 1 900 | r10+=1 901 | elif item >=0.1: 902 | r10 +=1 903 | # print("& %.3f &%.3f &%.3f "%(round(r1/data_len,4), round(r5/data_len,4), round(r10/data_len,4))) 904 | result = {"R@1":round(r1/data_len,3), "R@5": round(r5/data_len,3), "R@10": round(r10/data_len,3)} 905 | return result 906 | 907 | def time_format(time_cost): 908 | m, s = divmod(time_cost, 60) 909 | h, m = divmod(m, 60) 910 | # print("time_cost: %d" % (time_cost)) 911 | return "%02d:%02d:%02d" % (h, m, s) 912 | 913 | def array_split(original_data, core_num): 914 | data = [] 915 | total_size = len(original_data) 916 | per_core_size = math.ceil(total_size / core_num) 917 | for i in range(core_num): 918 | lower_bound = i * per_core_size 919 | upper_bound = min((i + 1) * per_core_size, total_size) 920 | data.append(original_data[lower_bound:upper_bound]) 921 | return data --------------------------------------------------------------------------------