├── parser ├── build.sh ├── __init__.py ├── build.py ├── utils.py └── DFG.py ├── CODE_OF_CONDUCT.md ├── LICENSE ├── utils ├── split_codes.py ├── get_res.py ├── search_bm25.py └── search_dense.py ├── SECURITY.md ├── model.py ├── generate ├── beam.py ├── dataset.py └── run_lm.py ├── README.md ├── infer.py ├── process_java.py ├── codenet_test.py └── process_python.py /parser/build.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/tree-sitter/tree-sitter-python 2 | git clone https://github.com/tree-sitter/tree-sitter-java 3 | python build.py 4 | -------------------------------------------------------------------------------- /parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import (remove_comments_and_docstrings, 2 | tree_to_token_index, 3 | index_to_code_token, 4 | tree_to_variable_index, 5 | traverse 6 | ) 7 | from .DFG import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp -------------------------------------------------------------------------------- /parser/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from tree_sitter import Language, Parser 5 | 6 | Language.build_library( 7 | # Store the library in the `build` directory 8 | 'my-languages.so', 9 | 10 | # Include one or more languages 11 | [ 12 | 'tree-sitter-python', 13 | 'tree-sitter-java' 14 | ] 15 | ) 16 | 17 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shuai Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/split_codes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # This script is used for splitting long codes into small chunks with the same length (300 by default) in search base. 5 | 6 | import os 7 | import pickle 8 | import argparse 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser(description='') 12 | parser.add_argument('--file_name', '-i', required=True, help="filename of the codes, should be in txt format") 13 | parser.add_argument('--length', '-l', type=int, default=300, help="length of the chunk") 14 | args = parser.parse_args() 15 | 16 | lines = open(args.file_name, "r").readlines() 17 | wf = open(args.file_name.split("/")[-1].split(".")[0]+"_split.txt", "w") 18 | nexts = [] 19 | cnt = 0 20 | for line in lines: 21 | tokens = line.strip().split() 22 | if len(tokens) <= args.length: 23 | wf.write(" ".join(tokens)+"\n") 24 | nexts.append(cnt) 25 | cnt += 1 26 | else: 27 | for i in range(0, len(tokens), args.length): 28 | wf.write(" ".join(tokens[i:i+args.length])+"\n") 29 | nexts.append(cnt+1) 30 | cnt += 1 31 | nexts[-1] -= 1 32 | wf.close() 33 | pickle.dump(nexts, open(args.file_name.split("/")[-1].split(".")[0]+"_split_nexts.pkl", "wb")) 34 | 35 | if __name__ == "__main__": 36 | main() 37 | 38 | 39 | -------------------------------------------------------------------------------- /utils/get_res.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import os 4 | import pickle 5 | import argparse 6 | import json 7 | import numpy as np 8 | from tqdm import tqdm 9 | import random 10 | 11 | def hybrid_scores(bm25_scores, dense_scores, alpha, beilv=100): 12 | # beilv: re-scaling dense score as it is percentage. 13 | scores = {} 14 | for idx, v in tqdm(dense_scores.items()): 15 | new_v = {} 16 | if idx not in bm25_scores: 17 | scores[idx] = v 18 | continue 19 | v2 = bm25_scores[idx] 20 | v_min = min(list(v.values())) 21 | v2_min = min(list(v2.values())) 22 | for _id, score in v.items(): 23 | if _id not in v2: 24 | new_v[_id] = beilv * score + alpha * v2_min 25 | else: 26 | new_v[_id] = beilv * score + alpha * v2[_id] 27 | for _id, score in v2.items(): 28 | if _id not in new_v: 29 | new_v[_id] = alpha * score + beilv * v_min 30 | scores[idx] = new_v 31 | return scores 32 | 33 | def get_res(bm25_file, dense_file, save_file, alpha): 34 | if bm25_file != "": 35 | bm25_scores = pickle.load(open(bm25_file, "rb")) 36 | print("bm25 scores loaded") 37 | else: 38 | bm25_scores = {} 39 | if dense_file != "": 40 | dense_scores = pickle.load(open(dense_file, "rb")) 41 | print("dense scores loaded") 42 | else: 43 | dense_scores = {} 44 | 45 | res = {} 46 | if len(bm25_scores) > 0 and len(dense_scores) > 0: 47 | scores = hybrid_scores(bm25_scores, dense_scores, alpha, 100) 48 | elif len(bm25_scores) > 0: 49 | scores = bm25_scores 50 | else: 51 | scores = dense_scores 52 | for idx, v in tqdm(scores.items()): 53 | v = sorted(v.items(), key=lambda x:-x[1]) 54 | # res[int(idx)] = int(v[0][0]) if v[0][0] != idx else int(v[1][0]) 55 | res[int(idx)] = int(v[0][0]) 56 | 57 | pickle.dump(res, open(save_file, "wb")) 58 | 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser(description='') 62 | parser.add_argument('--save_name', '-o', required=True, help="save file name") 63 | parser.add_argument('--bm25_res', '-b', default="", help="bm25 result file") 64 | parser.add_argument('--dense_res', '-d', default="", help="dense result file") 65 | parser.add_argument("--alpha", type=float, default=1.1, help="ratio of dense score") 66 | args = parser.parse_args() 67 | 68 | get_res(args.bm25_res, args.dense_res, args.save_name, args.alpha) 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | 74 | 75 | -------------------------------------------------------------------------------- /utils/search_bm25.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import os 4 | import argparse 5 | import json 6 | from tqdm import tqdm 7 | import pickle 8 | import random 9 | import csv 10 | import time 11 | 12 | from beir.datasets.data_loader import GenericDataLoader 13 | from beir.retrieval.evaluation import EvaluateRetrieval 14 | from beir.retrieval.search.lexical import BM25Search as BM25 15 | 16 | def build(corpus_file, query_file, temp_dir): 17 | # corpus and query file have to be plain text files 18 | print("Building bm25 corpus") 19 | corpus = [] 20 | queries = [] 21 | ids = [] 22 | datas = open(corpus_file).readlines() 23 | lines = open(query_file).readlines() 24 | try: 25 | os.mkdir(temp_dir) 26 | except FileExistsError: 27 | pass 28 | fidx = open(os.path.join(temp_dir, "corpus.jsonl"), "w") 29 | fq = open(os.path.join(temp_dir, "query.jsonl"), "w") 30 | fr = open(os.path.join(temp_dir, "res.tsv"), "w") 31 | csv_fr = csv.writer(fr, delimiter='\t') 32 | fr.write("q\td\t\s\n") 33 | for i,line in enumerate(tqdm(datas)): 34 | fidx.write(json.dumps({"_id":str(i), "text":line.strip()})+"\n") 35 | for i,line in enumerate(tqdm(lines)): 36 | content = json.loads(line) 37 | idx = content["id"] 38 | csv_fr.writerow([str(idx), str(idx), 1]) 39 | code = content["input"].strip() 40 | fq.write(json.dumps({"_id":str(idx), "text":code})+"\n") 41 | 42 | 43 | def main(): 44 | parser = argparse.ArgumentParser(description='') 45 | parser.add_argument('--search_corpus', '-i', required=True, help="search corpus file, plain text file") 46 | parser.add_argument('--query_file', '-q', required=True, help="queries file, json file") 47 | parser.add_argument('--save_name', '-o', required=True, help="same file name") 48 | parser.add_argument('--temp_path', '-t', default="beir", help="temp dir to save beir-format data") 49 | args = parser.parse_args() 50 | 51 | build(args.search_corpus, args.query_file, args.temp_path) 52 | time.sleep(10) 53 | 54 | corpus, queries, qrels = GenericDataLoader( 55 | corpus_file=os.path.join(args.temp_path, "corpus.jsonl"), 56 | query_file=os.path.join(args.temp_path, "query.jsonl"), 57 | qrels_file=os.path.join(args.temp_path, "res.tsv") 58 | ).load_custom() 59 | 60 | model = BM25(index_name="reacc", hostname="127.0.0.1:9200", initialize=True) 61 | retriever = EvaluateRetrieval(model) 62 | 63 | results = retriever.retrieve(corpus, queries) 64 | pickle.dump(results, open(args.save_name, "wb")) 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /utils/search_dense.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import os 4 | import json 5 | import pickle 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | import faiss 10 | 11 | 12 | def search(index_file, query_file, save_name): 13 | index_data = pickle.load(open(index_file, "rb")) 14 | query_data = pickle.load(open(query_file, "rb")) 15 | ids = [] 16 | indexs = [] 17 | id2n = {} 18 | for i, (idx, vec) in enumerate(index_data.items()): 19 | ids.append(idx) 20 | indexs.append(vec) 21 | id2n[idx] = i 22 | queries = [] 23 | idxq = [] 24 | for idx, vec in query_data.items(): 25 | queries.append(vec) 26 | idxq.append(idx) 27 | ids = np.array(ids) 28 | indexs = np.array(indexs) 29 | queries = np.array(queries) 30 | 31 | # build faiss index 32 | d = 768 33 | k = 101 34 | index = faiss.IndexFlatIP(d) 35 | assert index.is_trained 36 | 37 | index_id = faiss.IndexIDMap(index) 38 | index_id.add_with_ids(indexs, ids) 39 | 40 | res = {} 41 | D, I = index_id.search(queries, k) 42 | for i, (sd, si) in enumerate(zip(D, I)): 43 | res[str(idxq[i])] = {} 44 | for pd, pi in zip(sd, si): 45 | res[str(idxq[i])][str(pi)] = pd 46 | # if pi != idxq[i]: 47 | # res[str(idxq[i])][str(pi)] = pd 48 | 49 | pickle.dump(res, open(save_name, "wb")) 50 | 51 | def search_multi(index_file, query_file, save_name): 52 | index_data = pickle.load(open(index_file, "rb")) 53 | query_data = pickle.load(open(query_file, "rb")) 54 | ids = [] 55 | indexs = [] 56 | for i, (idx, vecs) in enumerate(index_data.items()): 57 | for vec in vecs: 58 | ids.append(idx) 59 | indexs.append(vec) 60 | queries = [] 61 | idxq = [] 62 | for idx, vec in query_data.items(): 63 | queries.append(vec[0]) 64 | idxq.append(idx) 65 | ids = np.array(ids) 66 | indexs = np.array(indexs) 67 | queries = np.array(queries) 68 | print(indexs.shape, queries.shape) 69 | 70 | # build faiss index 71 | d = 768 72 | k = 100 73 | index = faiss.IndexFlatIP(d) 74 | assert index.is_trained 75 | 76 | index_id = faiss.IndexIDMap(index) 77 | index_id.add_with_ids(indexs, ids) 78 | 79 | res = faiss.StandardGpuResources() 80 | gpu_index = faiss.index_cpu_to_gpu(res, 0, index_id) 81 | 82 | res = {} 83 | D, I = gpu_index.search(queries, k) 84 | for i, (sd, si) in enumerate(zip(D, I)): 85 | res[str(idxq[i])] = {} 86 | for pd, pi in zip(sd, si): 87 | if str(pi) not in res[str(idxq[i])]: 88 | res[str(idxq[i])][str(pi)] = pd 89 | if len(res[str(idxq[i])]) > 100: 90 | break 91 | 92 | pickle.dump(res, open(save_name, "wb")) 93 | 94 | def main(): 95 | parser = argparse.ArgumentParser(description='') 96 | parser.add_argument('--index_file', '-i', required=True, help="filename of index embeddings saved") 97 | parser.add_argument('--query_file', '-q', required=True, help="file containing query embeddings") 98 | parser.add_argument('--save_name', '-o', required=True, help="save file name") 99 | parser.add_argument("--multi", action='store_true', help="set true if one query/doc has multi embeddings") 100 | args = parser.parse_args() 101 | 102 | if args.multi: 103 | search_multi(args.index_file, args.query_file, args.save_name) 104 | else: 105 | search(args.index_file, args.query_file, args.save_name) 106 | 107 | if __name__ == "__main__": 108 | main() 109 | 110 | -------------------------------------------------------------------------------- /parser/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import re 5 | from io import StringIO 6 | import tokenize 7 | 8 | def traverse(node, indent=0): 9 | print(" "*indent + node.type, node.start_point, "-", node.end_point) 10 | for child in node.children: 11 | traverse(child, indent+1) 12 | 13 | def remove_comments_and_docstrings(source,lang): 14 | if lang in ['python']: 15 | """ 16 | Returns 'source' minus comments and docstrings. 17 | """ 18 | io_obj = StringIO(source) 19 | out = "" 20 | prev_toktype = tokenize.INDENT 21 | last_lineno = -1 22 | last_col = 0 23 | for tok in tokenize.generate_tokens(io_obj.readline): 24 | token_type = tok[0] 25 | token_string = tok[1] 26 | start_line, start_col = tok[2] 27 | end_line, end_col = tok[3] 28 | ltext = tok[4] 29 | if start_line > last_lineno: 30 | last_col = 0 31 | if start_col > last_col: 32 | out += (" " * (start_col - last_col)) 33 | # Remove comments: 34 | if token_type == tokenize.COMMENT: 35 | pass 36 | # This series of conditionals removes docstrings: 37 | elif token_type == tokenize.STRING: 38 | if prev_toktype != tokenize.INDENT: 39 | # This is likely a docstring; double-check we're not inside an operator: 40 | if prev_toktype != tokenize.NEWLINE: 41 | if start_col > 0: 42 | out += token_string 43 | else: 44 | out += token_string 45 | prev_toktype = token_type 46 | last_col = end_col 47 | last_lineno = end_line 48 | temp=[] 49 | for x in out.split('\n'): 50 | if x.strip()!="": 51 | temp.append(x) 52 | return '\n'.join(temp) 53 | elif lang in ['ruby']: 54 | return source 55 | else: 56 | def replacer(match): 57 | s = match.group(0) 58 | if s.startswith('/'): 59 | return " " # note: a space and not an empty string 60 | else: 61 | return s 62 | pattern = re.compile( 63 | r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', 64 | re.DOTALL | re.MULTILINE 65 | ) 66 | temp=[] 67 | for x in re.sub(pattern, replacer, source).split('\n'): 68 | if x.strip()!="": 69 | temp.append(x) 70 | return '\n'.join(temp) 71 | 72 | def tree_to_token_index(root_node): 73 | if (len(root_node.children)==0 or root_node.type=='string'): 74 | return [(root_node.start_point,root_node.end_point)] 75 | else: 76 | code_tokens=[] 77 | for child in root_node.children: 78 | code_tokens+=tree_to_token_index(child) 79 | return code_tokens 80 | 81 | def tree_to_variable_index(root_node,index_to_code): 82 | if (len(root_node.children)==0 or root_node.type=='string'): 83 | index=(root_node.start_point,root_node.end_point) 84 | _,code=index_to_code[index] 85 | if root_node.type!=code: 86 | return [(root_node.start_point,root_node.end_point)] 87 | else: 88 | return [] 89 | else: 90 | code_tokens=[] 91 | for child in root_node.children: 92 | code_tokens+=tree_to_variable_index(child,index_to_code) 93 | return code_tokens 94 | 95 | def index_to_code_token(index,code): 96 | start_point=index[0] 97 | end_point=index[1] 98 | if start_point[0]==end_point[0]: 99 | s=code[start_point[0]][start_point[1]:end_point[1]] 100 | else: 101 | s="" 102 | s+=code[start_point[0]][start_point[1]:] 103 | for i in range(start_point[0]+1,end_point[0]): 104 | s+=code[i] 105 | s+=code[end_point[0]][:end_point[1]] 106 | return s 107 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from transformers import RobertaModel 8 | import logging 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 13 | datefmt='%m/%d/%Y %H:%M:%S', 14 | level=logging.INFO) 15 | 16 | class RobertaLMHead(nn.Module): 17 | """Roberta Head for masked language modeling.""" 18 | 19 | def __init__(self, config): 20 | super().__init__() 21 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 22 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 23 | 24 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 25 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 26 | 27 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 28 | self.decoder.bias = self.bias 29 | 30 | def forward(self, features, **kwargs): 31 | x = self.dense(features) 32 | x = F.gelu(x) 33 | x = self.layer_norm(x) 34 | 35 | # project back to size of vocabulary with bias 36 | x = self.decoder(x) 37 | 38 | return x 39 | 40 | class SimpleModel(nn.Module): 41 | def __init__(self, config, args): 42 | super().__init__() 43 | if args.from_pretrained: 44 | self.encoder = RobertaModel.from_pretrained(args.pretrained_dir, add_pooling_layer=False) 45 | logger.warning(f"Loading encoder from {args.pretrained_dir}") 46 | else: 47 | self.encoder = RobertaModel(config, add_pooling_layer=False) 48 | self.encoder.resize_token_embeddings(args.vocab_size) 49 | 50 | self.config = config 51 | self.args = args 52 | self.lm_head = RobertaLMHead(config) 53 | self.n_vec = max(0, self.args.num_vec) 54 | self.tie_weights() 55 | 56 | def _tie_or_clone_weights(self, first_module, second_module): 57 | if self.config.torchscript: 58 | first_module.weight = nn.Parameter(second_module.weight.clone()) 59 | else: 60 | first_module.weight = second_module.weight 61 | 62 | def tie_weights(self): 63 | self._tie_or_clone_weights(self.lm_head.decoder, 64 | self.encoder.embeddings.word_embeddings) 65 | 66 | def forward(self, inputs_m, inputs1=None, inputs2=None, attn_mask=None, attn_mask1=None, attn_mask2=None, mlm_labels=None): 67 | outputs = self.encoder(inputs_m, attention_mask=attn_mask)[0] 68 | 69 | if inputs1 is None and inputs2 is None: 70 | # infer 71 | if self.n_vec > 0: 72 | outputs = nn.functional.normalize(outputs[:, :self.n_vec, :], dim=2) 73 | else: 74 | outputs = nn.functional.normalize(outputs[:, 0, :], dim=1) 75 | return outputs 76 | 77 | # training 78 | outputs1 = self.encoder(inputs1, attention_mask=attn_mask1)[0][:, 0, :] 79 | outputs2 = self.encoder(inputs2, attention_mask=attn_mask2)[0] 80 | 81 | lm_logits = self.lm_head(outputs) 82 | 83 | if mlm_labels is not None: 84 | loss_fct = nn.CrossEntropyLoss() 85 | mlm_loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), mlm_labels.view(-1)) 86 | else: 87 | mlm_loss = None 88 | 89 | if self.n_vec > 0: 90 | o1 = nn.functional.normalize(outputs1, dim=1) 91 | o2 = nn.functional.normalize(outputs2[:, :self.n_vec, :], dim=2) 92 | logits, _ = torch.max(torch.einsum('nc,mvc->nmv', [o1, o2]), -1) 93 | else: 94 | o1 = nn.functional.normalize(outputs1, dim=1) 95 | o2 = nn.functional.normalize(outputs2[:, 0, :], dim=1) 96 | logits = torch.einsum('nc,mc->nm', [o1, o2]) 97 | logits /= self.args.moco_T 98 | labels = torch.arange(end=logits.shape[0], dtype=torch.long).cuda() 99 | 100 | nce_loss = F.cross_entropy(logits, labels) 101 | 102 | return lm_logits, mlm_loss, nce_loss -------------------------------------------------------------------------------- /generate/beam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch 4 | from torch.autograd import Variable 5 | import copy 6 | 7 | class Beam(object): 8 | def __init__(self, size, sos, eos): 9 | self.size = size 10 | self.sent_size = 2*size 11 | self.tt = torch.cuda 12 | # The score for each translation on the beam. 13 | self.scores = self.tt.FloatTensor(size).zero_() 14 | # The backpointers at each time-step. 15 | self.prevKs = [] 16 | # The outputs at each time-step. 17 | self.nextYs = [self.tt.LongTensor(size) 18 | .fill_(0)] 19 | self.nextYs[0][:] = sos 20 | # Has EOS topped the beam yet. 21 | self._eos = eos 22 | self.eosTop = False 23 | # Time and k pair for finished. 24 | self.finished = [] 25 | 26 | def getCurrentState(self): 27 | "Get the outputs for the current timestep." 28 | batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) 29 | return batch 30 | 31 | def getCurrentOrigin(self): 32 | "Get the backpointers for the current timestep." 33 | return self.prevKs[-1] 34 | 35 | def advance(self, wordLk): 36 | """ 37 | Given prob over words for every last beam `wordLk` and attention 38 | `attnOut`: Compute and update the beam search. 39 | 40 | Parameters: 41 | 42 | * `wordLk`- probs of advancing from the last step (K x words) 43 | * `attnOut`- attention at the last step 44 | 45 | Returns: True if beam search is complete. 46 | """ 47 | numWords = wordLk.size(1) 48 | 49 | # Sum the previous scores. 50 | if len(self.prevKs) > 0: 51 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 52 | 53 | # Don't let EOS have children. 54 | for i in range(self.nextYs[-1].size(0)): 55 | if self.nextYs[-1][i] in self._eos: 56 | beamLk[i] = -1e20 57 | else: 58 | beamLk = wordLk[0] 59 | flatBeamLk = beamLk.view(-1) 60 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 61 | 62 | self.scores = bestScores 63 | 64 | # bestScoresId is flattened beam x word array, so calculate which 65 | # word and beam each score came from 66 | prevK = bestScoresId // numWords 67 | self.prevKs.append(prevK) 68 | self.nextYs.append((bestScoresId - prevK * numWords)) 69 | 70 | 71 | for i in range(self.nextYs[-1].size(0)): 72 | if self.nextYs[-1][i] in self._eos: 73 | s = self.scores[i] 74 | self.finished.append((s, len(self.nextYs) - 1, i)) 75 | 76 | # End condition is when top-of-beam is EOS and no global score. 77 | if self.nextYs[-1][0] in self._eos: 78 | self.eosTop = True 79 | 80 | def done(self): 81 | return self.eosTop and len(self.finished) >=self.sent_size 82 | 83 | def getFinal(self): 84 | if len(self.finished) == 0: 85 | self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) 86 | self.finished.sort(key=lambda a: -a[0]) 87 | if len(self.finished) != self.sent_size: 88 | unfinished=[] 89 | for i in range(self.nextYs[-1].size(0)): 90 | if self.nextYs[-1][i] not in self._eos: 91 | s = self.scores[i] 92 | unfinished.append((s, len(self.nextYs) - 1, i)) 93 | unfinished.sort(key=lambda a: -a[0]) 94 | self.finished+=unfinished[:self.sent_size-len(self.finished)] 95 | return self.finished[:self.sent_size] 96 | 97 | def getHyp(self, beam_res): 98 | """ 99 | Walk back to construct the full hypothesis. 100 | """ 101 | hyps=[] 102 | for _,timestep, k in beam_res: 103 | hyp = [] 104 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): 105 | hyp.append(self.nextYs[j+1][k]) 106 | k = self.prevKs[j][k] 107 | hyps.append(hyp[::-1]) 108 | return hyps 109 | 110 | def buildTargetTokens(self, preds): 111 | sentence=[] 112 | for pred in preds: 113 | tokens = [] 114 | for tok in pred: 115 | tokens.append(tok) 116 | if tok in self._eos: 117 | break 118 | sentence.append(tokens) 119 | return sentence 120 | -------------------------------------------------------------------------------- /generate/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import argparse 6 | import glob 7 | import logging 8 | import os 9 | import pickle 10 | import random 11 | import re 12 | import gc 13 | import shutil 14 | import json 15 | from tqdm import tqdm 16 | 17 | import numpy as np 18 | import torch 19 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset 20 | from torch.utils.data.distributed import DistributedSampler 21 | 22 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 23 | BertConfig, BertForMaskedLM, BertTokenizer, 24 | GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, 25 | OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, 26 | RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, 27 | DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) 28 | 29 | 30 | class RelineDataset(Dataset): 31 | def __init__(self, tokenizer, args, logger, file_type='test', block_size=924, load_file=None, search_res=None): 32 | datafile = os.path.join(args.data_dir, f"{file_type}.json") 33 | with open(datafile) as f: 34 | datas = f.readlines() 35 | 36 | if load_file is not None: 37 | id2code = {} 38 | lines = open(os.path.join(args.data_dir, load_file+".txt")).readlines() 39 | for i,line in enumerate(tqdm(lines)): 40 | id2code[i] = line.strip() 41 | 42 | search_results = pickle.load(open(search_res, "rb")) 43 | try: 44 | nexts = pickle.load(open(os.path.join(args.data_dir, load_file+"_nexts.pkl"), "rb")) 45 | except Exception: 46 | nexts = [i for i in range(len(lines))] 47 | 48 | length = len(datas) 49 | logger.info("Data size: %d"%(length)) 50 | self.inputs = [] 51 | self.gts = [] 52 | for i,data in enumerate(datas): 53 | if i % 1000 == 0: 54 | logger.info(f"Encoded {i}/{length} data") 55 | data = json.loads(data.strip()) 56 | if load_file is not None: 57 | try: 58 | cand_id = search_results[data["id"]] 59 | cand = id2code[cand_id] 60 | if nexts[cand_id] != cand_id: 61 | cand += id2code[nexts[cand_id]] 62 | cand = tokenizer.encode(cand) 63 | except: 64 | cand = [] 65 | else: 66 | cand = [] 67 | self.inputs.append((cand + tokenizer.encode(data["input"]))[-block_size:]) 68 | self.gts.append(data["gt"]) 69 | 70 | def __len__(self): 71 | return len(self.inputs) 72 | 73 | def __getitem__(self, item): 74 | return torch.tensor(self.inputs[item]), self.gts[item] 75 | 76 | class PPLDataset(Dataset): 77 | def __init__(self, tokenizer, args, logger, file_type='test', block_size=1024, load_file=None, search_res=None): 78 | datafile = os.path.join(args.data_dir, f"{file_type}.txt") 79 | with open(datafile) as f: 80 | datas = f.readlines() 81 | 82 | if load_file is not None: 83 | id2code = {} 84 | lines = open(os.path.join(args.data_dir, load_file+".txt")).readlines() 85 | for i,line in enumerate(tqdm(lines)): 86 | id2code[i] = line.strip() 87 | 88 | search_results = pickle.load(open(search_res, "rb")) 89 | try: 90 | nexts = pickle.load(open(os.path.join(args.data_dir, load_file+"_nexts.pkl"), "rb")) 91 | except Exception: 92 | nexts = [i for i in range(len(lines))] 93 | 94 | length = len(datas) 95 | logger.info("Data size: %d"%(length)) 96 | self.inputs = [] 97 | self.token_labels = [] 98 | for i,data in enumerate(tqdm(datas)): 99 | if i % 1000 == 0: 100 | logger.info(f"Encoded {i}/{length} data") 101 | tokens = data.strip().split(" ") 102 | if len(tokens) < 200: 103 | cut = len(tokens)//2 104 | else: 105 | cut = 100 106 | 107 | if load_file is not None: 108 | try: 109 | if i in search_results: 110 | cand_id = search_results[i] 111 | else: 112 | cand_id = search_results[str(i)] 113 | cand = id2code[cand_id] 114 | if nexts[cand_id] != cand_id: 115 | cand += id2code[nexts[cand_id]] 116 | cand = tokenizer.encode(cand) 117 | except: 118 | # print("OK") 119 | cand = [] 120 | else: 121 | cand = [] 122 | 123 | x1 = tokenizer.encode(" ".join(tokens[:cut]))[-block_size:] 124 | self.inputs.append(x1) 125 | self.token_labels.append([2]*len(x1)) 126 | 127 | pre_id = cand + tokenizer.encode(" ".join(tokens[:cut])) 128 | x2_0 = tokenizer.encode(" ".join(tokens[cut:]))[:block_size * 3 // 4] 129 | x2 = (pre_id + x2_0)[-block_size:] 130 | self.inputs.append(x2) 131 | self.token_labels.append([1]*(len(x2)-len(x2_0)) + [2]*len(x2_0)) 132 | 133 | def __len__(self): 134 | return len(self.inputs) 135 | 136 | def __getitem__(self, item): 137 | return torch.tensor(self.inputs[item]), torch.tensor(self.token_labels[item]) 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReACC 2 | 3 | Source codes for ACL 2022 paper "[ReACC: A Retrieval-Augmented Code Completion Framework](https://arxiv.org/abs/2203.07722)" 4 | ReACC combines a source code retiever and an auto-regresstive language model for programming languages. 5 | 6 | ## Dependency 7 | 8 | - pytorch >= 1.7.0 9 | - transformers >= 4.10.0 10 | - tree_sitter 11 | - faiss-gpu 12 | - beir (for BM25) 13 | - elastic search 14 | 15 | ## Instructions 16 | 17 | Here are the instructions to apply ReACC framework on the code completion task with PY150 dataset. 18 | 19 | ### 1. Pretrain a retriever 20 | 21 | Leverage [`microsoft/reacc-py-retriever`](https://huggingface.co/microsoft/reacc-py-retriever) as a code-to-code retriever for python source codes. 22 | 23 | ### 2. Build an index for search 24 | 25 | First, you have to prepare a codebase for retrieving. It is recommended to split each file/function into small chunks. (refer to `utils/split_codes.py`). Then run the command to get representations of all the codes in search corpus. 26 | 27 | ```bash 28 | python -m torch.distributed.launch --nproc_per_node=${PER_NODE_GPU} infer.py \ 29 | --data_path=data/train_split.txt \ 30 | --save_name=save_vec \ 31 | --lang=python \ 32 | --pretrained_dir=microsoft/reacc-py-retriever \ 33 | --num_vec=8 \ 34 | --block_size=512 \ 35 | --gpu_per_node ${PER_NODE_GPU} \ 36 | --logging_steps=100 37 | ``` 38 | 39 | You can modify the `InferDataset` in `infer.py` to fit your own dataset. Our dataset is formated as a jsonl file, where each line is like 40 | ```json 41 | { 42 | "code": "def function()", 43 | "id": 0 44 | } 45 | ``` 46 | or a plain text file, in which each line is a code snippet. 47 | 48 | ### 3. Retrieve step 49 | 50 | ReACC is a two-stage framework. The first stage is to retrieve the similar codes given a query. As the test set is fixed, we retrieve all the similar codes of the queries in test set in advance. **It would be better to merge step 3 into step 4.** 51 | 52 | First, get the representations of test queries like in step 2. Then run the script `utils/search_dense.py` to sort the similarity and get the most similar codes. 53 | 54 | If you would like to use BM25 algorithm to retrieve similar codes, run the script `utils/search_bm25.py`. 55 | 56 | At last, run `utils/get_res.py` to get the most similar code based on bm25 results, or dense retrieval results, or both. 57 | 58 | ### 4. Generation step 59 | 60 | Please download PY150 dataset first and use preprocess scripts in [CodeXGLUE](https://github.com/microsoft/CodeXGLUE/tree/main/Code-Code/CodeCompletion-token). And follow CodeXGLUE to fine-tune a model on it, like CodeGPT. 61 | 62 | The second stage in ReACC is to complete codes based on the context and the retrieved codes. We simply put the retrieved code before the context and concat them as inputs. 63 | 64 | Navigate to the `gen` folder. We adapt the code completion scripts in [CodeXGLUE](https://github.com/microsoft/CodeXGLUE/tree/main/Code-Code/CodeCompletion-line). We modify the script `dataset.py` to include similar codes as input. Run `run_lm.py` to evaluate your fine-tuned model. 65 | 66 | ```bash 67 | export CUDA_VISIBLE_DEVICES=0 68 | LANG=python 69 | DATADIR=dataset/py150 70 | LITFILE=${DATADIR}/literals.json 71 | OUTPUTDIR=save/py150 72 | PRETRAINDIR=py150-ckpt 73 | 74 | LOADFILE=${DATADIR}/train_split 75 | RESFILE=search_res.pkl 76 | SAVEFILE=prediction.txt 77 | 78 | python -u run_lm.py \ 79 | --data_dir=$DATADIR \ 80 | --lit_file=$LITFILE \ 81 | --langs=$LANG \ 82 | --output_dir=$OUTPUTDIR \ 83 | --pretrain_dir=$PRETRAINDIR \ 84 | --load_file_name=$LOADFILE \ 85 | --search_res=$RESFILE \ 86 | --save_name=$SAVEFILE \ 87 | --model_type=gpt2 \ 88 | --block_size=1024 \ 89 | --eval_line \ 90 | --logging_steps=100 \ 91 | --seed=42 92 | ``` 93 | 94 | 95 | ## Zero-shot code clone detection 96 | In order to evaluate the effectiveness of the code-to-code retrieval module in ReACC, 97 | we perform code clone detection task which aims to retrieve semantic equivalent programs. 98 | 99 | We extract the evaluation dataset from CodeNet, the same as in [UniXcoder paper](https://arxiv.org/abs/2203.03850). 100 | The dataset can be downloaded from [here](https://github.com/microsoft/CodeBERT/tree/master/UniXcoder/downstream-tasks/zero-shot-search/dataset) 101 | 102 | Run the `codenet_test.py` to reproduce this experiment. 103 | ```bash 104 | DATADIR=CodeNet 105 | PRETRAINDIR=microsoft/reacc-py-retriever 106 | 107 | python -u codenet_test.py \ 108 | --data_dir=$DATADIR \ 109 | --pretrained_dir=$PRETRAINDIR \ 110 | --lang=python \ 111 | --num_vec=8 \ 112 | --cut \ 113 | --block_size=512 \ 114 | --per_gpu_eval_batch_size=64 \ 115 | --logging_steps=100 \ 116 | --seed=614 117 | ``` 118 | 119 | ## Code of Conduct 120 | 121 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 122 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 123 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 124 | 125 | ## License 126 | 127 | Copyright (c) Microsoft Corporation. All rights reserved. 128 | 129 | Licensed under the [MIT](LICENSE) license. 130 | 131 | ## Reference 132 | If you use this code or ReACC, please consider citing us. 133 |
@article{lu2022reacc,
134 |   title={ReACC: A Retrieval-Augmented Code Completion Framework},
135 |   author={Lu, Shuai and Duan, Nan and Han, Hojae and Guo, Daya and Hwang, Seung-won and Svyatkovskiy, Alexey},
136 |   journal={arXiv preprint arXiv:2203.07722},
137 |   year={2022}
138 | }
139 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | # This script is used to generate the embedding vectors for the given dataset. 5 | 6 | import argparse 7 | import logging 8 | import os 9 | import random 10 | import re 11 | import json 12 | import pickle 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | from itertools import cycle 17 | from functools import partial 18 | from torch.nn.utils.rnn import pad_sequence 19 | import torch.nn.functional as F 20 | from torch.utils.data import DataLoader, Dataset, SequentialSampler 21 | from torch.utils.data.distributed import DistributedSampler 22 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 23 | RobertaConfig, RobertaModel, RobertaTokenizer) 24 | from tqdm import tqdm 25 | from tree_sitter import Language, Parser 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | class InferDataset(Dataset): 30 | def __init__(self, tokenizer, args, api=True): 31 | if args.local_rank == -1: 32 | local_rank = 0 33 | world_size = 1 34 | else: 35 | local_rank = args.local_rank 36 | world_size = torch.distributed.get_world_size() 37 | self.tokenizer = tokenizer 38 | self.args = args 39 | self.api = api 40 | data_file = args.data_path 41 | 42 | 43 | if args.lang == "java": 44 | from process_java import processor 45 | elif args.lang == "python": 46 | from process_python import processor 47 | self.proc = processor(args.lang, remove_comments=False) 48 | 49 | logger.info(f"Creating features from {data_file}") 50 | data_format = data_file.split(".")[-1] 51 | 52 | self.data = [] 53 | self.idx = [] 54 | n = 0 55 | with open(data_file) as f: 56 | for _ in f: 57 | n += 1 58 | # n = 100000 59 | st = n//world_size*local_rank 60 | ed = n//world_size*(local_rank+1) 61 | logger.warning(f"device {local_rank} will load {ed-st} data line from {st} to {ed}") 62 | with open(data_file) as f: 63 | for i,line in enumerate(f): 64 | if i >= st and i < ed: 65 | if (i-st) % 100000 == 0: 66 | logger.info(f"device {local_rank} created {i-st}/{ed-st} train data") 67 | if "json" in data_format: 68 | content = json.loads(line) 69 | self.data.append(self.convert_cxg_format_to_normal(content["input"])) 70 | self.idx.append(content["id"]) 71 | else: # txt 72 | self.data.append(self.convert_cxg_format_to_normal(line.strip())) 73 | self.idx.append(i) 74 | logger.warning(f"device {local_rank} loaded {len(self.data)} train data from {st} to {ed}") 75 | 76 | def convert_cxg_format_to_normal(self, code): 77 | if code.startswith(""): 78 | code = code.lstrip("") 79 | if code.endswith(""): 80 | code = code.rstrip("") 81 | code = code.replace("", "\n") 82 | code = code.replace("", "0").replace("", "").replace("", "") 83 | pattern = re.compile(r"<(STR|NUM|CHAR)_LIT:(.*?)>", re.S) 84 | lits = re.findall(pattern, code) 85 | for lit in lits: 86 | code = code.replace(f"<{lit[0]}_LIT:{lit[1]}>", lit[1]) 87 | return code 88 | 89 | def encode(self, code, api_seq): 90 | if self.api: 91 | code_tokens = [self.tokenizer.cls_token] + self.tokenizer.tokenize(code) + \ 92 | [self.tokenizer.sep_token] + self.tokenizer.tokenize(" ".join(api_seq)) + [self.tokenizer.sep_token] 93 | else: 94 | code_tokens = [self.tokenizer.cls_token] + self.tokenizer.tokenize(code) + [self.tokenizer.sep_token] 95 | code_tokens = code_tokens[:self.args.block_size] 96 | code_ids = self.tokenizer.convert_tokens_to_ids(code_tokens) 97 | return code_ids 98 | 99 | def process(self, code): 100 | self.proc.update(code) 101 | api_seq = self.proc.get_api_seq() 102 | code = self.proc.untokenize(cut_ratio=0.0) 103 | token_id = self.encode(code, api_seq) 104 | return token_id 105 | 106 | def __len__(self): 107 | return len(self.data) 108 | 109 | def __getitem__(self, item): 110 | return torch.tensor(self.process(self.data[item])), torch.tensor([self.idx[item]]) 111 | 112 | 113 | 114 | def set_seed(args): 115 | random.seed(args.seed) 116 | np.random.seed(args.seed) 117 | torch.manual_seed(args.seed) 118 | if args.n_gpu > 0: 119 | torch.cuda.manual_seed_all(args.seed) 120 | 121 | def my_collect_fn(sequences, batch_first=True, padding_value=1): 122 | inputs = [] 123 | inputs1 = [] 124 | for (x, x1) in sequences: 125 | inputs.append(x) 126 | inputs1.append(x1) 127 | return ( 128 | pad_sequence(inputs, batch_first, padding_value), 129 | pad_sequence(inputs1, batch_first, padding_value), 130 | ) 131 | 132 | def inference(args, tokenizer, model, save_name, api=False): 133 | #build dataloader 134 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 135 | dataset = InferDataset(tokenizer, args, api=api) 136 | sampler = SequentialSampler(dataset) 137 | dataloader = DataLoader(dataset, sampler=sampler, batch_size=args.eval_batch_size, collate_fn=partial(my_collect_fn, batch_first=True, padding_value=tokenizer.pad_token_id), num_workers=4) 138 | 139 | model.to(args.device) 140 | if args.local_rank != -1: 141 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank%args.gpu_per_node], 142 | output_device=args.local_rank%args.gpu_per_node, 143 | find_unused_parameters=True) 144 | 145 | # Eval! 146 | logger.info("***** Running Inference *****") 147 | logger.info(" Num examples = %d", len(dataset)) 148 | logger.info(" Batch size = %d", args.eval_batch_size) 149 | 150 | model.eval() 151 | 152 | steps = 0 153 | n_vec = max(0, args.num_vec) 154 | saved = {} 155 | for batch in dataloader: 156 | with torch.no_grad(): 157 | (inputs1, inputs2) = batch 158 | inputs1 = inputs1.to(args.device) 159 | attn_mask1 = torch.tensor(inputs1.clone().detach() != tokenizer.pad_token_id, dtype=torch.uint8, device=args.device) 160 | outputs = model(inputs1, attention_mask=attn_mask1)[0] 161 | if n_vec > 0: 162 | outputs = nn.functional.normalize(outputs[:, :n_vec, :], dim=2) 163 | else: 164 | outputs = nn.functional.normalize(outputs[:, 0, :], dim=1) 165 | outputs = outputs.detach().to("cpu").numpy() 166 | idxs = inputs2.numpy() 167 | for i in range(outputs.shape[0]): 168 | saved[idxs[i][0]] = outputs[i] 169 | steps += 1 170 | if steps % args.logging_steps == 0: 171 | logger.info(f"Inferenced {steps} steps") 172 | 173 | if args.local_rank != -1: 174 | pickle.dump(saved, open(save_name+f"_{args.local_rank}.pkl", "wb")) 175 | else: 176 | pickle.dump(saved, open(save_name+".pkl", "wb")) 177 | 178 | def merge(args, num, save_name): 179 | saved = {} 180 | for i in range(num): 181 | saved.update(pickle.load(open(save_name+f"_{i}.pkl", "rb"))) 182 | os.remove(save_name+f"_{i}.pkl") 183 | pickle.dump(saved, open(save_name+".pkl", "wb")) 184 | 185 | 186 | 187 | 188 | def main(): 189 | parser = argparse.ArgumentParser() 190 | 191 | ## Required parameters 192 | parser.add_argument("--data_path", default=None, type=str, required=True, 193 | help="The input data path.") 194 | parser.add_argument("--save_name", default=None, type=str, required=True, 195 | help="The output directory where the model predictions and checkpoints will be written.") 196 | parser.add_argument("--lang", default=None, type=str, required=True, 197 | help="Language of the dataset.") 198 | parser.add_argument("--pretrained_dir", default=None, type=str, 199 | help="The directory where the trained model and tokenizer are saved.") 200 | 201 | parser.add_argument("--cut_ratio", type=float, default=0.5, 202 | help="Ratio of replaced variables") 203 | parser.add_argument('--num_vec', type=int, default=-1, 204 | help="number of vectors") 205 | 206 | parser.add_argument("--block_size", default=512, type=int, 207 | help="Optional input sequence length after tokenization." 208 | "The training dataset will be truncated in block of this size for training." 209 | "Default to the model max input length for single sentence inputs (take into account special tokens).") 210 | 211 | parser.add_argument("--per_gpu_eval_batch_size", default=16, type=int, 212 | help="Batch size per GPU/CPU for evaluation.") 213 | 214 | parser.add_argument('--logging_steps', type=int, default=10, 215 | help="Log every X updates steps.") 216 | parser.add_argument("--no_cuda", action='store_true', 217 | help="Avoid using CUDA when available") 218 | parser.add_argument('--seed', type=int, default=42, 219 | help="random seed for initialization") 220 | 221 | parser.add_argument("--local_rank", type=int, default=-1, 222 | help="For distributed training: local_rank") 223 | parser.add_argument("--node_index", type=int, default=0, 224 | help="node index if multi-node running") 225 | parser.add_argument("--gpu_per_node", type=int, default=-1, 226 | help="num of gpus per node") 227 | 228 | args = parser.parse_args() 229 | 230 | logger.warning("local_rank: %d, node_index: %d, gpu_per_node: %d"%(args.local_rank, args.node_index, args.gpu_per_node)) 231 | # Setup CUDA, GPU & distributed training 232 | if args.local_rank == -1 or args.no_cuda: 233 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 234 | args.n_gpu = torch.cuda.device_count() 235 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 236 | torch.cuda.set_device(args.local_rank) 237 | device = torch.device("cuda", args.local_rank) 238 | torch.distributed.init_process_group(backend='nccl') 239 | args.local_rank += args.node_index * args.gpu_per_node 240 | args.n_gpu = 1 241 | args.device = device 242 | 243 | world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1 244 | 245 | # Setup logging 246 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 247 | datefmt='%m/%d/%Y %H:%M:%S', 248 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 249 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, world size: %s", 250 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), world_size) 251 | 252 | # Set seed 253 | set_seed(args) 254 | 255 | args.start_epoch = 0 256 | args.start_step = 0 257 | 258 | tokenizer = RobertaTokenizer.from_pretrained(args.pretrained_dir) 259 | model = RobertaModel.from_pretrained(args.pretrained_dir, add_pooling_layer=False) 260 | 261 | inference(args, tokenizer, model, args.save_name, api=True) 262 | logger.info(f"device {args.local_rank} finished") 263 | 264 | if args.local_rank != -1: 265 | torch.distributed.barrier() 266 | if args.local_rank in [-1, 0]: 267 | import time 268 | time.sleep(10) 269 | merge(args, world_size, save_name=args.save_name) 270 | 271 | 272 | if __name__ == "__main__": 273 | main() 274 | 275 | -------------------------------------------------------------------------------- /process_java.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT License. 4 | import os 5 | import json 6 | import re 7 | import pickle 8 | from tqdm import tqdm 9 | import random 10 | from collections import Counter, OrderedDict 11 | from tree_sitter import Language, Parser 12 | from parser import (remove_comments_and_docstrings, 13 | tree_to_token_index, 14 | index_to_code_token, 15 | tree_to_variable_index, 16 | traverse) 17 | from textwrap import dedent 18 | 19 | 20 | # ['nwo', 'path', 'language', 'identifier', 'parameters', 'argument_list', 'return_statement', 'docstring', 'docstring_summary', 'docstring_tokens', 'function', 'function_tokens'] 21 | 22 | def clean_docstring_comments(comment): 23 | comment = comment.strip().strip(""" "' """) 24 | comment = "\n".join(map(lambda s: s.lstrip("#"), comment.splitlines())) 25 | return dedent(comment) 26 | 27 | class processor(object): 28 | def __init__(self, lang, code=None, remove_comments=False): 29 | LANGUAGE = Language('parser/my-languages.so', lang) 30 | parser = Parser() 31 | parser.set_language(LANGUAGE) 32 | self.parser = [parser] 33 | self.lang = lang 34 | self.remove_comments = remove_comments 35 | self.preserve_words = set(["self", "super", "Exception", "__init__", "__main__"]) 36 | if code is None: 37 | self.tree = None 38 | else: 39 | self.update(code) 40 | 41 | def update(self, code, function=False): 42 | if self.lang == "php": 43 | code = "" 44 | if self.lang == "java" and function: 45 | code = "class A{\n"+code+"\n}" 46 | com = True 47 | if self.remove_comments: 48 | com = False 49 | try: 50 | code = remove_comments_and_docstrings(code, self.lang) 51 | except Exception: 52 | com = True 53 | self.code = code 54 | self.code_bytes = code.encode("utf-8") 55 | self.tree = self.parser[0].parse(self.code_bytes) 56 | root_node = self.tree.root_node 57 | tokens_index = tree_to_token_index(root_node) 58 | code = self.code.split('\n') 59 | self.code_tokens = [index_to_code_token(x, code) for x in tokens_index] 60 | self.index_to_code = OrderedDict() 61 | for idx, (index, code) in enumerate(zip(tokens_index, self.code_tokens)): 62 | self.index_to_code[index] = (idx, code) 63 | return com 64 | 65 | def get_doc(self): 66 | """ 67 | For file level data, merge the doc in each function 68 | """ 69 | self.functions = [] 70 | self.get_func_nodes(self.tree.root_node) 71 | docs = "" 72 | for func_node in self.functions: 73 | body_node = func_node.children[-1] 74 | if body_node.children and body_node.children[0].children: 75 | if body_node.children[0].children[0].type in ["string", "comment"]: 76 | docs += clean_docstring_comments( 77 | self.span_select(body_node.children[0].children[0]), 78 | ) + "" 79 | return docs 80 | 81 | def get_func_nodes(self, root): 82 | """ 83 | For both function level and file level data 84 | """ 85 | for node in root.children: 86 | if node.type == "function_definition": 87 | self.functions.append(node) 88 | else: 89 | self.get_func_nodes(node) 90 | 91 | 92 | def load_names(self, path): 93 | vnames = pickle.load(open(os.path.join(path, "vnames.pkl"), "rb")) 94 | self.preserve_words = vnames["not_replace"] 95 | self.vnames = vnames["cands"] 96 | 97 | def span_select(self, *nodes, indent=False): 98 | if not nodes: 99 | return "" 100 | start, end = nodes[0].start_byte, nodes[-1].end_byte 101 | select = self.code_bytes[start:end].decode("utf-8") 102 | if indent: 103 | return " " * nodes[0].start_point[1] + select 104 | return select 105 | 106 | def get_func_name(self): 107 | """ 108 | For function level data only 109 | """ 110 | root_node = self.tree.root_node 111 | func_nodes = [node for node in root_node.children if node.type == "function_definition"] 112 | try: 113 | func_name = func_nodes[0].child_by_field_name("name") 114 | except IndexError: 115 | return "" 116 | return self.span_select(func_name) 117 | 118 | def get_var_names(self): 119 | root_node = self.tree.root_node 120 | vnames = set() 121 | self._get_var_names_from_node(root_node, vnames) 122 | return vnames 123 | 124 | def get_api_seq(self): 125 | root_node = self.tree.root_node 126 | api_seq = [] 127 | self._get_api_seq(root_node, api_seq) 128 | return api_seq 129 | 130 | def _get_var_names_from_node(self, node, vnames, inactive=False): 131 | if len(node.children) > 0: 132 | if node.type in ["method_invocation"]: 133 | self._get_var_names_from_node(node.children[0], vnames, inactive) 134 | for child in node.children[1:]: 135 | if child.type == "argument_list": 136 | self._get_var_names_from_node(child, vnames, inactive) 137 | else: 138 | self._get_var_names_from_node(child, vnames, True) 139 | else: 140 | for child in node.children: 141 | if node.type in ["field_access", "modifiers", "import_declaration", "package_declaration"]: 142 | self._get_var_names_from_node(child, vnames, True) 143 | else: 144 | self._get_var_names_from_node(child, vnames, inactive) 145 | elif node.type == "identifier": 146 | if not inactive: 147 | vnames.add(self.span_select(node)) 148 | 149 | def _get_api_seq(self, node, api_seq): 150 | if node.type == "method_invocation": 151 | obj = node.child_by_field_name("object") 152 | func = node.child_by_field_name("name") 153 | if obj: 154 | api_seq.append(self.span_select(obj) + "." + self.span_select(func)) 155 | else: 156 | api_seq.append(self.span_select(func)) 157 | else: 158 | for child in node.children: 159 | self._get_api_seq(child, api_seq) 160 | 161 | def process(self, ratio=0.5, indent=False, add_dead_code=True, cut_ratio=0.0, function=False): 162 | vnames = [x for x in self.get_var_names() if x not in self.preserve_words] 163 | vnames = random.sample(vnames, int(len(vnames)*ratio)) 164 | cands = random.sample(self.vnames, len(vnames)+3) 165 | dead_vars = cands[-3:] 166 | if add_dead_code: 167 | deadcode = self.insert_dead_code(dead_vars) 168 | else: 169 | deadcode = None 170 | replaced = {v: c for v, c in zip(vnames, cands[:-3])} 171 | self.index_to_new_code = {} 172 | self._replace_var_names_from_node(self.tree.root_node, replaced) 173 | code_string = self.untokenize(indent, deadcode, True, cut_ratio=cut_ratio, function=function) 174 | return code_string 175 | 176 | def process_no_replace(self, indent=False, add_dead_code=True, cut_ratio=0.0, function=False): 177 | dead_vars = random.sample(self.vnames, 3) 178 | if add_dead_code: 179 | deadcode = self.insert_dead_code(dead_vars) 180 | else: 181 | deadcode = None 182 | code_string = self.untokenize(indent, deadcode, False, cut_ratio=cut_ratio, function=function) 183 | return code_string 184 | 185 | def insert_dead_code(self, v): 186 | # dead code types, vars that can't appear in original code 187 | # A = B, A 188 | # A.C(B, 1), A 189 | # A = B + C, AB 190 | # A = B(C), AB 191 | # A = B.C(), ABC 192 | # for (String i: A){\nB(C)\n} 193 | # A = [B for B in range(C)] 194 | # if (C){\nA = B()\n} 195 | dead_type = random.randrange(7) 196 | if dead_type == 0: 197 | return f"{v[0]} = {v[1]};\n" 198 | elif dead_type == 1: 199 | return f"{v[0]}.{v[2]}({v[1]}, 1);\n" 200 | elif dead_type == 2: 201 | return f"{v[0]} = {v[1]} + {v[2]};\n" 202 | elif dead_type == 3: 203 | return f"{v[0]} = {v[1]}({v[2]});\n" 204 | elif dead_type == 4: 205 | return f"{v[0]} = {v[1]}.{v[2]}();\n" 206 | elif dead_type == 5: 207 | return "for (String i: %s){\n%s(%s)\n;}\n"%(v[0], v[1], v[2]) 208 | elif dead_type == 6: 209 | return "if (%s){\n%s = %s\n;}\n"%(v[2], v[0], v[1]) 210 | 211 | def _replace_var_names_from_node(self, node, replaced): 212 | if len(node.children) > 0: 213 | for child in node.children: 214 | self._replace_var_names_from_node(child, replaced) 215 | elif node.type == "identifier": 216 | try: 217 | idf = self.index_to_code[(node.start_point, node.end_point)][1] 218 | except KeyError: 219 | idf = "None" 220 | if idf in replaced: 221 | self.index_to_new_code[(node.start_point, node.end_point)] = replaced[idf] 222 | 223 | def untokenize(self, indent=False, deadcode=None, replaced=False, cut_ratio=0.0, function=False, fix_cut_pos=False): 224 | code_string = "" 225 | prev_sp = None 226 | prev_ep = None 227 | prev_indent = 0 228 | indent_size = -1 229 | if function: 230 | total_line = list(self.index_to_code.keys())[-2][0][0] 231 | else: 232 | total_line = list(self.index_to_code.keys())[-1][0][0] 233 | insert_line = random.randint(total_line//5, total_line*4//5) 234 | cut = random.random() < cut_ratio 235 | if cut: 236 | if fix_cut_pos: 237 | cut_pos = (len(self.index_to_code)-4)//2 238 | else: 239 | cut_pos = random.randint((len(self.index_to_code)-4)//3, (len(self.index_to_code)-4)*2//3) 240 | if function: 241 | poses = list(self.index_to_code.keys())[3:-1] 242 | else: 243 | poses = list(self.index_to_code.keys()) 244 | for ip, pos in enumerate(poses): 245 | sp = pos[0] 246 | ep = pos[1] 247 | if cut and ip >= cut_pos: 248 | break 249 | if replaced and pos in self.index_to_new_code: 250 | add_token = self.index_to_new_code[pos] 251 | else: 252 | add_token = self.index_to_code[pos][1] 253 | if prev_sp is None or (sp[0] == prev_ep[0] and sp[1] == prev_ep[1]): 254 | code_string += add_token 255 | elif sp[0] == prev_ep[0]: 256 | if code_string and code_string[-1] != " ": 257 | code_string += " " 258 | code_string += add_token 259 | else: 260 | # if cut and cut_line >= 1 and cut_line <= prev_ep[0]: 261 | # break 262 | if replaced and deadcode: 263 | if insert_line <= prev_ep[0]: 264 | code_string += "\n" + deadcode 265 | insert_line = total_line+2 266 | if indent and add_token: 267 | code_string += "\n" 268 | omit = False 269 | if sp[1] != prev_indent and prev_indent == 0 and indent_size == -1: 270 | indent_size = sp[1] - prev_indent 271 | if sp[1] - prev_indent > 0: 272 | if sp[1] - prev_indent > 2 * indent_size: 273 | omit = True 274 | else: 275 | for i in range(prev_indent, sp[1], indent_size): 276 | code_string += "" 277 | elif sp[1] - prev_indent < 0: 278 | for i in range(sp[1], prev_indent, indent_size): 279 | code_string += "" 280 | code_string += add_token 281 | if not omit: 282 | prev_indent = sp[1] 283 | else: 284 | code_string += "\n" 285 | code_string += " " if sp[1] else "" 286 | code_string += add_token 287 | prev_sp, prev_ep = sp, ep 288 | return re.sub(re.compile("\s*\n"), "\n", code_string.lstrip()).replace("\n", "") 289 | 290 | def convert_to_normal(self, code): 291 | lines = code.split("") 292 | indent_size = 4 293 | indent = 0 294 | res = "" 295 | for line in lines: 296 | indent += line.count("") 297 | indent -= line.count("") 298 | res += "\n" + " "*indent_size*indent + line.replace("", "").replace("", "") 299 | return res 300 | 301 | 302 | 303 | 304 | -------------------------------------------------------------------------------- /codenet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | import argparse 4 | import logging 5 | import os 6 | import random 7 | import re 8 | import json 9 | import pickle 10 | from collections import Counter 11 | import numpy as np 12 | import torch 13 | from itertools import cycle 14 | from functools import partial 15 | from torch.nn.utils.rnn import pad_sequence 16 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset 17 | from torch.utils.data.distributed import DistributedSampler 18 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 19 | RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaTokenizer) 20 | from tqdm import tqdm 21 | from tree_sitter import Language, Parser 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | at_K = 100 26 | 27 | class InputFeatures(object): 28 | """A single training/test features for a example.""" 29 | def __init__(self, 30 | code_ids, 31 | index, 32 | label 33 | 34 | ): 35 | self.code_ids = code_ids 36 | self.index= index 37 | self.label = label 38 | 39 | 40 | class CodeWithDocNoRepDataset(Dataset): 41 | def __init__(self, tokenizer, args, file_type, cut_ratio=0.0): 42 | self.tokenizer = tokenizer 43 | self.args = args 44 | data_file = os.path.join(args.data_dir, f"{file_type}.jsonl") 45 | 46 | if args.lang == "java": 47 | from process_java import processor 48 | elif args.lang == "python": 49 | from process_python import processor 50 | self.proc = processor(args.lang, remove_comments=False) 51 | # self.proc.load_names(args.vars_dir) 52 | 53 | #load index 54 | logger.info(f"Creating features from {data_file}") 55 | 56 | 57 | self.examples = [] 58 | lines = open(data_file).readlines() 59 | for i,line in enumerate(lines): 60 | content = json.loads(line) 61 | self.proc.update(content["func"]) 62 | code = self.proc.untokenize(cut_ratio=cut_ratio, fix_cut_pos=True) 63 | self.proc.update(self.proc.convert_to_normal(code)) 64 | api_seq = self.proc.get_api_seq() 65 | token_id = self.encode_v3(code, api_seq) 66 | self.examples.append(InputFeatures(token_id, content["index"], int(content["label"]))) 67 | logger.info(f"loaded {len(self.examples)} data") 68 | 69 | self.label_examples={} 70 | for e in self.examples: 71 | if e.label not in self.label_examples: 72 | self.label_examples[e.label] = [] 73 | self.label_examples[e.label].append(e) 74 | 75 | def encode_v3(self, code, api_seq): 76 | code_tokens = [self.tokenizer.cls_token] + self.tokenizer.tokenize(code) + \ 77 | [self.tokenizer.sep_token] + self.tokenizer.tokenize(" ".join(api_seq)) + [self.tokenizer.sep_token] 78 | code_tokens = code_tokens[:self.args.block_size] 79 | code_ids = self.tokenizer.convert_tokens_to_ids(code_tokens) 80 | return code_ids 81 | 82 | def __len__(self): 83 | return len(self.examples) 84 | 85 | def __getitem__(self, i): 86 | return torch.tensor(self.examples[i].code_ids), self.examples[i].index 87 | 88 | def set_seed(args): 89 | random.seed(args.seed) 90 | np.random.seed(args.seed) 91 | torch.manual_seed(args.seed) 92 | if args.n_gpu > 0: 93 | torch.cuda.manual_seed_all(args.seed) 94 | 95 | def my_collect_fn(sequences, batch_first=True, padding_value=1): 96 | inputs1 = [] 97 | inputs2 = [] 98 | for (x1, x2) in sequences: 99 | inputs1.append(x1) 100 | inputs2.append(x2) 101 | return ( 102 | pad_sequence(inputs1, batch_first, padding_value), 103 | inputs2 104 | ) 105 | 106 | 107 | def eval_bm25_beir(args, tokenizer, file_name, candidate_file_name, cut=False): 108 | from beir.datasets.data_loader import GenericDataLoader 109 | from beir.retrieval.evaluation import EvaluateRetrieval 110 | from beir.retrieval.search.lexical import BM25Search as BM25 111 | 112 | if args.lang == "java": 113 | from process_java import processor 114 | elif args.lang == "python": 115 | from process_python import processor 116 | proc = processor(args.lang, remove_comments=False) 117 | 118 | idx2label = {} 119 | label2num = Counter() 120 | lines = open(os.path.join(args.data_dir, f"{candidate_file_name}.jsonl")).readlines() 121 | corpus = {} 122 | for i,line in enumerate(tqdm(lines)): 123 | content = json.loads(line) 124 | # proc.update(content["func"]) 125 | # code = proc.untokenize() 126 | code = content["func"] 127 | idx2label[content["index"]] = content["label"] 128 | label2num[content["label"]] += 1 129 | corpus[content["index"]] = {"text": code} 130 | 131 | lines = open(os.path.join(args.data_dir, f"{file_name}.jsonl")).readlines() 132 | queries = {} 133 | qrels = {} 134 | for i,line in enumerate(tqdm(lines)): 135 | content = json.loads(line) 136 | ori_code_tokens = content["func"].split() 137 | if cut: 138 | code = " ".join(ori_code_tokens[:len(ori_code_tokens)//2]) 139 | else: 140 | code = " ".join(ori_code_tokens) 141 | # proc.update(content["func"]) 142 | # code = proc.untokenize(cut_ratio=1.0 if cut else 0.0, fix_cut_pos=True) 143 | queries[content["index"]] = code 144 | qrels[content["index"]] = {content["index"]: 1} 145 | 146 | model = BM25(index_name="codenet", hostname="http://localhost:9200", initialize=True) 147 | retriever = EvaluateRetrieval(model, k_values=[at_K+1]) 148 | scores = retriever.retrieve(corpus, queries) 149 | 150 | # pickle.dump(scores, open(os.path.join(args.data_dir, "bm25_scores.pkl"), "wb")) 151 | 152 | MAP = [] 153 | PREC = 0.0 154 | for idx, v in tqdm(scores.items()): 155 | v = sorted(v.items(), key=lambda x:-x[1]) 156 | label = idx2label[idx] 157 | div = min(at_K, label2num[label]) 158 | Avep = [] 159 | cont = 0 160 | for i, (_id, score) in enumerate(v): 161 | if i - cont >= at_K: 162 | break 163 | if _id == idx: 164 | cont += 1 165 | continue 166 | if idx2label[_id] == label: 167 | Avep.append((len(Avep)+1)/(i+1-cont)) 168 | if i - cont == 0: 169 | PREC += 1.0 170 | MAP.append(sum(Avep)/div) 171 | 172 | result = { 173 | "eval_map":float(np.mean(MAP)), 174 | "eval_prec":float(PREC/len(MAP)) 175 | } 176 | for key in sorted(result.keys()): 177 | logger.info(" %s = %s", key, str(round(result[key], 4))) 178 | 179 | 180 | def evaluate(args, model, tokenizer, file_name, candidate_file_name, cut): 181 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 182 | 183 | query_dataset = CodeWithDocNoRepDataset(tokenizer, args, file_name, cut_ratio=1.0 if cut else 0.0) 184 | query_sampler = SequentialSampler(query_dataset) 185 | query_dataloader = DataLoader(query_dataset, sampler=query_sampler, batch_size=args.eval_batch_size, collate_fn=partial(my_collect_fn, batch_first=True, padding_value=tokenizer.pad_token_id), num_workers=4) 186 | 187 | candidate_dataset = CodeWithDocNoRepDataset(tokenizer, args, candidate_file_name) 188 | candidate_sampler = SequentialSampler(candidate_dataset) 189 | candidate_dataloader = DataLoader(candidate_dataset, sampler=candidate_sampler, batch_size=args.eval_batch_size, collate_fn=partial(my_collect_fn, batch_first=True, padding_value=tokenizer.pad_token_id), num_workers=4) 190 | 191 | 192 | idx2label = {} 193 | label2num = Counter() 194 | lines = open(os.path.join(args.data_dir, f"{candidate_file_name}.jsonl")).readlines() 195 | corpus = {} 196 | for i,line in enumerate(tqdm(lines)): 197 | content = json.loads(line) 198 | idx2label[content["index"]] = content["label"] 199 | label2num[int(content["label"])] += 1 200 | 201 | # multi-gpu evaluate 202 | if args.n_gpu > 1: 203 | model = torch.nn.DataParallel(model) 204 | 205 | # Eval! 206 | logger.info("***** Running evaluation *****") 207 | logger.info(" Num Query = %d", len(query_dataset)) 208 | logger.info(" Num Candidate = %d", len(candidate_dataset)) 209 | logger.info(" Batch size = %d", args.eval_batch_size) 210 | model.to(args.device) 211 | 212 | model.eval() 213 | query_vecs = [] 214 | query_indexs = [] 215 | candidate_vecs = [] 216 | candidate_indexs = [] 217 | 218 | for batch in tqdm(query_dataloader, total=len(query_dataloader)): 219 | code_inputs = batch[0].to(args.device) 220 | index = batch[1] 221 | with torch.no_grad(): 222 | attn_mask = torch.tensor(code_inputs.clone().detach() != tokenizer.pad_token_id, dtype=torch.uint8, device=args.device) 223 | code_vec = model(code_inputs, attention_mask=attn_mask)[0] 224 | code_vec = torch.nn.functional.normalize(code_vec[:, 0, :], dim=1) 225 | query_vecs.append(code_vec.cpu().numpy()) 226 | query_indexs.extend(index) 227 | 228 | for batch in tqdm(candidate_dataloader,total=len(candidate_dataloader)): 229 | code_inputs = batch[0].to(args.device) 230 | index = batch[1] 231 | with torch.no_grad(): 232 | attn_mask = torch.tensor(code_inputs.clone().detach() != tokenizer.pad_token_id, dtype=torch.uint8, device=args.device) 233 | code_vec = model(code_inputs, attention_mask=attn_mask)[0] 234 | if args.num_vec > 0: 235 | code_vec = torch.nn.functional.normalize(code_vec[:, :args.num_vec, :], dim=2) 236 | else: 237 | code_vec = torch.nn.functional.normalize(code_vec[:, 0, :], dim=1) 238 | candidate_vecs.append(code_vec.cpu().numpy()) 239 | candidate_indexs.extend(index) 240 | 241 | model.train() 242 | 243 | query_vecs = np.concatenate(query_vecs, 0) 244 | candidate_vecs = np.concatenate(candidate_vecs, 0) 245 | query_labels = [idx2label[x] for x in query_indexs] 246 | candidate_labels = [idx2label[x] for x in candidate_indexs] 247 | 248 | if args.num_vec > 0: 249 | scores=np.einsum('nd,mvd->nmv', query_vecs, candidate_vecs).max(-1) 250 | else: 251 | scores=np.matmul(query_vecs, candidate_vecs.T) 252 | sort_ids = np.argsort(scores, axis=-1, kind='quicksort', order=None)[:,::-1] 253 | 254 | MAP = [] 255 | MAP_at_K = [] 256 | PREC = 0.0 257 | for i in tqdm(range(scores.shape[0]), total=scores.shape[0]): 258 | cont = 0 259 | label = int(query_labels[i]) 260 | div = min(at_K, label2num[label]) 261 | query_index = query_indexs[i] 262 | Avep = [] 263 | for j,index in enumerate(list(sort_ids[i])): 264 | if query_index == candidate_indexs[index]: 265 | cont += 1 266 | continue 267 | if j - cont == at_K: 268 | MAP_at_K.append(sum(Avep)/div) 269 | if int(candidate_labels[index]) == label: 270 | Avep.append((len(Avep)+1)/(j+1-cont)) 271 | if j - cont == 0: 272 | PREC += 1.0 273 | if len(Avep) > 0: 274 | MAP.append(sum(Avep)/len(Avep)) 275 | else: 276 | MAP.append(0.0) 277 | 278 | result = { 279 | "Data size":len(MAP), 280 | "eval_map":float(np.mean(MAP)), 281 | f"eval_map_at_{at_K}":float(np.mean(MAP_at_K)), 282 | "eval_prec":float(PREC/len(MAP)) 283 | } 284 | for key in sorted(result.keys()): 285 | logger.info(" %s = %s", key, str(round(result[key], 4))) 286 | 287 | 288 | def main(): 289 | parser = argparse.ArgumentParser() 290 | 291 | ## Required parameters 292 | parser.add_argument("--data_dir", default=None, type=str, required=True, 293 | help="The input data path.") 294 | parser.add_argument("--pretrained_dir", default=None, type=str, 295 | help="The directory where the trained model and tokenizer are saved.") 296 | parser.add_argument("--lang", default="python", type=str, 297 | help="Language of dataset") 298 | 299 | parser.add_argument("--cut", action='store_true', 300 | help="Ratio of replaced variables") 301 | parser.add_argument('--num_vec', type=int, default=-1, 302 | help="number of vectors") 303 | parser.add_argument("--block_size", default=512, type=int, 304 | help="Optional input sequence length after tokenization." 305 | "The training dataset will be truncated in block of this size for training." 306 | "Default to the model max input length for single sentence inputs (take into account special tokens).") 307 | 308 | parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int, 309 | help="Batch size per GPU/CPU for evaluation.") 310 | parser.add_argument('--logging_steps', type=int, default=10, 311 | help="Log every X updates steps.") 312 | parser.add_argument("--no_cuda", action='store_true', 313 | help="Avoid using CUDA when available") 314 | parser.add_argument('--seed', type=int, default=42, 315 | help="random seed for initialization") 316 | args = parser.parse_args() 317 | 318 | 319 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 320 | args.n_gpu = torch.cuda.device_count() 321 | args.device = device 322 | 323 | # Setup logging 324 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 325 | datefmt='%m/%d/%Y %H:%M:%S', 326 | level=logging.INFO) 327 | 328 | # Set seed 329 | set_seed(args) 330 | 331 | tokenizer = RobertaTokenizer.from_pretrained(args.pretrained_dir) 332 | model = RobertaModel.from_pretrained(args.pretrained_dir, add_pooling_layer=False) 333 | 334 | 335 | evaluate(args, model, tokenizer, args.lang, args.lang, args.cut) 336 | # eval_bm25_beir(args, tokenizer, "java", "java", cut=True) 337 | 338 | if __name__ == "__main__": 339 | main() 340 | 341 | -------------------------------------------------------------------------------- /process_python.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT License. 4 | import os 5 | import json 6 | import re 7 | import pickle 8 | from tqdm import tqdm 9 | import random 10 | from collections import Counter, OrderedDict 11 | from tree_sitter import Language, Parser 12 | from parser import (remove_comments_and_docstrings, 13 | tree_to_token_index, 14 | index_to_code_token, 15 | tree_to_variable_index, 16 | traverse) 17 | from textwrap import dedent 18 | 19 | 20 | # ['nwo', 'path', 'language', 'identifier', 'parameters', 'argument_list', 'return_statement', 'docstring', 'docstring_summary', 'docstring_tokens', 'function', 'function_tokens'] 21 | 22 | def clean_docstring_comments(comment): 23 | comment = comment.strip().strip(""" "' """) 24 | comment = "\n".join(map(lambda s: s.lstrip("#"), comment.splitlines())) 25 | return dedent(comment) 26 | 27 | class processor(object): 28 | def __init__(self, lang, code=None, remove_comments=False): 29 | LANGUAGE = Language('parser/my-languages.so', lang) 30 | parser = Parser() 31 | parser.set_language(LANGUAGE) 32 | self.parser = [parser] 33 | self.lang = lang 34 | self.remove_comments = remove_comments 35 | self.preserve_words = set(["self", "super", "Exception", "__init__", "__main__"]) 36 | if code is None: 37 | self.tree = None 38 | else: 39 | self.update(code) 40 | 41 | def update(self, code): 42 | if self.lang == "php": 43 | code = "" 44 | com = True 45 | if self.remove_comments: 46 | com = False 47 | try: 48 | code = remove_comments_and_docstrings(code, self.lang) 49 | except Exception: 50 | com = True 51 | self.code = code 52 | self.code_bytes = code.encode("utf-8") 53 | self.tree = self.parser[0].parse(self.code_bytes) 54 | root_node = self.tree.root_node 55 | tokens_index = tree_to_token_index(root_node) 56 | code = self.code.split('\n') 57 | self.code_tokens = [index_to_code_token(x, code) for x in tokens_index] 58 | self.index_to_code = OrderedDict() 59 | for idx, (index, code) in enumerate(zip(tokens_index, self.code_tokens)): 60 | self.index_to_code[index] = (idx, code) 61 | return com 62 | 63 | def get_doc(self): 64 | """ 65 | For file level data, merge the doc in each function 66 | """ 67 | self.functions = [] 68 | self.get_func_nodes(self.tree.root_node) 69 | docs = "" 70 | for func_node in self.functions: 71 | body_node = func_node.children[-1] 72 | if body_node.children and body_node.children[0].children: 73 | if body_node.children[0].children[0].type in ["string", "comment"]: 74 | docs += clean_docstring_comments( 75 | self.span_select(body_node.children[0].children[0]), 76 | ) + "" 77 | return docs 78 | 79 | def get_func_nodes(self, root): 80 | """ 81 | For both function level and file level data 82 | """ 83 | for node in root.children: 84 | if node.type == "function_definition": 85 | self.functions.append(node) 86 | else: 87 | self.get_func_nodes(node) 88 | 89 | 90 | def load_names(self, path): 91 | self.vnames = pickle.load(open(os.path.join(path, "vnames.pkl"), "rb")) 92 | self.vnames = [x for (x, _) in self.vnames.most_common(50000) if x not in self.preserve_words] 93 | self.fnames = pickle.load(open(os.path.join(path, "fnames.pkl"), "rb")) 94 | self.fnames = [x for (x, _) in self.fnames.most_common(5000) if x not in self.preserve_words] 95 | 96 | def span_select(self, *nodes, indent=False): 97 | if not nodes: 98 | return "" 99 | start, end = nodes[0].start_byte, nodes[-1].end_byte 100 | select = self.code_bytes[start:end].decode("utf-8") 101 | if indent: 102 | return " " * nodes[0].start_point[1] + select 103 | return select 104 | 105 | def get_func_name(self): 106 | """ 107 | For function level data only 108 | """ 109 | root_node = self.tree.root_node 110 | func_nodes = [node for node in root_node.children if node.type == "function_definition"] 111 | try: 112 | func_name = func_nodes[0].child_by_field_name("name") 113 | except IndexError: 114 | return "" 115 | return self.span_select(func_name) 116 | 117 | def get_var_names(self): 118 | root_node = self.tree.root_node 119 | vnames = set() 120 | self._get_var_names_from_node(root_node, vnames) 121 | return vnames 122 | 123 | def get_api_seq(self): 124 | root_node = self.tree.root_node 125 | api_seq = [] 126 | self._get_api_seq(root_node, api_seq) 127 | return api_seq 128 | 129 | def _get_var_names_from_node(self, node, vnames, inactive=False): 130 | if len(node.children) > 0: 131 | for child in node.children: 132 | if ( 133 | (node.type == "call" and child.type != "argument_list") or 134 | (node.type == "attribute") or 135 | (node.type in ["import_statement", "import_from_statement"]) 136 | ): 137 | self._get_var_names_from_node(child, vnames, True) 138 | else: 139 | self._get_var_names_from_node(child, vnames, inactive) 140 | elif node.type == "identifier": 141 | if not inactive: 142 | vnames.add(self.span_select(node)) 143 | 144 | def _get_api_seq(self, node, api_seq, tmp=None): 145 | if node.type == "call": 146 | api = node.child_by_field_name("function") 147 | if tmp: 148 | tmp.append(self.span_select(api)) 149 | ant = False 150 | else: 151 | tmp = [self.span_select(api)] 152 | ant = True 153 | for child in node.children: 154 | self._get_api_seq(child, api_seq, tmp) 155 | if ant: 156 | api_seq += tmp[::-1] 157 | tmp = None 158 | else: 159 | for child in node.children: 160 | self._get_api_seq(child, api_seq, tmp) 161 | 162 | def process(self, ratio=0.85, indent=True, add_dead_code=True, cut_ratio=0.0): 163 | fname = self.get_func_name() 164 | vnames = [x for x in self.get_var_names() if x not in self.preserve_words] 165 | vnames = random.sample(vnames, int(len(vnames)*ratio)) 166 | cands = random.sample(self.vnames, len(vnames)+3) 167 | dead_vars = cands[-3:] 168 | if add_dead_code: 169 | deadcode = self.insert_dead_code(dead_vars) 170 | else: 171 | deadcode = None 172 | replaced = {v: c for v, c in zip(vnames, cands[:-3])} 173 | if ratio > 0 and fname and fname not in replaced: 174 | replaced[fname] = random.choice(self.fnames) 175 | self.index_to_new_code = {} 176 | self._replace_var_names_from_node(self.tree.root_node, replaced) 177 | code_string = self.untokenize(indent, deadcode, True, cut_ratio=cut_ratio) 178 | return code_string 179 | 180 | def process_no_replace(self, indent=True, add_dead_code=True, cut_ratio=0.0): 181 | dead_vars = random.sample(self.vnames, 3) 182 | if add_dead_code: 183 | deadcode = self.insert_dead_code(dead_vars) 184 | else: 185 | deadcode = None 186 | code_string = self.untokenize(indent, deadcode, False, cut_ratio=cut_ratio) 187 | return code_string 188 | 189 | def create_mask_seq(self, indent=True): 190 | fname = self.get_func_name() 191 | vnames = [x for x in self.get_var_names() if x not in self.preserve_words] 192 | replaced = {v: f"[MASK_{i}]" for i, v in enumerate(vnames)} 193 | if fname and fname not in replaced: 194 | replaced[fname] = "[MASK_F]" 195 | self.index_to_new_code = {} 196 | self._replace_var_names_from_node(self.tree.root_node, replaced) 197 | code_string = self.untokenize(indent, replaced=True) 198 | return code_string, replaced 199 | 200 | def insert_dead_code(self, v): 201 | # dead code types, vars that can't appear in original code 202 | # A = B, A 203 | # A(B, 0), A 204 | # A = B + C, AB 205 | # A = B(C), AB 206 | # A = B.C(), ABC 207 | # A = [B for B in range(C)] 208 | # A = B if C else 0 209 | dead_type = random.randrange(7) 210 | if dead_type == 0: 211 | return f"{v[0]} = {v[1]}" 212 | elif dead_type == 1: 213 | return f"{v[0]}({v[1]}, 0)" 214 | elif dead_type == 2: 215 | return f"{v[0]} = {v[1]} + {v[2]}" 216 | elif dead_type == 3: 217 | return f"{v[0]} = {v[1]}({v[2]})" 218 | elif dead_type == 4: 219 | return f"{v[0]} = {v[1]}.{v[2]}()" 220 | elif dead_type == 5: 221 | return f"{v[0]} = [{v[1]} for {v[1]} in range({v[2]})]" 222 | elif dead_type == 6: 223 | return f"{v[0]} = {v[1]} if {v[2]} else 0" 224 | 225 | def _replace_var_names_from_node(self, node, replaced, inactive=False): 226 | if len(node.children) > 0: 227 | if node.type == "attribute": 228 | self._replace_var_names_from_node(node.children[0], replaced, inactive) 229 | for child in node.children[1:]: 230 | self._replace_var_names_from_node(child, replaced, True) 231 | else: 232 | for child in node.children: 233 | if ( 234 | (node.type == "call" and child.type not in ["attribute", "argument_list"]) or 235 | (node.type in ["import_statement", "import_from_statement"]) 236 | ): 237 | self._replace_var_names_from_node(child, replaced, True) 238 | else: 239 | self._replace_var_names_from_node(child, replaced, inactive) 240 | elif node.type == "identifier": 241 | if not inactive: 242 | try: 243 | idf = self.index_to_code[(node.start_point, node.end_point)][1] 244 | except KeyError: 245 | idf = "None" 246 | if idf in replaced: 247 | self.index_to_new_code[(node.start_point, node.end_point)] = replaced[idf] 248 | 249 | def untokenize(self, indent=True, deadcode=None, replaced=False, cut_ratio=0.0, fix_cut_pos=False): 250 | code_string = "" 251 | prev_sp = None 252 | prev_ep = None 253 | prev_indent = 0 254 | indent_size = -1 255 | total_line = list(self.index_to_code.keys())[-1][0][0] 256 | insert_line = random.randint(total_line//5, total_line*4//5) 257 | cut = random.random() < cut_ratio 258 | if cut: 259 | if fix_cut_pos: 260 | cut_pos = len(self.index_to_code)//2 261 | else: 262 | cut_pos = random.randint(len(self.index_to_code)//3, len(self.index_to_code)*2//3) 263 | for ip, pos in enumerate(self.index_to_code): 264 | sp = pos[0] 265 | ep = pos[1] 266 | if cut and ip >= cut_pos: 267 | break 268 | if replaced and pos in self.index_to_new_code: 269 | add_token = self.index_to_new_code[pos] 270 | else: 271 | add_token = self.index_to_code[pos][1] 272 | if prev_sp is None or (sp[0] == prev_ep[0] and sp[1] == prev_ep[1]): 273 | code_string += add_token 274 | elif sp[0] == prev_ep[0]: 275 | if code_string[-1] != " ": 276 | code_string += " " 277 | code_string += add_token 278 | else: 279 | # if cut and cut_line >= 1 and cut_line <= prev_ep[0]: 280 | # break 281 | if replaced and deadcode: 282 | if insert_line <= prev_ep[0]: 283 | if sp[1] <= prev_indent: 284 | code_string += "\n" + deadcode 285 | insert_line = total_line+2 286 | if indent and add_token: 287 | code_string += "\n" 288 | omit = False 289 | if sp[1] != prev_indent and prev_indent == 0 and indent_size == -1: 290 | indent_size = sp[1] - prev_indent 291 | if sp[1] - prev_indent > 0: 292 | if sp[1] - prev_indent > 2 * indent_size: 293 | omit = True 294 | else: 295 | for i in range(prev_indent, sp[1], indent_size): 296 | code_string += "" 297 | elif sp[1] - prev_indent < 0: 298 | for i in range(sp[1], prev_indent, indent_size): 299 | code_string += "" 300 | code_string += add_token 301 | if not omit: 302 | prev_indent = sp[1] 303 | else: 304 | code_string += "\n" 305 | for i in range(sp[1]): 306 | code_string += " " 307 | code_string += add_token 308 | prev_sp, prev_ep = sp, ep 309 | return re.sub(re.compile("\s*\n"), "\n", code_string.lstrip()).replace("\n", "") 310 | 311 | def convert_to_normal(self, code): 312 | lines = code.split("") 313 | indent_size = 4 314 | indent = 0 315 | res = "" 316 | for line in lines: 317 | indent += line.count("") 318 | indent -= line.count("") 319 | res += "\n" + " "*indent_size*indent + line.replace("", "").replace("", "") 320 | return res 321 | 322 | def extract_dataflow(self): 323 | try: 324 | root_node = self.tree.root_node 325 | try: 326 | DFG, _ = self.parser[1](root_node, self.index_to_code, {}) 327 | except Exception: 328 | DFG = [] 329 | DFG = sorted(DFG, key=lambda x: x[1]) 330 | indexs = set() 331 | for d in DFG: 332 | if len(d[-1]) != 0: 333 | indexs.add(d[1]) 334 | for x in d[-1]: 335 | indexs.add(x) 336 | new_DFG = [] 337 | for d in DFG: 338 | if d[1] in indexs: 339 | new_DFG.append(d) 340 | dfg = new_DFG 341 | except Exception: 342 | dfg = [] 343 | return dfg 344 | 345 | 346 | -------------------------------------------------------------------------------- /generate/run_lm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import argparse 6 | from functools import partial 7 | import glob 8 | import logging 9 | import os 10 | import pickle 11 | import random 12 | import re 13 | import shutil 14 | import json 15 | 16 | import numpy as np 17 | import torch 18 | from torch.nn.utils.rnn import pad_sequence 19 | from torch.nn import CrossEntropyLoss 20 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset 21 | from torch.utils.data.distributed import DistributedSampler 22 | from dataset import RelineDataset, PPLDataset 23 | from beam import Beam 24 | 25 | from fuzzywuzzy import fuzz 26 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 27 | BertConfig, BertForMaskedLM, BertTokenizer, 28 | GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, 29 | OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, 30 | RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, 31 | DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | MODEL_CLASSES = { 36 | 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), 37 | 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 38 | 'bert': (BertConfig, BertForMaskedLM, BertTokenizer), 39 | 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer), 40 | 'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) 41 | } 42 | 43 | def set_seed(args): 44 | random.seed(args.seed) 45 | np.random.seed(args.seed) 46 | torch.manual_seed(args.seed) 47 | if args.n_gpu > 0: 48 | torch.cuda.manual_seed_all(args.seed) 49 | 50 | def update_config(args, config): 51 | # config.n_positions = config.n_ctx = args.block_size 52 | config.vocab_size = args.vocab_size 53 | 54 | def get_special_tokens(path): 55 | lits = json.load(open(path)) 56 | tokens = ["", "", ""] 57 | for lit in lits["str"]: 58 | tokens.append(f"") 59 | for lit in lits["num"]: 60 | tokens.append(f"") 61 | for lit in lits["char"]: 62 | tokens.append(f"") 63 | return tokens 64 | 65 | def my_collect_fn(sequences, batch_first=True, padding_value=1): 66 | inputs1 = [] 67 | inputs2 = [] 68 | for (x1, x2) in sequences: 69 | inputs1.append(x1) 70 | inputs2.append(x2) 71 | return ( 72 | pad_sequence(inputs1, batch_first, padding_value), 73 | pad_sequence(inputs2, batch_first, 0), 74 | ) 75 | 76 | def eval_ppl(args, model, tokenizer, file_type='test', load_file="train", res_file="dense.pkl"): 77 | model.to(args.device) 78 | if load_file is None: 79 | dataset = PPLDataset(tokenizer, args, logger, file_type=file_type, block_size=args.block_size) 80 | else: 81 | dataset = PPLDataset( 82 | tokenizer, args, logger, file_type=file_type, block_size=args.block_size, 83 | load_file=load_file, 84 | search_res=os.path.join(args.data_dir, "search_results", res_file), 85 | ) 86 | 87 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 88 | # Note that DistributedSampler samples randomly 89 | eval_sampler = SequentialSampler(dataset) 90 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, collate_fn=partial(my_collect_fn, batch_first=True, padding_value=tokenizer.pad_token_id), batch_size=args.eval_batch_size) 91 | 92 | # multi-gpu evaluate 93 | if args.n_gpu > 1: 94 | model = torch.nn.DataParallel(model) 95 | 96 | logger.info("***** Running test *****") 97 | logger.info(" Num examples = %d", len(dataset)) 98 | logger.info(" Batch size = %d", args.eval_batch_size) 99 | eval_loss = 0.0 100 | num_tok = 0 101 | nb_eval_steps = 0 102 | model.eval() 103 | 104 | for step, (batch, token_labels) in enumerate(eval_dataloader): 105 | 106 | inputs = batch.to(args.device) 107 | attn_mask = torch.tensor(token_labels.clone().detach() != 0, dtype=torch.uint8, device=args.device) 108 | loss_mask = torch.tensor(token_labels.clone().detach() == 2, dtype=torch.uint8, device=args.device) 109 | with torch.no_grad(): 110 | outputs = model(inputs, attention_mask=attn_mask) 111 | logits = outputs[0] 112 | labels = inputs 113 | shift_logits = logits[..., :-1, :].contiguous() 114 | shift_labels = labels[..., 1:].contiguous() 115 | # Flatten the tokens 116 | loss_fct = CrossEntropyLoss() 117 | flatten_shift_loss_mask = loss_mask[..., :-1].contiguous().view(-1) 118 | ids = torch.nonzero(flatten_shift_loss_mask).view(-1) 119 | all_labels = shift_labels.view(-1)[ids] 120 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[ids], all_labels) 121 | eval_loss += loss.item()*all_labels.shape[0] 122 | num_tok += all_labels.shape[0] 123 | 124 | if step % args.logging_steps == 0: 125 | logger.info(f"Test steps: {step}") 126 | nb_eval_steps += 1 127 | 128 | eval_loss = eval_loss / num_tok 129 | perplexity = torch.exp(torch.tensor(eval_loss)) 130 | 131 | result = { 132 | "perplexity": float(perplexity) 133 | } 134 | 135 | for key in sorted(result.keys()): 136 | logger.info(" %s = %s", key, str(result[key])) 137 | 138 | def eval_line_completion(args, model, tokenizer, file_type='test', load_file=None, res_file=None, save_name=None): 139 | """ 140 | Evaluate line level code completion on exact match and edit similarity. 141 | 142 | It is recommanded to use single GPU because it could not be batched. 143 | """ 144 | 145 | def DecodeIds(idxs): 146 | codes = "" 147 | for idx in idxs: 148 | to_add = tokenizer.convert_ids_to_tokens(idx) 149 | if tokenizer.convert_ids_to_tokens(idx)[0] == '\u0120': 150 | if not codes.endswith(" "): 151 | codes += " " + to_add[1:] 152 | else: 153 | codes += to_add[1:] 154 | elif ( 155 | idx in [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id] or 156 | tokenizer.convert_ids_to_tokens(idx).startswith("