├── ChatGPT ├── README.md ├── data │ ├── data_intro.json │ └── examples.json ├── execute_scripts.py ├── images │ ├── debug.png │ ├── one_shot_cot.png │ └── zero_shot.png ├── inference_scripts.py └── output │ └── debug.json ├── CodeT5+ ├── README.md ├── evaluator │ ├── CodeBLEU │ │ ├── bleu.py │ │ ├── calc_code_bleu.py │ │ ├── dataflow_match.py │ │ ├── keywords │ │ │ ├── c_sharp.txt │ │ │ ├── java.txt │ │ │ └── python.txt │ │ ├── parser │ │ │ ├── DFG.py │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ ├── build.sh │ │ │ ├── my-languages.so │ │ │ └── utils.py │ │ ├── readme.txt │ │ ├── syntax_match.py │ │ ├── utils.py │ │ └── weighted_ngram_match.py │ ├── bleu.py │ └── smooth_bleu.py ├── images │ ├── DLTrans.png │ ├── MultilingualTrans.png │ └── NicheTrans.png ├── run_preprocess.py ├── run_score.py ├── run_train_DLTrans.sh ├── run_train_MultilingualTrans_many_to_many.sh ├── run_train_MultilingualTrans_many_to_one.sh ├── run_train_MultilingualTrans_one_to_many.sh ├── run_train_MultilingualTrans_one_to_one.sh ├── run_train_RareTrans_many_to_many.sh ├── run_train_RareTrans_many_to_many_only_rare_to_popular.sh └── run_translation.py ├── LICENSE ├── README.md └── images ├── Google_Drive_Logo_16px.png ├── codetransocean.png └── leaderboard6.png /ChatGPT/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeixiangYAN/CodeTransOcean/42e2cd3b41b3a18a6dba3dfdf425f772360304ca/ChatGPT/README.md -------------------------------------------------------------------------------- /ChatGPT/execute_scripts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import re 4 | import subprocess 5 | import argparse 6 | import os 7 | 8 | def save_to_folder(base_dir, type): 9 | type_dir = os.path.join(base_dir, type) 10 | 11 | if not os.path.exists(type_dir): 12 | os.makedirs(type_dir) 13 | 14 | return type_dir 15 | 16 | def run_python_script(script_path): 17 | if not script_path: 18 | return "Please provide a valid Python script path." 19 | command = f"python {script_path}" 20 | k = 0 21 | try: 22 | start_time = time.time() 23 | process = subprocess.Popen(command, shell=True) 24 | while True: 25 | if process.poll() is not None: 26 | output = subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT, text=True) 27 | break 28 | elif time.time() - start_time > 20: 29 | k = 1 30 | process.terminate() 31 | output = "timeout" 32 | break 33 | time.sleep(0.1) 34 | except subprocess.CalledProcessError as e: 35 | k = 1 36 | output = str(e.output) 37 | return output, k 38 | 39 | def deal_code(code): 40 | code = re.sub(r'\n', ' ', code) 41 | code = re.sub(r'\s+', ' ', code) 42 | code = re.sub(r'^[\s\n]+|[\s\n]+$', '', code) 43 | return code 44 | 45 | def save_result_to_json(data, file_path): 46 | with open(file_path, "a",encoding='utf-8') as file: 47 | json.dump(data, file, ensure_ascii=False) 48 | file.write('\n') 49 | 50 | def process_scripts_and_save(ref_path,type): 51 | contents = [] 52 | with open(ref_path, 'r', encoding='utf-8') as f: 53 | for i in f.readlines(): 54 | content = json.loads(i) 55 | contents.append(content) 56 | 57 | for j, content in enumerate(contents): 58 | key_source = list(content.keys())[3] 59 | source = content[key_source] 60 | target = deal_code(content["output"]) 61 | output, k = run_python_script('output/'+type + f"/output{j}.py") 62 | output = deal_code(output) 63 | 64 | with open('output/'+type + f"/output{j}.py", "r", encoding="utf-8") as f: 65 | python_code = f.read() 66 | 67 | if output == target: 68 | dic = {'id': j, "label": 1, "output": output, "Python": python_code, key_source: source} 69 | elif k == 0 and output != target: 70 | dic = {'id': j, "label": 0, "output": "The code compiles but the output is incorrect.", "Python": python_code, key_source: source} 71 | else: 72 | if len(output) > 2000: 73 | output = output[:2000] 74 | dic = {'id': j, "label": 0, "output": output, "Python": python_code, key_source: source} 75 | 76 | type_dir = save_to_folder("executed_output", type) 77 | 78 | save_result_to_json(dic, os.path.join(type_dir, f"executed_result.json")) 79 | 80 | def calculate_accuracy(type): 81 | contents = [] 82 | dsr = 0 83 | type_dir = save_to_folder("executed_output", type) 84 | with open(os.path.join(type_dir, f"executed_result.json"), 'r', encoding='utf-8') as f: 85 | for i in f.readlines(): 86 | content = json.loads(i) 87 | contents.append(content) 88 | if content['label'] == 1: 89 | dsr += 1 90 | 91 | return dsr / len(contents) 92 | 93 | def main(): 94 | parser = argparse.ArgumentParser(description='Get py file.') 95 | parser.add_argument('--type', '-type', help="Experiment type.") 96 | parser.add_argument('--ref_path', '-ref_path', help="The ground truth of the execution results.") 97 | args = parser.parse_args() 98 | 99 | process_scripts_and_save(args.ref_path, args.type) 100 | dsr = calculate_accuracy(args.type) 101 | print("DSR:", dsr) 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /ChatGPT/images/debug.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeixiangYAN/CodeTransOcean/42e2cd3b41b3a18a6dba3dfdf425f772360304ca/ChatGPT/images/debug.png -------------------------------------------------------------------------------- /ChatGPT/images/one_shot_cot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeixiangYAN/CodeTransOcean/42e2cd3b41b3a18a6dba3dfdf425f772360304ca/ChatGPT/images/one_shot_cot.png -------------------------------------------------------------------------------- /ChatGPT/images/zero_shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeixiangYAN/CodeTransOcean/42e2cd3b41b3a18a6dba3dfdf425f772360304ca/ChatGPT/images/zero_shot.png -------------------------------------------------------------------------------- /ChatGPT/inference_scripts.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import json 3 | import time 4 | import argparse 5 | import os 6 | 7 | def save_to_folder(base_dir, type): 8 | 9 | 10 | type_dir = os.path.join(base_dir, type) 11 | 12 | if not os.path.exists(type_dir): 13 | os.makedirs(type_dir) 14 | 15 | return type_dir 16 | 17 | def process_content_keys(key_source): 18 | if key_source == 'VB': 19 | key_source = 'Visual Basic' 20 | return key_source 21 | 22 | def handle_debug(): 23 | contents = [] 24 | with open("output/final.json", 'r',encoding = 'utf-8') as f: 25 | for i in f.readlines(): 26 | content = json.loads(i) 27 | contents.append(content) 28 | f.close() 29 | 30 | type_dir = save_to_folder('output', 'debug') 31 | 32 | 33 | 34 | for j, content in enumerate(contents): 35 | label = content["label"] 36 | if label == 1: 37 | with open(type_dir +f"/output{j}.py", "w", encoding="utf-8") as f: 38 | f.write(content["Python"]) 39 | else: 40 | corrected_code = debug_code(content) 41 | with open(type_dir +f"/output{j}.py", "w", encoding="utf-8") as f: 42 | f.write(corrected_code) 43 | 44 | def debug_code(content): 45 | python_code = content["Python"] 46 | key_source = list(content.keys())[4] 47 | source = content[key_source] 48 | error = content['output'] 49 | text = f"Translate {key_source} to Python :{source}.\n\nChatGPT:{python_code}.\n\nUser: The above python code compiles with the following errors, please correct them.{error}" 50 | corrected_code = chatgpt(text) 51 | return corrected_code 52 | 53 | def prepare_prompt(type,file_path): 54 | if type == 'cot_1': 55 | file_path = "data/data_intro.json" 56 | texts = [] 57 | targets = [] 58 | with open(file_path, 'r', encoding='utf-8') as f: 59 | for i in f.readlines(): 60 | content = json.loads(i) 61 | target = content["output"] 62 | key_source = list(content.keys())[3] 63 | source = content[key_source] 64 | key_source = process_content_keys(list(content.keys())[3]) 65 | key_intro = list(content.keys())[4] 66 | intro = content[key_intro] 67 | text = f"Function description:{intro}\nPlease translate into Python code according to the following {key_source} code and its functional description:{source}\nDo not return anything including notes and the like except for one translated Python code." 68 | print(text) 69 | texts.append(text) 70 | targets.append(target) 71 | return texts, targets 72 | else: 73 | examples=[] 74 | with open("data/examples.json", 'r', encoding='utf-8') as f_new: 75 | for z in f_new.readlines(): 76 | example = json.loads(z) 77 | examples.append(example) 78 | f_new.close() 79 | texts = [] 80 | targets = [] 81 | with open(file_path, 'r', encoding='utf-8') as f: 82 | for i in f.readlines(): 83 | content = json.loads(i) 84 | target = content["output"] 85 | key_source = list(content.keys())[3] 86 | source = content[key_source] 87 | key_source = process_content_keys(list(content.keys())[3]) 88 | if(type == "zero_shot_1"): 89 | text = f"Translate {key_source} to Python:{source}\nDo not return anything including notes and the like except for one translated Python code." 90 | elif(type == "zero_shot_2" or type == "0shot_3" or type == "0shot_4" or type == "0shot_5"): 91 | text = f"Please provide the Python translation for the following {key_source} code:\n{source}\nDo not return anything including notes and the like except for one translated Python code." 92 | elif(type == "zero_shot_6"): 93 | text = f"Please translate the following {key_source} code into Python code:{source}\nDo not return anything including notes and the like except for one translated Python code." 94 | elif(type == "zero_shot_7"): 95 | text = f"Translating {key_source} to Python ensures that Python code can be compiled:{source}\nDo not return anything including notes and the like except for one translated Python code." 96 | elif(type == "zero_shot_8"): 97 | text = f"Can you rewrite this {key_source} code in Python? {source}" 98 | elif(type == "one_shot_1"): 99 | text = '''Here is an example of a translation from Java to Python.\n"Java": "import java.util.ArrayList;\nimport java.util.Arrays;\nimport java.util.LinkedList;\nimport java.util.List;\nimport java.util.Queue;\n\npublic class WordBreak {\n\n public static void main(String[] args) {\n List dict = Arrays.asList(\"a\", \"aa\", \"b\", \"ab\", \"aab\");\n for ( String testString : Arrays.asList(\"aab\", \"aa b\") ) {\n List> matches = wordBreak(testString, dict);\n System.out.printf(\"String = %s, Dictionary = %s. Solutions = %d:%n\", testString, dict, matches.size());\n for ( List match : matches ) {\n System.out.printf(\" Word Break = %s%n\", match);\n }\n System.out.printf(\"%n\");\n }\n dict = Arrays.asList(\"abc\", \"a\", \"ac\", \"b\", \"c\", \"cb\", \"d\");\n for ( String testString : Arrays.asList(\"abcd\", \"abbc\", \"abcbcd\", \"acdbc\", \"abcdd\") ) {\n List> matches = wordBreak(testString, dict);\n System.out.printf(\"String = %s, Dictionary = %s. Solutions = %d:%n\", testString, dict, matches.size());\n for ( List match : matches ) {\n System.out.printf(\" Word Break = %s%n\", match);\n }\n System.out.printf(\"%n\");\n }\n }\n \n private static List> wordBreak(String s, List dictionary) {\n List> matches = new ArrayList<>();\n Queue queue = new LinkedList<>();\n queue.add(new Node(s));\n while ( ! queue.isEmpty() ) {\n Node node = queue.remove();\n \n if ( node.val.length() == 0 ) {\n matches.add(node.parsed);\n }\n else {\n for ( String word : dictionary ) {\n \n if ( node.val.startsWith(word) ) {\n String valNew = node.val.substring(word.length(), node.val.length());\n List parsedNew = new ArrayList<>();\n parsedNew.addAll(node.parsed);\n parsedNew.add(word);\n queue.add(new Node(valNew, parsedNew));\n }\n }\n }\n }\n return matches;\n }\n \n private static class Node {\n private String val; \n private List parsed; \n public Node(String initial) {\n val = initial;\n parsed = new ArrayList<>();\n }\n public Node(String s, List p) {\n val = s;\n parsed = p;\n }\n }\n\n}\n", "Python": "from itertools import (chain)\n\n\n\ndef stringParse(lexicon):\n \n return lambda s: Node(s)(\n tokenTrees(lexicon)(s)\n )\n\n\n\ndef tokenTrees(wds):\n \n def go(s):\n return [Node(s)([])] if s in wds else (\n concatMap(nxt(s))(wds)\n )\n\n def nxt(s):\n return lambda w: parse(\n w, go(s[len(w):])\n ) if s.startswith(w) else []\n\n def parse(w, xs):\n return [Node(w)(xs)] if xs else xs\n\n return lambda s: go(s)\n\n\n\ndef showParse(tree):\n \n def showTokens(x):\n xs = x['nest']\n return ' ' + x['root'] + (showTokens(xs[0]) if xs else '')\n parses = tree['nest']\n return tree['root'] + ':\\n' + (\n '\\n'.join(\n map(showTokens, parses)\n ) if parses else ' ( Not parseable in terms of these words )'\n )\n\n\n\n\ndef main():\n \n\n lexicon = 'a bc abc cd b'.split()\n testSamples = 'abcd abbc abcbcd acdbc abcdd'.split()\n\n print(unlines(\n map(\n showParse,\n map(\n stringParse(lexicon),\n testSamples\n )\n )\n ))\n\n\n\n\n\ndef Node(v):\n \n return lambda xs: {'type': 'Node', 'root': v, 'nest': xs}\n\n\n\ndef concatMap(f):\n \n return lambda xs: list(\n chain.from_iterable(map(f, xs))\n )\n\n\n\ndef unlines(xs):\n \n return '\\n'.join(xs)\n\n\n\nif __name__ == '__main__':\n main()\n" '''+ f"Please imitate this example to translate following code from {key_source} to Python:{source}\nDo not return anything including notes and the like except for one translated Python code." 100 | elif(type == "one_shot_2"): 101 | for j in range(len(examples)): 102 | example = examples[j] 103 | example_source = list(example.keys())[1] 104 | example_source = process_content_keys(example_source) 105 | example_target = list(example.keys())[2] 106 | if(example_source==key_source and example_target=="Python"): 107 | t = str(example) 108 | break 109 | text = f'''Here is an example of a translation from {key_source} to Python. 110 | '''+ t+f"\nPlease imitate this example to translate following code from {key_source} to Python:{source}Do not return anything including notes and the like except for one translated Python code." 111 | elif(type == "one_shot_3"): 112 | text = '''Here is an example of a translation from Go to C++. "Go": "package main\n\nimport (\n \"errors\"\n \"fmt\"\n \"log\"\n)\n\nvar (\n v1 = []int{1, 3, -5}\n v2 = []int{4, -2, -1}\n)\n\nfunc dot(x, y []int) (r int, err error) {\n if len(x) != len(y) {\n return 0, errors.New(\"incompatible lengths\")\n }\n for i, xi := range x {\n r += xi * y[i]\n }\n return\n}\n\nfunc main() {\n d, err := dot([]int{1, 3, -5}, []int{4, -2, -1})\n if err != nil {\n log.Fatal(err)\n }\n fmt.Println(d)\n}\n", "C++": "#include \n#include \n\nint main()\n{\n int a[] = { 1, 3, -5 };\n int b[] = { 4, -2, -1 };\n\n std::cout << std::inner_product(a, a + sizeof(a) / sizeof(a[0]), b, 0) << std::endl;\n\n return 0;\n}" 113 | '''+ f"Please imitate this example to translate following code from {key_source} to Python:{source}\nDo not return anything including notes and the like except for one translated Python code." 114 | elif(type == "cot_2"): 115 | text = f"First, understand the function of the following {key_source} code. Then, translate the {key_source} code into Python code while keeping the function unchanged.\n{source}\nDo not return anything including notes and the like except for one translated Python code." 116 | elif(type == "cot_3"): 117 | text = f"First, understand the functionality of the following {key_source} code and predict the compilation output. Then, translate the {key_source} code into Python while maintaining the same functionality, ensuring that the translated code can be successfully compiled.\n{source}\nDo not return anything including notes and the like except for one translated Python code." 118 | elif(type == "cot_4"): 119 | for j in range(len(examples)): 120 | example = examples[j] 121 | example_source = list(example.keys())[1] 122 | example_source = process_content_keys(example_source) 123 | example_target = list(example.keys())[2] 124 | if(example_source==key_source and example_target=="Python"): 125 | t = str(example) 126 | break 127 | text = f''' 128 | First, learn how to translate {key_source} code to Python based on the example, '''+t \ 129 | +f'''. Then, understand the functionality of the following {key_source} code and predict the compilation output, 130 | {key_source}: {source}. Finally, translate the {key_source} code into Python while maintaining the same functionality, ensuring that the translated code can be successfully compiled. 131 | '''+ f"Do not return anything including notes and the like except for one translated Python code." 132 | print(text) 133 | texts.append(text) 134 | targets.append(target) 135 | return texts, targets 136 | 137 | def chatgpt(text='', type='', i=0): 138 | messages = [] 139 | # Check if a system role is needed based on the type 140 | if type in ["zero_shot_3", "zero_shot_4", "zero_shot_5"]: 141 | key_sources = [] 142 | with open("data/LLM_trans.json", 'r', encoding='utf-8') as f: 143 | key_sources = [list(json.loads(j).keys())[3] for j in f.readlines()] 144 | text_system = { 145 | "zero_shot_3": f"You are a code translation system that specializes in {key_sources[i]} and Python programming languages.", 146 | "zero_shot_4": "You are a programmer proficient in multiple programming languages.", 147 | "zero_shot_5": f"You are a programmer proficient in {key_sources[i]} and Python programming languages." 148 | }.get(type) 149 | messages.append({"role": "system", "content": text_system}) 150 | messages.append({"role": "user", "content": text}) 151 | while True: 152 | try: 153 | response = openai.ChatCompletion.create( 154 | model="gpt-3.5-turbo", 155 | temperature=0, 156 | top_p=0, 157 | messages=messages 158 | ) 159 | return response['choices'][0]['message']['content'] 160 | 161 | except Exception as e: 162 | print("An error occurred: ", str(e)) 163 | time.sleep(5) 164 | 165 | def main(): 166 | parser = argparse.ArgumentParser(description='Get py file.') 167 | parser.add_argument('--key', '-key',help="Key of your chatgpt api.") 168 | parser.add_argument('--type', '-type',help="Experiment type. Options available:['zero_shot_1', 'zero_shot_2', 'zero_shot_3', 'zero_shot_4', 'zero_shot_5', 'zero_shot_6', 'zero_shot_7', 'zero_shot_8', 'one_shot_1', 'one_shot_2', 'one_shot_3', 'cot_1', 'cot_2', cot_3', 'debug']. Please refer to the explanations of the strategists in the paper.") 169 | parser.add_argument('--path', '-path',help="path of dataset.") 170 | 171 | args = parser.parse_args() 172 | 173 | openai.api_key = args.key 174 | if args.type == 'debug': 175 | handle_debug() 176 | else: 177 | texts, targets = prepare_prompt(args.type, args.path) 178 | restexts = [chatgpt(text, args.type, i) for i, text in enumerate(texts)] 179 | 180 | type_dir = save_to_folder('output', args.type) 181 | 182 | for j, (restext) in enumerate(restexts): 183 | with open(type_dir + f"/output{j}.py", "w", encoding="utf-8") as f: 184 | f.write(restext) 185 | 186 | if __name__ == '__main__': 187 | main() -------------------------------------------------------------------------------- /ChatGPT/output/debug.json: -------------------------------------------------------------------------------- 1 | {"id": 0, "label": 0, "output": "File \"output/output0.py\", line 1 Here is the translation of the given C code to Python: ^ SyntaxError: invalid syntax", "Python": "Here is the translation of the given C code to Python:\n\n```python\nimport ctypes\n\ndef can_make_words(blocks, word):\n i = 0\n ret = 0\n c = word[0].upper()\n\n def swap(a, b):\n if a != b:\n tmp = a\n a = b\n b = tmp\n\n if not c:\n return 1\n if not blocks[0]:\n return 0\n\n while blocks[i] and not ret:\n if blocks[i][0] != c and blocks[i][1] != c:\n i += 1\n continue\n swap(blocks[i], blocks[0])\n ret = can_make_words(blocks[1:], word[1:])\n swap(blocks[i], blocks[0])\n i += 1\n\n return ret\n\ndef main():\n blocks = (ctypes.c_char_p * 21)(\n b\"BO\", b\"XK\", b\"DQ\", b\"CP\", b\"NA\",\n b\"GT\", b\"RE\", b\"TG\", b\"QD\", b\"FS\",\n b\"JW\", b\"HU\", b\"VI\", b\"AN\", b\"OB\",\n b\"ER\", b\"FS\", b\"LY\", b\"PC\", b\"ZM\",\n 0\n )\n\n words = (ctypes.c_char_p * 8)(\n b\"\", b\"A\", b\"BARK\", b\"BOOK\", b\"TREAT\", b\"COMMON\", b\"SQUAD\", b\"Confuse\", 0\n )\n\n w = words\n while w.contents.value:\n print(w.contents.value.decode(), can_make_words(blocks, w.contents.value.decode()))\n w += 1\n\nif __name__ == \"__main__\":\n main()\n```\n\nNote: The translation uses the `ctypes` module to handle the C-style arrays in Python.", "C": "#include \n#include \n\nint can_make_words(char **b, char *word)\n{\n\tint i, ret = 0, c = toupper(*word);\n\n#define SWAP(a, b) if (a != b) { char * tmp = a; a = b; b = tmp; }\n\n\tif (!c) return 1;\n\tif (!b[0]) return 0;\n\n\tfor (i = 0; b[i] && !ret; i++) {\n\t\tif (b[i][0] != c && b[i][1] != c) continue;\n\t\tSWAP(b[i], b[0]);\n\t\tret = can_make_words(b + 1, word + 1);\n\t\tSWAP(b[i], b[0]);\n\t}\n\n\treturn ret;\n}\n\nint main(void)\n{\n\tchar* blocks[] = {\n\t\t\"BO\", \"XK\", \"DQ\", \"CP\", \"NA\", \n\t\t\"GT\", \"RE\", \"TG\", \"QD\", \"FS\", \n\t\t\"JW\", \"HU\", \"VI\", \"AN\", \"OB\", \n\t\t\"ER\", \"FS\", \"LY\", \"PC\", \"ZM\",\n\t\t0 };\n\n\tchar *words[] = {\n\t\t\"\", \"A\", \"BARK\", \"BOOK\", \"TREAT\", \"COMMON\", \"SQUAD\", \"Confuse\", 0\n\t};\n\n\tchar **w;\n\tfor (w = words; *w; w++)\n\t\tprintf(\"%s\\t%d\\n\", *w, can_make_words(blocks, *w));\n\n\treturn 0;\n}"} 2 | {"id": 1, "label": 0, "output": "File \"output/output1.py\", line 1 Here is the translated Python code from the given C code: ^ SyntaxError: invalid syntax", "Python": "Here is the translated Python code from the given C code:\n\n```python\ndef droot(x, base, pers):\n d = 0\n if pers:\n *pers = 0\n while x >= base:\n for d = 0, x:\n d += x % base\n x /= base\n *pers += 1\n elif x and not (d = x % (base - 1)):\n d = base - 1\n\n return d\n\ndef main():\n x = [627615, 39390, 588225, 393900588225]\n for i in range(4):\n pers = 0\n d = droot(x[i], 10, pers)\n print(f\"{x[i]}: pers {pers}, root {d}\")\n\nif __name__ == '__main__':\n main()\n```\n\nNote: The code may need further adjustments or modifications based on the specific requirements and context.", "C": "#include \n\nint droot(long long int x, int base, int *pers)\n{\n\tint d = 0;\n\tif (pers)\n\t\tfor (*pers = 0; x >= base; x = d, (*pers)++)\n\t\t\tfor (d = 0; x; d += x % base, x /= base);\n\telse if (x && !(d = x % (base - 1)))\n\t\t\td = base - 1;\n\n\treturn d;\n}\n\nint main(void)\n{\n\tint i, d, pers;\n\tlong long x[] = {627615, 39390, 588225, 393900588225LL};\n\n\tfor (i = 0; i < 4; i++) {\n\t\td = droot(x[i], 10, &pers);\n\t\tprintf(\"%lld: pers %d, root %d\\n\", x[i], pers, d);\n\t}\n\n\treturn 0;\n}"} 3 | {"id": 2, "label": 0, "output": "File \"output/output2.py\", line 1 Here is the translation of the given C code to Python: ^ SyntaxError: invalid syntax", "Python": "Here is the translation of the given C code to Python:\n\n```python\ndef mul_inv(a, b):\n b0 = b\n t, q = 0, 0\n x0, x1 = 0, 1\n if b == 1:\n return 1\n while a > 1:\n q = a // b\n t = b\n b = a % b\n a = t\n t = x0\n x0 = x1 - q * x0\n x1 = t\n if x1 < 0:\n x1 += b0\n return x1\n\ndef chinese_remainder(n, a):\n prod = 1\n sum = 0\n for i in range(len(n)):\n prod *= n[i]\n for i in range(len(n)):\n p = prod // n[i]\n sum += a[i] * mul_inv(p, n[i]) * p\n return sum % prod\n\nn = [3, 5, 7]\na = [2, 3, 2]\nprint(chinese_remainder(n, a))\n```\n\nThis Python code performs the same calculations as the given C code and prints the result.", "C": "#include \n\n// returns x where (a * x) % b == 1\nint mul_inv(int a, int b)\n{\n\tint b0 = b, t, q;\n\tint x0 = 0, x1 = 1;\n\tif (b == 1) return 1;\n\twhile (a > 1) {\n\t\tq = a / b;\n\t\tt = b, b = a % b, a = t;\n\t\tt = x0, x0 = x1 - q * x0, x1 = t;\n\t}\n\tif (x1 < 0) x1 += b0;\n\treturn x1;\n}\n\nint chinese_remainder(int *n, int *a, int len)\n{\n\tint p, i, prod = 1, sum = 0;\n\n\tfor (i = 0; i < len; i++) prod *= n[i];\n\n\tfor (i = 0; i < len; i++) {\n\t\tp = prod / n[i];\n\t\tsum += a[i] * mul_inv(p, n[i]) * p;\n\t}\n\n\treturn sum % prod;\n}\n\nint main(void)\n{\n\tint n[] = { 3, 5, 7 };\n\tint a[] = { 2, 3, 2 };\n\n\tprintf(\"%d\\n\", chinese_remainder(n, a, sizeof(n)/sizeof(n[0])));\n\treturn 0;\n}"} 4 | {"id": 3, "label": 0, "output": "File \"output/output3.py\", line 1 Here is the translation of the given C code to Python: ^ SyntaxError: invalid syntax", "Python": "Here is the translation of the given C code to Python:\n\n```python\ndef dot_product(a, b):\n sum = 0\n for i in range(len(a)):\n sum += a[i] * b[i]\n return sum\n\na = [1, 3, -5]\nb = [4, -2, -1]\nprint(dot_product(a, b))\n```\n\nNote: In Python, there is no need to specify the size of the arrays explicitly. The `len()` function is used to get the length of the arrays.", "C": "#include \n#include \n\nint dot_product(int *, int *, size_t);\n\nint\nmain(void)\n{\n int a[3] = {1, 3, -5};\n int b[3] = {4, -2, -1};\n\n printf(\"%d\\n\", dot_product(a, b, sizeof(a) / sizeof(a[0])));\n\n return EXIT_SUCCESS;\n}\n\nint\ndot_product(int *a, int *b, size_t n)\n{\n int sum = 0;\n size_t i;\n\n for (i = 0; i < n; i++) {\n sum += a[i] * b[i];\n }\n\n return sum;\n}"} 5 | {"id": 4, "label": 0, "output": "File \"output/output4.py\", line 1 Here is the translation of the given C code to Python: ^ SyntaxError: invalid syntax", "Python": "Here is the translation of the given C code to Python:\n\n```python\nimport sys\n\ndef main():\n current = 0 # the current number\n square = 0 # the square of the current number\n\n # the strategy of taking the rest of division by 1e06 is\n # to take a number whose 6 last digits are 269696\n while ((square := current * current) % 1000000 != 269696) and (square < sys.maxsize):\n current += 1\n\n # output\n if square >= sys.maxsize:\n print(\"Condition not satisfied before INT_MAX reached.\")\n else:\n print(f\"The smallest number whose square ends in 269696 is {current}\")\n\n # the end\n return 0\n\n\nif __name__ == '__main__':\n main()\n```\n\nNote: In Python, the `sys.maxsize` constant is used to represent the maximum value for integers.", "C": "#include \n#include \n#include \n \nint main() {\n\tint current = 0, \t//the current number \n\t square;\t\t//the square of the current number\n\n\t//the strategy of take the rest of division by 1e06 is\n\t//to take the a number how 6 last digits are 269696\n\twhile (((square=current*current) % 1000000 != 269696) && (square+INT_MAX)\n\t printf(\"Condition not satisfied before INT_MAX reached.\");\n\telse\t\t \n\t printf (\"The smallest number whose square ends in 269696 is %d\\n\", current);\n\t \n //the end\n\treturn 0 ;\n}"} 6 | {"id": 5, "label": 0, "output": "File \"output/output5.py\", line 1 Here is the translation of the given C code to Python: ^ SyntaxError: invalid syntax", "Python": "Here is the translation of the given C code to Python:\n\n```python\nimport random\n\ndef isBal(s, l):\n c = 0\n while l > 0:\n if s[l-1] == ']':\n c += 1\n elif s[l-1] == '[':\n c -= 1\n if c < 0:\n break\n l -= 1\n return c == 0\n\ndef shuffle(s, h):\n for i in range(h):\n x = random.randint(0, h-1)\n t = s[x]\n s[x] = s[i]\n s[i] = t\n\ndef genSeq(s, n):\n if n > 0:\n s[:n] = '[' * n\n s[n:2*n] = ']' * n\n shuffle(s, 2*n)\n s[2*n] = '\\0'\n\ndef doSeq(n):\n s = [''] * 64\n o = \"False\"\n genSeq(s, n)\n if isBal(s, 2*n):\n o = \"True\"\n print(\"'{}': {}\".format(''.join(s), o))\n\ndef main():\n n = 0\n while n < 9:\n doSeq(n)\n n += 1\n\nif __name__ == \"__main__\":\n main()\n```\n\nNote: The translation may not be exact and may require further adjustments based on the specific requirements and conventions of the Python programming language.", "C": "#include\n#include\n#include\n\nint isBal(const char*s,int l){\n signed c=0;\n while(l--)\n\tif(s[l]==']') ++c;\n\telse if(s[l]=='[') if(--c<0) break;\n return !c;\n}\n\nvoid shuffle(char*s,int h){\n int x,t,i=h;\n while(i--){\n\tt=s[x=rand()%h];\n\ts[x]=s[i];\n\ts[i]=t;\n }\n}\n\nvoid genSeq(char*s,int n){\n if(n){\n\tmemset(s,'[',n);\n\tmemset(s+n,']',n);\n\tshuffle(s,n*2);\n }\n s[n*2]=0;\n}\n\nvoid doSeq(int n){\n char s[64];\n const char *o=\"False\";\n genSeq(s,n);\n if(isBal(s,n*2)) o=\"True\";\n printf(\"'%s': %s\\n\",s,o);\n}\n\nint main(){\n int n=0;\n while(n<9) doSeq(n++);\n return 0;\n}"} 7 | {"id": 6, "label": 0, "output": "File \"output/output6.py\", line 1 Here is the translation of the given C code to Python: ^ SyntaxError: invalid syntax", "Python": "Here is the translation of the given C code to Python:\n\n```python\ndef main():\n fact = [1] * 12\n\n for n in range(1, 12):\n fact[n] = fact[n-1] * n\n\n for b in range(9, 13):\n print(f\"The factorions for base {b} are:\")\n for i in range(1, 1500000):\n sum = 0\n j = i\n while j > 0:\n d = j % b\n sum += fact[d]\n j //= b\n if sum == i:\n print(i, end=\" \")\n print(\"\\n\")\n\nif __name__ == '__main__':\n main()\n```\n\nNote: The translation assumes that the `printf` function is replaced with the `print` function in Python.", "C": "#include \n\nint main() { \n int n, b, d;\n unsigned long long i, j, sum, fact[12];\n // cache factorials from 0 to 11\n fact[0] = 1;\n for (n = 1; n < 12; ++n) {\n fact[n] = fact[n-1] * n;\n }\n\n for (b = 9; b <= 12; ++b) {\n printf(\"The factorions for base %d are:\\n\", b);\n for (i = 1; i < 1500000; ++i) {\n sum = 0;\n j = i;\n while (j > 0) {\n d = j % b;\n sum += fact[d];\n j /= b;\n }\n if (sum == i) printf(\"%llu \", i);\n }\n printf(\"\\n\\n\");\n }\n return 0;\n}"} 8 | {"id": 7, "label": 0, "output": "File \"output/output7.py\", line 1 Here is the translation of the given C code to Python: ^ SyntaxError: invalid syntax", "Python": "Here is the translation of the given C code to Python:\n\n```python\ndef feigenbaum():\n max_it = 13\n max_it_j = 10\n a1 = 1.0\n a2 = 0.0\n d1 = 3.2\n print(\" i d\")\n for i in range(2, max_it + 1):\n a = a1 + (a1 - a2) / d1\n for j in range(1, max_it_j + 1):\n x = 0.0\n y = 0.0\n for k in range(1, 1 << i + 1):\n y = 1.0 - 2.0 * y * x\n x = a - x * x\n a -= x / y\n d = (a1 - a2) / (a - a1)\n print(\"%2d %.8f\" % (i, d))\n d1 = d\n a2 = a1\n a1 = a\n\ndef main():\n feigenbaum()\n return 0\n\nif __name__ == '__main__':\n main()\n```\n\nNote: The translation assumes that the necessary libraries have been imported before this code snippet.", "C": "#include \n\nvoid feigenbaum() {\n int i, j, k, max_it = 13, max_it_j = 10;\n double a, x, y, d, a1 = 1.0, a2 = 0.0, d1 = 3.2;\n printf(\" i d\\n\");\n for (i = 2; i <= max_it; ++i) {\n a = a1 + (a1 - a2) / d1;\n for (j = 1; j <= max_it_j; ++j) {\n x = 0.0;\n y = 0.0;\n for (k = 1; k <= 1 << i; ++k) {\n y = 1.0 - 2.0 * y * x;\n x = a - x * x;\n }\n a -= x / y;\n }\n d = (a1 - a2) / (a - a1);\n printf(\"%2d %.8f\\n\", i, d);\n d1 = d;\n a2 = a1;\n a1 = a;\n }\n}\n\nint main() {\n feigenbaum();\n return 0;\n}"} 9 | {"id": 8, "label": 0, "output": "File \"output/output8.py\", line 1 Here is the translation of the given C code to Python: ^ SyntaxError: invalid syntax", "Python": "Here is the translation of the given C code to Python:\n\n```python\ndef damm(input):\n table = [\n [0, 3, 1, 7, 5, 9, 8, 6, 4, 2],\n [7, 0, 9, 2, 1, 5, 4, 8, 6, 3],\n [4, 2, 0, 6, 8, 7, 1, 3, 5, 9],\n [1, 7, 5, 0, 9, 8, 3, 4, 2, 6],\n [6, 1, 2, 3, 0, 4, 5, 9, 7, 8],\n [3, 6, 7, 4, 2, 0, 9, 5, 8, 1],\n [5, 8, 6, 9, 7, 2, 0, 1, 3, 4],\n [8, 9, 4, 5, 3, 6, 2, 0, 1, 7],\n [9, 4, 3, 8, 6, 1, 7, 2, 0, 5],\n [2, 5, 8, 1, 4, 3, 6, 7, 9, 0]\n ]\n\n interim = 0\n for i in input:\n interim = table[interim][i]\n return interim == 0\n\ninput = [5, 7, 2, 4]\nprint(\"Checksum correct\" if damm(input) else \"Checksum incorrect\")\n```\n\nNote: In the C code, the `size_t` data type is used for the length of the input array. In Python, we don't need to specify the data type explicitly, so I removed the `size_t` and used the `len()` function to get the length of the input array.", "C": "#include \n#include \n#include \n\nbool damm(unsigned char *input, size_t length) {\n static const unsigned char table[10][10] = {\n {0, 3, 1, 7, 5, 9, 8, 6, 4, 2},\n {7, 0, 9, 2, 1, 5, 4, 8, 6, 3},\n {4, 2, 0, 6, 8, 7, 1, 3, 5, 9},\n {1, 7, 5, 0, 9, 8, 3, 4, 2, 6},\n {6, 1, 2, 3, 0, 4, 5, 9, 7, 8},\n {3, 6, 7, 4, 2, 0, 9, 5, 8, 1},\n {5, 8, 6, 9, 7, 2, 0, 1, 3, 4},\n {8, 9, 4, 5, 3, 6, 2, 0, 1, 7},\n {9, 4, 3, 8, 6, 1, 7, 2, 0, 5},\n {2, 5, 8, 1, 4, 3, 6, 7, 9, 0},\n };\n \n unsigned char interim = 0;\n for (size_t i = 0; i < length; i++) {\n interim = table[interim][input[i]];\n }\n return interim == 0;\n}\n\nint main() {\n unsigned char input[4] = {5, 7, 2, 4};\n puts(damm(input, 4) ? \"Checksum correct\" : \"Checksum incorrect\");\n return 0;\n}"} 10 | {"id": 9, "label": 0, "output": "File \"output/output9.py\", line 1 Here is the translation of the given C code to Python: ^ SyntaxError: invalid syntax", "Python": "Here is the translation of the given C code to Python:\n\n```python\ndef modpow(b, e, m):\n p = 1\n while e:\n if e & 1:\n p = p * b % m\n b = b * b % m\n e >>= 1\n return p\n\ndef is_deceptive(n):\n if n & 1 and n % 3 and n % 5:\n x = 7\n while x * x <= n:\n if not (n % x and n % (x + 4)):\n return modpow(10, n - 1, n) == 1\n x += 6\n return 0\n\ndef main():\n c = 0\n i = 20\n while c != 20:\n if is_deceptive(i):\n print(i, end=\" \")\n c += 1\n i += 1\n\nif __name__ == \"__main__\":\n main()\n```\n\nThis Python code performs the same functionality as the given C code.", "C": "#include \n\nunsigned modpow(unsigned b, unsigned e, unsigned m)\n{\n unsigned p;\n for (p = 1; e; e >>= 1) {\n if (e & 1)\n p = p * b % m;\n b = b * b % m;\n }\n return p;\n}\n\nint is_deceptive(unsigned n)\n{\n unsigned x;\n if (n & 1 && n % 3 && n % 5) {\n for (x = 7; x * x <= n; x += 6) {\n if (!(n % x && n % (x + 4)))\n return modpow(10, n - 1, n) == 1;\n }\n }\n return 0;\n}\n\nint main(void)\n{\n unsigned c, i = 20;\n for (c = 0; c != 20; ++i) {\n if (is_deceptive(i)) {\n printf(\" %u\", i);\n ++c;\n }\n }\n return 0;\n}"} -------------------------------------------------------------------------------- /CodeT5+/README.md: -------------------------------------------------------------------------------- 1 | # Experiments based on CodeT5+ 2 | 3 | We performed all experiments on the CodeT5+ `220M` model. 4 | 5 | # How to Use? 6 | 7 | ## Environment 8 | 9 | ```bash 10 | conda create -n codetransocean python=3.9 11 | conda activate codetransocean 12 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia # Please check your CUDA version. 13 | pip install transformers==4.25.1 14 | pip install datasets 15 | pip install tensorboard 16 | pip install tree_sitter 17 | pip install evaluate 18 | ``` 19 | 20 | ## Finetuning & Inference & Evaluation 21 | 22 | ``` run_preprocess.py ``` is used to pre-process data. 23 | 24 | ``` run_translation.py ``` is used for training and inference on a specified dataset. 25 | 26 | ``` run_score.py ``` and ``` evaluator ``` are used to calculate inference results in the BLEU score. 27 | 28 | Other ```.sh``` files are used to specify which multilingual modeling methods to use on which datasets to train CodeT5+ and infer. 29 | 30 | 31 | ## Experimental results 32 | 33 |
34 | 35 | 36 | 37 |
38 | 39 | For more detailed experimental results, please see our paper. 40 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/bleu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Natural Language Toolkit: BLEU Score 3 | # 4 | # Copyright (C) 2001-2020 NLTK Project 5 | # Authors: Chin Yee Lee, Hengfeng Li, Ruxin Hou, Calvin Tanujaya Lim 6 | # Contributors: Björn Mattsson, Dmitrijs Milajevs, Liling Tan 7 | # URL: 8 | # For license information, see LICENSE.TXT 9 | 10 | """BLEU score implementation.""" 11 | 12 | import math 13 | import sys 14 | from fractions import Fraction 15 | import warnings 16 | from collections import Counter 17 | 18 | from evaluator.CodeBLEU.utils import ngrams 19 | 20 | 21 | def sentence_bleu( 22 | references, 23 | hypothesis, 24 | weights=(0.25, 0.25, 0.25, 0.25), 25 | smoothing_function=None, 26 | auto_reweigh=False, 27 | ): 28 | """ 29 | Calculate BLEU score (Bilingual Evaluation Understudy) from 30 | Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu. 2002. 31 | "BLEU: a method for automatic evaluation of machine translation." 32 | In Proceedings of ACL. http://www.aclweb.org/anthology/P02-1040.pdf 33 | >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 34 | ... 'ensures', 'that', 'the', 'military', 'always', 35 | ... 'obeys', 'the', 'commands', 'of', 'the', 'party'] 36 | >>> hypothesis2 = ['It', 'is', 'to', 'insure', 'the', 'troops', 37 | ... 'forever', 'hearing', 'the', 'activity', 'guidebook', 38 | ... 'that', 'party', 'direct'] 39 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 40 | ... 'ensures', 'that', 'the', 'military', 'will', 'forever', 41 | ... 'heed', 'Party', 'commands'] 42 | >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which', 43 | ... 'guarantees', 'the', 'military', 'forces', 'always', 44 | ... 'being', 'under', 'the', 'command', 'of', 'the', 45 | ... 'Party'] 46 | >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 47 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 48 | ... 'of', 'the', 'party'] 49 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis1) # doctest: +ELLIPSIS 50 | 0.5045... 51 | If there is no ngrams overlap for any order of n-grams, BLEU returns the 52 | value 0. This is because the precision for the order of n-grams without 53 | overlap is 0, and the geometric mean in the final BLEU score computation 54 | multiplies the 0 with the precision of other n-grams. This results in 0 55 | (independently of the precision of the othe n-gram orders). The following 56 | example has zero 3-gram and 4-gram overlaps: 57 | >>> round(sentence_bleu([reference1, reference2, reference3], hypothesis2),4) # doctest: +ELLIPSIS 58 | 0.0 59 | To avoid this harsh behaviour when no ngram overlaps are found a smoothing 60 | function can be used. 61 | >>> chencherry = SmoothingFunction() 62 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis2, 63 | ... smoothing_function=chencherry.method1) # doctest: +ELLIPSIS 64 | 0.0370... 65 | The default BLEU calculates a score for up to 4-grams using uniform 66 | weights (this is called BLEU-4). To evaluate your translations with 67 | higher/lower order ngrams, use customized weights. E.g. when accounting 68 | for up to 5-grams with uniform weights (this is called BLEU-5) use: 69 | >>> weights = (1./5., 1./5., 1./5., 1./5., 1./5.) 70 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis1, weights) # doctest: +ELLIPSIS 71 | 0.3920... 72 | :param references: reference sentences 73 | :type references: list(list(str)) 74 | :param hypothesis: a hypothesis sentence 75 | :type hypothesis: list(str) 76 | :param weights: weights for unigrams, bigrams, trigrams and so on 77 | :type weights: list(float) 78 | :param smoothing_function: 79 | :type smoothing_function: SmoothingFunction 80 | :param auto_reweigh: Option to re-normalize the weights uniformly. 81 | :type auto_reweigh: bool 82 | :return: The sentence-level BLEU score. 83 | :rtype: float 84 | """ 85 | return corpus_bleu( 86 | [references], [hypothesis], weights, smoothing_function, auto_reweigh 87 | ) 88 | 89 | 90 | def corpus_bleu( 91 | list_of_references, 92 | hypotheses, 93 | weights=(0.25, 0.25, 0.25, 0.25), 94 | smoothing_function=None, 95 | auto_reweigh=False, 96 | ): 97 | """ 98 | Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all 99 | the hypotheses and their respective references. 100 | Instead of averaging the sentence level BLEU scores (i.e. marco-average 101 | precision), the original BLEU metric (Papineni et al. 2002) accounts for 102 | the micro-average precision (i.e. summing the numerators and denominators 103 | for each hypothesis-reference(s) pairs before the division). 104 | >>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 105 | ... 'ensures', 'that', 'the', 'military', 'always', 106 | ... 'obeys', 'the', 'commands', 'of', 'the', 'party'] 107 | >>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 108 | ... 'ensures', 'that', 'the', 'military', 'will', 'forever', 109 | ... 'heed', 'Party', 'commands'] 110 | >>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which', 111 | ... 'guarantees', 'the', 'military', 'forces', 'always', 112 | ... 'being', 'under', 'the', 'command', 'of', 'the', 'Party'] 113 | >>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 114 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 115 | ... 'of', 'the', 'party'] 116 | >>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was', 117 | ... 'interested', 'in', 'world', 'history'] 118 | >>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history', 119 | ... 'because', 'he', 'read', 'the', 'book'] 120 | >>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]] 121 | >>> hypotheses = [hyp1, hyp2] 122 | >>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS 123 | 0.5920... 124 | The example below show that corpus_bleu() is different from averaging 125 | sentence_bleu() for hypotheses 126 | >>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1) 127 | >>> score2 = sentence_bleu([ref2a], hyp2) 128 | >>> (score1 + score2) / 2 # doctest: +ELLIPSIS 129 | 0.6223... 130 | :param list_of_references: a corpus of lists of reference sentences, w.r.t. hypotheses 131 | :type list_of_references: list(list(list(str))) 132 | :param hypotheses: a list of hypothesis sentences 133 | :type hypotheses: list(list(str)) 134 | :param weights: weights for unigrams, bigrams, trigrams and so on 135 | :type weights: list(float) 136 | :param smoothing_function: 137 | :type smoothing_function: SmoothingFunction 138 | :param auto_reweigh: Option to re-normalize the weights uniformly. 139 | :type auto_reweigh: bool 140 | :return: The corpus-level BLEU score. 141 | :rtype: float 142 | """ 143 | # Before proceeding to compute BLEU, perform sanity checks. 144 | 145 | p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches. 146 | p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref. 147 | hyp_lengths, ref_lengths = 0, 0 148 | 149 | assert len(list_of_references) == len(hypotheses), ( 150 | "The number of hypotheses and their reference(s) should be the " "same " 151 | ) 152 | 153 | # Iterate through each hypothesis and their corresponding references. 154 | for references, hypothesis in zip(list_of_references, hypotheses): 155 | # For each order of ngram, calculate the numerator and 156 | # denominator for the corpus-level modified precision. 157 | for i, _ in enumerate(weights, start=1): 158 | p_i = modified_precision(references, hypothesis, i) 159 | p_numerators[i] += p_i.numerator 160 | p_denominators[i] += p_i.denominator 161 | 162 | # Calculate the hypothesis length and the closest reference length. 163 | # Adds them to the corpus-level hypothesis and reference counts. 164 | hyp_len = len(hypothesis) 165 | hyp_lengths += hyp_len 166 | ref_lengths += closest_ref_length(references, hyp_len) 167 | 168 | # Calculate corpus-level brevity penalty. 169 | bp = brevity_penalty(ref_lengths, hyp_lengths) 170 | 171 | # Uniformly re-weighting based on maximum hypothesis lengths if largest 172 | # order of n-grams < 4 and weights is set at default. 173 | if auto_reweigh: 174 | if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25): 175 | weights = (1 / hyp_lengths,) * hyp_lengths 176 | 177 | # Collects the various precision values for the different ngram orders. 178 | p_n = [ 179 | Fraction(p_numerators[i], p_denominators[i], _normalize=False) 180 | for i, _ in enumerate(weights, start=1) 181 | ] 182 | 183 | # Returns 0 if there's no matching n-grams 184 | # We only need to check for p_numerators[1] == 0, since if there's 185 | # no unigrams, there won't be any higher order ngrams. 186 | if p_numerators[1] == 0: 187 | return 0 188 | 189 | # If there's no smoothing, set use method0 from SmoothinFunction class. 190 | if not smoothing_function: 191 | smoothing_function = SmoothingFunction().method1 192 | # Smoothen the modified precision. 193 | # Note: smoothing_function() may convert values into floats; 194 | # it tries to retain the Fraction object as much as the 195 | # smoothing method allows. 196 | p_n = smoothing_function( 197 | p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths 198 | ) 199 | s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n)) 200 | s = bp * math.exp(math.fsum(s)) 201 | return s 202 | 203 | 204 | def modified_precision(references, hypothesis, n): 205 | """ 206 | Calculate modified ngram precision. 207 | The normal precision method may lead to some wrong translations with 208 | high-precision, e.g., the translation, in which a word of reference 209 | repeats several times, has very high precision. 210 | This function only returns the Fraction object that contains the numerator 211 | and denominator necessary to calculate the corpus-level precision. 212 | To calculate the modified precision for a single pair of hypothesis and 213 | references, cast the Fraction object into a float. 214 | The famous "the the the ... " example shows that you can get BLEU precision 215 | by duplicating high frequency words. 216 | >>> reference1 = 'the cat is on the mat'.split() 217 | >>> reference2 = 'there is a cat on the mat'.split() 218 | >>> hypothesis1 = 'the the the the the the the'.split() 219 | >>> references = [reference1, reference2] 220 | >>> float(modified_precision(references, hypothesis1, n=1)) # doctest: +ELLIPSIS 221 | 0.2857... 222 | In the modified n-gram precision, a reference word will be considered 223 | exhausted after a matching hypothesis word is identified, e.g. 224 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 225 | ... 'ensures', 'that', 'the', 'military', 'will', 226 | ... 'forever', 'heed', 'Party', 'commands'] 227 | >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which', 228 | ... 'guarantees', 'the', 'military', 'forces', 'always', 229 | ... 'being', 'under', 'the', 'command', 'of', 'the', 230 | ... 'Party'] 231 | >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 232 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 233 | ... 'of', 'the', 'party'] 234 | >>> hypothesis = 'of the'.split() 235 | >>> references = [reference1, reference2, reference3] 236 | >>> float(modified_precision(references, hypothesis, n=1)) 237 | 1.0 238 | >>> float(modified_precision(references, hypothesis, n=2)) 239 | 1.0 240 | An example of a normal machine translation hypothesis: 241 | >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 242 | ... 'ensures', 'that', 'the', 'military', 'always', 243 | ... 'obeys', 'the', 'commands', 'of', 'the', 'party'] 244 | >>> hypothesis2 = ['It', 'is', 'to', 'insure', 'the', 'troops', 245 | ... 'forever', 'hearing', 'the', 'activity', 'guidebook', 246 | ... 'that', 'party', 'direct'] 247 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 248 | ... 'ensures', 'that', 'the', 'military', 'will', 249 | ... 'forever', 'heed', 'Party', 'commands'] 250 | >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which', 251 | ... 'guarantees', 'the', 'military', 'forces', 'always', 252 | ... 'being', 'under', 'the', 'command', 'of', 'the', 253 | ... 'Party'] 254 | >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 255 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 256 | ... 'of', 'the', 'party'] 257 | >>> references = [reference1, reference2, reference3] 258 | >>> float(modified_precision(references, hypothesis1, n=1)) # doctest: +ELLIPSIS 259 | 0.9444... 260 | >>> float(modified_precision(references, hypothesis2, n=1)) # doctest: +ELLIPSIS 261 | 0.5714... 262 | >>> float(modified_precision(references, hypothesis1, n=2)) # doctest: +ELLIPSIS 263 | 0.5882352941176471 264 | >>> float(modified_precision(references, hypothesis2, n=2)) # doctest: +ELLIPSIS 265 | 0.07692... 266 | :param references: A list of reference translations. 267 | :type references: list(list(str)) 268 | :param hypothesis: A hypothesis translation. 269 | :type hypothesis: list(str) 270 | :param n: The ngram order. 271 | :type n: int 272 | :return: BLEU's modified precision for the nth order ngram. 273 | :rtype: Fraction 274 | """ 275 | # Extracts all ngrams in hypothesis 276 | # Set an empty Counter if hypothesis is empty. 277 | 278 | counts = Counter(ngrams(hypothesis, n)) if len(hypothesis) >= n else Counter() 279 | # Extract a union of references' counts. 280 | # max_counts = reduce(or_, [Counter(ngrams(ref, n)) for ref in references]) 281 | max_counts = {} 282 | for reference in references: 283 | reference_counts = ( 284 | Counter(ngrams(reference, n)) if len(reference) >= n else Counter() 285 | ) 286 | for ngram in counts: 287 | max_counts[ngram] = max(max_counts.get(ngram, 0), reference_counts[ngram]) 288 | 289 | # Assigns the intersection between hypothesis and references' counts. 290 | clipped_counts = { 291 | ngram: min(count, max_counts[ngram]) for ngram, count in counts.items() 292 | } 293 | 294 | numerator = sum(clipped_counts.values()) 295 | # Ensures that denominator is minimum 1 to avoid ZeroDivisionError. 296 | # Usually this happens when the ngram order is > len(reference). 297 | denominator = max(1, sum(counts.values())) 298 | 299 | return Fraction(numerator, denominator, _normalize=False) 300 | 301 | 302 | def closest_ref_length(references, hyp_len): 303 | """ 304 | This function finds the reference that is the closest length to the 305 | hypothesis. The closest reference length is referred to as *r* variable 306 | from the brevity penalty formula in Papineni et. al. (2002) 307 | :param references: A list of reference translations. 308 | :type references: list(list(str)) 309 | :param hyp_len: The length of the hypothesis. 310 | :type hyp_len: int 311 | :return: The length of the reference that's closest to the hypothesis. 312 | :rtype: int 313 | """ 314 | ref_lens = (len(reference) for reference in references) 315 | closest_ref_len = min( 316 | ref_lens, key=lambda ref_len: (abs(ref_len - hyp_len), ref_len) 317 | ) 318 | return closest_ref_len 319 | 320 | 321 | def brevity_penalty(closest_ref_len, hyp_len): 322 | """ 323 | Calculate brevity penalty. 324 | As the modified n-gram precision still has the problem from the short 325 | length sentence, brevity penalty is used to modify the overall BLEU 326 | score according to length. 327 | An example from the paper. There are three references with length 12, 15 328 | and 17. And a concise hypothesis of the length 12. The brevity penalty is 1. 329 | >>> reference1 = list('aaaaaaaaaaaa') # i.e. ['a'] * 12 330 | >>> reference2 = list('aaaaaaaaaaaaaaa') # i.e. ['a'] * 15 331 | >>> reference3 = list('aaaaaaaaaaaaaaaaa') # i.e. ['a'] * 17 332 | >>> hypothesis = list('aaaaaaaaaaaa') # i.e. ['a'] * 12 333 | >>> references = [reference1, reference2, reference3] 334 | >>> hyp_len = len(hypothesis) 335 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 336 | >>> brevity_penalty(closest_ref_len, hyp_len) 337 | 1.0 338 | In case a hypothesis translation is shorter than the references, penalty is 339 | applied. 340 | >>> references = [['a'] * 28, ['a'] * 28] 341 | >>> hypothesis = ['a'] * 12 342 | >>> hyp_len = len(hypothesis) 343 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 344 | >>> brevity_penalty(closest_ref_len, hyp_len) 345 | 0.2635971381157267 346 | The length of the closest reference is used to compute the penalty. If the 347 | length of a hypothesis is 12, and the reference lengths are 13 and 2, the 348 | penalty is applied because the hypothesis length (12) is less then the 349 | closest reference length (13). 350 | >>> references = [['a'] * 13, ['a'] * 2] 351 | >>> hypothesis = ['a'] * 12 352 | >>> hyp_len = len(hypothesis) 353 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 354 | >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS 355 | 0.9200... 356 | The brevity penalty doesn't depend on reference order. More importantly, 357 | when two reference sentences are at the same distance, the shortest 358 | reference sentence length is used. 359 | >>> references = [['a'] * 13, ['a'] * 11] 360 | >>> hypothesis = ['a'] * 12 361 | >>> hyp_len = len(hypothesis) 362 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 363 | >>> bp1 = brevity_penalty(closest_ref_len, hyp_len) 364 | >>> hyp_len = len(hypothesis) 365 | >>> closest_ref_len = closest_ref_length(reversed(references), hyp_len) 366 | >>> bp2 = brevity_penalty(closest_ref_len, hyp_len) 367 | >>> bp1 == bp2 == 1 368 | True 369 | A test example from mteval-v13a.pl (starting from the line 705): 370 | >>> references = [['a'] * 11, ['a'] * 8] 371 | >>> hypothesis = ['a'] * 7 372 | >>> hyp_len = len(hypothesis) 373 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 374 | >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS 375 | 0.8668... 376 | >>> references = [['a'] * 11, ['a'] * 8, ['a'] * 6, ['a'] * 7] 377 | >>> hypothesis = ['a'] * 7 378 | >>> hyp_len = len(hypothesis) 379 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 380 | >>> brevity_penalty(closest_ref_len, hyp_len) 381 | 1.0 382 | :param hyp_len: The length of the hypothesis for a single sentence OR the 383 | sum of all the hypotheses' lengths for a corpus 384 | :type hyp_len: int 385 | :param closest_ref_len: The length of the closest reference for a single 386 | hypothesis OR the sum of all the closest references for every hypotheses. 387 | :type closest_ref_len: int 388 | :return: BLEU's brevity penalty. 389 | :rtype: float 390 | """ 391 | if hyp_len > closest_ref_len: 392 | return 1 393 | # If hypothesis is empty, brevity penalty = 0 should result in BLEU = 0.0 394 | elif hyp_len == 0: 395 | return 0 396 | else: 397 | return math.exp(1 - closest_ref_len / hyp_len) 398 | 399 | 400 | class SmoothingFunction: 401 | """ 402 | This is an implementation of the smoothing techniques 403 | for segment-level BLEU scores that was presented in 404 | Boxing Chen and Collin Cherry (2014) A Systematic Comparison of 405 | Smoothing Techniques for Sentence-Level BLEU. In WMT14. 406 | http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf 407 | """ 408 | 409 | def __init__(self, epsilon=0.1, alpha=5, k=5): 410 | """ 411 | This will initialize the parameters required for the various smoothing 412 | techniques, the default values are set to the numbers used in the 413 | experiments from Chen and Cherry (2014). 414 | >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 'ensures', 415 | ... 'that', 'the', 'military', 'always', 'obeys', 'the', 416 | ... 'commands', 'of', 'the', 'party'] 417 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 'ensures', 418 | ... 'that', 'the', 'military', 'will', 'forever', 'heed', 419 | ... 'Party', 'commands'] 420 | >>> chencherry = SmoothingFunction() 421 | >>> print(sentence_bleu([reference1], hypothesis1)) # doctest: +ELLIPSIS 422 | 0.4118... 423 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method0)) # doctest: +ELLIPSIS 424 | 0.4118... 425 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method1)) # doctest: +ELLIPSIS 426 | 0.4118... 427 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method2)) # doctest: +ELLIPSIS 428 | 0.4489... 429 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method3)) # doctest: +ELLIPSIS 430 | 0.4118... 431 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method4)) # doctest: +ELLIPSIS 432 | 0.4118... 433 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method5)) # doctest: +ELLIPSIS 434 | 0.4905... 435 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method6)) # doctest: +ELLIPSIS 436 | 0.4135... 437 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method7)) # doctest: +ELLIPSIS 438 | 0.4905... 439 | :param epsilon: the epsilon value use in method 1 440 | :type epsilon: float 441 | :param alpha: the alpha value use in method 6 442 | :type alpha: int 443 | :param k: the k value use in method 4 444 | :type k: int 445 | """ 446 | self.epsilon = epsilon 447 | self.alpha = alpha 448 | self.k = k 449 | 450 | def method0(self, p_n, *args, **kwargs): 451 | """ 452 | No smoothing. 453 | """ 454 | p_n_new = [] 455 | for i, p_i in enumerate(p_n): 456 | if p_i.numerator != 0: 457 | p_n_new.append(p_i) 458 | else: 459 | _msg = str( 460 | "\nThe hypothesis contains 0 counts of {}-gram overlaps.\n" 461 | "Therefore the BLEU score evaluates to 0, independently of\n" 462 | "how many N-gram overlaps of lower order it contains.\n" 463 | "Consider using lower n-gram order or use " 464 | "SmoothingFunction()" 465 | ).format(i + 1) 466 | warnings.warn(_msg) 467 | # When numerator==0 where denonminator==0 or !=0, the result 468 | # for the precision score should be equal to 0 or undefined. 469 | # Due to BLEU geometric mean computation in logarithm space, 470 | # we we need to take the return sys.float_info.min such that 471 | # math.log(sys.float_info.min) returns a 0 precision score. 472 | p_n_new.append(sys.float_info.min) 473 | return p_n_new 474 | 475 | def method1(self, p_n, *args, **kwargs): 476 | """ 477 | Smoothing method 1: Add *epsilon* counts to precision with 0 counts. 478 | """ 479 | return [ 480 | (p_i.numerator + self.epsilon) / p_i.denominator 481 | if p_i.numerator == 0 482 | else p_i 483 | for p_i in p_n 484 | ] 485 | 486 | def method2(self, p_n, *args, **kwargs): 487 | """ 488 | Smoothing method 2: Add 1 to both numerator and denominator from 489 | Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of 490 | machine translation quality using longest common subsequence and 491 | skip-bigram statistics. In ACL04. 492 | """ 493 | return [ 494 | Fraction(p_i.numerator + 1, p_i.denominator + 1, _normalize=False) 495 | for p_i in p_n 496 | ] 497 | 498 | def method3(self, p_n, *args, **kwargs): 499 | """ 500 | Smoothing method 3: NIST geometric sequence smoothing 501 | The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each 502 | precision score whose matching n-gram count is null. 503 | k is 1 for the first 'n' value for which the n-gram match count is null/ 504 | For example, if the text contains: 505 | - one 2-gram match 506 | - and (consequently) two 1-gram matches 507 | the n-gram count for each individual precision score would be: 508 | - n=1 => prec_count = 2 (two unigrams) 509 | - n=2 => prec_count = 1 (one bigram) 510 | - n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1) 511 | - n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2) 512 | """ 513 | incvnt = 1 # From the mteval-v13a.pl, it's referred to as k. 514 | for i, p_i in enumerate(p_n): 515 | if p_i.numerator == 0: 516 | p_n[i] = 1 / (2 ** incvnt * p_i.denominator) 517 | incvnt += 1 518 | return p_n 519 | 520 | def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 521 | """ 522 | Smoothing method 4: 523 | Shorter translations may have inflated precision values due to having 524 | smaller denominators; therefore, we give them proportionally 525 | smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry 526 | suggests dividing by 1/ln(len(T)), where T is the length of the translation. 527 | """ 528 | hyp_len = hyp_len if hyp_len else len(hypothesis) 529 | for i, p_i in enumerate(p_n): 530 | if p_i.numerator == 0 and hyp_len != 0: 531 | incvnt = i + 1 * self.k / math.log( 532 | hyp_len 533 | ) # Note that this K is different from the K from NIST. 534 | p_n[i] = incvnt / p_i.denominator 535 | return p_n 536 | 537 | def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 538 | """ 539 | Smoothing method 5: 540 | The matched counts for similar values of n should be similar. To a 541 | calculate the n-gram matched count, it averages the n−1, n and n+1 gram 542 | matched counts. 543 | """ 544 | hyp_len = hyp_len if hyp_len else len(hypothesis) 545 | m = {} 546 | # Requires an precision value for an addition ngram order. 547 | p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)] 548 | m[-1] = p_n[0] + 1 549 | for i, p_i in enumerate(p_n): 550 | p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3 551 | m[i] = p_n[i] 552 | return p_n 553 | 554 | def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 555 | """ 556 | Smoothing method 6: 557 | Interpolates the maximum likelihood estimate of the precision *p_n* with 558 | a prior estimate *pi0*. The prior is estimated by assuming that the ratio 559 | between pn and pn−1 will be the same as that between pn−1 and pn−2; from 560 | Gao and He (2013) Training MRF-Based Phrase Translation Models using 561 | Gradient Ascent. In NAACL. 562 | """ 563 | hyp_len = hyp_len if hyp_len else len(hypothesis) 564 | # This smoothing only works when p_1 and p_2 is non-zero. 565 | # Raise an error with an appropriate message when the input is too short 566 | # to use this smoothing technique. 567 | assert p_n[2], "This smoothing method requires non-zero precision for bigrams." 568 | for i, p_i in enumerate(p_n): 569 | if i in [0, 1]: # Skips the first 2 orders of ngrams. 570 | continue 571 | else: 572 | pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2] 573 | # No. of ngrams in translation that matches the reference. 574 | m = p_i.numerator 575 | # No. of ngrams in translation. 576 | l = sum(1 for _ in ngrams(hypothesis, i + 1)) 577 | # Calculates the interpolated precision. 578 | p_n[i] = (m + self.alpha * pi0) / (l + self.alpha) 579 | return p_n 580 | 581 | def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 582 | """ 583 | Smoothing method 7: 584 | Interpolates methods 4 and 5. 585 | """ 586 | hyp_len = hyp_len if hyp_len else len(hypothesis) 587 | p_n = self.method4(p_n, references, hypothesis, hyp_len) 588 | p_n = self.method5(p_n, references, hypothesis, hyp_len) 589 | return p_n 590 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/calc_code_bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | # https://github.com/microsoft/CodeXGLUE/tree/main/Code-Code/code-to-code-trans/evaluator/CodeBLEU 4 | 5 | # -*- coding:utf-8 -*- 6 | import sys 7 | sys.path.append('/mnt/workspace/project/transformers/CodeTansBaselines/CodeT5/CodeT5') 8 | import argparse 9 | import os 10 | from evaluator.CodeBLEU import bleu, weighted_ngram_match, syntax_match, dataflow_match 11 | import json 12 | 13 | def get_codebleu(refs, hyp, lang, params='0.25,0.25,0.25,0.25', naive=False, ): 14 | if not isinstance(refs, list): 15 | refs = [refs] 16 | alpha, beta, gamma, theta = [float(x) for x in params.split(',')] 17 | 18 | 19 | # preprocess inputs 20 | if 'json' in hyp: 21 | if naive: 22 | hypothesis = [json.loads(x)['source'] for x in open(hyp, 'r', encoding='utf-8').readlines()] 23 | else: 24 | hypothesis = [json.loads(x)['prediction'] for x in open(hyp, 'r', encoding='utf-8').readlines()] 25 | pre_references = [[json.loads(x)['target'] for x in open(hyp, 'r', encoding='utf-8').readlines()]] 26 | else: 27 | pre_references = [[x.strip() for x in open(file, 'r', encoding='utf-8').readlines()] for file in refs] 28 | hypothesis = [x.strip() for x in open(hyp, 'r', encoding='utf-8').readlines()] 29 | 30 | for i in range(len(pre_references)): 31 | assert len(hypothesis) == len(pre_references[i]) 32 | 33 | references = [] 34 | for i in range(len(hypothesis)): 35 | ref_for_instance = [] 36 | for j in range(len(pre_references)): 37 | ref_for_instance.append(pre_references[j][i]) 38 | references.append(ref_for_instance) 39 | assert len(references) == len(pre_references) * len(hypothesis) 40 | 41 | # calculate ngram match (BLEU) 42 | tokenized_hyps = [x.split() for x in hypothesis] 43 | tokenized_refs = [[x.split() for x in reference] for reference in references] 44 | 45 | ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps) 46 | 47 | # calculate weighted ngram match 48 | root_dir = os.path.dirname(__file__) 49 | keywords = [x.strip() for x in open(root_dir + '/keywords/' + lang + '.txt', 'r', encoding='utf-8').readlines()] 50 | 51 | def make_weights(reference_tokens, key_word_list): 52 | return {token: 1 if token in key_word_list else 0.2 for token in reference_tokens} 53 | 54 | tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)] \ 55 | for reference_tokens in reference] for reference in tokenized_refs] 56 | 57 | weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(tokenized_refs_with_weights, tokenized_hyps) 58 | 59 | # calculate syntax match 60 | syntax_match_score = syntax_match.corpus_syntax_match(references, hypothesis, lang) 61 | 62 | # calculate dataflow match 63 | dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis, lang) 64 | 65 | print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'. \ 66 | format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score)) 67 | 68 | code_bleu_score = alpha * ngram_match_score \ 69 | + beta * weighted_ngram_match_score \ 70 | + gamma * syntax_match_score \ 71 | + theta * dataflow_match_score 72 | 73 | print('code_bleu_score', code_bleu_score) 74 | 75 | return code_bleu_score 76 | 77 | 78 | def get_codebleu_list(refs, hyp, lang, params='0.25,0.25,0.25,0.25'): 79 | alpha, beta, gamma, theta = [float(x) for x in params.split(',')] 80 | pre_references = refs 81 | hypothesis = hyp 82 | # preprocess inputs 83 | for i in range(len(pre_references)): 84 | assert len(hypothesis) == len(pre_references[i]) 85 | 86 | references = [] 87 | for i in range(len(hypothesis)): 88 | ref_for_instance = [] 89 | for j in range(len(pre_references)): 90 | ref_for_instance.append(pre_references[j][i]) 91 | references.append(ref_for_instance) 92 | assert len(references) == len(pre_references) * len(hypothesis) 93 | 94 | # calculate ngram match (BLEU) 95 | tokenized_hyps = [x.split() for x in hypothesis] 96 | tokenized_refs = [[x.split() for x in reference] for reference in references] 97 | 98 | ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps) 99 | 100 | # calculate weighted ngram match 101 | root_dir = os.path.dirname(__file__) 102 | keywords = [x.strip() for x in open(root_dir + '/keywords/' + lang + '.txt', 'r', encoding='utf-8').readlines()] 103 | 104 | def make_weights(reference_tokens, key_word_list): 105 | return {token: 1 if token in key_word_list else 0.2 for token in reference_tokens} 106 | 107 | tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)] \ 108 | for reference_tokens in reference] for reference in tokenized_refs] 109 | 110 | weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(tokenized_refs_with_weights, tokenized_hyps) 111 | 112 | # calculate syntax match 113 | syntax_match_score = syntax_match.corpus_syntax_match(references, hypothesis, lang) 114 | 115 | # calculate dataflow match 116 | dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis, lang) 117 | 118 | print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'. \ 119 | format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score)) 120 | 121 | code_bleu_score = alpha * ngram_match_score \ 122 | + beta * weighted_ngram_match_score \ 123 | + gamma * syntax_match_score \ 124 | + theta * dataflow_match_score 125 | 126 | print('code_bleu_score', code_bleu_score) 127 | 128 | return code_bleu_score 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument('--refs', type=str, nargs='+', required=True, 134 | help='reference files') 135 | parser.add_argument('--hyp', type=str, required=True, 136 | help='hypothesis file') 137 | parser.add_argument('--lang', type=str, required=True, 138 | choices=['java', 'js', 'c_sharp', 'php', 'go', 'python', 'ruby'], 139 | help='programming language') 140 | parser.add_argument('--params', type=str, default='0.25,0.25,0.25,0.25', 141 | help='alpha, beta and gamma') 142 | parser.add_argument('--naive', action='store_true') 143 | 144 | args = parser.parse_args() 145 | code_bleu_score = get_codebleu(args.refs, args.hyp, args.lang, args.params, args.naive) 146 | print('CodeBLEU score: ', code_bleu_score) 147 | 148 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/dataflow_match.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from evaluator.CodeBLEU.parser import DFG_python, DFG_java, DFG_ruby, DFG_go, DFG_php, DFG_javascript, DFG_csharp 5 | from evaluator.CodeBLEU.parser import (remove_comments_and_docstrings, 6 | tree_to_token_index, 7 | index_to_code_token, 8 | tree_to_variable_index) 9 | from tree_sitter import Language, Parser 10 | import os 11 | 12 | root_dir = os.path.dirname(__file__) 13 | 14 | dfg_function = { 15 | 'python': DFG_python, 16 | 'java': DFG_java, 17 | 'ruby': DFG_ruby, 18 | 'go': DFG_go, 19 | 'php': DFG_php, 20 | 'javascript': DFG_javascript, 21 | 'c_sharp': DFG_csharp, 22 | } 23 | 24 | 25 | def calc_dataflow_match(references, candidate, lang): 26 | return corpus_dataflow_match([references], [candidate], lang) 27 | 28 | 29 | def corpus_dataflow_match(references, candidates, lang): 30 | LANGUAGE = Language(root_dir + '/parser/my-languages.so', lang) 31 | parser = Parser() 32 | parser.set_language(LANGUAGE) 33 | parser = [parser, dfg_function[lang]] 34 | match_count = 0 35 | total_count = 0 36 | 37 | for i in range(len(candidates)): 38 | references_sample = references[i] 39 | candidate = candidates[i] 40 | for reference in references_sample: 41 | try: 42 | candidate = remove_comments_and_docstrings(candidate, 'java') 43 | except: 44 | pass 45 | try: 46 | reference = remove_comments_and_docstrings(reference, 'java') 47 | except: 48 | pass 49 | 50 | cand_dfg = get_data_flow(candidate, parser) 51 | ref_dfg = get_data_flow(reference, parser) 52 | 53 | normalized_cand_dfg = normalize_dataflow(cand_dfg) 54 | normalized_ref_dfg = normalize_dataflow(ref_dfg) 55 | 56 | if len(normalized_ref_dfg) > 0: 57 | total_count += len(normalized_ref_dfg) 58 | for dataflow in normalized_ref_dfg: 59 | if dataflow in normalized_cand_dfg: 60 | match_count += 1 61 | normalized_cand_dfg.remove(dataflow) 62 | if total_count == 0: 63 | print( 64 | "WARNING: There is no reference data-flows extracted from the whole corpus, and the data-flow match score degenerates to 0. Please consider ignoring this score.") 65 | return 0 66 | score = match_count / total_count 67 | return score 68 | 69 | 70 | def get_data_flow(code, parser): 71 | try: 72 | tree = parser[0].parse(bytes(code, 'utf8')) 73 | root_node = tree.root_node 74 | tokens_index = tree_to_token_index(root_node) 75 | code = code.split('\n') 76 | code_tokens = [index_to_code_token(x, code) for x in tokens_index] 77 | index_to_code = {} 78 | for idx, (index, code) in enumerate(zip(tokens_index, code_tokens)): 79 | index_to_code[index] = (idx, code) 80 | try: 81 | DFG, _ = parser[1](root_node, index_to_code, {}) 82 | except: 83 | DFG = [] 84 | DFG = sorted(DFG, key=lambda x: x[1]) 85 | indexs = set() 86 | for d in DFG: 87 | if len(d[-1]) != 0: 88 | indexs.add(d[1]) 89 | for x in d[-1]: 90 | indexs.add(x) 91 | new_DFG = [] 92 | for d in DFG: 93 | if d[1] in indexs: 94 | new_DFG.append(d) 95 | codes = code_tokens 96 | dfg = new_DFG 97 | except: 98 | codes = code.split() 99 | dfg = [] 100 | # merge nodes 101 | dic = {} 102 | for d in dfg: 103 | if d[1] not in dic: 104 | dic[d[1]] = d 105 | else: 106 | dic[d[1]] = (d[0], d[1], d[2], list(set(dic[d[1]][3] + d[3])), list(set(dic[d[1]][4] + d[4]))) 107 | DFG = [] 108 | for d in dic: 109 | DFG.append(dic[d]) 110 | dfg = DFG 111 | return dfg 112 | 113 | 114 | def normalize_dataflow_item(dataflow_item): 115 | var_name = dataflow_item[0] 116 | var_pos = dataflow_item[1] 117 | relationship = dataflow_item[2] 118 | par_vars_name_list = dataflow_item[3] 119 | par_vars_pos_list = dataflow_item[4] 120 | 121 | var_names = list(set(par_vars_name_list + [var_name])) 122 | norm_names = {} 123 | for i in range(len(var_names)): 124 | norm_names[var_names[i]] = 'var_' + str(i) 125 | 126 | norm_var_name = norm_names[var_name] 127 | relationship = dataflow_item[2] 128 | norm_par_vars_name_list = [norm_names[x] for x in par_vars_name_list] 129 | 130 | return (norm_var_name, relationship, norm_par_vars_name_list) 131 | 132 | 133 | def normalize_dataflow(dataflow): 134 | var_dict = {} 135 | i = 0 136 | normalized_dataflow = [] 137 | for item in dataflow: 138 | var_name = item[0] 139 | relationship = item[2] 140 | par_vars_name_list = item[3] 141 | for name in par_vars_name_list: 142 | if name not in var_dict: 143 | var_dict[name] = 'var_' + str(i) 144 | i += 1 145 | if var_name not in var_dict: 146 | var_dict[var_name] = 'var_' + str(i) 147 | i += 1 148 | normalized_dataflow.append((var_dict[var_name], relationship, [var_dict[x] for x in par_vars_name_list])) 149 | return normalized_dataflow 150 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/keywords/c_sharp.txt: -------------------------------------------------------------------------------- 1 | abstract 2 | as 3 | base 4 | bool 5 | break 6 | byte 7 | case 8 | catch 9 | char 10 | checked 11 | class 12 | const 13 | continue 14 | decimal 15 | default 16 | delegate 17 | do 18 | double 19 | else 20 | enum 21 | event 22 | explicit 23 | extern 24 | false 25 | finally 26 | fixed 27 | float 28 | for 29 | foreach 30 | goto 31 | if 32 | implicit 33 | in 34 | int 35 | interface 36 | internal 37 | is 38 | lock 39 | long 40 | namespace 41 | new 42 | null 43 | object 44 | operator 45 | out 46 | override 47 | params 48 | private 49 | protected 50 | public 51 | readonly 52 | ref 53 | return 54 | sbyte 55 | sealed 56 | short 57 | sizeof 58 | stackalloc 59 | static 60 | string 61 | struct 62 | switch 63 | this 64 | throw 65 | true 66 | try 67 | typeof 68 | uint 69 | ulong 70 | unchecked 71 | unsafe 72 | ushort 73 | using 74 | virtual 75 | void 76 | volatile 77 | while 78 | add 79 | alias 80 | ascending 81 | async 82 | await 83 | by 84 | descending 85 | dynamic 86 | equals 87 | from 88 | get 89 | global 90 | group 91 | into 92 | join 93 | let 94 | nameof 95 | notnull 96 | on 97 | orderby 98 | partial 99 | remove 100 | select 101 | set 102 | unmanaged 103 | value 104 | var 105 | when 106 | where 107 | yield 108 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/keywords/java.txt: -------------------------------------------------------------------------------- 1 | abstract 2 | assert 3 | boolean 4 | break 5 | byte 6 | case 7 | catch 8 | char 9 | class 10 | const 11 | continue 12 | default 13 | do 14 | double 15 | else 16 | enum 17 | extends 18 | final 19 | finally 20 | float 21 | for 22 | goto 23 | if 24 | implements 25 | import 26 | instanceof 27 | int 28 | interface 29 | long 30 | native 31 | new 32 | package 33 | private 34 | protected 35 | public 36 | return 37 | short 38 | static 39 | strictfp 40 | super 41 | switch 42 | synchronized 43 | this 44 | throw 45 | throws 46 | transient 47 | try 48 | void 49 | volatile 50 | while 51 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/keywords/python.txt: -------------------------------------------------------------------------------- 1 | False 2 | None 3 | True 4 | and 5 | as 6 | assert 7 | async 8 | await 9 | break 10 | class 11 | continue 12 | def 13 | del 14 | elif 15 | else 16 | except 17 | finally 18 | for 19 | from 20 | global 21 | if 22 | import 23 | in 24 | is 25 | lambda 26 | nonlocal 27 | not 28 | or 29 | pass 30 | raise 31 | return 32 | try 33 | while 34 | with 35 | yield 36 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/parser/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from .utils import (remove_comments_and_docstrings, 5 | tree_to_token_index, 6 | index_to_code_token, 7 | tree_to_variable_index) 8 | from .DFG import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/parser/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from tree_sitter import Language, Parser 5 | 6 | Language.build_library( 7 | # Store the library in the `build` directory 8 | 'my-languages.so', 9 | 10 | # Include one or more languages 11 | [ 12 | 'tree-sitter-go', 13 | 'tree-sitter-javascript', 14 | 'tree-sitter-python', 15 | 'tree-sitter-php', 16 | 'tree-sitter-java', 17 | 'tree-sitter-ruby', 18 | 'tree-sitter-c-sharp', 19 | ] 20 | ) 21 | 22 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/parser/build.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/tree-sitter/tree-sitter-go 2 | git clone https://github.com/tree-sitter/tree-sitter-javascript 3 | git clone https://github.com/tree-sitter/tree-sitter-python 4 | git clone https://github.com/tree-sitter/tree-sitter-ruby 5 | git clone https://github.com/tree-sitter/tree-sitter-php 6 | git clone https://github.com/tree-sitter/tree-sitter-java 7 | git clone https://github.com/tree-sitter/tree-sitter-c-sharp 8 | python build.py 9 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/parser/my-languages.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeixiangYAN/CodeTransOcean/42e2cd3b41b3a18a6dba3dfdf425f772360304ca/CodeT5+/evaluator/CodeBLEU/parser/my-languages.so -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/parser/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import re 5 | from io import StringIO 6 | import tokenize 7 | 8 | 9 | def remove_comments_and_docstrings(source, lang): 10 | if lang in ['python']: 11 | """ 12 | Returns 'source' minus comments and docstrings. 13 | """ 14 | io_obj = StringIO(source) 15 | out = "" 16 | prev_toktype = tokenize.INDENT 17 | last_lineno = -1 18 | last_col = 0 19 | for tok in tokenize.generate_tokens(io_obj.readline): 20 | token_type = tok[0] 21 | token_string = tok[1] 22 | start_line, start_col = tok[2] 23 | end_line, end_col = tok[3] 24 | ltext = tok[4] 25 | if start_line > last_lineno: 26 | last_col = 0 27 | if start_col > last_col: 28 | out += (" " * (start_col - last_col)) 29 | # Remove comments: 30 | if token_type == tokenize.COMMENT: 31 | pass 32 | # This series of conditionals removes docstrings: 33 | elif token_type == tokenize.STRING: 34 | if prev_toktype != tokenize.INDENT: 35 | # This is likely a docstring; double-check we're not inside an operator: 36 | if prev_toktype != tokenize.NEWLINE: 37 | if start_col > 0: 38 | out += token_string 39 | else: 40 | out += token_string 41 | prev_toktype = token_type 42 | last_col = end_col 43 | last_lineno = end_line 44 | temp = [] 45 | for x in out.split('\n'): 46 | if x.strip() != "": 47 | temp.append(x) 48 | return '\n'.join(temp) 49 | elif lang in ['ruby']: 50 | return source 51 | else: 52 | def replacer(match): 53 | s = match.group(0) 54 | if s.startswith('/'): 55 | return " " # note: a space and not an empty string 56 | else: 57 | return s 58 | 59 | pattern = re.compile( 60 | r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', 61 | re.DOTALL | re.MULTILINE 62 | ) 63 | temp = [] 64 | for x in re.sub(pattern, replacer, source).split('\n'): 65 | if x.strip() != "": 66 | temp.append(x) 67 | return '\n'.join(temp) 68 | 69 | 70 | def tree_to_token_index(root_node): 71 | if (len(root_node.children) == 0 or root_node.type in ['string_literal', 'string', 72 | 'character_literal']) and root_node.type != 'comment': 73 | return [(root_node.start_point, root_node.end_point)] 74 | else: 75 | code_tokens = [] 76 | for child in root_node.children: 77 | code_tokens += tree_to_token_index(child) 78 | return code_tokens 79 | 80 | 81 | def tree_to_variable_index(root_node, index_to_code): 82 | if (len(root_node.children) == 0 or root_node.type in ['string_literal', 'string', 83 | 'character_literal']) and root_node.type != 'comment': 84 | index = (root_node.start_point, root_node.end_point) 85 | _, code = index_to_code[index] 86 | if root_node.type != code: 87 | return [(root_node.start_point, root_node.end_point)] 88 | else: 89 | return [] 90 | else: 91 | code_tokens = [] 92 | for child in root_node.children: 93 | code_tokens += tree_to_variable_index(child, index_to_code) 94 | return code_tokens 95 | 96 | 97 | def index_to_code_token(index, code): 98 | start_point = index[0] 99 | end_point = index[1] 100 | if start_point[0] == end_point[0]: 101 | s = code[start_point[0]][start_point[1]:end_point[1]] 102 | else: 103 | s = "" 104 | s += code[start_point[0]][start_point[1]:] 105 | for i in range(start_point[0] + 1, end_point[0]): 106 | s += code[i] 107 | s += code[end_point[0]][:end_point[1]] 108 | return s 109 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/readme.txt: -------------------------------------------------------------------------------- 1 | python calc_code_bleu.py --refs reference_files --hyp candidate_file --language java ( or c_sharp) --params 0.25,0.25,0.25,0.25(default) -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/syntax_match.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from evaluator.CodeBLEU.parser import DFG_python, DFG_java, DFG_ruby, DFG_go, DFG_php, DFG_javascript, DFG_csharp 5 | from evaluator.CodeBLEU.parser import (remove_comments_and_docstrings, 6 | tree_to_token_index, 7 | index_to_code_token, 8 | tree_to_variable_index) 9 | from tree_sitter import Language, Parser 10 | import os 11 | 12 | root_dir = os.path.dirname(__file__) 13 | dfg_function = { 14 | 'python': DFG_python, 15 | 'java': DFG_java, 16 | 'ruby': DFG_ruby, 17 | 'go': DFG_go, 18 | 'php': DFG_php, 19 | 'javascript': DFG_javascript, 20 | 'c_sharp': DFG_csharp, 21 | } 22 | 23 | 24 | def calc_syntax_match(references, candidate, lang): 25 | return corpus_syntax_match([references], [candidate], lang) 26 | 27 | 28 | def corpus_syntax_match(references, candidates, lang): 29 | JAVA_LANGUAGE = Language(root_dir + '/parser/my-languages.so', lang) 30 | parser = Parser() 31 | parser.set_language(JAVA_LANGUAGE) 32 | match_count = 0 33 | total_count = 0 34 | 35 | for i in range(len(candidates)): 36 | references_sample = references[i] 37 | candidate = candidates[i] 38 | for reference in references_sample: 39 | try: 40 | candidate = remove_comments_and_docstrings(candidate, 'java') 41 | except: 42 | pass 43 | try: 44 | reference = remove_comments_and_docstrings(reference, 'java') 45 | except: 46 | pass 47 | 48 | candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node 49 | 50 | reference_tree = parser.parse(bytes(reference, 'utf8')).root_node 51 | 52 | def get_all_sub_trees(root_node): 53 | node_stack = [] 54 | sub_tree_sexp_list = [] 55 | depth = 1 56 | node_stack.append([root_node, depth]) 57 | while len(node_stack) != 0: 58 | cur_node, cur_depth = node_stack.pop() 59 | sub_tree_sexp_list.append([cur_node.sexp(), cur_depth]) 60 | for child_node in cur_node.children: 61 | if len(child_node.children) != 0: 62 | depth = cur_depth + 1 63 | node_stack.append([child_node, depth]) 64 | return sub_tree_sexp_list 65 | 66 | cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)] 67 | ref_sexps = get_all_sub_trees(reference_tree) 68 | 69 | # print(cand_sexps) 70 | # print(ref_sexps) 71 | 72 | for sub_tree, depth in ref_sexps: 73 | if sub_tree in cand_sexps: 74 | match_count += 1 75 | total_count += len(ref_sexps) 76 | 77 | score = match_count / total_count 78 | return score 79 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/utils.py: -------------------------------------------------------------------------------- 1 | # Natural Language Toolkit: Utility functions 2 | # 3 | # Copyright (C) 2001-2020 NLTK Project 4 | # Author: Steven Bird 5 | # URL: 6 | # For license information, see LICENSE.TXT 7 | 8 | from itertools import chain 9 | 10 | def pad_sequence( 11 | sequence, 12 | n, 13 | pad_left=False, 14 | pad_right=False, 15 | left_pad_symbol=None, 16 | right_pad_symbol=None, 17 | ): 18 | """ 19 | Returns a padded sequence of items before ngram extraction. 20 | >>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='', right_pad_symbol='')) 21 | ['', 1, 2, 3, 4, 5, ''] 22 | >>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='')) 23 | ['', 1, 2, 3, 4, 5] 24 | >>> list(pad_sequence([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='')) 25 | [1, 2, 3, 4, 5, ''] 26 | :param sequence: the source data to be padded 27 | :type sequence: sequence or iter 28 | :param n: the degree of the ngrams 29 | :type n: int 30 | :param pad_left: whether the ngrams should be left-padded 31 | :type pad_left: bool 32 | :param pad_right: whether the ngrams should be right-padded 33 | :type pad_right: bool 34 | :param left_pad_symbol: the symbol to use for left padding (default is None) 35 | :type left_pad_symbol: any 36 | :param right_pad_symbol: the symbol to use for right padding (default is None) 37 | :type right_pad_symbol: any 38 | :rtype: sequence or iter 39 | """ 40 | sequence = iter(sequence) 41 | if pad_left: 42 | sequence = chain((left_pad_symbol,) * (n - 1), sequence) 43 | if pad_right: 44 | sequence = chain(sequence, (right_pad_symbol,) * (n - 1)) 45 | return sequence 46 | 47 | 48 | # add a flag to pad the sequence so we get peripheral ngrams? 49 | 50 | 51 | def ngrams( 52 | sequence, 53 | n, 54 | pad_left=False, 55 | pad_right=False, 56 | left_pad_symbol=None, 57 | right_pad_symbol=None, 58 | ): 59 | """ 60 | Return the ngrams generated from a sequence of items, as an iterator. 61 | For example: 62 | >>> from nltk.util import ngrams 63 | >>> list(ngrams([1,2,3,4,5], 3)) 64 | [(1, 2, 3), (2, 3, 4), (3, 4, 5)] 65 | Wrap with list for a list version of this function. Set pad_left 66 | or pad_right to true in order to get additional ngrams: 67 | >>> list(ngrams([1,2,3,4,5], 2, pad_right=True)) 68 | [(1, 2), (2, 3), (3, 4), (4, 5), (5, None)] 69 | >>> list(ngrams([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='')) 70 | [(1, 2), (2, 3), (3, 4), (4, 5), (5, '')] 71 | >>> list(ngrams([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='')) 72 | [('', 1), (1, 2), (2, 3), (3, 4), (4, 5)] 73 | >>> list(ngrams([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='', right_pad_symbol='')) 74 | [('', 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, '')] 75 | :param sequence: the source data to be converted into ngrams 76 | :type sequence: sequence or iter 77 | :param n: the degree of the ngrams 78 | :type n: int 79 | :param pad_left: whether the ngrams should be left-padded 80 | :type pad_left: bool 81 | :param pad_right: whether the ngrams should be right-padded 82 | :type pad_right: bool 83 | :param left_pad_symbol: the symbol to use for left padding (default is None) 84 | :type left_pad_symbol: any 85 | :param right_pad_symbol: the symbol to use for right padding (default is None) 86 | :type right_pad_symbol: any 87 | :rtype: sequence or iter 88 | """ 89 | sequence = pad_sequence( 90 | sequence, n, pad_left, pad_right, left_pad_symbol, right_pad_symbol 91 | ) 92 | 93 | history = [] 94 | while n > 1: 95 | # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator 96 | try: 97 | next_item = next(sequence) 98 | except StopIteration: 99 | # no more data, terminate the generator 100 | return 101 | history.append(next_item) 102 | n -= 1 103 | for item in sequence: 104 | history.append(item) 105 | yield tuple(history) 106 | del history[0] -------------------------------------------------------------------------------- /CodeT5+/evaluator/CodeBLEU/weighted_ngram_match.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT license. 4 | 5 | # Natural Language Toolkit: BLEU Score 6 | # 7 | # Copyright (C) 2001-2020 NLTK Project 8 | # Authors: Chin Yee Lee, Hengfeng Li, Ruxin Hou, Calvin Tanujaya Lim 9 | # Contributors: Björn Mattsson, Dmitrijs Milajevs, Liling Tan 10 | # URL: 11 | # For license information, see LICENSE.TXT 12 | 13 | """BLEU score implementation.""" 14 | 15 | import math 16 | import sys 17 | from fractions import Fraction 18 | import warnings 19 | from collections import Counter 20 | 21 | from evaluator.CodeBLEU.utils import ngrams 22 | import pdb 23 | 24 | 25 | def sentence_bleu( 26 | references, 27 | hypothesis, 28 | weights=(0.25, 0.25, 0.25, 0.25), 29 | smoothing_function=None, 30 | auto_reweigh=False, 31 | ): 32 | """ 33 | Calculate BLEU score (Bilingual Evaluation Understudy) from 34 | Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu. 2002. 35 | "BLEU: a method for automatic evaluation of machine translation." 36 | In Proceedings of ACL. http://www.aclweb.org/anthology/P02-1040.pdf 37 | >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 38 | ... 'ensures', 'that', 'the', 'military', 'always', 39 | ... 'obeys', 'the', 'commands', 'of', 'the', 'party'] 40 | >>> hypothesis2 = ['It', 'is', 'to', 'insure', 'the', 'troops', 41 | ... 'forever', 'hearing', 'the', 'activity', 'guidebook', 42 | ... 'that', 'party', 'direct'] 43 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 44 | ... 'ensures', 'that', 'the', 'military', 'will', 'forever', 45 | ... 'heed', 'Party', 'commands'] 46 | >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which', 47 | ... 'guarantees', 'the', 'military', 'forces', 'always', 48 | ... 'being', 'under', 'the', 'command', 'of', 'the', 49 | ... 'Party'] 50 | >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 51 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 52 | ... 'of', 'the', 'party'] 53 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis1) # doctest: +ELLIPSIS 54 | 0.5045... 55 | If there is no ngrams overlap for any order of n-grams, BLEU returns the 56 | value 0. This is because the precision for the order of n-grams without 57 | overlap is 0, and the geometric mean in the final BLEU score computation 58 | multiplies the 0 with the precision of other n-grams. This results in 0 59 | (independently of the precision of the othe n-gram orders). The following 60 | example has zero 3-gram and 4-gram overlaps: 61 | >>> round(sentence_bleu([reference1, reference2, reference3], hypothesis2),4) # doctest: +ELLIPSIS 62 | 0.0 63 | To avoid this harsh behaviour when no ngram overlaps are found a smoothing 64 | function can be used. 65 | >>> chencherry = SmoothingFunction() 66 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis2, 67 | ... smoothing_function=chencherry.method1) # doctest: +ELLIPSIS 68 | 0.0370... 69 | The default BLEU calculates a score for up to 4-grams using uniform 70 | weights (this is called BLEU-4). To evaluate your translations with 71 | higher/lower order ngrams, use customized weights. E.g. when accounting 72 | for up to 5-grams with uniform weights (this is called BLEU-5) use: 73 | >>> weights = (1./5., 1./5., 1./5., 1./5., 1./5.) 74 | >>> sentence_bleu([reference1, reference2, reference3], hypothesis1, weights) # doctest: +ELLIPSIS 75 | 0.3920... 76 | :param references: reference sentences 77 | :type references: list(list(str)) 78 | :param hypothesis: a hypothesis sentence 79 | :type hypothesis: list(str) 80 | :param weights: weights for unigrams, bigrams, trigrams and so on 81 | :type weights: list(float) 82 | :param smoothing_function: 83 | :type smoothing_function: SmoothingFunction 84 | :param auto_reweigh: Option to re-normalize the weights uniformly. 85 | :type auto_reweigh: bool 86 | :return: The sentence-level BLEU score. 87 | :rtype: float 88 | """ 89 | return corpus_bleu( 90 | [references], [hypothesis], weights, smoothing_function, auto_reweigh 91 | ) 92 | 93 | 94 | def corpus_bleu( 95 | list_of_references, 96 | hypotheses, 97 | weights=(0.25, 0.25, 0.25, 0.25), 98 | smoothing_function=None, 99 | auto_reweigh=False, 100 | ): 101 | """ 102 | Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all 103 | the hypotheses and their respective references. 104 | Instead of averaging the sentence level BLEU scores (i.e. marco-average 105 | precision), the original BLEU metric (Papineni et al. 2002) accounts for 106 | the micro-average precision (i.e. summing the numerators and denominators 107 | for each hypothesis-reference(s) pairs before the division). 108 | >>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 109 | ... 'ensures', 'that', 'the', 'military', 'always', 110 | ... 'obeys', 'the', 'commands', 'of', 'the', 'party'] 111 | >>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 112 | ... 'ensures', 'that', 'the', 'military', 'will', 'forever', 113 | ... 'heed', 'Party', 'commands'] 114 | >>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which', 115 | ... 'guarantees', 'the', 'military', 'forces', 'always', 116 | ... 'being', 'under', 'the', 'command', 'of', 'the', 'Party'] 117 | >>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 118 | ... 'army', 'always', 'to', 'heed', 'the', 'directions', 119 | ... 'of', 'the', 'party'] 120 | >>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was', 121 | ... 'interested', 'in', 'world', 'history'] 122 | >>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history', 123 | ... 'because', 'he', 'read', 'the', 'book'] 124 | >>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]] 125 | >>> hypotheses = [hyp1, hyp2] 126 | >>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS 127 | 0.5920... 128 | The example below show that corpus_bleu() is different from averaging 129 | sentence_bleu() for hypotheses 130 | >>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1) 131 | >>> score2 = sentence_bleu([ref2a], hyp2) 132 | >>> (score1 + score2) / 2 # doctest: +ELLIPSIS 133 | 0.6223... 134 | :param list_of_references: a corpus of lists of reference sentences, w.r.t. hypotheses 135 | :type list_of_references: list(list(list(str))) 136 | :param hypotheses: a list of hypothesis sentences 137 | :type hypotheses: list(list(str)) 138 | :param weights: weights for unigrams, bigrams, trigrams and so on 139 | :type weights: list(float) 140 | :param smoothing_function: 141 | :type smoothing_function: SmoothingFunction 142 | :param auto_reweigh: Option to re-normalize the weights uniformly. 143 | :type auto_reweigh: bool 144 | :return: The corpus-level BLEU score. 145 | :rtype: float 146 | """ 147 | # Before proceeding to compute BLEU, perform sanity checks. 148 | 149 | p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches. 150 | p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref. 151 | hyp_lengths, ref_lengths = 0, 0 152 | 153 | assert len(list_of_references) == len(hypotheses), ( 154 | "The number of hypotheses and their reference(s) should be the " "same " 155 | ) 156 | 157 | # Iterate through each hypothesis and their corresponding references. 158 | for references, hypothesis in zip(list_of_references, hypotheses): 159 | # For each order of ngram, calculate the numerator and 160 | # denominator for the corpus-level modified precision. 161 | for i, _ in enumerate(weights, start=1): 162 | p_i_numeraotr, p_i_denominator = modified_recall(references, hypothesis, i) 163 | p_numerators[i] += p_i_numeraotr 164 | p_denominators[i] += p_i_denominator 165 | 166 | # Calculate the hypothesis length and the closest reference length. 167 | # Adds them to the corpus-level hypothesis and reference counts. 168 | hyp_len = len(hypothesis) 169 | hyp_lengths += hyp_len 170 | ref_lengths += closest_ref_length(references, hyp_len) 171 | 172 | # Calculate corpus-level brevity penalty. 173 | bp = brevity_penalty(ref_lengths, hyp_lengths) 174 | 175 | # Uniformly re-weighting based on maximum hypothesis lengths if largest 176 | # order of n-grams < 4 and weights is set at default. 177 | if auto_reweigh: 178 | if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25): 179 | weights = (1 / hyp_lengths,) * hyp_lengths 180 | 181 | # Collects the various recall values for the different ngram orders. 182 | p_n = [ 183 | (p_numerators[i], p_denominators[i]) 184 | for i, _ in enumerate(weights, start=1) 185 | ] 186 | 187 | # Returns 0 if there's no matching n-grams 188 | # We only need to check for p_numerators[1] == 0, since if there's 189 | # no unigrams, there won't be any higher order ngrams. 190 | if p_numerators[1] == 0: 191 | return 0 192 | 193 | # If there's no smoothing, set use method0 from SmoothinFunction class. 194 | if not smoothing_function: 195 | smoothing_function = SmoothingFunction().method1 196 | # Smoothen the modified precision. 197 | # Note: smoothing_function() may convert values into floats; 198 | # it tries to retain the Fraction object as much as the 199 | # smoothing method allows. 200 | p_n = smoothing_function( 201 | p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths 202 | ) 203 | # pdb.set_trace() 204 | s = (w_i * math.log(p_i[0]/p_i[1]) for w_i, p_i in zip(weights, p_n)) 205 | s = bp * math.exp(math.fsum(s)) 206 | return s 207 | 208 | 209 | def modified_recall(references, hypothesis, n): 210 | """ 211 | Calculate modified ngram recall. 212 | :param references: A list of reference translations. 213 | :type references: list(list(str)) 214 | :param hypothesis: A hypothesis translation. 215 | :type hypothesis: list(str) 216 | :param n: The ngram order. 217 | :type n: int 218 | :return: BLEU's modified precision for the nth order ngram. 219 | :rtype: Fraction 220 | """ 221 | # Extracts all ngrams in hypothesis 222 | # Set an empty Counter if hypothesis is empty. 223 | # pdb.set_trace() 224 | numerator = 0 225 | denominator = 0 226 | 227 | counts = Counter(ngrams(hypothesis, n)) if len(hypothesis) >= n else Counter() 228 | # Extract a union of references' counts. 229 | # max_counts = reduce(or_, [Counter(ngrams(ref, n)) for ref in references]) 230 | max_counts = {} 231 | for reference_and_weights in references: 232 | reference = reference_and_weights[0] 233 | weights = reference_and_weights[1] 234 | reference_counts = ( 235 | Counter(ngrams(reference, n)) if len(reference) >= n else Counter() 236 | ) 237 | # for ngram in reference_counts: 238 | # max_counts[ngram] = max(max_counts.get(ngram, 0), counts[ngram]) 239 | clipped_counts = { 240 | ngram: min(count, counts[ngram]) for ngram, count in reference_counts.items() 241 | } 242 | # reweight 243 | if n == 1 and len(weights) == len(reference_counts): 244 | def weighted_sum(weights, counts): 245 | sum_counts = 0 246 | for ngram, count in counts.items(): 247 | sum_counts += count * (weights[ngram[0]] if ngram[0] in weights else 1) 248 | return sum_counts 249 | 250 | numerator += weighted_sum(weights, clipped_counts) 251 | denominator += max(1, weighted_sum(weights, reference_counts)) 252 | 253 | else: 254 | numerator += sum(clipped_counts.values()) 255 | denominator += max(1, sum(reference_counts.values())) 256 | 257 | # # Assigns the intersection between hypothesis and references' counts. 258 | # clipped_counts = { 259 | # ngram: min(count, max_counts[ngram]) for ngram, count in counts.items() 260 | # } 261 | 262 | # numerator += sum(clipped_counts.values()) 263 | # # Ensures that denominator is minimum 1 to avoid ZeroDivisionError. 264 | # # Usually this happens when the ngram order is > len(reference). 265 | # denominator += max(1, sum(counts.values())) 266 | 267 | #return Fraction(numerator, denominator, _normalize=False) 268 | return numerator, denominator 269 | 270 | 271 | def closest_ref_length(references, hyp_len): 272 | """ 273 | This function finds the reference that is the closest length to the 274 | hypothesis. The closest reference length is referred to as *r* variable 275 | from the brevity penalty formula in Papineni et. al. (2002) 276 | :param references: A list of reference translations. 277 | :type references: list(list(str)) 278 | :param hyp_len: The length of the hypothesis. 279 | :type hyp_len: int 280 | :return: The length of the reference that's closest to the hypothesis. 281 | :rtype: int 282 | """ 283 | ref_lens = (len(reference) for reference in references) 284 | closest_ref_len = min( 285 | ref_lens, key=lambda ref_len: (abs(ref_len - hyp_len), ref_len) 286 | ) 287 | return closest_ref_len 288 | 289 | 290 | def brevity_penalty(closest_ref_len, hyp_len): 291 | """ 292 | Calculate brevity penalty. 293 | As the modified n-gram precision still has the problem from the short 294 | length sentence, brevity penalty is used to modify the overall BLEU 295 | score according to length. 296 | An example from the paper. There are three references with length 12, 15 297 | and 17. And a concise hypothesis of the length 12. The brevity penalty is 1. 298 | >>> reference1 = list('aaaaaaaaaaaa') # i.e. ['a'] * 12 299 | >>> reference2 = list('aaaaaaaaaaaaaaa') # i.e. ['a'] * 15 300 | >>> reference3 = list('aaaaaaaaaaaaaaaaa') # i.e. ['a'] * 17 301 | >>> hypothesis = list('aaaaaaaaaaaa') # i.e. ['a'] * 12 302 | >>> references = [reference1, reference2, reference3] 303 | >>> hyp_len = len(hypothesis) 304 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 305 | >>> brevity_penalty(closest_ref_len, hyp_len) 306 | 1.0 307 | In case a hypothesis translation is shorter than the references, penalty is 308 | applied. 309 | >>> references = [['a'] * 28, ['a'] * 28] 310 | >>> hypothesis = ['a'] * 12 311 | >>> hyp_len = len(hypothesis) 312 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 313 | >>> brevity_penalty(closest_ref_len, hyp_len) 314 | 0.2635971381157267 315 | The length of the closest reference is used to compute the penalty. If the 316 | length of a hypothesis is 12, and the reference lengths are 13 and 2, the 317 | penalty is applied because the hypothesis length (12) is less then the 318 | closest reference length (13). 319 | >>> references = [['a'] * 13, ['a'] * 2] 320 | >>> hypothesis = ['a'] * 12 321 | >>> hyp_len = len(hypothesis) 322 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 323 | >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS 324 | 0.9200... 325 | The brevity penalty doesn't depend on reference order. More importantly, 326 | when two reference sentences are at the same distance, the shortest 327 | reference sentence length is used. 328 | >>> references = [['a'] * 13, ['a'] * 11] 329 | >>> hypothesis = ['a'] * 12 330 | >>> hyp_len = len(hypothesis) 331 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 332 | >>> bp1 = brevity_penalty(closest_ref_len, hyp_len) 333 | >>> hyp_len = len(hypothesis) 334 | >>> closest_ref_len = closest_ref_length(reversed(references), hyp_len) 335 | >>> bp2 = brevity_penalty(closest_ref_len, hyp_len) 336 | >>> bp1 == bp2 == 1 337 | True 338 | A test example from mteval-v13a.pl (starting from the line 705): 339 | >>> references = [['a'] * 11, ['a'] * 8] 340 | >>> hypothesis = ['a'] * 7 341 | >>> hyp_len = len(hypothesis) 342 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 343 | >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS 344 | 0.8668... 345 | >>> references = [['a'] * 11, ['a'] * 8, ['a'] * 6, ['a'] * 7] 346 | >>> hypothesis = ['a'] * 7 347 | >>> hyp_len = len(hypothesis) 348 | >>> closest_ref_len = closest_ref_length(references, hyp_len) 349 | >>> brevity_penalty(closest_ref_len, hyp_len) 350 | 1.0 351 | :param hyp_len: The length of the hypothesis for a single sentence OR the 352 | sum of all the hypotheses' lengths for a corpus 353 | :type hyp_len: int 354 | :param closest_ref_len: The length of the closest reference for a single 355 | hypothesis OR the sum of all the closest references for every hypotheses. 356 | :type closest_ref_len: int 357 | :return: BLEU's brevity penalty. 358 | :rtype: float 359 | """ 360 | if hyp_len > closest_ref_len: 361 | return 1 362 | # If hypothesis is empty, brevity penalty = 0 should result in BLEU = 0.0 363 | elif hyp_len == 0: 364 | return 0 365 | else: 366 | return math.exp(1 - closest_ref_len / hyp_len) 367 | 368 | 369 | class SmoothingFunction: 370 | """ 371 | This is an implementation of the smoothing techniques 372 | for segment-level BLEU scores that was presented in 373 | Boxing Chen and Collin Cherry (2014) A Systematic Comparison of 374 | Smoothing Techniques for Sentence-Level BLEU. In WMT14. 375 | http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf 376 | """ 377 | 378 | def __init__(self, epsilon=0.1, alpha=5, k=5): 379 | """ 380 | This will initialize the parameters required for the various smoothing 381 | techniques, the default values are set to the numbers used in the 382 | experiments from Chen and Cherry (2014). 383 | >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 'ensures', 384 | ... 'that', 'the', 'military', 'always', 'obeys', 'the', 385 | ... 'commands', 'of', 'the', 'party'] 386 | >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 'ensures', 387 | ... 'that', 'the', 'military', 'will', 'forever', 'heed', 388 | ... 'Party', 'commands'] 389 | >>> chencherry = SmoothingFunction() 390 | >>> print(sentence_bleu([reference1], hypothesis1)) # doctest: +ELLIPSIS 391 | 0.4118... 392 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method0)) # doctest: +ELLIPSIS 393 | 0.4118... 394 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method1)) # doctest: +ELLIPSIS 395 | 0.4118... 396 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method2)) # doctest: +ELLIPSIS 397 | 0.4489... 398 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method3)) # doctest: +ELLIPSIS 399 | 0.4118... 400 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method4)) # doctest: +ELLIPSIS 401 | 0.4118... 402 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method5)) # doctest: +ELLIPSIS 403 | 0.4905... 404 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method6)) # doctest: +ELLIPSIS 405 | 0.4135... 406 | >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method7)) # doctest: +ELLIPSIS 407 | 0.4905... 408 | :param epsilon: the epsilon value use in method 1 409 | :type epsilon: float 410 | :param alpha: the alpha value use in method 6 411 | :type alpha: int 412 | :param k: the k value use in method 4 413 | :type k: int 414 | """ 415 | self.epsilon = epsilon 416 | self.alpha = alpha 417 | self.k = k 418 | 419 | def method0(self, p_n, *args, **kwargs): 420 | """ 421 | No smoothing. 422 | """ 423 | p_n_new = [] 424 | for i, p_i in enumerate(p_n): 425 | if p_i[0] != 0: 426 | p_n_new.append(p_i) 427 | else: 428 | _msg = str( 429 | "\nThe hypothesis contains 0 counts of {}-gram overlaps.\n" 430 | "Therefore the BLEU score evaluates to 0, independently of\n" 431 | "how many N-gram overlaps of lower order it contains.\n" 432 | "Consider using lower n-gram order or use " 433 | "SmoothingFunction()" 434 | ).format(i + 1) 435 | warnings.warn(_msg) 436 | # When numerator==0 where denonminator==0 or !=0, the result 437 | # for the precision score should be equal to 0 or undefined. 438 | # Due to BLEU geometric mean computation in logarithm space, 439 | # we we need to take the return sys.float_info.min such that 440 | # math.log(sys.float_info.min) returns a 0 precision score. 441 | p_n_new.append(sys.float_info.min) 442 | return p_n_new 443 | 444 | def method1(self, p_n, *args, **kwargs): 445 | """ 446 | Smoothing method 1: Add *epsilon* counts to precision with 0 counts. 447 | """ 448 | return [ 449 | ((p_i[0] + self.epsilon), p_i[1]) 450 | if p_i[0] == 0 451 | else p_i 452 | for p_i in p_n 453 | ] 454 | 455 | def method2(self, p_n, *args, **kwargs): 456 | """ 457 | Smoothing method 2: Add 1 to both numerator and denominator from 458 | Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of 459 | machine translation quality using longest common subsequence and 460 | skip-bigram statistics. In ACL04. 461 | """ 462 | return [ 463 | (p_i[0] + 1, p_i[1] + 1) 464 | for p_i in p_n 465 | ] 466 | 467 | def method3(self, p_n, *args, **kwargs): 468 | """ 469 | Smoothing method 3: NIST geometric sequence smoothing 470 | The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each 471 | precision score whose matching n-gram count is null. 472 | k is 1 for the first 'n' value for which the n-gram match count is null/ 473 | For example, if the text contains: 474 | - one 2-gram match 475 | - and (consequently) two 1-gram matches 476 | the n-gram count for each individual precision score would be: 477 | - n=1 => prec_count = 2 (two unigrams) 478 | - n=2 => prec_count = 1 (one bigram) 479 | - n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1) 480 | - n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2) 481 | """ 482 | incvnt = 1 # From the mteval-v13a.pl, it's referred to as k. 483 | for i, p_i in enumerate(p_n): 484 | if p_i.numerator == 0: 485 | p_n[i] = 1 / (2 ** incvnt * p_i.denominator) 486 | incvnt += 1 487 | return p_n 488 | 489 | def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 490 | """ 491 | Smoothing method 4: 492 | Shorter translations may have inflated precision values due to having 493 | smaller denominators; therefore, we give them proportionally 494 | smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry 495 | suggests dividing by 1/ln(len(T)), where T is the length of the translation. 496 | """ 497 | hyp_len = hyp_len if hyp_len else len(hypothesis) 498 | for i, p_i in enumerate(p_n): 499 | if p_i.numerator == 0 and hyp_len != 0: 500 | incvnt = i + 1 * self.k / math.log( 501 | hyp_len 502 | ) # Note that this K is different from the K from NIST. 503 | p_n[i] = incvnt / p_i.denominator 504 | return p_n 505 | 506 | def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 507 | """ 508 | Smoothing method 5: 509 | The matched counts for similar values of n should be similar. To a 510 | calculate the n-gram matched count, it averages the n−1, n and n+1 gram 511 | matched counts. 512 | """ 513 | hyp_len = hyp_len if hyp_len else len(hypothesis) 514 | m = {} 515 | # Requires an precision value for an addition ngram order. 516 | p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)] 517 | m[-1] = p_n[0] + 1 518 | for i, p_i in enumerate(p_n): 519 | p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3 520 | m[i] = p_n[i] 521 | return p_n 522 | 523 | def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 524 | """ 525 | Smoothing method 6: 526 | Interpolates the maximum likelihood estimate of the precision *p_n* with 527 | a prior estimate *pi0*. The prior is estimated by assuming that the ratio 528 | between pn and pn−1 will be the same as that between pn−1 and pn−2; from 529 | Gao and He (2013) Training MRF-Based Phrase Translation Models using 530 | Gradient Ascent. In NAACL. 531 | """ 532 | hyp_len = hyp_len if hyp_len else len(hypothesis) 533 | # This smoothing only works when p_1 and p_2 is non-zero. 534 | # Raise an error with an appropriate message when the input is too short 535 | # to use this smoothing technique. 536 | assert p_n[2], "This smoothing method requires non-zero precision for bigrams." 537 | for i, p_i in enumerate(p_n): 538 | if i in [0, 1]: # Skips the first 2 orders of ngrams. 539 | continue 540 | else: 541 | pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2] 542 | # No. of ngrams in translation that matches the reference. 543 | m = p_i.numerator 544 | # No. of ngrams in translation. 545 | l = sum(1 for _ in ngrams(hypothesis, i + 1)) 546 | # Calculates the interpolated precision. 547 | p_n[i] = (m + self.alpha * pi0) / (l + self.alpha) 548 | return p_n 549 | 550 | def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs): 551 | """ 552 | Smoothing method 7: 553 | Interpolates methods 4 and 5. 554 | """ 555 | hyp_len = hyp_len if hyp_len else len(hypothesis) 556 | p_n = self.method4(p_n, references, hypothesis, hyp_len) 557 | p_n = self.method5(p_n, references, hypothesis, hyp_len) 558 | return p_n 559 | -------------------------------------------------------------------------------- /CodeT5+/evaluator/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | """ 23 | 24 | import collections 25 | import math 26 | import json 27 | import re 28 | 29 | 30 | def _get_ngrams(segment, max_order): 31 | """Extracts all n-grams upto a given maximum order from an input segment. 32 | 33 | Args: 34 | segment: text segment from which n-grams will be extracted. 35 | max_order: maximum length in tokens of the n-grams returned by this 36 | methods. 37 | 38 | Returns: 39 | The Counter containing all n-grams upto max_order in segment 40 | with a count of how many times each n-gram occurred. 41 | """ 42 | ngram_counts = collections.Counter() 43 | for order in range(1, max_order + 1): 44 | for i in range(0, len(segment) - order + 1): 45 | ngram = tuple(segment[i:i+order]) 46 | ngram_counts[ngram] += 1 47 | return ngram_counts 48 | 49 | 50 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 51 | smooth=False): 52 | """Computes BLEU score of translated segments against one or more references. 53 | 54 | Args: 55 | reference_corpus: list of lists of references for each translation. Each 56 | reference should be tokenized into a list of tokens. 57 | translation_corpus: list of translations to score. Each translation 58 | should be tokenized into a list of tokens. 59 | max_order: Maximum n-gram order to use when computing BLEU score. 60 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 61 | 62 | Returns: 63 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 64 | precisions and brevity penalty. 65 | """ 66 | matches_by_order = [0] * max_order 67 | possible_matches_by_order = [0] * max_order 68 | reference_length = 0 69 | translation_length = 0 70 | for (references, translation) in zip(reference_corpus, 71 | translation_corpus): 72 | reference_length += min(len(r) for r in references) 73 | translation_length += len(translation) 74 | 75 | merged_ref_ngram_counts = collections.Counter() 76 | for reference in references: 77 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 78 | translation_ngram_counts = _get_ngrams(translation, max_order) 79 | overlap = translation_ngram_counts & merged_ref_ngram_counts 80 | for ngram in overlap: 81 | matches_by_order[len(ngram)-1] += overlap[ngram] 82 | for order in range(1, max_order+1): 83 | possible_matches = len(translation) - order + 1 84 | if possible_matches > 0: 85 | possible_matches_by_order[order-1] += possible_matches 86 | 87 | precisions = [0] * max_order 88 | for i in range(0, max_order): 89 | if smooth: 90 | precisions[i] = ((matches_by_order[i] + 1.) / 91 | (possible_matches_by_order[i] + 1.)) 92 | else: 93 | if possible_matches_by_order[i] > 0: 94 | precisions[i] = (float(matches_by_order[i]) / 95 | possible_matches_by_order[i]) 96 | else: 97 | precisions[i] = 0.0 98 | 99 | if min(precisions) > 0: 100 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 101 | geo_mean = math.exp(p_log_sum) 102 | else: 103 | geo_mean = 0 104 | 105 | ratio = float(translation_length) / reference_length 106 | 107 | if ratio > 1.0: 108 | bp = 1. 109 | else: 110 | bp = math.exp(1 - 1. / ratio) 111 | 112 | bleu = geo_mean * bp 113 | 114 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 115 | 116 | 117 | def _bleu(ref_file, trans_file, subword_option=None): 118 | max_order = 4 119 | smooth = True 120 | ref_files = [ref_file] 121 | reference_text = [] 122 | for reference_filename in ref_files: 123 | with open(reference_filename) as fh: 124 | reference_text.append(fh.readlines()) 125 | per_segment_references = [] 126 | for references in zip(*reference_text): 127 | reference_list = [] 128 | for reference in references: 129 | reference_list.append(reference.strip().split()) 130 | per_segment_references.append(reference_list) 131 | translations = [] 132 | with open(trans_file) as fh: 133 | for line in fh: 134 | translations.append(line.strip().split()) 135 | bleu_score, _, _, _, _, _ = compute_bleu(per_segment_references, translations, max_order, smooth) 136 | return round(100 * bleu_score,2) 137 | 138 | def _bleu_json(json_file): 139 | max_order = 4 140 | smooth = True 141 | per_segment_references = [] 142 | translations = [] 143 | with open(json_file, 'r') as f: 144 | for json_string in f: 145 | json_data = json.loads(json_string) 146 | per_segment_references.append([json_data['target'].strip().split()]) 147 | translations.append(json_data['prediction'].strip().split()) 148 | bleu_score, _, _, _, _, _ = compute_bleu(per_segment_references, translations, max_order, smooth) 149 | return round(100 * bleu_score,2) 150 | 151 | def _bleu_json_select(json_file, args, naive=None): 152 | max_order = 4 153 | smooth = True 154 | per_segment_references = [] 155 | translations = [] 156 | with open(json_file, 'r') as f: 157 | for json_string in f: 158 | json_data = json.loads(json_string) 159 | matches = re.search(r"Translate (\S+) to (\S+): ", json_data['source']) 160 | # print(json_string) 161 | source_name = matches.groups()[0] 162 | target_name = matches.groups()[1] 163 | source_code = re.sub(r"^Translate (\S+) to (\S+): ", "", json_data['source']) 164 | # print(source_name, target_name) 165 | if source_name in args.source_names.split(',') and target_name in args.target_names.split(','): 166 | per_segment_references.append([json_data['target'].strip().split()]) 167 | if naive: 168 | translations.append(source_code.strip().split()) 169 | else: 170 | translations.append(json_data['prediction'].strip().split()) 171 | bleu_score, _, _, _, _, _ = compute_bleu(per_segment_references, translations, max_order, smooth) 172 | return round(100 * bleu_score,2) -------------------------------------------------------------------------------- /CodeT5+/evaluator/smooth_bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | ''' 4 | This script was adapted from the original version by hieuhoang1972 which is part of MOSES. 5 | ''' 6 | 7 | # $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $ 8 | 9 | '''Provides: 10 | 11 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 12 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 13 | score_cooked(alltest, n=4): Score a list of cooked test sentences. 14 | 15 | score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids. 16 | 17 | The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible. 18 | ''' 19 | 20 | import sys, math, re, xml.sax.saxutils 21 | import subprocess 22 | import os 23 | 24 | # Added to bypass NIST-style pre-processing of hyp and ref files -- wade 25 | nonorm = 0 26 | 27 | preserve_case = False 28 | eff_ref_len = "shortest" 29 | 30 | normalize1 = [ 31 | ('', ''), # strip "skipped" tags 32 | (r'-\n', ''), # strip end-of-line hyphenation and join lines 33 | (r'\n', ' '), # join lines 34 | # (r'(\d)\s+(?=\d)', r'\1'), # join digits 35 | ] 36 | normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1] 37 | 38 | normalize2 = [ 39 | (r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', r' \1 '), # tokenize punctuation. apostrophe is missing 40 | (r'([^0-9])([\.,])', r'\1 \2 '), # tokenize period and comma unless preceded by a digit 41 | (r'([\.,])([^0-9])', r' \1 \2'), # tokenize period and comma unless followed by a digit 42 | (r'([0-9])(-)', r'\1 \2 ') # tokenize dash when preceded by a digit 43 | ] 44 | normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2] 45 | 46 | 47 | def normalize(s): 48 | '''Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl.''' 49 | # Added to bypass NIST-style pre-processing of hyp and ref files -- wade 50 | if (nonorm): 51 | return s.split() 52 | if type(s) is not str: 53 | s = " ".join(s) 54 | # language-independent part: 55 | for (pattern, replace) in normalize1: 56 | s = re.sub(pattern, replace, s) 57 | s = xml.sax.saxutils.unescape(s, {'"': '"'}) 58 | # language-dependent part (assuming Western languages): 59 | s = " %s " % s 60 | if not preserve_case: 61 | s = s.lower() # this might not be identical to the original 62 | for (pattern, replace) in normalize2: 63 | s = re.sub(pattern, replace, s) 64 | return s.split() 65 | 66 | 67 | def count_ngrams(words, n=4): 68 | counts = {} 69 | for k in range(1, n + 1): 70 | for i in range(len(words) - k + 1): 71 | ngram = tuple(words[i:i + k]) 72 | counts[ngram] = counts.get(ngram, 0) + 1 73 | return counts 74 | 75 | 76 | def cook_refs(refs, n=4): 77 | '''Takes a list of reference sentences for a single segment 78 | and returns an object that encapsulates everything that BLEU 79 | needs to know about them.''' 80 | 81 | refs = [normalize(ref) for ref in refs] 82 | maxcounts = {} 83 | for ref in refs: 84 | counts = count_ngrams(ref, n) 85 | for (ngram, count) in counts.items(): 86 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) 87 | return ([len(ref) for ref in refs], maxcounts) 88 | 89 | 90 | def cook_test(test, item, n=4): 91 | '''Takes a test sentence and returns an object that 92 | encapsulates everything that BLEU needs to know about it.''' 93 | (reflens, refmaxcounts) = item 94 | test = normalize(test) 95 | result = {} 96 | result["testlen"] = len(test) 97 | 98 | # Calculate effective reference sentence length. 99 | 100 | if eff_ref_len == "shortest": 101 | result["reflen"] = min(reflens) 102 | elif eff_ref_len == "average": 103 | result["reflen"] = float(sum(reflens)) / len(reflens) 104 | elif eff_ref_len == "closest": 105 | min_diff = None 106 | for reflen in reflens: 107 | if min_diff is None or abs(reflen - len(test)) < min_diff: 108 | min_diff = abs(reflen - len(test)) 109 | result['reflen'] = reflen 110 | 111 | result["guess"] = [max(len(test) - k + 1, 0) for k in range(1, n + 1)] 112 | 113 | result['correct'] = [0] * n 114 | counts = count_ngrams(test, n) 115 | for (ngram, count) in counts.items(): 116 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) 117 | 118 | return result 119 | 120 | 121 | def score_cooked(allcomps, n=4, ground=0, smooth=1): 122 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n} 123 | for comps in allcomps: 124 | for key in ['testlen', 'reflen']: 125 | totalcomps[key] += comps[key] 126 | for key in ['guess', 'correct']: 127 | for k in range(n): 128 | totalcomps[key][k] += comps[key][k] 129 | logbleu = 0.0 130 | all_bleus = [] 131 | for k in range(n): 132 | correct = totalcomps['correct'][k] 133 | guess = totalcomps['guess'][k] 134 | addsmooth = 0 135 | if smooth == 1 and k > 0: 136 | addsmooth = 1 137 | logbleu += math.log(correct + addsmooth + sys.float_info.min) - math.log(guess + addsmooth + sys.float_info.min) 138 | if guess == 0: 139 | all_bleus.append(-10000000) 140 | else: 141 | all_bleus.append(math.log(correct + sys.float_info.min) - math.log(guess)) 142 | 143 | logbleu /= float(n) 144 | all_bleus.insert(0, logbleu) 145 | 146 | brevPenalty = min(0, 1 - float(totalcomps['reflen'] + 1) / (totalcomps['testlen'] + 1)) 147 | for i in range(len(all_bleus)): 148 | if i == 0: 149 | all_bleus[i] += brevPenalty 150 | all_bleus[i] = math.exp(all_bleus[i]) 151 | return all_bleus 152 | 153 | 154 | def bleu(refs, candidate, ground=0, smooth=1): 155 | refs = cook_refs(refs) 156 | test = cook_test(candidate, refs) 157 | return score_cooked([test], ground=ground, smooth=smooth) 158 | 159 | 160 | def splitPuncts(line): 161 | return ' '.join(re.findall(r"[\w]+|[^\s\w]", line)) 162 | 163 | 164 | def computeMaps(predictions, goldfile): 165 | predictionMap = {} 166 | goldMap = {} 167 | gf = open(goldfile, 'r') 168 | 169 | for row in predictions: 170 | cols = row.strip().split('\t') 171 | if len(cols) == 1: 172 | (rid, pred) = (cols[0], '') 173 | else: 174 | (rid, pred) = (cols[0], cols[1]) 175 | predictionMap[rid] = [splitPuncts(pred.strip().lower())] 176 | 177 | for row in gf: 178 | (rid, pred) = row.split('\t') 179 | if rid in predictionMap: # Only insert if the id exists for the method 180 | if rid not in goldMap: 181 | goldMap[rid] = [] 182 | goldMap[rid].append(splitPuncts(pred.strip().lower())) 183 | 184 | sys.stderr.write('Total: ' + str(len(goldMap)) + '\n') 185 | return (goldMap, predictionMap) 186 | 187 | 188 | # m1 is the reference map 189 | # m2 is the prediction map 190 | def bleuFromMaps(m1, m2): 191 | score = [0] * 5 192 | num = 0.0 193 | 194 | for key in m1: 195 | if key in m2: 196 | bl = bleu(m1[key], m2[key][0]) 197 | score = [score[i] + bl[i] for i in range(0, len(bl))] 198 | num += 1 199 | return [s * 100.0 / num for s in score] 200 | 201 | 202 | if __name__ == '__main__': 203 | reference_file = sys.argv[1] 204 | predictions = [] 205 | for row in sys.stdin: 206 | predictions.append(row) 207 | (goldMap, predictionMap) = computeMaps(predictions, reference_file) 208 | print(bleuFromMaps(goldMap, predictionMap)[0]) 209 | -------------------------------------------------------------------------------- /CodeT5+/images/DLTrans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeixiangYAN/CodeTransOcean/42e2cd3b41b3a18a6dba3dfdf425f772360304ca/CodeT5+/images/DLTrans.png -------------------------------------------------------------------------------- /CodeT5+/images/MultilingualTrans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeixiangYAN/CodeTransOcean/42e2cd3b41b3a18a6dba3dfdf425f772360304ca/CodeT5+/images/MultilingualTrans.png -------------------------------------------------------------------------------- /CodeT5+/images/NicheTrans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeixiangYAN/CodeTransOcean/42e2cd3b41b3a18a6dba3dfdf425f772360304ca/CodeT5+/images/NicheTrans.png -------------------------------------------------------------------------------- /CodeT5+/run_preprocess.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | 18 | """ 19 | 20 | import logging 21 | import argparse 22 | import numpy as np 23 | import json 24 | import re 25 | 26 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 27 | datefmt='%m/%d/%Y %H:%M:%S', 28 | level=logging.INFO) 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--input_file", default=None, type=str) 35 | parser.add_argument("--output_file", default=None, type=str) 36 | parser.add_argument('--source_names', type=str, default=None, 37 | help="source_names") 38 | parser.add_argument('--target_names', type=str, default=None, 39 | help="target_names") 40 | parser.add_argument('--sub_task', type=str, default=None, 41 | help="sub_task") 42 | args = parser.parse_args() 43 | 44 | with open(args.input_file, 'r') as fi: 45 | with open(args.output_file, 'w') as fw: 46 | for line in fi: 47 | json_data = json.loads(line) 48 | cur_keys = list(json_data.keys()) 49 | if args.sub_task in ['MultilingualTrans', 'RareTrans']: 50 | source_lang = cur_keys[2] 51 | target_lang = cur_keys[3] 52 | elif args.sub_task in ['LLMTrans']: 53 | source_lang = cur_keys[3] 54 | target_lang = cur_keys[2] 55 | else: 56 | source_lang = cur_keys[1] 57 | target_lang = cur_keys[2] 58 | if source_lang in args.source_names.split(',') and target_lang in args.target_names.split(','): 59 | source = 'Translate ' + source_lang + ' to ' + target_lang + ': ' + json_data[source_lang] 60 | target = json_data[target_lang] 61 | del json_data[source_lang] 62 | del json_data[target_lang] 63 | json_data['source'] = source 64 | json_data['target'] = target 65 | json_string = json.dumps(json_data) 66 | fw.write(json_string + '\n') 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /CodeT5+/run_score.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | 18 | """ 19 | 20 | import logging 21 | import argparse 22 | import numpy as np 23 | import json 24 | import re 25 | 26 | from evaluator.CodeBLEU import calc_code_bleu 27 | from evaluator.bleu import _bleu, _bleu_json, _bleu_json_select 28 | 29 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 30 | datefmt='%m/%d/%Y %H:%M:%S', 31 | level=logging.INFO) 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--input_file", default=None, type=str) 38 | parser.add_argument('--source_names', type=str, default=None, 39 | help="source_names") 40 | parser.add_argument('--target_names', type=str, default=None, 41 | help="target_names") 42 | parser.add_argument('--codebleu', action='store_true') 43 | parser.add_argument('--naive', action='store_true') 44 | args = parser.parse_args() 45 | dev_accs = [] 46 | hypothesis = [] 47 | pre_references = [] 48 | with open(args.input_file, 'r') as f: 49 | for line in f: 50 | json_data = json.loads(line) 51 | matches = re.search(r"^Translate (\S+) to (\S+): ", json_data['source']) 52 | source_name = matches.groups()[0] 53 | target_name = matches.groups()[1] 54 | source_code = re.sub(r"^Translate (\S+) to (\S+): ", "", json_data['source']) 55 | if source_name in args.source_names.split(',') and target_name in args.target_names.split(','): 56 | if args.naive: 57 | dev_accs.append(source_code.strip() == json_data['target'].strip()) 58 | hypothesis.append(source_code.strip()) 59 | else: 60 | dev_accs.append(json_data['prediction'].strip() == json_data['target'].strip()) 61 | hypothesis.append(json_data['prediction'].strip()) 62 | pre_references.append(json_data['target'].strip()) 63 | 64 | pre_references = [pre_references] 65 | bleu = round(_bleu_json_select(args.input_file, args, args.naive), 2) 66 | if args.codebleu: 67 | codebleu = calc_code_bleu.get_codebleu_list(pre_references, hypothesis, 'python') 68 | result = {'em': round(np.mean(dev_accs) * 100, 2), 'bleu': bleu, 'codebleu': round(codebleu * 100, 2)} 69 | else: 70 | result = {'em': round(np.mean(dev_accs) * 100, 2), 'bleu': bleu} 71 | print(result) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /CodeT5+/run_train_DLTrans.sh: -------------------------------------------------------------------------------- 1 | DATADIR="" 2 | run() { 3 | #preprocess data 4 | NAME=$(date +%Y%m%d%H)_DLTrans_many_to_many_seed${seed}_lr${lr}_maxlen512 5 | OUTPUT_DIR=output/saved_models/DLTrans/many_to_many/${NAME} 6 | echo ${NAME} 7 | mkdir -p ${OUTPUT_DIR} 8 | mkdir -p ${OUTPUT_DIR}/cache_data 9 | 10 | for name in 'test' 'valid' 'train'; do 11 | python run_preprocess.py --input_file ${DATADIR}/DLTrans/dl_${name}.json \ 12 | --output_file ${OUTPUT_DIR}/cache_data/dl_${name}.json \ 13 | --source_names ${source_names} --target_names ${target_names} --sub_task 'DLTrans' 14 | done 15 | 16 | #train and predict model 17 | PORT_ID=$(expr $RANDOM + 1000) 18 | 19 | CUDA_VISIBLE_DEVICES=$GPUID python -m torch.distributed.launch --nproc_per_node 1 --master_port ${PORT_ID} \ 20 | run_translation.py \ 21 | --model_name_or_path ./pretrain/codet5p-220m \ 22 | --do_train \ 23 | --do_eval \ 24 | --do_predict \ 25 | --train_file ${OUTPUT_DIR}/cache_data/dl_train.json \ 26 | --validation_file ${OUTPUT_DIR}/cache_data/dl_valid.json \ 27 | --test_file ${OUTPUT_DIR}/cache_data/dl_test.json \ 28 | --source_prefix "" \ 29 | --output_dir ${OUTPUT_DIR} \ 30 | --text_column source \ 31 | --summary_column target \ 32 | --max_source_length 512 \ 33 | --max_target_length 512 \ 34 | --per_device_train_batch_size=4 \ 35 | --gradient_accumulation_steps=4 \ 36 | --per_device_eval_batch_size=4 \ 37 | --learning_rate $lr \ 38 | --num_train_epochs 5 \ 39 | --metric_for_best_model loss \ 40 | --save_total_limit 1 \ 41 | --save_strategy epoch \ 42 | --load_best_model_at_end \ 43 | --evaluation_strategy epoch \ 44 | --overwrite_output_dir \ 45 | --predict_with_generate \ 46 | --logging_strategy epoch \ 47 | --logging_dir ${OUTPUT_DIR} \ 48 | --num_beams 5 \ 49 | --warmup_steps 200 \ 50 | --fp16 \ 51 | --seed $seed \ 52 | --report_to tensorboard 2>&1 | tee ${OUTPUT_DIR}/log_train.txt 53 | 54 | #calculate score 55 | echo ${source_names}_to_${target_names} 56 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 57 | --source_names ${source_names} \ 58 | --target_names ${target_names} \ 59 | --codebleu \ 60 | 2>&1 | tee ${OUTPUT_DIR}/score_${source_names}_to_${target_names}.log 61 | 62 | for source_name in mxnet paddle pytorch tensorflow; do 63 | for target_name in mxnet paddle pytorch tensorflow; do 64 | if [ "$source_name" != "$target_name" ]; then 65 | echo ${source_name}_to_${target_name} 66 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 67 | --source_names ${source_name} \ 68 | --target_names ${target_name} \ 69 | --codebleu \ 70 | 2>&1 | tee ${OUTPUT_DIR}/score_${source_name}_to_${target_name}.log 71 | fi 72 | done 73 | done 74 | 75 | } 76 | 77 | GPUID=0 78 | seed=1234 79 | lr=3e-5 80 | source_names="mxnet,paddle,pytorch,tensorflow" 81 | target_names="mxnet,paddle,pytorch,tensorflow" 82 | run & 83 | GPUID=1 84 | seed=2345 85 | lr=3e-5 86 | source_names="mxnet,paddle,pytorch,tensorflow" 87 | target_names="mxnet,paddle,pytorch,tensorflow" 88 | run & 89 | GPUID=2 90 | seed=3456 91 | lr=3e-5 92 | source_names="mxnet,paddle,pytorch,tensorflow" 93 | target_names="mxnet,paddle,pytorch,tensorflow" 94 | run & 95 | 96 | -------------------------------------------------------------------------------- /CodeT5+/run_train_MultilingualTrans_many_to_many.sh: -------------------------------------------------------------------------------- 1 | DATADIR="" # Change to your data path. 2 | 3 | run() { 4 | for seed in 1234 2345 3456; do 5 | #preprocess data 6 | NAME=$(date +%Y%m%d%H)_MultilingualTrans_many_to_many_seed${seed}_lr${lr}_maxlen1536_warmup200 7 | OUTPUT_DIR=output/saved_models/MultilingualTrans/many_to_many_${lr}_maxlen1536_seed${seed}/${NAME} 8 | echo ${NAME} 9 | mkdir -p ${OUTPUT_DIR} 10 | mkdir -p ${OUTPUT_DIR}/cache_data 11 | 12 | for name in 'test' 'valid' 'train'; do 13 | python run_preprocess.py --input_file ${DATADIR}/MultilingualTrans/multilingual_${name}.json \ 14 | --output_file ${OUTPUT_DIR}/cache_data/multilingual_${name}.json \ 15 | --source_names ${source_names} --target_names ${target_names} --sub_task 'MultilingualTrans' 16 | done 17 | 18 | # train and predict model 19 | PORT_ID=$(expr $RANDOM + 1000) 20 | 21 | CUDA_VISIBLE_DEVICES=$GPUID python -m torch.distributed.launch --nproc_per_node 4 --master_port ${PORT_ID} \ 22 | run_translation.py \ 23 | --model_name_or_path ./pretrain/codet5p-220m \ 24 | --do_train \ 25 | --do_eval \ 26 | --do_predict \ 27 | --train_file ${OUTPUT_DIR}/cache_data/multilingual_train.json \ 28 | --validation_file ${OUTPUT_DIR}/cache_data/multilingual_valid.json \ 29 | --test_file ${OUTPUT_DIR}/cache_data/multilingual_test.json \ 30 | --source_prefix "" \ 31 | --output_dir ${OUTPUT_DIR} \ 32 | --text_column source \ 33 | --summary_column target \ 34 | --max_source_length 1536 \ 35 | --max_target_length 1536 \ 36 | --per_device_train_batch_size=2 \ 37 | --gradient_accumulation_steps=2 \ 38 | --per_device_eval_batch_size=24 \ 39 | --learning_rate $lr \ 40 | --num_train_epochs 5 \ 41 | --metric_for_best_model loss \ 42 | --save_total_limit 2 \ 43 | --save_strategy epoch \ 44 | --load_best_model_at_end \ 45 | --evaluation_strategy epoch \ 46 | --overwrite_output_dir \ 47 | --predict_with_generate \ 48 | --logging_strategy epoch \ 49 | --logging_dir ${OUTPUT_DIR} \ 50 | --num_beams 1 \ 51 | --seed $seed \ 52 | --fp16 \ 53 | --warmup_steps 200 \ 54 | --report_to tensorboard 2>&1 | tee ${OUTPUT_DIR}/log_train.txt 55 | 56 | #calculate score 57 | echo ${source_names}_to_${target_names} 58 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 59 | --source_names ${source_names} \ 60 | --target_names ${target_names} \ 61 | --codebleu \ 62 | 2>&1 | tee ${OUTPUT_DIR}/score_${source_names}_to_${target_names}.log 63 | 64 | for source_name in "C" "C++" "C#" "Java" "Go" "PHP" "Python" "VB"; do 65 | for target_name in "C" "C++" "C#" "Java" "Go" "PHP" "Python" "VB"; do 66 | if [ "$source_name" != "$target_name" ]; then 67 | echo ${source_name}_to_${target_name} 68 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 69 | --source_names ${source_name} \ 70 | --target_names ${target_name} \ 71 | --codebleu \ 72 | 2>&1 | tee ${OUTPUT_DIR}/score_${source_name}_to_${target_name}.log 73 | fi 74 | done 75 | done 76 | done 77 | 78 | } 79 | 80 | GPUID='0,1,2,3' 81 | lr=3e-5 82 | source_names="C,C++,C#,Java,Go,PHP,Python,VB" 83 | target_names="C,C++,C#,Java,Go,PHP,Python,VB" 84 | run & 85 | -------------------------------------------------------------------------------- /CodeT5+/run_train_MultilingualTrans_many_to_one.sh: -------------------------------------------------------------------------------- 1 | DATADIR="" 2 | 3 | run() { 4 | for seed in 1234 2345 3456; do 5 | for target_names in "C" "C++" "C#" "Java" "Go" "PHP" "Python" "VB"; do 6 | #preprocess data 7 | NAME=$(date +%Y%m%d%H)_MultilingualTrans_many_to_one_seed${seed}_lr${lr}_maxlen1536_warmup200_${source_names}_to_${target_names} 8 | OUTPUT_DIR=output/saved_models/MultilingualTrans/many_to_one_${lr}_maxlen1536_seed${seed}/${NAME} 9 | echo ${NAME} 10 | mkdir -p ${OUTPUT_DIR} 11 | mkdir -p ${OUTPUT_DIR}/cache_data 12 | 13 | for name in 'test' 'valid' 'train'; do 14 | python run_preprocess.py --input_file ${DATADIR}/MultilingualTrans/multilingual_${name}.json \ 15 | --output_file ${OUTPUT_DIR}/cache_data/multilingual_${name}.json \ 16 | --source_names ${source_names} --target_names ${target_names} --sub_task 'MultilingualTrans' 17 | done 18 | 19 | # train and predict model 20 | PORT_ID=$(expr $RANDOM + 1000) 21 | 22 | CUDA_VISIBLE_DEVICES=$GPUID python -m torch.distributed.launch --nproc_per_node 4 --master_port ${PORT_ID} \ 23 | run_translation.py \ 24 | --model_name_or_path ./pretrain/codet5p-220m \ 25 | --do_train \ 26 | --do_eval \ 27 | --do_predict \ 28 | --train_file ${OUTPUT_DIR}/cache_data/multilingual_train.json \ 29 | --validation_file ${OUTPUT_DIR}/cache_data/multilingual_valid.json \ 30 | --test_file ${OUTPUT_DIR}/cache_data/multilingual_test.json \ 31 | --source_prefix "" \ 32 | --output_dir ${OUTPUT_DIR} \ 33 | --text_column source \ 34 | --summary_column target \ 35 | --max_source_length 1536 \ 36 | --max_target_length 1536 \ 37 | --per_device_train_batch_size=2 \ 38 | --gradient_accumulation_steps=2 \ 39 | --per_device_eval_batch_size=24 \ 40 | --learning_rate $lr \ 41 | --num_train_epochs 5 \ 42 | --metric_for_best_model loss \ 43 | --save_total_limit 2 \ 44 | --save_strategy epoch \ 45 | --load_best_model_at_end \ 46 | --evaluation_strategy epoch \ 47 | --overwrite_output_dir \ 48 | --predict_with_generate \ 49 | --logging_strategy epoch \ 50 | --logging_dir ${OUTPUT_DIR} \ 51 | --num_beams 1 \ 52 | --seed $seed \ 53 | --fp16 \ 54 | --warmup_steps 200 \ 55 | --report_to tensorboard 2>&1 | tee ${OUTPUT_DIR}/log_train.txt 56 | 57 | #calculate score 58 | echo ${source_names}_to_${target_names} 59 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 60 | --source_names ${source_names} \ 61 | --target_names ${target_names} \ 62 | --codebleu \ 63 | 2>&1 | tee ${OUTPUT_DIR}/score_${source_names}_to_${target_names}.log 64 | 65 | target_name=${target_names} 66 | for source_name in "C" "C++" "C#" "Java" "Go" "PHP" "Python" "VB"; do 67 | if [ "$source_name" != "$target_name" ]; then 68 | echo ${source_name}_to_${target_name} 69 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 70 | --source_names ${source_name} \ 71 | --target_names ${target_name} \ 72 | --codebleu \ 73 | 2>&1 | tee ${OUTPUT_DIR}/score_${source_name}_to_${target_name}.log 74 | fi 75 | done 76 | done 77 | done 78 | 79 | } 80 | 81 | GPUID='0,1,2,3' 82 | lr=3e-5 83 | source_names="C,C++,C#,Java,Go,PHP,Python,VB" 84 | run & 85 | -------------------------------------------------------------------------------- /CodeT5+/run_train_MultilingualTrans_one_to_many.sh: -------------------------------------------------------------------------------- 1 | DATADIR="" 2 | 3 | run() { 4 | for seed in 1234 2345 3456; do 5 | for source_names in "C" "C++" "C#" "Java" "Go" "PHP" "Python" "VB"; do 6 | #preprocess data 7 | NAME=$(date +%Y%m%d%H)_MultilingualTrans_one_to_many_seed${seed}_lr${lr}_maxlen1536_warmup200_${source_names}_to_${target_names} 8 | OUTPUT_DIR=output/saved_models/MultilingualTrans/one_to_many_${lr}_maxlen1536_seed${seed}/${NAME} 9 | echo ${NAME} 10 | mkdir -p ${OUTPUT_DIR} 11 | mkdir -p ${OUTPUT_DIR}/cache_data 12 | 13 | for name in 'test' 'valid' 'train'; do 14 | python run_preprocess.py --input_file ${DATADIR}/MultilingualTrans/multilingual_${name}.json \ 15 | --output_file ${OUTPUT_DIR}/cache_data/multilingual_${name}.json \ 16 | --source_names ${source_names} --target_names ${target_names} --sub_task 'MultilingualTrans' 17 | done 18 | 19 | # train and predict model 20 | PORT_ID=$(expr $RANDOM + 1000) 21 | 22 | CUDA_VISIBLE_DEVICES=$GPUID python -m torch.distributed.launch --nproc_per_node 2 --master_port ${PORT_ID} \ 23 | run_translation.py \ 24 | --model_name_or_path ./pretrain/codet5p-220m \ 25 | --do_train \ 26 | --do_eval \ 27 | --do_predict \ 28 | --train_file ${OUTPUT_DIR}/cache_data/multilingual_train.json \ 29 | --validation_file ${OUTPUT_DIR}/cache_data/multilingual_valid.json \ 30 | --test_file ${OUTPUT_DIR}/cache_data/multilingual_test.json \ 31 | --source_prefix "" \ 32 | --output_dir ${OUTPUT_DIR} \ 33 | --text_column source \ 34 | --summary_column target \ 35 | --max_source_length 1536 \ 36 | --max_target_length 1536 \ 37 | --per_device_train_batch_size=2 \ 38 | --gradient_accumulation_steps=4 \ 39 | --per_device_eval_batch_size=24 \ 40 | --learning_rate $lr \ 41 | --num_train_epochs 5 \ 42 | --metric_for_best_model loss \ 43 | --save_total_limit 2 \ 44 | --save_strategy epoch \ 45 | --load_best_model_at_end \ 46 | --evaluation_strategy epoch \ 47 | --overwrite_output_dir \ 48 | --predict_with_generate \ 49 | --logging_strategy epoch \ 50 | --logging_dir ${OUTPUT_DIR} \ 51 | --num_beams 1 \ 52 | --seed $seed \ 53 | --fp16 \ 54 | --warmup_steps 200 \ 55 | --report_to tensorboard 2>&1 | tee ${OUTPUT_DIR}/log_train.txt 56 | 57 | #calculate score 58 | echo ${source_names}_to_${target_names} 59 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 60 | --source_names ${source_names} \ 61 | --target_names ${target_names} \ 62 | --codebleu \ 63 | 2>&1 | tee ${OUTPUT_DIR}/score_${source_names}_to_${target_names}.log 64 | 65 | source_name=${source_names} 66 | for target_name in "C" "C++" "C#" "Java" "Go" "PHP" "Python" "VB"; do 67 | if [ "$source_name" != "$target_name" ]; then 68 | echo ${source_name}_to_${target_name} 69 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 70 | --source_names ${source_name} \ 71 | --target_names ${target_name} \ 72 | --codebleu \ 73 | 2>&1 | tee ${OUTPUT_DIR}/score_${source_name}_to_${target_name}.log 74 | fi 75 | done 76 | done 77 | done 78 | 79 | } 80 | 81 | GPUID='0,1' 82 | lr=3e-5 83 | target_names="C,C++,C#,Java,Go,PHP,Python,VB" 84 | run & 85 | -------------------------------------------------------------------------------- /CodeT5+/run_train_MultilingualTrans_one_to_one.sh: -------------------------------------------------------------------------------- 1 | DATADIR="" 2 | 3 | run() { 4 | 5 | for seed in 1234 2345 3456; do 6 | for target_names in "C" "C++" "C#" "Java" "Go" "PHP" "Python" "VB"; do 7 | if [ "$source_names" != "$target_names" ]; then 8 | #preprocess data 9 | NAME=$(date +%Y%m%d%H)_MultilingualTrans_one_to_one_seed${seed}_lr${lr}_maxlen1536_warmup200_${source_names}_to_${target_names} 10 | OUTPUT_DIR=output/saved_models/MultilingualTrans/one_to_one_${lr}_maxlen1536_seed${seed}/${NAME} 11 | echo ${NAME} 12 | mkdir -p ${OUTPUT_DIR} 13 | mkdir -p ${OUTPUT_DIR}/cache_data 14 | 15 | for name in 'test' 'valid' 'train'; do 16 | python run_preprocess.py --input_file ${DATADIR}/MultilingualTrans/multilingual_${name}.json \ 17 | --output_file ${OUTPUT_DIR}/cache_data/multilingual_${name}.json \ 18 | --source_names ${source_names} --target_names ${target_names} --sub_task 'MultilingualTrans' 19 | done 20 | 21 | # train and predict model 22 | PORT_ID=$(expr $RANDOM + 1000) 23 | 24 | CUDA_VISIBLE_DEVICES=$GPUID python -m torch.distributed.launch --nproc_per_node 1 --master_port ${PORT_ID} \ 25 | run_translation.py \ 26 | --model_name_or_path ./pretrain/codet5p-220m \ 27 | --do_train \ 28 | --do_eval \ 29 | --do_predict \ 30 | --train_file ${OUTPUT_DIR}/cache_data/multilingual_train.json \ 31 | --validation_file ${OUTPUT_DIR}/cache_data/multilingual_valid.json \ 32 | --test_file ${OUTPUT_DIR}/cache_data/multilingual_test.json \ 33 | --source_prefix "" \ 34 | --output_dir ${OUTPUT_DIR} \ 35 | --text_column source \ 36 | --summary_column target \ 37 | --max_source_length 1536 \ 38 | --max_target_length 1536 \ 39 | --per_device_train_batch_size=2 \ 40 | --gradient_accumulation_steps=8 \ 41 | --per_device_eval_batch_size=24 \ 42 | --learning_rate $lr \ 43 | --num_train_epochs 5 \ 44 | --metric_for_best_model loss \ 45 | --save_total_limit 2 \ 46 | --save_strategy epoch \ 47 | --load_best_model_at_end \ 48 | --evaluation_strategy epoch \ 49 | --overwrite_output_dir \ 50 | --predict_with_generate \ 51 | --logging_strategy epoch \ 52 | --logging_dir ${OUTPUT_DIR} \ 53 | --num_beams 1 \ 54 | --seed $seed \ 55 | --fp16 \ 56 | --warmup_steps 200 \ 57 | --report_to tensorboard 2>&1 | tee ${OUTPUT_DIR}/log_train.txt 58 | 59 | #calculate score 60 | echo ${source_names}_to_${target_names} 61 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 62 | --source_names ${source_names} \ 63 | --target_names ${target_names} \ 64 | --codebleu \ 65 | 2>&1 | tee ${OUTPUT_DIR}/score_${source_names}_to_${target_names}.log 66 | 67 | rm -rf ${OUTPUT_DIR}/checkpoint* 68 | 69 | fi 70 | 71 | done 72 | done 73 | 74 | } 75 | 76 | run2() { 77 | source_names="C" 78 | run 79 | source_names="C++" 80 | run 81 | } 82 | 83 | run3() { 84 | source_names="C#" 85 | run 86 | source_names="Java" 87 | run 88 | } 89 | 90 | run4() { 91 | source_names="Go" 92 | run 93 | source_names="PHP" 94 | run 95 | } 96 | 97 | run5() { 98 | source_names="Python" 99 | run 100 | source_names="VB" 101 | run 102 | } 103 | 104 | lr=3e-5 105 | GPUID='0' 106 | run2 & 107 | GPUID='1' 108 | run3 & 109 | GPUID='2' 110 | run4 & 111 | GPUID='3' 112 | run5 & 113 | -------------------------------------------------------------------------------- /CodeT5+/run_train_RareTrans_many_to_many.sh: -------------------------------------------------------------------------------- 1 | DATADIR="" 2 | 3 | run() { 4 | for seed in 1234 2345 3456; do 5 | #preprocess data 6 | NAME=$(date +%Y%m%d%H)_RareTrans_many_to_many_seed${seed}_lr${lr}_maxlen1536_warmup200 7 | OUTPUT_DIR=output/saved_models/RareTrans/many_to_many_seed${seed}/${NAME} 8 | echo ${NAME} 9 | mkdir -p ${OUTPUT_DIR} 10 | mkdir -p ${OUTPUT_DIR}/cache_data 11 | 12 | for name in 'test' 'valid' 'train'; do 13 | python run_preprocess.py --input_file ${DATADIR}/RareTrans/rare_${name}.json \ 14 | --output_file ${OUTPUT_DIR}/cache_data/rare_${name}.json \ 15 | --source_names ${source_names} --target_names ${target_names} --sub_task 'RareTrans' 16 | done 17 | 18 | # train and predict model 19 | PORT_ID=$(expr $RANDOM + 1000) 20 | 21 | CUDA_VISIBLE_DEVICES=$GPUID python -m torch.distributed.launch --nproc_per_node 4 --master_port ${PORT_ID} \ 22 | run_translation.py \ 23 | --model_name_or_path ./pretrain/codet5p-220m \ 24 | --do_train \ 25 | --do_eval \ 26 | --do_predict \ 27 | --train_file ${OUTPUT_DIR}/cache_data/rare_train.json \ 28 | --validation_file ${OUTPUT_DIR}/cache_data/rare_valid.json \ 29 | --test_file ${OUTPUT_DIR}/cache_data/rare_test.json \ 30 | --source_prefix "" \ 31 | --output_dir ${OUTPUT_DIR} \ 32 | --text_column source \ 33 | --summary_column target \ 34 | --max_source_length 1536 \ 35 | --max_target_length 1536 \ 36 | --per_device_train_batch_size=2 \ 37 | --gradient_accumulation_steps=2 \ 38 | --per_device_eval_batch_size=24 \ 39 | --learning_rate $lr \ 40 | --num_train_epochs 5 \ 41 | --metric_for_best_model loss \ 42 | --save_total_limit 2 \ 43 | --save_strategy epoch \ 44 | --load_best_model_at_end \ 45 | --evaluation_strategy epoch \ 46 | --overwrite_output_dir \ 47 | --predict_with_generate \ 48 | --logging_strategy epoch \ 49 | --logging_dir ${OUTPUT_DIR} \ 50 | --num_beams 1 \ 51 | --seed $seed \ 52 | --fp16 \ 53 | --warmup_steps 200 \ 54 | --report_to tensorboard 2>&1 | tee ${OUTPUT_DIR}/log_train.txt 55 | 56 | #calculate score 57 | echo ${source_names}_to_${target_names} 58 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 59 | --source_names ${source_names} \ 60 | --target_names ${target_names} \ 61 | --codebleu \ 62 | 2>&1 | tee ${OUTPUT_DIR}/score_many_to_many.log 63 | 64 | for target_name in "C" "C++" "C#" "Java" "Go" "PHP" "Python" "VB"; do 65 | if [ "$source_names" != "$target_name" ]; then 66 | echo ${source_names}_to_${target_name} 67 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 68 | --source_names ${source_names} \ 69 | --target_names ${target_name} \ 70 | --codebleu \ 71 | 2>&1 | tee ${OUTPUT_DIR}/score_many_to_${target_name}.log 72 | fi 73 | done 74 | done 75 | } 76 | 77 | GPUID='0,1,2,3' 78 | lr=2e-5 79 | source_names="AWK,Ada,Arturo,AutoHotKey,BBC_Basic,C,C#,C++,COBOL,Clojure,Common_Lisp,D,Delphi,Elixir,Erlang,F#,Factor,Forth,Fortran,Go,Groovy,Haskell,Icon,J,Java,Julia,Kotlin,Lua,MATLAB,Mathematica,Nim,OCaml,PHP,Pascal,Perl,PowerShell,Python,R,REXX,Racket,Ruby,Rust,Scala,Swift,Tcl,VB" 80 | target_names="AWK,Ada,Arturo,AutoHotKey,BBC_Basic,C,C#,C++,COBOL,Clojure,Common_Lisp,D,Delphi,Elixir,Erlang,F#,Factor,Forth,Fortran,Go,Groovy,Haskell,Icon,J,Java,Julia,Kotlin,Lua,MATLAB,Mathematica,Nim,OCaml,PHP,Pascal,Perl,PowerShell,Python,R,REXX,Racket,Ruby,Rust,Scala,Swift,Tcl,VB" 81 | run & 82 | 83 | -------------------------------------------------------------------------------- /CodeT5+/run_train_RareTrans_many_to_many_only_rare_to_popular.sh: -------------------------------------------------------------------------------- 1 | DATADIR="" 2 | 3 | run() { 4 | for seed in 1234 2345 3456; do 5 | #preprocess data 6 | NAME=$(date +%Y%m%d%H)_RareTrans_many_to_many_seed${seed}_lr${lr}_maxlen1536_warmup200 7 | OUTPUT_DIR=output/saved_models/RareTrans/many_to_many_only_rare_to_popular_seed${seed}/${NAME} 8 | echo ${NAME} 9 | mkdir -p ${OUTPUT_DIR} 10 | mkdir -p ${OUTPUT_DIR}/cache_data 11 | 12 | for name in 'test' 'valid' 'train'; do 13 | python run_preprocess.py --input_file ${DATADIR}/RareTrans/rare_${name}.json \ 14 | --output_file ${OUTPUT_DIR}/cache_data/rare_${name}.json \ 15 | --source_names ${source_names} --target_names ${target_names} --sub_task 'RareTrans' 16 | done 17 | 18 | # train and predict model 19 | PORT_ID=$(expr $RANDOM + 1000) 20 | 21 | CUDA_VISIBLE_DEVICES=$GPUID python -m torch.distributed.launch --nproc_per_node 4 --master_port ${PORT_ID} \ 22 | run_translation.py \ 23 | --model_name_or_path ./pretrain/codet5p-220m \ 24 | --do_train \ 25 | --do_eval \ 26 | --do_predict \ 27 | --train_file ${OUTPUT_DIR}/cache_data/rare_train.json \ 28 | --validation_file ${OUTPUT_DIR}/cache_data/rare_valid.json \ 29 | --test_file ${OUTPUT_DIR}/cache_data/rare_test.json \ 30 | --source_prefix "" \ 31 | --output_dir ${OUTPUT_DIR} \ 32 | --text_column source \ 33 | --summary_column target \ 34 | --max_source_length 1536 \ 35 | --max_target_length 1536 \ 36 | --per_device_train_batch_size=1 \ 37 | --gradient_accumulation_steps=4 \ 38 | --per_device_eval_batch_size=8 \ 39 | --learning_rate $lr \ 40 | --num_train_epochs 5 \ 41 | --metric_for_best_model loss \ 42 | --save_total_limit 2 \ 43 | --save_strategy epoch \ 44 | --load_best_model_at_end \ 45 | --evaluation_strategy epoch \ 46 | --overwrite_output_dir \ 47 | --predict_with_generate \ 48 | --logging_strategy epoch \ 49 | --logging_dir ${OUTPUT_DIR} \ 50 | --num_beams 1 \ 51 | --seed $seed \ 52 | --fp16 \ 53 | --warmup_steps 200 \ 54 | --report_to tensorboard 2>&1 | tee ${OUTPUT_DIR}/log_train.txt 55 | 56 | #calculate score 57 | echo ${source_names}_to_${target_names} 58 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 59 | --source_names ${source_names} \ 60 | --target_names ${target_names} \ 61 | --codebleu \ 62 | 2>&1 | tee ${OUTPUT_DIR}/score_many_to_many.log 63 | 64 | for target_name in "C" "C++" "C#" "Java" "Go" "PHP" "Python" "VB"; do 65 | if [ "$source_names" != "$target_name" ]; then 66 | echo ${source_names}_to_${target_name} 67 | python run_score.py --input_file ${OUTPUT_DIR}/generated_predictions.json \ 68 | --source_names ${source_names} \ 69 | --target_names ${target_name} \ 70 | --codebleu \ 71 | 2>&1 | tee ${OUTPUT_DIR}/score_many_to_${target_name}.log 72 | fi 73 | done 74 | done 75 | 76 | } 77 | 78 | GPUID='0,1,2,3' 79 | lr=2e-5 80 | source_names="AWK,Ada,Arturo,AutoHotKey,BBC_Basic,C,C#,C++,COBOL,Clojure,Common_Lisp,D,Delphi,Elixir,Erlang,F#,Factor,Forth,Fortran,Go,Groovy,Haskell,Icon,J,Java,Julia,Kotlin,Lua,MATLAB,Mathematica,Nim,OCaml,PHP,Pascal,Perl,PowerShell,Python,R,REXX,Racket,Ruby,Rust,Scala,Swift,Tcl,VB" 81 | target_names="C,C++,C#,Java,Go,PHP,Python,VB" 82 | run & 83 | 84 | -------------------------------------------------------------------------------- /CodeT5+/run_translation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | from dataclasses import dataclass, field 25 | from typing import Optional 26 | 27 | import datasets 28 | # import nltk # Here to have a nice missing dependency error message early on 29 | import numpy as np 30 | from datasets import load_dataset 31 | 32 | import evaluate 33 | import transformers 34 | from filelock import FileLock 35 | from transformers import ( 36 | AutoConfig, 37 | AutoModelForSeq2SeqLM, 38 | AutoTokenizer, 39 | DataCollatorForSeq2Seq, 40 | HfArgumentParser, 41 | MBart50Tokenizer, 42 | MBart50TokenizerFast, 43 | MBartTokenizer, 44 | MBartTokenizerFast, 45 | Seq2SeqTrainer, 46 | Seq2SeqTrainingArguments, 47 | set_seed, 48 | ) 49 | from transformers.trainer_utils import get_last_checkpoint 50 | from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry 51 | from transformers.utils.versions import require_version 52 | 53 | 54 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 55 | check_min_version("4.25.0") 56 | 57 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") 58 | 59 | logger = logging.getLogger(__name__) 60 | 61 | # A list of all multilingual tokenizer which require lang attribute. 62 | MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] 63 | 64 | 65 | @dataclass 66 | class ModelArguments: 67 | """ 68 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 69 | """ 70 | 71 | model_name_or_path: str = field( 72 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 73 | ) 74 | config_name: Optional[str] = field( 75 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 76 | ) 77 | tokenizer_name: Optional[str] = field( 78 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 79 | ) 80 | cache_dir: Optional[str] = field( 81 | default=None, 82 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 83 | ) 84 | use_fast_tokenizer: bool = field( 85 | default=True, 86 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 87 | ) 88 | model_revision: str = field( 89 | default="main", 90 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 91 | ) 92 | use_auth_token: bool = field( 93 | default=False, 94 | metadata={ 95 | "help": ( 96 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 97 | "with private models)." 98 | ) 99 | }, 100 | ) 101 | resize_position_embeddings: Optional[bool] = field( 102 | default=None, 103 | metadata={ 104 | "help": ( 105 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 106 | "the model's position embeddings." 107 | ) 108 | }, 109 | ) 110 | 111 | 112 | @dataclass 113 | class DataTrainingArguments: 114 | """ 115 | Arguments pertaining to what data we are going to input our model for training and eval. 116 | """ 117 | 118 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 119 | 120 | dataset_name: Optional[str] = field( 121 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 122 | ) 123 | dataset_config_name: Optional[str] = field( 124 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 125 | ) 126 | text_column: Optional[str] = field( 127 | default=None, 128 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 129 | ) 130 | summary_column: Optional[str] = field( 131 | default=None, 132 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 133 | ) 134 | train_file: Optional[str] = field( 135 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 136 | ) 137 | validation_file: Optional[str] = field( 138 | default=None, 139 | metadata={ 140 | "help": ( 141 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 142 | ) 143 | }, 144 | ) 145 | test_file: Optional[str] = field( 146 | default=None, 147 | metadata={ 148 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 149 | }, 150 | ) 151 | overwrite_cache: bool = field( 152 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 153 | ) 154 | preprocessing_num_workers: Optional[int] = field( 155 | default=None, 156 | metadata={"help": "The number of processes to use for the preprocessing."}, 157 | ) 158 | max_source_length: Optional[int] = field( 159 | default=1024, 160 | metadata={ 161 | "help": ( 162 | "The maximum total input sequence length after tokenization. Sequences longer " 163 | "than this will be truncated, sequences shorter will be padded." 164 | ) 165 | }, 166 | ) 167 | max_target_length: Optional[int] = field( 168 | default=128, 169 | metadata={ 170 | "help": ( 171 | "The maximum total sequence length for target text after tokenization. Sequences longer " 172 | "than this will be truncated, sequences shorter will be padded." 173 | ) 174 | }, 175 | ) 176 | val_max_target_length: Optional[int] = field( 177 | default=None, 178 | metadata={ 179 | "help": ( 180 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 181 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 182 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 183 | "during ``evaluate`` and ``predict``." 184 | ) 185 | }, 186 | ) 187 | pad_to_max_length: bool = field( 188 | default=False, 189 | metadata={ 190 | "help": ( 191 | "Whether to pad all samples to model maximum sentence length. " 192 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 193 | "efficient on GPU but very bad for TPU." 194 | ) 195 | }, 196 | ) 197 | max_train_samples: Optional[int] = field( 198 | default=None, 199 | metadata={ 200 | "help": ( 201 | "For debugging purposes or quicker training, truncate the number of training examples to this " 202 | "value if set." 203 | ) 204 | }, 205 | ) 206 | max_eval_samples: Optional[int] = field( 207 | default=None, 208 | metadata={ 209 | "help": ( 210 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 211 | "value if set." 212 | ) 213 | }, 214 | ) 215 | max_predict_samples: Optional[int] = field( 216 | default=None, 217 | metadata={ 218 | "help": ( 219 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 220 | "value if set." 221 | ) 222 | }, 223 | ) 224 | num_beams: Optional[int] = field( 225 | default=None, 226 | metadata={ 227 | "help": ( 228 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 229 | "which is used during ``evaluate`` and ``predict``." 230 | ) 231 | }, 232 | ) 233 | ignore_pad_token_for_loss: bool = field( 234 | default=True, 235 | metadata={ 236 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 237 | }, 238 | ) 239 | source_prefix: Optional[str] = field( 240 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 241 | ) 242 | 243 | forced_bos_token: Optional[str] = field( 244 | default=None, 245 | metadata={ 246 | "help": ( 247 | "The token to force as the first generated token after the decoder_start_token_id." 248 | "Useful for multilingual models like mBART where the first generated token" 249 | "needs to be the target language token (Usually it is the target language token)" 250 | ) 251 | }, 252 | ) 253 | 254 | def __post_init__(self): 255 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 256 | raise ValueError("Need either a dataset name or a training/validation file.") 257 | else: 258 | if self.train_file is not None: 259 | extension = self.train_file.split(".")[-1] 260 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 261 | if self.validation_file is not None: 262 | extension = self.validation_file.split(".")[-1] 263 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 264 | if self.val_max_target_length is None: 265 | self.val_max_target_length = self.max_target_length 266 | 267 | 268 | summarization_name_mapping = { 269 | "amazon_reviews_multi": ("review_body", "review_title"), 270 | "big_patent": ("description", "abstract"), 271 | "cnn_dailymail": ("article", "highlights"), 272 | "orange_sum": ("text", "summary"), 273 | "pn_summary": ("article", "summary"), 274 | "psc": ("extract_text", "summary_text"), 275 | "samsum": ("dialogue", "summary"), 276 | "thaisum": ("body", "summary"), 277 | "xglue": ("news_body", "news_title"), 278 | "xsum": ("document", "summary"), 279 | "wiki_summary": ("article", "highlights"), 280 | "multi_news": ("document", "summary"), 281 | } 282 | 283 | 284 | def main(): 285 | # See all possible arguments in src/transformers/training_args.py 286 | # or by passing the --help flag to this script. 287 | # We now keep distinct sets of args, for a cleaner separation of concerns. 288 | 289 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 290 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 291 | # If we pass only one argument to the script and it's the path to a json file, 292 | # let's parse it to get our arguments. 293 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 294 | else: 295 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 296 | 297 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 298 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 299 | send_example_telemetry("run_summarization", model_args, data_args) 300 | 301 | # Setup logging 302 | logging.basicConfig( 303 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 304 | datefmt="%m/%d/%Y %H:%M:%S", 305 | handlers=[logging.StreamHandler(sys.stdout)], 306 | ) 307 | log_level = training_args.get_process_log_level() 308 | logger.setLevel(log_level) 309 | datasets.utils.logging.set_verbosity(log_level) 310 | transformers.utils.logging.set_verbosity(log_level) 311 | transformers.utils.logging.enable_default_handler() 312 | transformers.utils.logging.enable_explicit_format() 313 | 314 | # Log on each process the small summary: 315 | logger.warning( 316 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 317 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 318 | ) 319 | logger.info(f"Training/evaluation parameters {training_args}") 320 | 321 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 322 | "t5-small", 323 | "t5-base", 324 | "t5-large", 325 | "t5-3b", 326 | "t5-11b", 327 | ]: 328 | logger.warning( 329 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 330 | "`--source_prefix 'summarize: ' `" 331 | ) 332 | 333 | # Detecting last checkpoint. 334 | last_checkpoint = None 335 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 336 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 337 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 338 | raise ValueError( 339 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 340 | "Use --overwrite_output_dir to overcome." 341 | ) 342 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 343 | logger.info( 344 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 345 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 346 | ) 347 | 348 | # Set seed before initializing model. 349 | set_seed(training_args.seed) 350 | 351 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 352 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 353 | # (the dataset will be downloaded automatically from the datasets Hub). 354 | # 355 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 356 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 357 | # 358 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 359 | # download the dataset. 360 | if data_args.dataset_name is not None: 361 | # Downloading and loading a dataset from the hub. 362 | raw_datasets = load_dataset( 363 | data_args.dataset_name, 364 | data_args.dataset_config_name, 365 | cache_dir=model_args.cache_dir, 366 | use_auth_token=True if model_args.use_auth_token else None, 367 | ) 368 | else: 369 | data_files = {} 370 | if data_args.train_file is not None: 371 | data_files["train"] = data_args.train_file 372 | extension = data_args.train_file.split(".")[-1] 373 | if data_args.validation_file is not None: 374 | data_files["validation"] = data_args.validation_file 375 | extension = data_args.validation_file.split(".")[-1] 376 | if data_args.test_file is not None: 377 | data_files["test"] = data_args.test_file 378 | extension = data_args.test_file.split(".")[-1] 379 | raw_datasets = load_dataset( 380 | extension, 381 | data_files=data_files, 382 | cache_dir=model_args.cache_dir, 383 | use_auth_token=True if model_args.use_auth_token else None, 384 | ) 385 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 386 | # https://huggingface.co/docs/datasets/loading_datasets.html. 387 | 388 | # Load pretrained model and tokenizer 389 | # 390 | # Distributed training: 391 | # The .from_pretrained methods guarantee that only one local process can concurrently 392 | # download model & vocab. 393 | config = AutoConfig.from_pretrained( 394 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 395 | cache_dir=model_args.cache_dir, 396 | revision=model_args.model_revision, 397 | use_auth_token=True if model_args.use_auth_token else None, 398 | ) 399 | tokenizer = AutoTokenizer.from_pretrained( 400 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 401 | cache_dir=model_args.cache_dir, 402 | use_fast=model_args.use_fast_tokenizer, 403 | revision=model_args.model_revision, 404 | use_auth_token=True if model_args.use_auth_token else None, 405 | ) 406 | model = AutoModelForSeq2SeqLM.from_pretrained( 407 | model_args.model_name_or_path, 408 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 409 | config=config, 410 | cache_dir=model_args.cache_dir, 411 | revision=model_args.model_revision, 412 | use_auth_token=True if model_args.use_auth_token else None, 413 | ) 414 | 415 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 416 | # on a small vocab and want a smaller embedding size, remove this test. 417 | embedding_size = model.get_input_embeddings().weight.shape[0] 418 | if len(tokenizer) > embedding_size: 419 | model.resize_token_embeddings(len(tokenizer)) 420 | 421 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): 422 | if isinstance(tokenizer, MBartTokenizer): 423 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang] 424 | else: 425 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang) 426 | 427 | if model.config.decoder_start_token_id is None: 428 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 429 | 430 | if ( 431 | hasattr(model.config, "max_position_embeddings") 432 | and model.config.max_position_embeddings < data_args.max_source_length 433 | ): 434 | if model_args.resize_position_embeddings is None: 435 | logger.warning( 436 | "Increasing the model's number of position embedding vectors from" 437 | f" {model.config.max_position_embeddings} to {data_args.max_source_length}." 438 | ) 439 | model.resize_position_embeddings(data_args.max_source_length) 440 | elif model_args.resize_position_embeddings: 441 | model.resize_position_embeddings(data_args.max_source_length) 442 | else: 443 | raise ValueError( 444 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has" 445 | f" {model.config.max_position_embeddings} position encodings. Consider either reducing" 446 | f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the" 447 | " model's position encodings by passing `--resize_position_embeddings`." 448 | ) 449 | 450 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 451 | 452 | # Preprocessing the datasets. 453 | # We need to tokenize inputs and targets. 454 | if training_args.do_train: 455 | column_names = raw_datasets["train"].column_names 456 | elif training_args.do_eval: 457 | column_names = raw_datasets["validation"].column_names 458 | elif training_args.do_predict: 459 | column_names = raw_datasets["test"].column_names 460 | else: 461 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 462 | return 463 | 464 | if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): 465 | assert ( 466 | data_args.lang is not None 467 | ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" 468 | 469 | tokenizer.src_lang = data_args.lang 470 | tokenizer.tgt_lang = data_args.lang 471 | 472 | # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token 473 | # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. 474 | forced_bos_token_id = ( 475 | tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None 476 | ) 477 | model.config.forced_bos_token_id = forced_bos_token_id 478 | 479 | # Get the column names for input/target. 480 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 481 | if data_args.text_column is None: 482 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 483 | else: 484 | text_column = data_args.text_column 485 | if text_column not in column_names: 486 | raise ValueError( 487 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 488 | ) 489 | if data_args.summary_column is None: 490 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 491 | else: 492 | summary_column = data_args.summary_column 493 | if summary_column not in column_names: 494 | raise ValueError( 495 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 496 | ) 497 | 498 | # Temporarily set max_target_length for training. 499 | max_target_length = data_args.max_target_length 500 | padding = "max_length" if data_args.pad_to_max_length else False 501 | 502 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 503 | logger.warning( 504 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 505 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 506 | ) 507 | 508 | def preprocess_function(examples): 509 | # remove pairs where at least one record is None 510 | 511 | inputs, targets = [], [] 512 | for i in range(len(examples[text_column])): 513 | if examples[text_column][i] and examples[summary_column][i]: 514 | inputs.append(examples[text_column][i]) 515 | targets.append(examples[summary_column][i]) 516 | 517 | inputs = [prefix + inp for inp in inputs] 518 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 519 | 520 | # Tokenize targets with the `text_target` keyword argument 521 | labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True) 522 | 523 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 524 | # padding in the loss. 525 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 526 | labels["input_ids"] = [ 527 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 528 | ] 529 | 530 | model_inputs["labels"] = labels["input_ids"] 531 | return model_inputs 532 | 533 | if training_args.do_train: 534 | if "train" not in raw_datasets: 535 | raise ValueError("--do_train requires a train dataset") 536 | train_dataset = raw_datasets["train"] 537 | if data_args.max_train_samples is not None: 538 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 539 | train_dataset = train_dataset.select(range(max_train_samples)) 540 | with training_args.main_process_first(desc="train dataset map pre-processing"): 541 | train_dataset = train_dataset.map( 542 | preprocess_function, 543 | batched=True, 544 | num_proc=data_args.preprocessing_num_workers, 545 | remove_columns=column_names, 546 | load_from_cache_file=not data_args.overwrite_cache, 547 | desc="Running tokenizer on train dataset", 548 | ) 549 | 550 | if training_args.do_eval: 551 | max_target_length = data_args.val_max_target_length 552 | if "validation" not in raw_datasets: 553 | raise ValueError("--do_eval requires a validation dataset") 554 | eval_dataset = raw_datasets["validation"] 555 | if data_args.max_eval_samples is not None: 556 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 557 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 558 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 559 | eval_dataset = eval_dataset.map( 560 | preprocess_function, 561 | batched=True, 562 | num_proc=data_args.preprocessing_num_workers, 563 | remove_columns=column_names, 564 | load_from_cache_file=not data_args.overwrite_cache, 565 | desc="Running tokenizer on validation dataset", 566 | ) 567 | 568 | if training_args.do_predict: 569 | max_target_length = data_args.val_max_target_length 570 | if "test" not in raw_datasets: 571 | raise ValueError("--do_predict requires a test dataset") 572 | predict_dataset = raw_datasets["test"] 573 | if data_args.max_predict_samples is not None: 574 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 575 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 576 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 577 | predict_dataset = predict_dataset.map( 578 | preprocess_function, 579 | batched=True, 580 | num_proc=data_args.preprocessing_num_workers, 581 | remove_columns=column_names, 582 | load_from_cache_file=not data_args.overwrite_cache, 583 | desc="Running tokenizer on prediction dataset", 584 | ) 585 | 586 | # Data collator 587 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 588 | data_collator = DataCollatorForSeq2Seq( 589 | tokenizer, 590 | model=model, 591 | label_pad_token_id=label_pad_token_id, 592 | pad_to_multiple_of=8 if training_args.fp16 else None, 593 | ) 594 | 595 | # Initialize our Trainer 596 | trainer = Seq2SeqTrainer( 597 | model=model, 598 | args=training_args, 599 | train_dataset=train_dataset if training_args.do_train else None, 600 | eval_dataset=eval_dataset if training_args.do_eval else None, 601 | tokenizer=tokenizer, 602 | data_collator=data_collator, 603 | ) 604 | 605 | # Training 606 | if training_args.do_train: 607 | checkpoint = None 608 | if training_args.resume_from_checkpoint is not None: 609 | checkpoint = training_args.resume_from_checkpoint 610 | elif last_checkpoint is not None: 611 | checkpoint = last_checkpoint 612 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 613 | trainer.save_model() # Saves the tokenizer too for easy upload 614 | 615 | metrics = train_result.metrics 616 | max_train_samples = ( 617 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 618 | ) 619 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 620 | 621 | trainer.log_metrics("train", metrics) 622 | trainer.save_metrics("train", metrics) 623 | trainer.save_state() 624 | 625 | # Evaluation 626 | results = {} 627 | max_length = ( 628 | training_args.generation_max_length 629 | if training_args.generation_max_length is not None 630 | else data_args.val_max_target_length 631 | ) 632 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 633 | if training_args.do_eval: 634 | logger.info("*** Evaluate ***") 635 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 636 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 637 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 638 | 639 | trainer.log_metrics("eval", metrics) 640 | trainer.save_metrics("eval", metrics) 641 | 642 | if training_args.do_predict: 643 | logger.info("*** Predict ***") 644 | 645 | predict_results = trainer.predict( 646 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams 647 | ) 648 | metrics = predict_results.metrics 649 | max_predict_samples = ( 650 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 651 | ) 652 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 653 | 654 | trainer.log_metrics("predict", metrics) 655 | trainer.save_metrics("predict", metrics) 656 | 657 | if trainer.is_world_process_zero(): 658 | if training_args.predict_with_generate: 659 | predictions = tokenizer.batch_decode( 660 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 661 | ) 662 | predictions = [pred.strip() for pred in predictions] 663 | print('predictions', len(predictions)) 664 | import json 665 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.json") 666 | with open(output_prediction_file, "w") as writer: 667 | # print(raw_datasets["test"]) 668 | for idx, data in enumerate(raw_datasets["test"]): 669 | if idx >= max_predict_samples: 670 | break 671 | # print(data) 672 | json_data = data.copy() 673 | json_data['prediction'] = predictions[idx] 674 | string_data = json.dumps(json_data) 675 | writer.write(string_data + '\n') 676 | 677 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 678 | if data_args.dataset_name is not None: 679 | kwargs["dataset_tags"] = data_args.dataset_name 680 | if data_args.dataset_config_name is not None: 681 | kwargs["dataset_args"] = data_args.dataset_config_name 682 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 683 | else: 684 | kwargs["dataset"] = data_args.dataset_name 685 | 686 | if data_args.lang is not None: 687 | kwargs["language"] = data_args.lang 688 | 689 | if training_args.push_to_hub: 690 | trainer.push_to_hub(**kwargs) 691 | else: 692 | trainer.create_model_card(**kwargs) 693 | 694 | return results 695 | 696 | 697 | def _mp_fn(index): 698 | # For xla_spawn (TPUs) 699 | main() 700 | 701 | 702 | if __name__ == "__main__": 703 | main() 704 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CodeTransOcean: A Comprehensive Multilingual Benchmark for Code Translation](https://arxiv.org/abs/2310.04951) 2 | 3 | 9 | 10 |
11 | 12 | **CodeTransOcean**, a large-scale comprehensive benchmark that supports the largest variety of programming languages for code translation. CodeTransOcean consists of three novel multilingual datasets, namely, **MultilingualTrans** supporting translations between multiple popular programming languages, **NicheTrans** for translating between niche programming languages and popular ones, and **LLMTrans** for evaluating executability of translated code by large language models (LLMs). CodeTransOcean also includes a novel cross-framework dataset, **DLTrans**, for translating deep learning code across different frameworks. 13 | 14 | 15 |
16 | 17 |
18 | 19 | 20 | ## Datasets 21 | 🤗[Hugging Face](https://huggingface.co/datasets/WeixiangYan/CodeTransOcean) or [Google Drive](https://drive.google.com/file/d/1xw6Edqf_nknKoei_LC49n4EtvNQezKGe/view?usp=sharing) 22 | 23 | 24 | ## Code 25 | The MultilingualTrans, NicheTrans, and DLTrans datasets were experimented with on CodeT5+, and the code is in the [CodeT5+](https://github.com/WeixiangYAN/CodeTransOcean/tree/main/CodeT5%2B) file. 26 | 27 | The LLMTrans dataset was experimented with on GPT-3.5, and the code is in the [ChatGPT](https://github.com/WeixiangYAN/CodeTransOcean/tree/main/ChatGPT) file. 28 | 29 | 30 | 31 | ## Citation 32 | Please cite the paper if you use the data or code from CodeTransOcean. 33 | ``` 34 | @article{yan2023codetransocean, 35 | title={CodeTransOcean: A Comprehensive Multilingual Benchmark for Code Translation}, 36 | author={Yan, Weixiang and Tian, Yuchen and Li, Yunzhe and Chen, Qian and Wang, Wen}, 37 | journal={arXiv preprint arXiv:2310.04951}, 38 | year={2023} 39 | } 40 | ``` 41 | 42 | ## Contact 43 | For questions, please feel free to reach out via email at ``yanweixiang.ywx@gmail.com``. 44 | -------------------------------------------------------------------------------- /images/Google_Drive_Logo_16px.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeixiangYAN/CodeTransOcean/42e2cd3b41b3a18a6dba3dfdf425f772360304ca/images/Google_Drive_Logo_16px.png -------------------------------------------------------------------------------- /images/codetransocean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeixiangYAN/CodeTransOcean/42e2cd3b41b3a18a6dba3dfdf425f772360304ca/images/codetransocean.png -------------------------------------------------------------------------------- /images/leaderboard6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeixiangYAN/CodeTransOcean/42e2cd3b41b3a18a6dba3dfdf425f772360304ca/images/leaderboard6.png --------------------------------------------------------------------------------