├── parser
├── build.sh
├── __init__.py
├── build.py
├── utils.py
└── DFG.py
├── CODE_OF_CONDUCT.md
├── LICENSE
├── utils
├── split_codes.py
├── get_res.py
├── search_bm25.py
└── search_dense.py
├── SECURITY.md
├── model.py
├── generate
├── beam.py
├── dataset.py
└── run_lm.py
├── README.md
├── infer.py
├── process_java.py
├── codenet_test.py
└── process_python.py
/parser/build.sh:
--------------------------------------------------------------------------------
1 | git clone https://github.com/tree-sitter/tree-sitter-python
2 | git clone https://github.com/tree-sitter/tree-sitter-java
3 | python build.py
4 |
--------------------------------------------------------------------------------
/parser/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import (remove_comments_and_docstrings,
2 | tree_to_token_index,
3 | index_to_code_token,
4 | tree_to_variable_index,
5 | traverse
6 | )
7 | from .DFG import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp
--------------------------------------------------------------------------------
/parser/build.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from tree_sitter import Language, Parser
5 |
6 | Language.build_library(
7 | # Store the library in the `build` directory
8 | 'my-languages.so',
9 |
10 | # Include one or more languages
11 | [
12 | 'tree-sitter-python',
13 | 'tree-sitter-java'
14 | ]
15 | )
16 |
17 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Shuai Lu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils/split_codes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # This script is used for splitting long codes into small chunks with the same length (300 by default) in search base.
5 |
6 | import os
7 | import pickle
8 | import argparse
9 |
10 | def main():
11 | parser = argparse.ArgumentParser(description='')
12 | parser.add_argument('--file_name', '-i', required=True, help="filename of the codes, should be in txt format")
13 | parser.add_argument('--length', '-l', type=int, default=300, help="length of the chunk")
14 | args = parser.parse_args()
15 |
16 | lines = open(args.file_name, "r").readlines()
17 | wf = open(args.file_name.split("/")[-1].split(".")[0]+"_split.txt", "w")
18 | nexts = []
19 | cnt = 0
20 | for line in lines:
21 | tokens = line.strip().split()
22 | if len(tokens) <= args.length:
23 | wf.write(" ".join(tokens)+"\n")
24 | nexts.append(cnt)
25 | cnt += 1
26 | else:
27 | for i in range(0, len(tokens), args.length):
28 | wf.write(" ".join(tokens[i:i+args.length])+"\n")
29 | nexts.append(cnt+1)
30 | cnt += 1
31 | nexts[-1] -= 1
32 | wf.close()
33 | pickle.dump(nexts, open(args.file_name.split("/")[-1].split(".")[0]+"_split_nexts.pkl", "wb"))
34 |
35 | if __name__ == "__main__":
36 | main()
37 |
38 |
39 |
--------------------------------------------------------------------------------
/utils/get_res.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 | import os
4 | import pickle
5 | import argparse
6 | import json
7 | import numpy as np
8 | from tqdm import tqdm
9 | import random
10 |
11 | def hybrid_scores(bm25_scores, dense_scores, alpha, beilv=100):
12 | # beilv: re-scaling dense score as it is percentage.
13 | scores = {}
14 | for idx, v in tqdm(dense_scores.items()):
15 | new_v = {}
16 | if idx not in bm25_scores:
17 | scores[idx] = v
18 | continue
19 | v2 = bm25_scores[idx]
20 | v_min = min(list(v.values()))
21 | v2_min = min(list(v2.values()))
22 | for _id, score in v.items():
23 | if _id not in v2:
24 | new_v[_id] = beilv * score + alpha * v2_min
25 | else:
26 | new_v[_id] = beilv * score + alpha * v2[_id]
27 | for _id, score in v2.items():
28 | if _id not in new_v:
29 | new_v[_id] = alpha * score + beilv * v_min
30 | scores[idx] = new_v
31 | return scores
32 |
33 | def get_res(bm25_file, dense_file, save_file, alpha):
34 | if bm25_file != "":
35 | bm25_scores = pickle.load(open(bm25_file, "rb"))
36 | print("bm25 scores loaded")
37 | else:
38 | bm25_scores = {}
39 | if dense_file != "":
40 | dense_scores = pickle.load(open(dense_file, "rb"))
41 | print("dense scores loaded")
42 | else:
43 | dense_scores = {}
44 |
45 | res = {}
46 | if len(bm25_scores) > 0 and len(dense_scores) > 0:
47 | scores = hybrid_scores(bm25_scores, dense_scores, alpha, 100)
48 | elif len(bm25_scores) > 0:
49 | scores = bm25_scores
50 | else:
51 | scores = dense_scores
52 | for idx, v in tqdm(scores.items()):
53 | v = sorted(v.items(), key=lambda x:-x[1])
54 | # res[int(idx)] = int(v[0][0]) if v[0][0] != idx else int(v[1][0])
55 | res[int(idx)] = int(v[0][0])
56 |
57 | pickle.dump(res, open(save_file, "wb"))
58 |
59 |
60 | def main():
61 | parser = argparse.ArgumentParser(description='')
62 | parser.add_argument('--save_name', '-o', required=True, help="save file name")
63 | parser.add_argument('--bm25_res', '-b', default="", help="bm25 result file")
64 | parser.add_argument('--dense_res', '-d', default="", help="dense result file")
65 | parser.add_argument("--alpha", type=float, default=1.1, help="ratio of dense score")
66 | args = parser.parse_args()
67 |
68 | get_res(args.bm25_res, args.dense_res, args.save_name, args.alpha)
69 |
70 |
71 | if __name__ == "__main__":
72 | main()
73 |
74 |
75 |
--------------------------------------------------------------------------------
/utils/search_bm25.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 | import os
4 | import argparse
5 | import json
6 | from tqdm import tqdm
7 | import pickle
8 | import random
9 | import csv
10 | import time
11 |
12 | from beir.datasets.data_loader import GenericDataLoader
13 | from beir.retrieval.evaluation import EvaluateRetrieval
14 | from beir.retrieval.search.lexical import BM25Search as BM25
15 |
16 | def build(corpus_file, query_file, temp_dir):
17 | # corpus and query file have to be plain text files
18 | print("Building bm25 corpus")
19 | corpus = []
20 | queries = []
21 | ids = []
22 | datas = open(corpus_file).readlines()
23 | lines = open(query_file).readlines()
24 | try:
25 | os.mkdir(temp_dir)
26 | except FileExistsError:
27 | pass
28 | fidx = open(os.path.join(temp_dir, "corpus.jsonl"), "w")
29 | fq = open(os.path.join(temp_dir, "query.jsonl"), "w")
30 | fr = open(os.path.join(temp_dir, "res.tsv"), "w")
31 | csv_fr = csv.writer(fr, delimiter='\t')
32 | fr.write("q\td\t\s\n")
33 | for i,line in enumerate(tqdm(datas)):
34 | fidx.write(json.dumps({"_id":str(i), "text":line.strip()})+"\n")
35 | for i,line in enumerate(tqdm(lines)):
36 | content = json.loads(line)
37 | idx = content["id"]
38 | csv_fr.writerow([str(idx), str(idx), 1])
39 | code = content["input"].strip()
40 | fq.write(json.dumps({"_id":str(idx), "text":code})+"\n")
41 |
42 |
43 | def main():
44 | parser = argparse.ArgumentParser(description='')
45 | parser.add_argument('--search_corpus', '-i', required=True, help="search corpus file, plain text file")
46 | parser.add_argument('--query_file', '-q', required=True, help="queries file, json file")
47 | parser.add_argument('--save_name', '-o', required=True, help="same file name")
48 | parser.add_argument('--temp_path', '-t', default="beir", help="temp dir to save beir-format data")
49 | args = parser.parse_args()
50 |
51 | build(args.search_corpus, args.query_file, args.temp_path)
52 | time.sleep(10)
53 |
54 | corpus, queries, qrels = GenericDataLoader(
55 | corpus_file=os.path.join(args.temp_path, "corpus.jsonl"),
56 | query_file=os.path.join(args.temp_path, "query.jsonl"),
57 | qrels_file=os.path.join(args.temp_path, "res.tsv")
58 | ).load_custom()
59 |
60 | model = BM25(index_name="reacc", hostname="127.0.0.1:9200", initialize=True)
61 | retriever = EvaluateRetrieval(model)
62 |
63 | results = retriever.retrieve(corpus, queries)
64 | pickle.dump(results, open(args.save_name, "wb"))
65 |
66 | if __name__ == "__main__":
67 | main()
68 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/utils/search_dense.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 | import os
4 | import json
5 | import pickle
6 | import argparse
7 | import numpy as np
8 | from tqdm import tqdm
9 | import faiss
10 |
11 |
12 | def search(index_file, query_file, save_name):
13 | index_data = pickle.load(open(index_file, "rb"))
14 | query_data = pickle.load(open(query_file, "rb"))
15 | ids = []
16 | indexs = []
17 | id2n = {}
18 | for i, (idx, vec) in enumerate(index_data.items()):
19 | ids.append(idx)
20 | indexs.append(vec)
21 | id2n[idx] = i
22 | queries = []
23 | idxq = []
24 | for idx, vec in query_data.items():
25 | queries.append(vec)
26 | idxq.append(idx)
27 | ids = np.array(ids)
28 | indexs = np.array(indexs)
29 | queries = np.array(queries)
30 |
31 | # build faiss index
32 | d = 768
33 | k = 101
34 | index = faiss.IndexFlatIP(d)
35 | assert index.is_trained
36 |
37 | index_id = faiss.IndexIDMap(index)
38 | index_id.add_with_ids(indexs, ids)
39 |
40 | res = {}
41 | D, I = index_id.search(queries, k)
42 | for i, (sd, si) in enumerate(zip(D, I)):
43 | res[str(idxq[i])] = {}
44 | for pd, pi in zip(sd, si):
45 | res[str(idxq[i])][str(pi)] = pd
46 | # if pi != idxq[i]:
47 | # res[str(idxq[i])][str(pi)] = pd
48 |
49 | pickle.dump(res, open(save_name, "wb"))
50 |
51 | def search_multi(index_file, query_file, save_name):
52 | index_data = pickle.load(open(index_file, "rb"))
53 | query_data = pickle.load(open(query_file, "rb"))
54 | ids = []
55 | indexs = []
56 | for i, (idx, vecs) in enumerate(index_data.items()):
57 | for vec in vecs:
58 | ids.append(idx)
59 | indexs.append(vec)
60 | queries = []
61 | idxq = []
62 | for idx, vec in query_data.items():
63 | queries.append(vec[0])
64 | idxq.append(idx)
65 | ids = np.array(ids)
66 | indexs = np.array(indexs)
67 | queries = np.array(queries)
68 | print(indexs.shape, queries.shape)
69 |
70 | # build faiss index
71 | d = 768
72 | k = 100
73 | index = faiss.IndexFlatIP(d)
74 | assert index.is_trained
75 |
76 | index_id = faiss.IndexIDMap(index)
77 | index_id.add_with_ids(indexs, ids)
78 |
79 | res = faiss.StandardGpuResources()
80 | gpu_index = faiss.index_cpu_to_gpu(res, 0, index_id)
81 |
82 | res = {}
83 | D, I = gpu_index.search(queries, k)
84 | for i, (sd, si) in enumerate(zip(D, I)):
85 | res[str(idxq[i])] = {}
86 | for pd, pi in zip(sd, si):
87 | if str(pi) not in res[str(idxq[i])]:
88 | res[str(idxq[i])][str(pi)] = pd
89 | if len(res[str(idxq[i])]) > 100:
90 | break
91 |
92 | pickle.dump(res, open(save_name, "wb"))
93 |
94 | def main():
95 | parser = argparse.ArgumentParser(description='')
96 | parser.add_argument('--index_file', '-i', required=True, help="filename of index embeddings saved")
97 | parser.add_argument('--query_file', '-q', required=True, help="file containing query embeddings")
98 | parser.add_argument('--save_name', '-o', required=True, help="save file name")
99 | parser.add_argument("--multi", action='store_true', help="set true if one query/doc has multi embeddings")
100 | args = parser.parse_args()
101 |
102 | if args.multi:
103 | search_multi(args.index_file, args.query_file, args.save_name)
104 | else:
105 | search(args.index_file, args.query_file, args.save_name)
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
110 |
--------------------------------------------------------------------------------
/parser/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | import re
5 | from io import StringIO
6 | import tokenize
7 |
8 | def traverse(node, indent=0):
9 | print(" "*indent + node.type, node.start_point, "-", node.end_point)
10 | for child in node.children:
11 | traverse(child, indent+1)
12 |
13 | def remove_comments_and_docstrings(source,lang):
14 | if lang in ['python']:
15 | """
16 | Returns 'source' minus comments and docstrings.
17 | """
18 | io_obj = StringIO(source)
19 | out = ""
20 | prev_toktype = tokenize.INDENT
21 | last_lineno = -1
22 | last_col = 0
23 | for tok in tokenize.generate_tokens(io_obj.readline):
24 | token_type = tok[0]
25 | token_string = tok[1]
26 | start_line, start_col = tok[2]
27 | end_line, end_col = tok[3]
28 | ltext = tok[4]
29 | if start_line > last_lineno:
30 | last_col = 0
31 | if start_col > last_col:
32 | out += (" " * (start_col - last_col))
33 | # Remove comments:
34 | if token_type == tokenize.COMMENT:
35 | pass
36 | # This series of conditionals removes docstrings:
37 | elif token_type == tokenize.STRING:
38 | if prev_toktype != tokenize.INDENT:
39 | # This is likely a docstring; double-check we're not inside an operator:
40 | if prev_toktype != tokenize.NEWLINE:
41 | if start_col > 0:
42 | out += token_string
43 | else:
44 | out += token_string
45 | prev_toktype = token_type
46 | last_col = end_col
47 | last_lineno = end_line
48 | temp=[]
49 | for x in out.split('\n'):
50 | if x.strip()!="":
51 | temp.append(x)
52 | return '\n'.join(temp)
53 | elif lang in ['ruby']:
54 | return source
55 | else:
56 | def replacer(match):
57 | s = match.group(0)
58 | if s.startswith('/'):
59 | return " " # note: a space and not an empty string
60 | else:
61 | return s
62 | pattern = re.compile(
63 | r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"',
64 | re.DOTALL | re.MULTILINE
65 | )
66 | temp=[]
67 | for x in re.sub(pattern, replacer, source).split('\n'):
68 | if x.strip()!="":
69 | temp.append(x)
70 | return '\n'.join(temp)
71 |
72 | def tree_to_token_index(root_node):
73 | if (len(root_node.children)==0 or root_node.type=='string'):
74 | return [(root_node.start_point,root_node.end_point)]
75 | else:
76 | code_tokens=[]
77 | for child in root_node.children:
78 | code_tokens+=tree_to_token_index(child)
79 | return code_tokens
80 |
81 | def tree_to_variable_index(root_node,index_to_code):
82 | if (len(root_node.children)==0 or root_node.type=='string'):
83 | index=(root_node.start_point,root_node.end_point)
84 | _,code=index_to_code[index]
85 | if root_node.type!=code:
86 | return [(root_node.start_point,root_node.end_point)]
87 | else:
88 | return []
89 | else:
90 | code_tokens=[]
91 | for child in root_node.children:
92 | code_tokens+=tree_to_variable_index(child,index_to_code)
93 | return code_tokens
94 |
95 | def index_to_code_token(index,code):
96 | start_point=index[0]
97 | end_point=index[1]
98 | if start_point[0]==end_point[0]:
99 | s=code[start_point[0]][start_point[1]:end_point[1]]
100 | else:
101 | s=""
102 | s+=code[start_point[0]][start_point[1]:]
103 | for i in range(start_point[0]+1,end_point[0]):
104 | s+=code[i]
105 | s+=code[end_point[0]][:end_point[1]]
106 | return s
107 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 | import os
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from transformers import RobertaModel
8 | import logging
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
13 | datefmt='%m/%d/%Y %H:%M:%S',
14 | level=logging.INFO)
15 |
16 | class RobertaLMHead(nn.Module):
17 | """Roberta Head for masked language modeling."""
18 |
19 | def __init__(self, config):
20 | super().__init__()
21 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
22 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
23 |
24 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
25 | self.bias = nn.Parameter(torch.zeros(config.vocab_size))
26 |
27 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
28 | self.decoder.bias = self.bias
29 |
30 | def forward(self, features, **kwargs):
31 | x = self.dense(features)
32 | x = F.gelu(x)
33 | x = self.layer_norm(x)
34 |
35 | # project back to size of vocabulary with bias
36 | x = self.decoder(x)
37 |
38 | return x
39 |
40 | class SimpleModel(nn.Module):
41 | def __init__(self, config, args):
42 | super().__init__()
43 | if args.from_pretrained:
44 | self.encoder = RobertaModel.from_pretrained(args.pretrained_dir, add_pooling_layer=False)
45 | logger.warning(f"Loading encoder from {args.pretrained_dir}")
46 | else:
47 | self.encoder = RobertaModel(config, add_pooling_layer=False)
48 | self.encoder.resize_token_embeddings(args.vocab_size)
49 |
50 | self.config = config
51 | self.args = args
52 | self.lm_head = RobertaLMHead(config)
53 | self.n_vec = max(0, self.args.num_vec)
54 | self.tie_weights()
55 |
56 | def _tie_or_clone_weights(self, first_module, second_module):
57 | if self.config.torchscript:
58 | first_module.weight = nn.Parameter(second_module.weight.clone())
59 | else:
60 | first_module.weight = second_module.weight
61 |
62 | def tie_weights(self):
63 | self._tie_or_clone_weights(self.lm_head.decoder,
64 | self.encoder.embeddings.word_embeddings)
65 |
66 | def forward(self, inputs_m, inputs1=None, inputs2=None, attn_mask=None, attn_mask1=None, attn_mask2=None, mlm_labels=None):
67 | outputs = self.encoder(inputs_m, attention_mask=attn_mask)[0]
68 |
69 | if inputs1 is None and inputs2 is None:
70 | # infer
71 | if self.n_vec > 0:
72 | outputs = nn.functional.normalize(outputs[:, :self.n_vec, :], dim=2)
73 | else:
74 | outputs = nn.functional.normalize(outputs[:, 0, :], dim=1)
75 | return outputs
76 |
77 | # training
78 | outputs1 = self.encoder(inputs1, attention_mask=attn_mask1)[0][:, 0, :]
79 | outputs2 = self.encoder(inputs2, attention_mask=attn_mask2)[0]
80 |
81 | lm_logits = self.lm_head(outputs)
82 |
83 | if mlm_labels is not None:
84 | loss_fct = nn.CrossEntropyLoss()
85 | mlm_loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), mlm_labels.view(-1))
86 | else:
87 | mlm_loss = None
88 |
89 | if self.n_vec > 0:
90 | o1 = nn.functional.normalize(outputs1, dim=1)
91 | o2 = nn.functional.normalize(outputs2[:, :self.n_vec, :], dim=2)
92 | logits, _ = torch.max(torch.einsum('nc,mvc->nmv', [o1, o2]), -1)
93 | else:
94 | o1 = nn.functional.normalize(outputs1, dim=1)
95 | o2 = nn.functional.normalize(outputs2[:, 0, :], dim=1)
96 | logits = torch.einsum('nc,mc->nm', [o1, o2])
97 | logits /= self.args.moco_T
98 | labels = torch.arange(end=logits.shape[0], dtype=torch.long).cuda()
99 |
100 | nce_loss = F.cross_entropy(logits, labels)
101 |
102 | return lm_logits, mlm_loss, nce_loss
--------------------------------------------------------------------------------
/generate/beam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch
4 | from torch.autograd import Variable
5 | import copy
6 |
7 | class Beam(object):
8 | def __init__(self, size, sos, eos):
9 | self.size = size
10 | self.sent_size = 2*size
11 | self.tt = torch.cuda
12 | # The score for each translation on the beam.
13 | self.scores = self.tt.FloatTensor(size).zero_()
14 | # The backpointers at each time-step.
15 | self.prevKs = []
16 | # The outputs at each time-step.
17 | self.nextYs = [self.tt.LongTensor(size)
18 | .fill_(0)]
19 | self.nextYs[0][:] = sos
20 | # Has EOS topped the beam yet.
21 | self._eos = eos
22 | self.eosTop = False
23 | # Time and k pair for finished.
24 | self.finished = []
25 |
26 | def getCurrentState(self):
27 | "Get the outputs for the current timestep."
28 | batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)
29 | return batch
30 |
31 | def getCurrentOrigin(self):
32 | "Get the backpointers for the current timestep."
33 | return self.prevKs[-1]
34 |
35 | def advance(self, wordLk):
36 | """
37 | Given prob over words for every last beam `wordLk` and attention
38 | `attnOut`: Compute and update the beam search.
39 |
40 | Parameters:
41 |
42 | * `wordLk`- probs of advancing from the last step (K x words)
43 | * `attnOut`- attention at the last step
44 |
45 | Returns: True if beam search is complete.
46 | """
47 | numWords = wordLk.size(1)
48 |
49 | # Sum the previous scores.
50 | if len(self.prevKs) > 0:
51 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
52 |
53 | # Don't let EOS have children.
54 | for i in range(self.nextYs[-1].size(0)):
55 | if self.nextYs[-1][i] in self._eos:
56 | beamLk[i] = -1e20
57 | else:
58 | beamLk = wordLk[0]
59 | flatBeamLk = beamLk.view(-1)
60 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
61 |
62 | self.scores = bestScores
63 |
64 | # bestScoresId is flattened beam x word array, so calculate which
65 | # word and beam each score came from
66 | prevK = bestScoresId // numWords
67 | self.prevKs.append(prevK)
68 | self.nextYs.append((bestScoresId - prevK * numWords))
69 |
70 |
71 | for i in range(self.nextYs[-1].size(0)):
72 | if self.nextYs[-1][i] in self._eos:
73 | s = self.scores[i]
74 | self.finished.append((s, len(self.nextYs) - 1, i))
75 |
76 | # End condition is when top-of-beam is EOS and no global score.
77 | if self.nextYs[-1][0] in self._eos:
78 | self.eosTop = True
79 |
80 | def done(self):
81 | return self.eosTop and len(self.finished) >=self.sent_size
82 |
83 | def getFinal(self):
84 | if len(self.finished) == 0:
85 | self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
86 | self.finished.sort(key=lambda a: -a[0])
87 | if len(self.finished) != self.sent_size:
88 | unfinished=[]
89 | for i in range(self.nextYs[-1].size(0)):
90 | if self.nextYs[-1][i] not in self._eos:
91 | s = self.scores[i]
92 | unfinished.append((s, len(self.nextYs) - 1, i))
93 | unfinished.sort(key=lambda a: -a[0])
94 | self.finished+=unfinished[:self.sent_size-len(self.finished)]
95 | return self.finished[:self.sent_size]
96 |
97 | def getHyp(self, beam_res):
98 | """
99 | Walk back to construct the full hypothesis.
100 | """
101 | hyps=[]
102 | for _,timestep, k in beam_res:
103 | hyp = []
104 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
105 | hyp.append(self.nextYs[j+1][k])
106 | k = self.prevKs[j][k]
107 | hyps.append(hyp[::-1])
108 | return hyps
109 |
110 | def buildTargetTokens(self, preds):
111 | sentence=[]
112 | for pred in preds:
113 | tokens = []
114 | for tok in pred:
115 | tokens.append(tok)
116 | if tok in self._eos:
117 | break
118 | sentence.append(tokens)
119 | return sentence
120 |
--------------------------------------------------------------------------------
/generate/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 | from __future__ import absolute_import, division, print_function
4 |
5 | import argparse
6 | import glob
7 | import logging
8 | import os
9 | import pickle
10 | import random
11 | import re
12 | import gc
13 | import shutil
14 | import json
15 | from tqdm import tqdm
16 |
17 | import numpy as np
18 | import torch
19 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset
20 | from torch.utils.data.distributed import DistributedSampler
21 |
22 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
23 | BertConfig, BertForMaskedLM, BertTokenizer,
24 | GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
25 | OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
26 | RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
27 | DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
28 |
29 |
30 | class RelineDataset(Dataset):
31 | def __init__(self, tokenizer, args, logger, file_type='test', block_size=924, load_file=None, search_res=None):
32 | datafile = os.path.join(args.data_dir, f"{file_type}.json")
33 | with open(datafile) as f:
34 | datas = f.readlines()
35 |
36 | if load_file is not None:
37 | id2code = {}
38 | lines = open(os.path.join(args.data_dir, load_file+".txt")).readlines()
39 | for i,line in enumerate(tqdm(lines)):
40 | id2code[i] = line.strip()
41 |
42 | search_results = pickle.load(open(search_res, "rb"))
43 | try:
44 | nexts = pickle.load(open(os.path.join(args.data_dir, load_file+"_nexts.pkl"), "rb"))
45 | except Exception:
46 | nexts = [i for i in range(len(lines))]
47 |
48 | length = len(datas)
49 | logger.info("Data size: %d"%(length))
50 | self.inputs = []
51 | self.gts = []
52 | for i,data in enumerate(datas):
53 | if i % 1000 == 0:
54 | logger.info(f"Encoded {i}/{length} data")
55 | data = json.loads(data.strip())
56 | if load_file is not None:
57 | try:
58 | cand_id = search_results[data["id"]]
59 | cand = id2code[cand_id]
60 | if nexts[cand_id] != cand_id:
61 | cand += id2code[nexts[cand_id]]
62 | cand = tokenizer.encode(cand)
63 | except:
64 | cand = []
65 | else:
66 | cand = []
67 | self.inputs.append((cand + tokenizer.encode(data["input"]))[-block_size:])
68 | self.gts.append(data["gt"])
69 |
70 | def __len__(self):
71 | return len(self.inputs)
72 |
73 | def __getitem__(self, item):
74 | return torch.tensor(self.inputs[item]), self.gts[item]
75 |
76 | class PPLDataset(Dataset):
77 | def __init__(self, tokenizer, args, logger, file_type='test', block_size=1024, load_file=None, search_res=None):
78 | datafile = os.path.join(args.data_dir, f"{file_type}.txt")
79 | with open(datafile) as f:
80 | datas = f.readlines()
81 |
82 | if load_file is not None:
83 | id2code = {}
84 | lines = open(os.path.join(args.data_dir, load_file+".txt")).readlines()
85 | for i,line in enumerate(tqdm(lines)):
86 | id2code[i] = line.strip()
87 |
88 | search_results = pickle.load(open(search_res, "rb"))
89 | try:
90 | nexts = pickle.load(open(os.path.join(args.data_dir, load_file+"_nexts.pkl"), "rb"))
91 | except Exception:
92 | nexts = [i for i in range(len(lines))]
93 |
94 | length = len(datas)
95 | logger.info("Data size: %d"%(length))
96 | self.inputs = []
97 | self.token_labels = []
98 | for i,data in enumerate(tqdm(datas)):
99 | if i % 1000 == 0:
100 | logger.info(f"Encoded {i}/{length} data")
101 | tokens = data.strip().split(" ")
102 | if len(tokens) < 200:
103 | cut = len(tokens)//2
104 | else:
105 | cut = 100
106 |
107 | if load_file is not None:
108 | try:
109 | if i in search_results:
110 | cand_id = search_results[i]
111 | else:
112 | cand_id = search_results[str(i)]
113 | cand = id2code[cand_id]
114 | if nexts[cand_id] != cand_id:
115 | cand += id2code[nexts[cand_id]]
116 | cand = tokenizer.encode(cand)
117 | except:
118 | # print("OK")
119 | cand = []
120 | else:
121 | cand = []
122 |
123 | x1 = tokenizer.encode(" ".join(tokens[:cut]))[-block_size:]
124 | self.inputs.append(x1)
125 | self.token_labels.append([2]*len(x1))
126 |
127 | pre_id = cand + tokenizer.encode(" ".join(tokens[:cut]))
128 | x2_0 = tokenizer.encode(" ".join(tokens[cut:]))[:block_size * 3 // 4]
129 | x2 = (pre_id + x2_0)[-block_size:]
130 | self.inputs.append(x2)
131 | self.token_labels.append([1]*(len(x2)-len(x2_0)) + [2]*len(x2_0))
132 |
133 | def __len__(self):
134 | return len(self.inputs)
135 |
136 | def __getitem__(self, item):
137 | return torch.tensor(self.inputs[item]), torch.tensor(self.token_labels[item])
138 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ReACC
2 |
3 | Source codes for ACL 2022 paper "[ReACC: A Retrieval-Augmented Code Completion Framework](https://arxiv.org/abs/2203.07722)"
4 | ReACC combines a source code retiever and an auto-regresstive language model for programming languages.
5 |
6 | ## Dependency
7 |
8 | - pytorch >= 1.7.0
9 | - transformers >= 4.10.0
10 | - tree_sitter
11 | - faiss-gpu
12 | - beir (for BM25)
13 | - elastic search
14 |
15 | ## Instructions
16 |
17 | Here are the instructions to apply ReACC framework on the code completion task with PY150 dataset.
18 |
19 | ### 1. Pretrain a retriever
20 |
21 | Leverage [`microsoft/reacc-py-retriever`](https://huggingface.co/microsoft/reacc-py-retriever) as a code-to-code retriever for python source codes.
22 |
23 | ### 2. Build an index for search
24 |
25 | First, you have to prepare a codebase for retrieving. It is recommended to split each file/function into small chunks. (refer to `utils/split_codes.py`). Then run the command to get representations of all the codes in search corpus.
26 |
27 | ```bash
28 | python -m torch.distributed.launch --nproc_per_node=${PER_NODE_GPU} infer.py \
29 | --data_path=data/train_split.txt \
30 | --save_name=save_vec \
31 | --lang=python \
32 | --pretrained_dir=microsoft/reacc-py-retriever \
33 | --num_vec=8 \
34 | --block_size=512 \
35 | --gpu_per_node ${PER_NODE_GPU} \
36 | --logging_steps=100
37 | ```
38 |
39 | You can modify the `InferDataset` in `infer.py` to fit your own dataset. Our dataset is formated as a jsonl file, where each line is like
40 | ```json
41 | {
42 | "code": "def function()",
43 | "id": 0
44 | }
45 | ```
46 | or a plain text file, in which each line is a code snippet.
47 |
48 | ### 3. Retrieve step
49 |
50 | ReACC is a two-stage framework. The first stage is to retrieve the similar codes given a query. As the test set is fixed, we retrieve all the similar codes of the queries in test set in advance. **It would be better to merge step 3 into step 4.**
51 |
52 | First, get the representations of test queries like in step 2. Then run the script `utils/search_dense.py` to sort the similarity and get the most similar codes.
53 |
54 | If you would like to use BM25 algorithm to retrieve similar codes, run the script `utils/search_bm25.py`.
55 |
56 | At last, run `utils/get_res.py` to get the most similar code based on bm25 results, or dense retrieval results, or both.
57 |
58 | ### 4. Generation step
59 |
60 | Please download PY150 dataset first and use preprocess scripts in [CodeXGLUE](https://github.com/microsoft/CodeXGLUE/tree/main/Code-Code/CodeCompletion-token). And follow CodeXGLUE to fine-tune a model on it, like CodeGPT.
61 |
62 | The second stage in ReACC is to complete codes based on the context and the retrieved codes. We simply put the retrieved code before the context and concat them as inputs.
63 |
64 | Navigate to the `gen` folder. We adapt the code completion scripts in [CodeXGLUE](https://github.com/microsoft/CodeXGLUE/tree/main/Code-Code/CodeCompletion-line). We modify the script `dataset.py` to include similar codes as input. Run `run_lm.py` to evaluate your fine-tuned model.
65 |
66 | ```bash
67 | export CUDA_VISIBLE_DEVICES=0
68 | LANG=python
69 | DATADIR=dataset/py150
70 | LITFILE=${DATADIR}/literals.json
71 | OUTPUTDIR=save/py150
72 | PRETRAINDIR=py150-ckpt
73 |
74 | LOADFILE=${DATADIR}/train_split
75 | RESFILE=search_res.pkl
76 | SAVEFILE=prediction.txt
77 |
78 | python -u run_lm.py \
79 | --data_dir=$DATADIR \
80 | --lit_file=$LITFILE \
81 | --langs=$LANG \
82 | --output_dir=$OUTPUTDIR \
83 | --pretrain_dir=$PRETRAINDIR \
84 | --load_file_name=$LOADFILE \
85 | --search_res=$RESFILE \
86 | --save_name=$SAVEFILE \
87 | --model_type=gpt2 \
88 | --block_size=1024 \
89 | --eval_line \
90 | --logging_steps=100 \
91 | --seed=42
92 | ```
93 |
94 |
95 | ## Zero-shot code clone detection
96 | In order to evaluate the effectiveness of the code-to-code retrieval module in ReACC,
97 | we perform code clone detection task which aims to retrieve semantic equivalent programs.
98 |
99 | We extract the evaluation dataset from CodeNet, the same as in [UniXcoder paper](https://arxiv.org/abs/2203.03850).
100 | The dataset can be downloaded from [here](https://github.com/microsoft/CodeBERT/tree/master/UniXcoder/downstream-tasks/zero-shot-search/dataset)
101 |
102 | Run the `codenet_test.py` to reproduce this experiment.
103 | ```bash
104 | DATADIR=CodeNet
105 | PRETRAINDIR=microsoft/reacc-py-retriever
106 |
107 | python -u codenet_test.py \
108 | --data_dir=$DATADIR \
109 | --pretrained_dir=$PRETRAINDIR \
110 | --lang=python \
111 | --num_vec=8 \
112 | --cut \
113 | --block_size=512 \
114 | --per_gpu_eval_batch_size=64 \
115 | --logging_steps=100 \
116 | --seed=614
117 | ```
118 |
119 | ## Code of Conduct
120 |
121 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
122 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
123 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
124 |
125 | ## License
126 |
127 | Copyright (c) Microsoft Corporation. All rights reserved.
128 |
129 | Licensed under the [MIT](LICENSE) license.
130 |
131 | ## Reference
132 | If you use this code or ReACC, please consider citing us.
133 |
@article{lu2022reacc,
134 | title={ReACC: A Retrieval-Augmented Code Completion Framework},
135 | author={Lu, Shuai and Duan, Nan and Han, Hojae and Guo, Daya and Hwang, Seung-won and Svyatkovskiy, Alexey},
136 | journal={arXiv preprint arXiv:2203.07722},
137 | year={2022}
138 | }
139 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | # This script is used to generate the embedding vectors for the given dataset.
5 |
6 | import argparse
7 | import logging
8 | import os
9 | import random
10 | import re
11 | import json
12 | import pickle
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 | from itertools import cycle
17 | from functools import partial
18 | from torch.nn.utils.rnn import pad_sequence
19 | import torch.nn.functional as F
20 | from torch.utils.data import DataLoader, Dataset, SequentialSampler
21 | from torch.utils.data.distributed import DistributedSampler
22 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
23 | RobertaConfig, RobertaModel, RobertaTokenizer)
24 | from tqdm import tqdm
25 | from tree_sitter import Language, Parser
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | class InferDataset(Dataset):
30 | def __init__(self, tokenizer, args, api=True):
31 | if args.local_rank == -1:
32 | local_rank = 0
33 | world_size = 1
34 | else:
35 | local_rank = args.local_rank
36 | world_size = torch.distributed.get_world_size()
37 | self.tokenizer = tokenizer
38 | self.args = args
39 | self.api = api
40 | data_file = args.data_path
41 |
42 |
43 | if args.lang == "java":
44 | from process_java import processor
45 | elif args.lang == "python":
46 | from process_python import processor
47 | self.proc = processor(args.lang, remove_comments=False)
48 |
49 | logger.info(f"Creating features from {data_file}")
50 | data_format = data_file.split(".")[-1]
51 |
52 | self.data = []
53 | self.idx = []
54 | n = 0
55 | with open(data_file) as f:
56 | for _ in f:
57 | n += 1
58 | # n = 100000
59 | st = n//world_size*local_rank
60 | ed = n//world_size*(local_rank+1)
61 | logger.warning(f"device {local_rank} will load {ed-st} data line from {st} to {ed}")
62 | with open(data_file) as f:
63 | for i,line in enumerate(f):
64 | if i >= st and i < ed:
65 | if (i-st) % 100000 == 0:
66 | logger.info(f"device {local_rank} created {i-st}/{ed-st} train data")
67 | if "json" in data_format:
68 | content = json.loads(line)
69 | self.data.append(self.convert_cxg_format_to_normal(content["input"]))
70 | self.idx.append(content["id"])
71 | else: # txt
72 | self.data.append(self.convert_cxg_format_to_normal(line.strip()))
73 | self.idx.append(i)
74 | logger.warning(f"device {local_rank} loaded {len(self.data)} train data from {st} to {ed}")
75 |
76 | def convert_cxg_format_to_normal(self, code):
77 | if code.startswith(""):
78 | code = code.lstrip("")
79 | if code.endswith(""):
80 | code = code.rstrip("")
81 | code = code.replace("", "\n")
82 | code = code.replace("", "0").replace("", "").replace("", "")
83 | pattern = re.compile(r"<(STR|NUM|CHAR)_LIT:(.*?)>", re.S)
84 | lits = re.findall(pattern, code)
85 | for lit in lits:
86 | code = code.replace(f"<{lit[0]}_LIT:{lit[1]}>", lit[1])
87 | return code
88 |
89 | def encode(self, code, api_seq):
90 | if self.api:
91 | code_tokens = [self.tokenizer.cls_token] + self.tokenizer.tokenize(code) + \
92 | [self.tokenizer.sep_token] + self.tokenizer.tokenize(" ".join(api_seq)) + [self.tokenizer.sep_token]
93 | else:
94 | code_tokens = [self.tokenizer.cls_token] + self.tokenizer.tokenize(code) + [self.tokenizer.sep_token]
95 | code_tokens = code_tokens[:self.args.block_size]
96 | code_ids = self.tokenizer.convert_tokens_to_ids(code_tokens)
97 | return code_ids
98 |
99 | def process(self, code):
100 | self.proc.update(code)
101 | api_seq = self.proc.get_api_seq()
102 | code = self.proc.untokenize(cut_ratio=0.0)
103 | token_id = self.encode(code, api_seq)
104 | return token_id
105 |
106 | def __len__(self):
107 | return len(self.data)
108 |
109 | def __getitem__(self, item):
110 | return torch.tensor(self.process(self.data[item])), torch.tensor([self.idx[item]])
111 |
112 |
113 |
114 | def set_seed(args):
115 | random.seed(args.seed)
116 | np.random.seed(args.seed)
117 | torch.manual_seed(args.seed)
118 | if args.n_gpu > 0:
119 | torch.cuda.manual_seed_all(args.seed)
120 |
121 | def my_collect_fn(sequences, batch_first=True, padding_value=1):
122 | inputs = []
123 | inputs1 = []
124 | for (x, x1) in sequences:
125 | inputs.append(x)
126 | inputs1.append(x1)
127 | return (
128 | pad_sequence(inputs, batch_first, padding_value),
129 | pad_sequence(inputs1, batch_first, padding_value),
130 | )
131 |
132 | def inference(args, tokenizer, model, save_name, api=False):
133 | #build dataloader
134 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
135 | dataset = InferDataset(tokenizer, args, api=api)
136 | sampler = SequentialSampler(dataset)
137 | dataloader = DataLoader(dataset, sampler=sampler, batch_size=args.eval_batch_size, collate_fn=partial(my_collect_fn, batch_first=True, padding_value=tokenizer.pad_token_id), num_workers=4)
138 |
139 | model.to(args.device)
140 | if args.local_rank != -1:
141 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank%args.gpu_per_node],
142 | output_device=args.local_rank%args.gpu_per_node,
143 | find_unused_parameters=True)
144 |
145 | # Eval!
146 | logger.info("***** Running Inference *****")
147 | logger.info(" Num examples = %d", len(dataset))
148 | logger.info(" Batch size = %d", args.eval_batch_size)
149 |
150 | model.eval()
151 |
152 | steps = 0
153 | n_vec = max(0, args.num_vec)
154 | saved = {}
155 | for batch in dataloader:
156 | with torch.no_grad():
157 | (inputs1, inputs2) = batch
158 | inputs1 = inputs1.to(args.device)
159 | attn_mask1 = torch.tensor(inputs1.clone().detach() != tokenizer.pad_token_id, dtype=torch.uint8, device=args.device)
160 | outputs = model(inputs1, attention_mask=attn_mask1)[0]
161 | if n_vec > 0:
162 | outputs = nn.functional.normalize(outputs[:, :n_vec, :], dim=2)
163 | else:
164 | outputs = nn.functional.normalize(outputs[:, 0, :], dim=1)
165 | outputs = outputs.detach().to("cpu").numpy()
166 | idxs = inputs2.numpy()
167 | for i in range(outputs.shape[0]):
168 | saved[idxs[i][0]] = outputs[i]
169 | steps += 1
170 | if steps % args.logging_steps == 0:
171 | logger.info(f"Inferenced {steps} steps")
172 |
173 | if args.local_rank != -1:
174 | pickle.dump(saved, open(save_name+f"_{args.local_rank}.pkl", "wb"))
175 | else:
176 | pickle.dump(saved, open(save_name+".pkl", "wb"))
177 |
178 | def merge(args, num, save_name):
179 | saved = {}
180 | for i in range(num):
181 | saved.update(pickle.load(open(save_name+f"_{i}.pkl", "rb")))
182 | os.remove(save_name+f"_{i}.pkl")
183 | pickle.dump(saved, open(save_name+".pkl", "wb"))
184 |
185 |
186 |
187 |
188 | def main():
189 | parser = argparse.ArgumentParser()
190 |
191 | ## Required parameters
192 | parser.add_argument("--data_path", default=None, type=str, required=True,
193 | help="The input data path.")
194 | parser.add_argument("--save_name", default=None, type=str, required=True,
195 | help="The output directory where the model predictions and checkpoints will be written.")
196 | parser.add_argument("--lang", default=None, type=str, required=True,
197 | help="Language of the dataset.")
198 | parser.add_argument("--pretrained_dir", default=None, type=str,
199 | help="The directory where the trained model and tokenizer are saved.")
200 |
201 | parser.add_argument("--cut_ratio", type=float, default=0.5,
202 | help="Ratio of replaced variables")
203 | parser.add_argument('--num_vec', type=int, default=-1,
204 | help="number of vectors")
205 |
206 | parser.add_argument("--block_size", default=512, type=int,
207 | help="Optional input sequence length after tokenization."
208 | "The training dataset will be truncated in block of this size for training."
209 | "Default to the model max input length for single sentence inputs (take into account special tokens).")
210 |
211 | parser.add_argument("--per_gpu_eval_batch_size", default=16, type=int,
212 | help="Batch size per GPU/CPU for evaluation.")
213 |
214 | parser.add_argument('--logging_steps', type=int, default=10,
215 | help="Log every X updates steps.")
216 | parser.add_argument("--no_cuda", action='store_true',
217 | help="Avoid using CUDA when available")
218 | parser.add_argument('--seed', type=int, default=42,
219 | help="random seed for initialization")
220 |
221 | parser.add_argument("--local_rank", type=int, default=-1,
222 | help="For distributed training: local_rank")
223 | parser.add_argument("--node_index", type=int, default=0,
224 | help="node index if multi-node running")
225 | parser.add_argument("--gpu_per_node", type=int, default=-1,
226 | help="num of gpus per node")
227 |
228 | args = parser.parse_args()
229 |
230 | logger.warning("local_rank: %d, node_index: %d, gpu_per_node: %d"%(args.local_rank, args.node_index, args.gpu_per_node))
231 | # Setup CUDA, GPU & distributed training
232 | if args.local_rank == -1 or args.no_cuda:
233 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
234 | args.n_gpu = torch.cuda.device_count()
235 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
236 | torch.cuda.set_device(args.local_rank)
237 | device = torch.device("cuda", args.local_rank)
238 | torch.distributed.init_process_group(backend='nccl')
239 | args.local_rank += args.node_index * args.gpu_per_node
240 | args.n_gpu = 1
241 | args.device = device
242 |
243 | world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1
244 |
245 | # Setup logging
246 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
247 | datefmt='%m/%d/%Y %H:%M:%S',
248 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
249 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, world size: %s",
250 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), world_size)
251 |
252 | # Set seed
253 | set_seed(args)
254 |
255 | args.start_epoch = 0
256 | args.start_step = 0
257 |
258 | tokenizer = RobertaTokenizer.from_pretrained(args.pretrained_dir)
259 | model = RobertaModel.from_pretrained(args.pretrained_dir, add_pooling_layer=False)
260 |
261 | inference(args, tokenizer, model, args.save_name, api=True)
262 | logger.info(f"device {args.local_rank} finished")
263 |
264 | if args.local_rank != -1:
265 | torch.distributed.barrier()
266 | if args.local_rank in [-1, 0]:
267 | import time
268 | time.sleep(10)
269 | merge(args, world_size, save_name=args.save_name)
270 |
271 |
272 | if __name__ == "__main__":
273 | main()
274 |
275 |
--------------------------------------------------------------------------------
/process_java.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # Copyright (c) Microsoft Corporation.
3 | # Licensed under the MIT License.
4 | import os
5 | import json
6 | import re
7 | import pickle
8 | from tqdm import tqdm
9 | import random
10 | from collections import Counter, OrderedDict
11 | from tree_sitter import Language, Parser
12 | from parser import (remove_comments_and_docstrings,
13 | tree_to_token_index,
14 | index_to_code_token,
15 | tree_to_variable_index,
16 | traverse)
17 | from textwrap import dedent
18 |
19 |
20 | # ['nwo', 'path', 'language', 'identifier', 'parameters', 'argument_list', 'return_statement', 'docstring', 'docstring_summary', 'docstring_tokens', 'function', 'function_tokens']
21 |
22 | def clean_docstring_comments(comment):
23 | comment = comment.strip().strip(""" "' """)
24 | comment = "\n".join(map(lambda s: s.lstrip("#"), comment.splitlines()))
25 | return dedent(comment)
26 |
27 | class processor(object):
28 | def __init__(self, lang, code=None, remove_comments=False):
29 | LANGUAGE = Language('parser/my-languages.so', lang)
30 | parser = Parser()
31 | parser.set_language(LANGUAGE)
32 | self.parser = [parser]
33 | self.lang = lang
34 | self.remove_comments = remove_comments
35 | self.preserve_words = set(["self", "super", "Exception", "__init__", "__main__"])
36 | if code is None:
37 | self.tree = None
38 | else:
39 | self.update(code)
40 |
41 | def update(self, code, function=False):
42 | if self.lang == "php":
43 | code = ""
44 | if self.lang == "java" and function:
45 | code = "class A{\n"+code+"\n}"
46 | com = True
47 | if self.remove_comments:
48 | com = False
49 | try:
50 | code = remove_comments_and_docstrings(code, self.lang)
51 | except Exception:
52 | com = True
53 | self.code = code
54 | self.code_bytes = code.encode("utf-8")
55 | self.tree = self.parser[0].parse(self.code_bytes)
56 | root_node = self.tree.root_node
57 | tokens_index = tree_to_token_index(root_node)
58 | code = self.code.split('\n')
59 | self.code_tokens = [index_to_code_token(x, code) for x in tokens_index]
60 | self.index_to_code = OrderedDict()
61 | for idx, (index, code) in enumerate(zip(tokens_index, self.code_tokens)):
62 | self.index_to_code[index] = (idx, code)
63 | return com
64 |
65 | def get_doc(self):
66 | """
67 | For file level data, merge the doc in each function
68 | """
69 | self.functions = []
70 | self.get_func_nodes(self.tree.root_node)
71 | docs = ""
72 | for func_node in self.functions:
73 | body_node = func_node.children[-1]
74 | if body_node.children and body_node.children[0].children:
75 | if body_node.children[0].children[0].type in ["string", "comment"]:
76 | docs += clean_docstring_comments(
77 | self.span_select(body_node.children[0].children[0]),
78 | ) + ""
79 | return docs
80 |
81 | def get_func_nodes(self, root):
82 | """
83 | For both function level and file level data
84 | """
85 | for node in root.children:
86 | if node.type == "function_definition":
87 | self.functions.append(node)
88 | else:
89 | self.get_func_nodes(node)
90 |
91 |
92 | def load_names(self, path):
93 | vnames = pickle.load(open(os.path.join(path, "vnames.pkl"), "rb"))
94 | self.preserve_words = vnames["not_replace"]
95 | self.vnames = vnames["cands"]
96 |
97 | def span_select(self, *nodes, indent=False):
98 | if not nodes:
99 | return ""
100 | start, end = nodes[0].start_byte, nodes[-1].end_byte
101 | select = self.code_bytes[start:end].decode("utf-8")
102 | if indent:
103 | return " " * nodes[0].start_point[1] + select
104 | return select
105 |
106 | def get_func_name(self):
107 | """
108 | For function level data only
109 | """
110 | root_node = self.tree.root_node
111 | func_nodes = [node for node in root_node.children if node.type == "function_definition"]
112 | try:
113 | func_name = func_nodes[0].child_by_field_name("name")
114 | except IndexError:
115 | return ""
116 | return self.span_select(func_name)
117 |
118 | def get_var_names(self):
119 | root_node = self.tree.root_node
120 | vnames = set()
121 | self._get_var_names_from_node(root_node, vnames)
122 | return vnames
123 |
124 | def get_api_seq(self):
125 | root_node = self.tree.root_node
126 | api_seq = []
127 | self._get_api_seq(root_node, api_seq)
128 | return api_seq
129 |
130 | def _get_var_names_from_node(self, node, vnames, inactive=False):
131 | if len(node.children) > 0:
132 | if node.type in ["method_invocation"]:
133 | self._get_var_names_from_node(node.children[0], vnames, inactive)
134 | for child in node.children[1:]:
135 | if child.type == "argument_list":
136 | self._get_var_names_from_node(child, vnames, inactive)
137 | else:
138 | self._get_var_names_from_node(child, vnames, True)
139 | else:
140 | for child in node.children:
141 | if node.type in ["field_access", "modifiers", "import_declaration", "package_declaration"]:
142 | self._get_var_names_from_node(child, vnames, True)
143 | else:
144 | self._get_var_names_from_node(child, vnames, inactive)
145 | elif node.type == "identifier":
146 | if not inactive:
147 | vnames.add(self.span_select(node))
148 |
149 | def _get_api_seq(self, node, api_seq):
150 | if node.type == "method_invocation":
151 | obj = node.child_by_field_name("object")
152 | func = node.child_by_field_name("name")
153 | if obj:
154 | api_seq.append(self.span_select(obj) + "." + self.span_select(func))
155 | else:
156 | api_seq.append(self.span_select(func))
157 | else:
158 | for child in node.children:
159 | self._get_api_seq(child, api_seq)
160 |
161 | def process(self, ratio=0.5, indent=False, add_dead_code=True, cut_ratio=0.0, function=False):
162 | vnames = [x for x in self.get_var_names() if x not in self.preserve_words]
163 | vnames = random.sample(vnames, int(len(vnames)*ratio))
164 | cands = random.sample(self.vnames, len(vnames)+3)
165 | dead_vars = cands[-3:]
166 | if add_dead_code:
167 | deadcode = self.insert_dead_code(dead_vars)
168 | else:
169 | deadcode = None
170 | replaced = {v: c for v, c in zip(vnames, cands[:-3])}
171 | self.index_to_new_code = {}
172 | self._replace_var_names_from_node(self.tree.root_node, replaced)
173 | code_string = self.untokenize(indent, deadcode, True, cut_ratio=cut_ratio, function=function)
174 | return code_string
175 |
176 | def process_no_replace(self, indent=False, add_dead_code=True, cut_ratio=0.0, function=False):
177 | dead_vars = random.sample(self.vnames, 3)
178 | if add_dead_code:
179 | deadcode = self.insert_dead_code(dead_vars)
180 | else:
181 | deadcode = None
182 | code_string = self.untokenize(indent, deadcode, False, cut_ratio=cut_ratio, function=function)
183 | return code_string
184 |
185 | def insert_dead_code(self, v):
186 | # dead code types, vars that can't appear in original code
187 | # A = B, A
188 | # A.C(B, 1), A
189 | # A = B + C, AB
190 | # A = B(C), AB
191 | # A = B.C(), ABC
192 | # for (String i: A){\nB(C)\n}
193 | # A = [B for B in range(C)]
194 | # if (C){\nA = B()\n}
195 | dead_type = random.randrange(7)
196 | if dead_type == 0:
197 | return f"{v[0]} = {v[1]};\n"
198 | elif dead_type == 1:
199 | return f"{v[0]}.{v[2]}({v[1]}, 1);\n"
200 | elif dead_type == 2:
201 | return f"{v[0]} = {v[1]} + {v[2]};\n"
202 | elif dead_type == 3:
203 | return f"{v[0]} = {v[1]}({v[2]});\n"
204 | elif dead_type == 4:
205 | return f"{v[0]} = {v[1]}.{v[2]}();\n"
206 | elif dead_type == 5:
207 | return "for (String i: %s){\n%s(%s)\n;}\n"%(v[0], v[1], v[2])
208 | elif dead_type == 6:
209 | return "if (%s){\n%s = %s\n;}\n"%(v[2], v[0], v[1])
210 |
211 | def _replace_var_names_from_node(self, node, replaced):
212 | if len(node.children) > 0:
213 | for child in node.children:
214 | self._replace_var_names_from_node(child, replaced)
215 | elif node.type == "identifier":
216 | try:
217 | idf = self.index_to_code[(node.start_point, node.end_point)][1]
218 | except KeyError:
219 | idf = "None"
220 | if idf in replaced:
221 | self.index_to_new_code[(node.start_point, node.end_point)] = replaced[idf]
222 |
223 | def untokenize(self, indent=False, deadcode=None, replaced=False, cut_ratio=0.0, function=False, fix_cut_pos=False):
224 | code_string = ""
225 | prev_sp = None
226 | prev_ep = None
227 | prev_indent = 0
228 | indent_size = -1
229 | if function:
230 | total_line = list(self.index_to_code.keys())[-2][0][0]
231 | else:
232 | total_line = list(self.index_to_code.keys())[-1][0][0]
233 | insert_line = random.randint(total_line//5, total_line*4//5)
234 | cut = random.random() < cut_ratio
235 | if cut:
236 | if fix_cut_pos:
237 | cut_pos = (len(self.index_to_code)-4)//2
238 | else:
239 | cut_pos = random.randint((len(self.index_to_code)-4)//3, (len(self.index_to_code)-4)*2//3)
240 | if function:
241 | poses = list(self.index_to_code.keys())[3:-1]
242 | else:
243 | poses = list(self.index_to_code.keys())
244 | for ip, pos in enumerate(poses):
245 | sp = pos[0]
246 | ep = pos[1]
247 | if cut and ip >= cut_pos:
248 | break
249 | if replaced and pos in self.index_to_new_code:
250 | add_token = self.index_to_new_code[pos]
251 | else:
252 | add_token = self.index_to_code[pos][1]
253 | if prev_sp is None or (sp[0] == prev_ep[0] and sp[1] == prev_ep[1]):
254 | code_string += add_token
255 | elif sp[0] == prev_ep[0]:
256 | if code_string and code_string[-1] != " ":
257 | code_string += " "
258 | code_string += add_token
259 | else:
260 | # if cut and cut_line >= 1 and cut_line <= prev_ep[0]:
261 | # break
262 | if replaced and deadcode:
263 | if insert_line <= prev_ep[0]:
264 | code_string += "\n" + deadcode
265 | insert_line = total_line+2
266 | if indent and add_token:
267 | code_string += "\n"
268 | omit = False
269 | if sp[1] != prev_indent and prev_indent == 0 and indent_size == -1:
270 | indent_size = sp[1] - prev_indent
271 | if sp[1] - prev_indent > 0:
272 | if sp[1] - prev_indent > 2 * indent_size:
273 | omit = True
274 | else:
275 | for i in range(prev_indent, sp[1], indent_size):
276 | code_string += ""
277 | elif sp[1] - prev_indent < 0:
278 | for i in range(sp[1], prev_indent, indent_size):
279 | code_string += ""
280 | code_string += add_token
281 | if not omit:
282 | prev_indent = sp[1]
283 | else:
284 | code_string += "\n"
285 | code_string += " " if sp[1] else ""
286 | code_string += add_token
287 | prev_sp, prev_ep = sp, ep
288 | return re.sub(re.compile("\s*\n"), "\n", code_string.lstrip()).replace("\n", "")
289 |
290 | def convert_to_normal(self, code):
291 | lines = code.split("")
292 | indent_size = 4
293 | indent = 0
294 | res = ""
295 | for line in lines:
296 | indent += line.count("")
297 | indent -= line.count("")
298 | res += "\n" + " "*indent_size*indent + line.replace("", "").replace("", "")
299 | return res
300 |
301 |
302 |
303 |
304 |
--------------------------------------------------------------------------------
/codenet_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 | import argparse
4 | import logging
5 | import os
6 | import random
7 | import re
8 | import json
9 | import pickle
10 | from collections import Counter
11 | import numpy as np
12 | import torch
13 | from itertools import cycle
14 | from functools import partial
15 | from torch.nn.utils.rnn import pad_sequence
16 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset
17 | from torch.utils.data.distributed import DistributedSampler
18 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
19 | RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaTokenizer)
20 | from tqdm import tqdm
21 | from tree_sitter import Language, Parser
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | at_K = 100
26 |
27 | class InputFeatures(object):
28 | """A single training/test features for a example."""
29 | def __init__(self,
30 | code_ids,
31 | index,
32 | label
33 |
34 | ):
35 | self.code_ids = code_ids
36 | self.index= index
37 | self.label = label
38 |
39 |
40 | class CodeWithDocNoRepDataset(Dataset):
41 | def __init__(self, tokenizer, args, file_type, cut_ratio=0.0):
42 | self.tokenizer = tokenizer
43 | self.args = args
44 | data_file = os.path.join(args.data_dir, f"{file_type}.jsonl")
45 |
46 | if args.lang == "java":
47 | from process_java import processor
48 | elif args.lang == "python":
49 | from process_python import processor
50 | self.proc = processor(args.lang, remove_comments=False)
51 | # self.proc.load_names(args.vars_dir)
52 |
53 | #load index
54 | logger.info(f"Creating features from {data_file}")
55 |
56 |
57 | self.examples = []
58 | lines = open(data_file).readlines()
59 | for i,line in enumerate(lines):
60 | content = json.loads(line)
61 | self.proc.update(content["func"])
62 | code = self.proc.untokenize(cut_ratio=cut_ratio, fix_cut_pos=True)
63 | self.proc.update(self.proc.convert_to_normal(code))
64 | api_seq = self.proc.get_api_seq()
65 | token_id = self.encode_v3(code, api_seq)
66 | self.examples.append(InputFeatures(token_id, content["index"], int(content["label"])))
67 | logger.info(f"loaded {len(self.examples)} data")
68 |
69 | self.label_examples={}
70 | for e in self.examples:
71 | if e.label not in self.label_examples:
72 | self.label_examples[e.label] = []
73 | self.label_examples[e.label].append(e)
74 |
75 | def encode_v3(self, code, api_seq):
76 | code_tokens = [self.tokenizer.cls_token] + self.tokenizer.tokenize(code) + \
77 | [self.tokenizer.sep_token] + self.tokenizer.tokenize(" ".join(api_seq)) + [self.tokenizer.sep_token]
78 | code_tokens = code_tokens[:self.args.block_size]
79 | code_ids = self.tokenizer.convert_tokens_to_ids(code_tokens)
80 | return code_ids
81 |
82 | def __len__(self):
83 | return len(self.examples)
84 |
85 | def __getitem__(self, i):
86 | return torch.tensor(self.examples[i].code_ids), self.examples[i].index
87 |
88 | def set_seed(args):
89 | random.seed(args.seed)
90 | np.random.seed(args.seed)
91 | torch.manual_seed(args.seed)
92 | if args.n_gpu > 0:
93 | torch.cuda.manual_seed_all(args.seed)
94 |
95 | def my_collect_fn(sequences, batch_first=True, padding_value=1):
96 | inputs1 = []
97 | inputs2 = []
98 | for (x1, x2) in sequences:
99 | inputs1.append(x1)
100 | inputs2.append(x2)
101 | return (
102 | pad_sequence(inputs1, batch_first, padding_value),
103 | inputs2
104 | )
105 |
106 |
107 | def eval_bm25_beir(args, tokenizer, file_name, candidate_file_name, cut=False):
108 | from beir.datasets.data_loader import GenericDataLoader
109 | from beir.retrieval.evaluation import EvaluateRetrieval
110 | from beir.retrieval.search.lexical import BM25Search as BM25
111 |
112 | if args.lang == "java":
113 | from process_java import processor
114 | elif args.lang == "python":
115 | from process_python import processor
116 | proc = processor(args.lang, remove_comments=False)
117 |
118 | idx2label = {}
119 | label2num = Counter()
120 | lines = open(os.path.join(args.data_dir, f"{candidate_file_name}.jsonl")).readlines()
121 | corpus = {}
122 | for i,line in enumerate(tqdm(lines)):
123 | content = json.loads(line)
124 | # proc.update(content["func"])
125 | # code = proc.untokenize()
126 | code = content["func"]
127 | idx2label[content["index"]] = content["label"]
128 | label2num[content["label"]] += 1
129 | corpus[content["index"]] = {"text": code}
130 |
131 | lines = open(os.path.join(args.data_dir, f"{file_name}.jsonl")).readlines()
132 | queries = {}
133 | qrels = {}
134 | for i,line in enumerate(tqdm(lines)):
135 | content = json.loads(line)
136 | ori_code_tokens = content["func"].split()
137 | if cut:
138 | code = " ".join(ori_code_tokens[:len(ori_code_tokens)//2])
139 | else:
140 | code = " ".join(ori_code_tokens)
141 | # proc.update(content["func"])
142 | # code = proc.untokenize(cut_ratio=1.0 if cut else 0.0, fix_cut_pos=True)
143 | queries[content["index"]] = code
144 | qrels[content["index"]] = {content["index"]: 1}
145 |
146 | model = BM25(index_name="codenet", hostname="http://localhost:9200", initialize=True)
147 | retriever = EvaluateRetrieval(model, k_values=[at_K+1])
148 | scores = retriever.retrieve(corpus, queries)
149 |
150 | # pickle.dump(scores, open(os.path.join(args.data_dir, "bm25_scores.pkl"), "wb"))
151 |
152 | MAP = []
153 | PREC = 0.0
154 | for idx, v in tqdm(scores.items()):
155 | v = sorted(v.items(), key=lambda x:-x[1])
156 | label = idx2label[idx]
157 | div = min(at_K, label2num[label])
158 | Avep = []
159 | cont = 0
160 | for i, (_id, score) in enumerate(v):
161 | if i - cont >= at_K:
162 | break
163 | if _id == idx:
164 | cont += 1
165 | continue
166 | if idx2label[_id] == label:
167 | Avep.append((len(Avep)+1)/(i+1-cont))
168 | if i - cont == 0:
169 | PREC += 1.0
170 | MAP.append(sum(Avep)/div)
171 |
172 | result = {
173 | "eval_map":float(np.mean(MAP)),
174 | "eval_prec":float(PREC/len(MAP))
175 | }
176 | for key in sorted(result.keys()):
177 | logger.info(" %s = %s", key, str(round(result[key], 4)))
178 |
179 |
180 | def evaluate(args, model, tokenizer, file_name, candidate_file_name, cut):
181 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
182 |
183 | query_dataset = CodeWithDocNoRepDataset(tokenizer, args, file_name, cut_ratio=1.0 if cut else 0.0)
184 | query_sampler = SequentialSampler(query_dataset)
185 | query_dataloader = DataLoader(query_dataset, sampler=query_sampler, batch_size=args.eval_batch_size, collate_fn=partial(my_collect_fn, batch_first=True, padding_value=tokenizer.pad_token_id), num_workers=4)
186 |
187 | candidate_dataset = CodeWithDocNoRepDataset(tokenizer, args, candidate_file_name)
188 | candidate_sampler = SequentialSampler(candidate_dataset)
189 | candidate_dataloader = DataLoader(candidate_dataset, sampler=candidate_sampler, batch_size=args.eval_batch_size, collate_fn=partial(my_collect_fn, batch_first=True, padding_value=tokenizer.pad_token_id), num_workers=4)
190 |
191 |
192 | idx2label = {}
193 | label2num = Counter()
194 | lines = open(os.path.join(args.data_dir, f"{candidate_file_name}.jsonl")).readlines()
195 | corpus = {}
196 | for i,line in enumerate(tqdm(lines)):
197 | content = json.loads(line)
198 | idx2label[content["index"]] = content["label"]
199 | label2num[int(content["label"])] += 1
200 |
201 | # multi-gpu evaluate
202 | if args.n_gpu > 1:
203 | model = torch.nn.DataParallel(model)
204 |
205 | # Eval!
206 | logger.info("***** Running evaluation *****")
207 | logger.info(" Num Query = %d", len(query_dataset))
208 | logger.info(" Num Candidate = %d", len(candidate_dataset))
209 | logger.info(" Batch size = %d", args.eval_batch_size)
210 | model.to(args.device)
211 |
212 | model.eval()
213 | query_vecs = []
214 | query_indexs = []
215 | candidate_vecs = []
216 | candidate_indexs = []
217 |
218 | for batch in tqdm(query_dataloader, total=len(query_dataloader)):
219 | code_inputs = batch[0].to(args.device)
220 | index = batch[1]
221 | with torch.no_grad():
222 | attn_mask = torch.tensor(code_inputs.clone().detach() != tokenizer.pad_token_id, dtype=torch.uint8, device=args.device)
223 | code_vec = model(code_inputs, attention_mask=attn_mask)[0]
224 | code_vec = torch.nn.functional.normalize(code_vec[:, 0, :], dim=1)
225 | query_vecs.append(code_vec.cpu().numpy())
226 | query_indexs.extend(index)
227 |
228 | for batch in tqdm(candidate_dataloader,total=len(candidate_dataloader)):
229 | code_inputs = batch[0].to(args.device)
230 | index = batch[1]
231 | with torch.no_grad():
232 | attn_mask = torch.tensor(code_inputs.clone().detach() != tokenizer.pad_token_id, dtype=torch.uint8, device=args.device)
233 | code_vec = model(code_inputs, attention_mask=attn_mask)[0]
234 | if args.num_vec > 0:
235 | code_vec = torch.nn.functional.normalize(code_vec[:, :args.num_vec, :], dim=2)
236 | else:
237 | code_vec = torch.nn.functional.normalize(code_vec[:, 0, :], dim=1)
238 | candidate_vecs.append(code_vec.cpu().numpy())
239 | candidate_indexs.extend(index)
240 |
241 | model.train()
242 |
243 | query_vecs = np.concatenate(query_vecs, 0)
244 | candidate_vecs = np.concatenate(candidate_vecs, 0)
245 | query_labels = [idx2label[x] for x in query_indexs]
246 | candidate_labels = [idx2label[x] for x in candidate_indexs]
247 |
248 | if args.num_vec > 0:
249 | scores=np.einsum('nd,mvd->nmv', query_vecs, candidate_vecs).max(-1)
250 | else:
251 | scores=np.matmul(query_vecs, candidate_vecs.T)
252 | sort_ids = np.argsort(scores, axis=-1, kind='quicksort', order=None)[:,::-1]
253 |
254 | MAP = []
255 | MAP_at_K = []
256 | PREC = 0.0
257 | for i in tqdm(range(scores.shape[0]), total=scores.shape[0]):
258 | cont = 0
259 | label = int(query_labels[i])
260 | div = min(at_K, label2num[label])
261 | query_index = query_indexs[i]
262 | Avep = []
263 | for j,index in enumerate(list(sort_ids[i])):
264 | if query_index == candidate_indexs[index]:
265 | cont += 1
266 | continue
267 | if j - cont == at_K:
268 | MAP_at_K.append(sum(Avep)/div)
269 | if int(candidate_labels[index]) == label:
270 | Avep.append((len(Avep)+1)/(j+1-cont))
271 | if j - cont == 0:
272 | PREC += 1.0
273 | if len(Avep) > 0:
274 | MAP.append(sum(Avep)/len(Avep))
275 | else:
276 | MAP.append(0.0)
277 |
278 | result = {
279 | "Data size":len(MAP),
280 | "eval_map":float(np.mean(MAP)),
281 | f"eval_map_at_{at_K}":float(np.mean(MAP_at_K)),
282 | "eval_prec":float(PREC/len(MAP))
283 | }
284 | for key in sorted(result.keys()):
285 | logger.info(" %s = %s", key, str(round(result[key], 4)))
286 |
287 |
288 | def main():
289 | parser = argparse.ArgumentParser()
290 |
291 | ## Required parameters
292 | parser.add_argument("--data_dir", default=None, type=str, required=True,
293 | help="The input data path.")
294 | parser.add_argument("--pretrained_dir", default=None, type=str,
295 | help="The directory where the trained model and tokenizer are saved.")
296 | parser.add_argument("--lang", default="python", type=str,
297 | help="Language of dataset")
298 |
299 | parser.add_argument("--cut", action='store_true',
300 | help="Ratio of replaced variables")
301 | parser.add_argument('--num_vec', type=int, default=-1,
302 | help="number of vectors")
303 | parser.add_argument("--block_size", default=512, type=int,
304 | help="Optional input sequence length after tokenization."
305 | "The training dataset will be truncated in block of this size for training."
306 | "Default to the model max input length for single sentence inputs (take into account special tokens).")
307 |
308 | parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int,
309 | help="Batch size per GPU/CPU for evaluation.")
310 | parser.add_argument('--logging_steps', type=int, default=10,
311 | help="Log every X updates steps.")
312 | parser.add_argument("--no_cuda", action='store_true',
313 | help="Avoid using CUDA when available")
314 | parser.add_argument('--seed', type=int, default=42,
315 | help="random seed for initialization")
316 | args = parser.parse_args()
317 |
318 |
319 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
320 | args.n_gpu = torch.cuda.device_count()
321 | args.device = device
322 |
323 | # Setup logging
324 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
325 | datefmt='%m/%d/%Y %H:%M:%S',
326 | level=logging.INFO)
327 |
328 | # Set seed
329 | set_seed(args)
330 |
331 | tokenizer = RobertaTokenizer.from_pretrained(args.pretrained_dir)
332 | model = RobertaModel.from_pretrained(args.pretrained_dir, add_pooling_layer=False)
333 |
334 |
335 | evaluate(args, model, tokenizer, args.lang, args.lang, args.cut)
336 | # eval_bm25_beir(args, tokenizer, "java", "java", cut=True)
337 |
338 | if __name__ == "__main__":
339 | main()
340 |
341 |
--------------------------------------------------------------------------------
/process_python.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # Copyright (c) Microsoft Corporation.
3 | # Licensed under the MIT License.
4 | import os
5 | import json
6 | import re
7 | import pickle
8 | from tqdm import tqdm
9 | import random
10 | from collections import Counter, OrderedDict
11 | from tree_sitter import Language, Parser
12 | from parser import (remove_comments_and_docstrings,
13 | tree_to_token_index,
14 | index_to_code_token,
15 | tree_to_variable_index,
16 | traverse)
17 | from textwrap import dedent
18 |
19 |
20 | # ['nwo', 'path', 'language', 'identifier', 'parameters', 'argument_list', 'return_statement', 'docstring', 'docstring_summary', 'docstring_tokens', 'function', 'function_tokens']
21 |
22 | def clean_docstring_comments(comment):
23 | comment = comment.strip().strip(""" "' """)
24 | comment = "\n".join(map(lambda s: s.lstrip("#"), comment.splitlines()))
25 | return dedent(comment)
26 |
27 | class processor(object):
28 | def __init__(self, lang, code=None, remove_comments=False):
29 | LANGUAGE = Language('parser/my-languages.so', lang)
30 | parser = Parser()
31 | parser.set_language(LANGUAGE)
32 | self.parser = [parser]
33 | self.lang = lang
34 | self.remove_comments = remove_comments
35 | self.preserve_words = set(["self", "super", "Exception", "__init__", "__main__"])
36 | if code is None:
37 | self.tree = None
38 | else:
39 | self.update(code)
40 |
41 | def update(self, code):
42 | if self.lang == "php":
43 | code = ""
44 | com = True
45 | if self.remove_comments:
46 | com = False
47 | try:
48 | code = remove_comments_and_docstrings(code, self.lang)
49 | except Exception:
50 | com = True
51 | self.code = code
52 | self.code_bytes = code.encode("utf-8")
53 | self.tree = self.parser[0].parse(self.code_bytes)
54 | root_node = self.tree.root_node
55 | tokens_index = tree_to_token_index(root_node)
56 | code = self.code.split('\n')
57 | self.code_tokens = [index_to_code_token(x, code) for x in tokens_index]
58 | self.index_to_code = OrderedDict()
59 | for idx, (index, code) in enumerate(zip(tokens_index, self.code_tokens)):
60 | self.index_to_code[index] = (idx, code)
61 | return com
62 |
63 | def get_doc(self):
64 | """
65 | For file level data, merge the doc in each function
66 | """
67 | self.functions = []
68 | self.get_func_nodes(self.tree.root_node)
69 | docs = ""
70 | for func_node in self.functions:
71 | body_node = func_node.children[-1]
72 | if body_node.children and body_node.children[0].children:
73 | if body_node.children[0].children[0].type in ["string", "comment"]:
74 | docs += clean_docstring_comments(
75 | self.span_select(body_node.children[0].children[0]),
76 | ) + ""
77 | return docs
78 |
79 | def get_func_nodes(self, root):
80 | """
81 | For both function level and file level data
82 | """
83 | for node in root.children:
84 | if node.type == "function_definition":
85 | self.functions.append(node)
86 | else:
87 | self.get_func_nodes(node)
88 |
89 |
90 | def load_names(self, path):
91 | self.vnames = pickle.load(open(os.path.join(path, "vnames.pkl"), "rb"))
92 | self.vnames = [x for (x, _) in self.vnames.most_common(50000) if x not in self.preserve_words]
93 | self.fnames = pickle.load(open(os.path.join(path, "fnames.pkl"), "rb"))
94 | self.fnames = [x for (x, _) in self.fnames.most_common(5000) if x not in self.preserve_words]
95 |
96 | def span_select(self, *nodes, indent=False):
97 | if not nodes:
98 | return ""
99 | start, end = nodes[0].start_byte, nodes[-1].end_byte
100 | select = self.code_bytes[start:end].decode("utf-8")
101 | if indent:
102 | return " " * nodes[0].start_point[1] + select
103 | return select
104 |
105 | def get_func_name(self):
106 | """
107 | For function level data only
108 | """
109 | root_node = self.tree.root_node
110 | func_nodes = [node for node in root_node.children if node.type == "function_definition"]
111 | try:
112 | func_name = func_nodes[0].child_by_field_name("name")
113 | except IndexError:
114 | return ""
115 | return self.span_select(func_name)
116 |
117 | def get_var_names(self):
118 | root_node = self.tree.root_node
119 | vnames = set()
120 | self._get_var_names_from_node(root_node, vnames)
121 | return vnames
122 |
123 | def get_api_seq(self):
124 | root_node = self.tree.root_node
125 | api_seq = []
126 | self._get_api_seq(root_node, api_seq)
127 | return api_seq
128 |
129 | def _get_var_names_from_node(self, node, vnames, inactive=False):
130 | if len(node.children) > 0:
131 | for child in node.children:
132 | if (
133 | (node.type == "call" and child.type != "argument_list") or
134 | (node.type == "attribute") or
135 | (node.type in ["import_statement", "import_from_statement"])
136 | ):
137 | self._get_var_names_from_node(child, vnames, True)
138 | else:
139 | self._get_var_names_from_node(child, vnames, inactive)
140 | elif node.type == "identifier":
141 | if not inactive:
142 | vnames.add(self.span_select(node))
143 |
144 | def _get_api_seq(self, node, api_seq, tmp=None):
145 | if node.type == "call":
146 | api = node.child_by_field_name("function")
147 | if tmp:
148 | tmp.append(self.span_select(api))
149 | ant = False
150 | else:
151 | tmp = [self.span_select(api)]
152 | ant = True
153 | for child in node.children:
154 | self._get_api_seq(child, api_seq, tmp)
155 | if ant:
156 | api_seq += tmp[::-1]
157 | tmp = None
158 | else:
159 | for child in node.children:
160 | self._get_api_seq(child, api_seq, tmp)
161 |
162 | def process(self, ratio=0.85, indent=True, add_dead_code=True, cut_ratio=0.0):
163 | fname = self.get_func_name()
164 | vnames = [x for x in self.get_var_names() if x not in self.preserve_words]
165 | vnames = random.sample(vnames, int(len(vnames)*ratio))
166 | cands = random.sample(self.vnames, len(vnames)+3)
167 | dead_vars = cands[-3:]
168 | if add_dead_code:
169 | deadcode = self.insert_dead_code(dead_vars)
170 | else:
171 | deadcode = None
172 | replaced = {v: c for v, c in zip(vnames, cands[:-3])}
173 | if ratio > 0 and fname and fname not in replaced:
174 | replaced[fname] = random.choice(self.fnames)
175 | self.index_to_new_code = {}
176 | self._replace_var_names_from_node(self.tree.root_node, replaced)
177 | code_string = self.untokenize(indent, deadcode, True, cut_ratio=cut_ratio)
178 | return code_string
179 |
180 | def process_no_replace(self, indent=True, add_dead_code=True, cut_ratio=0.0):
181 | dead_vars = random.sample(self.vnames, 3)
182 | if add_dead_code:
183 | deadcode = self.insert_dead_code(dead_vars)
184 | else:
185 | deadcode = None
186 | code_string = self.untokenize(indent, deadcode, False, cut_ratio=cut_ratio)
187 | return code_string
188 |
189 | def create_mask_seq(self, indent=True):
190 | fname = self.get_func_name()
191 | vnames = [x for x in self.get_var_names() if x not in self.preserve_words]
192 | replaced = {v: f"[MASK_{i}]" for i, v in enumerate(vnames)}
193 | if fname and fname not in replaced:
194 | replaced[fname] = "[MASK_F]"
195 | self.index_to_new_code = {}
196 | self._replace_var_names_from_node(self.tree.root_node, replaced)
197 | code_string = self.untokenize(indent, replaced=True)
198 | return code_string, replaced
199 |
200 | def insert_dead_code(self, v):
201 | # dead code types, vars that can't appear in original code
202 | # A = B, A
203 | # A(B, 0), A
204 | # A = B + C, AB
205 | # A = B(C), AB
206 | # A = B.C(), ABC
207 | # A = [B for B in range(C)]
208 | # A = B if C else 0
209 | dead_type = random.randrange(7)
210 | if dead_type == 0:
211 | return f"{v[0]} = {v[1]}"
212 | elif dead_type == 1:
213 | return f"{v[0]}({v[1]}, 0)"
214 | elif dead_type == 2:
215 | return f"{v[0]} = {v[1]} + {v[2]}"
216 | elif dead_type == 3:
217 | return f"{v[0]} = {v[1]}({v[2]})"
218 | elif dead_type == 4:
219 | return f"{v[0]} = {v[1]}.{v[2]}()"
220 | elif dead_type == 5:
221 | return f"{v[0]} = [{v[1]} for {v[1]} in range({v[2]})]"
222 | elif dead_type == 6:
223 | return f"{v[0]} = {v[1]} if {v[2]} else 0"
224 |
225 | def _replace_var_names_from_node(self, node, replaced, inactive=False):
226 | if len(node.children) > 0:
227 | if node.type == "attribute":
228 | self._replace_var_names_from_node(node.children[0], replaced, inactive)
229 | for child in node.children[1:]:
230 | self._replace_var_names_from_node(child, replaced, True)
231 | else:
232 | for child in node.children:
233 | if (
234 | (node.type == "call" and child.type not in ["attribute", "argument_list"]) or
235 | (node.type in ["import_statement", "import_from_statement"])
236 | ):
237 | self._replace_var_names_from_node(child, replaced, True)
238 | else:
239 | self._replace_var_names_from_node(child, replaced, inactive)
240 | elif node.type == "identifier":
241 | if not inactive:
242 | try:
243 | idf = self.index_to_code[(node.start_point, node.end_point)][1]
244 | except KeyError:
245 | idf = "None"
246 | if idf in replaced:
247 | self.index_to_new_code[(node.start_point, node.end_point)] = replaced[idf]
248 |
249 | def untokenize(self, indent=True, deadcode=None, replaced=False, cut_ratio=0.0, fix_cut_pos=False):
250 | code_string = ""
251 | prev_sp = None
252 | prev_ep = None
253 | prev_indent = 0
254 | indent_size = -1
255 | total_line = list(self.index_to_code.keys())[-1][0][0]
256 | insert_line = random.randint(total_line//5, total_line*4//5)
257 | cut = random.random() < cut_ratio
258 | if cut:
259 | if fix_cut_pos:
260 | cut_pos = len(self.index_to_code)//2
261 | else:
262 | cut_pos = random.randint(len(self.index_to_code)//3, len(self.index_to_code)*2//3)
263 | for ip, pos in enumerate(self.index_to_code):
264 | sp = pos[0]
265 | ep = pos[1]
266 | if cut and ip >= cut_pos:
267 | break
268 | if replaced and pos in self.index_to_new_code:
269 | add_token = self.index_to_new_code[pos]
270 | else:
271 | add_token = self.index_to_code[pos][1]
272 | if prev_sp is None or (sp[0] == prev_ep[0] and sp[1] == prev_ep[1]):
273 | code_string += add_token
274 | elif sp[0] == prev_ep[0]:
275 | if code_string[-1] != " ":
276 | code_string += " "
277 | code_string += add_token
278 | else:
279 | # if cut and cut_line >= 1 and cut_line <= prev_ep[0]:
280 | # break
281 | if replaced and deadcode:
282 | if insert_line <= prev_ep[0]:
283 | if sp[1] <= prev_indent:
284 | code_string += "\n" + deadcode
285 | insert_line = total_line+2
286 | if indent and add_token:
287 | code_string += "\n"
288 | omit = False
289 | if sp[1] != prev_indent and prev_indent == 0 and indent_size == -1:
290 | indent_size = sp[1] - prev_indent
291 | if sp[1] - prev_indent > 0:
292 | if sp[1] - prev_indent > 2 * indent_size:
293 | omit = True
294 | else:
295 | for i in range(prev_indent, sp[1], indent_size):
296 | code_string += ""
297 | elif sp[1] - prev_indent < 0:
298 | for i in range(sp[1], prev_indent, indent_size):
299 | code_string += ""
300 | code_string += add_token
301 | if not omit:
302 | prev_indent = sp[1]
303 | else:
304 | code_string += "\n"
305 | for i in range(sp[1]):
306 | code_string += " "
307 | code_string += add_token
308 | prev_sp, prev_ep = sp, ep
309 | return re.sub(re.compile("\s*\n"), "\n", code_string.lstrip()).replace("\n", "")
310 |
311 | def convert_to_normal(self, code):
312 | lines = code.split("")
313 | indent_size = 4
314 | indent = 0
315 | res = ""
316 | for line in lines:
317 | indent += line.count("")
318 | indent -= line.count("")
319 | res += "\n" + " "*indent_size*indent + line.replace("", "").replace("", "")
320 | return res
321 |
322 | def extract_dataflow(self):
323 | try:
324 | root_node = self.tree.root_node
325 | try:
326 | DFG, _ = self.parser[1](root_node, self.index_to_code, {})
327 | except Exception:
328 | DFG = []
329 | DFG = sorted(DFG, key=lambda x: x[1])
330 | indexs = set()
331 | for d in DFG:
332 | if len(d[-1]) != 0:
333 | indexs.add(d[1])
334 | for x in d[-1]:
335 | indexs.add(x)
336 | new_DFG = []
337 | for d in DFG:
338 | if d[1] in indexs:
339 | new_DFG.append(d)
340 | dfg = new_DFG
341 | except Exception:
342 | dfg = []
343 | return dfg
344 |
345 |
346 |
--------------------------------------------------------------------------------
/generate/run_lm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 | from __future__ import absolute_import, division, print_function
4 |
5 | import argparse
6 | from functools import partial
7 | import glob
8 | import logging
9 | import os
10 | import pickle
11 | import random
12 | import re
13 | import shutil
14 | import json
15 |
16 | import numpy as np
17 | import torch
18 | from torch.nn.utils.rnn import pad_sequence
19 | from torch.nn import CrossEntropyLoss
20 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset
21 | from torch.utils.data.distributed import DistributedSampler
22 | from dataset import RelineDataset, PPLDataset
23 | from beam import Beam
24 |
25 | from fuzzywuzzy import fuzz
26 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
27 | BertConfig, BertForMaskedLM, BertTokenizer,
28 | GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
29 | OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
30 | RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
31 | DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 | MODEL_CLASSES = {
36 | 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
37 | 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
38 | 'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
39 | 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
40 | 'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
41 | }
42 |
43 | def set_seed(args):
44 | random.seed(args.seed)
45 | np.random.seed(args.seed)
46 | torch.manual_seed(args.seed)
47 | if args.n_gpu > 0:
48 | torch.cuda.manual_seed_all(args.seed)
49 |
50 | def update_config(args, config):
51 | # config.n_positions = config.n_ctx = args.block_size
52 | config.vocab_size = args.vocab_size
53 |
54 | def get_special_tokens(path):
55 | lits = json.load(open(path))
56 | tokens = ["", "", ""]
57 | for lit in lits["str"]:
58 | tokens.append(f"")
59 | for lit in lits["num"]:
60 | tokens.append(f"")
61 | for lit in lits["char"]:
62 | tokens.append(f"")
63 | return tokens
64 |
65 | def my_collect_fn(sequences, batch_first=True, padding_value=1):
66 | inputs1 = []
67 | inputs2 = []
68 | for (x1, x2) in sequences:
69 | inputs1.append(x1)
70 | inputs2.append(x2)
71 | return (
72 | pad_sequence(inputs1, batch_first, padding_value),
73 | pad_sequence(inputs2, batch_first, 0),
74 | )
75 |
76 | def eval_ppl(args, model, tokenizer, file_type='test', load_file="train", res_file="dense.pkl"):
77 | model.to(args.device)
78 | if load_file is None:
79 | dataset = PPLDataset(tokenizer, args, logger, file_type=file_type, block_size=args.block_size)
80 | else:
81 | dataset = PPLDataset(
82 | tokenizer, args, logger, file_type=file_type, block_size=args.block_size,
83 | load_file=load_file,
84 | search_res=os.path.join(args.data_dir, "search_results", res_file),
85 | )
86 |
87 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
88 | # Note that DistributedSampler samples randomly
89 | eval_sampler = SequentialSampler(dataset)
90 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, collate_fn=partial(my_collect_fn, batch_first=True, padding_value=tokenizer.pad_token_id), batch_size=args.eval_batch_size)
91 |
92 | # multi-gpu evaluate
93 | if args.n_gpu > 1:
94 | model = torch.nn.DataParallel(model)
95 |
96 | logger.info("***** Running test *****")
97 | logger.info(" Num examples = %d", len(dataset))
98 | logger.info(" Batch size = %d", args.eval_batch_size)
99 | eval_loss = 0.0
100 | num_tok = 0
101 | nb_eval_steps = 0
102 | model.eval()
103 |
104 | for step, (batch, token_labels) in enumerate(eval_dataloader):
105 |
106 | inputs = batch.to(args.device)
107 | attn_mask = torch.tensor(token_labels.clone().detach() != 0, dtype=torch.uint8, device=args.device)
108 | loss_mask = torch.tensor(token_labels.clone().detach() == 2, dtype=torch.uint8, device=args.device)
109 | with torch.no_grad():
110 | outputs = model(inputs, attention_mask=attn_mask)
111 | logits = outputs[0]
112 | labels = inputs
113 | shift_logits = logits[..., :-1, :].contiguous()
114 | shift_labels = labels[..., 1:].contiguous()
115 | # Flatten the tokens
116 | loss_fct = CrossEntropyLoss()
117 | flatten_shift_loss_mask = loss_mask[..., :-1].contiguous().view(-1)
118 | ids = torch.nonzero(flatten_shift_loss_mask).view(-1)
119 | all_labels = shift_labels.view(-1)[ids]
120 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[ids], all_labels)
121 | eval_loss += loss.item()*all_labels.shape[0]
122 | num_tok += all_labels.shape[0]
123 |
124 | if step % args.logging_steps == 0:
125 | logger.info(f"Test steps: {step}")
126 | nb_eval_steps += 1
127 |
128 | eval_loss = eval_loss / num_tok
129 | perplexity = torch.exp(torch.tensor(eval_loss))
130 |
131 | result = {
132 | "perplexity": float(perplexity)
133 | }
134 |
135 | for key in sorted(result.keys()):
136 | logger.info(" %s = %s", key, str(result[key]))
137 |
138 | def eval_line_completion(args, model, tokenizer, file_type='test', load_file=None, res_file=None, save_name=None):
139 | """
140 | Evaluate line level code completion on exact match and edit similarity.
141 |
142 | It is recommanded to use single GPU because it could not be batched.
143 | """
144 |
145 | def DecodeIds(idxs):
146 | codes = ""
147 | for idx in idxs:
148 | to_add = tokenizer.convert_ids_to_tokens(idx)
149 | if tokenizer.convert_ids_to_tokens(idx)[0] == '\u0120':
150 | if not codes.endswith(" "):
151 | codes += " " + to_add[1:]
152 | else:
153 | codes += to_add[1:]
154 | elif (
155 | idx in [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id] or
156 | tokenizer.convert_ids_to_tokens(idx).startswith("").strip()
227 | else:
228 | text = DecodeIds(t).strip("{").strip()
229 | preds.append(text)
230 | gts.append(gt[0])
231 | edit_sim += fuzz.ratio(text, gt[0])
232 | em += 1 if text == gt[0] else 0
233 | if step % args.logging_steps == 0:
234 | logger.warning(f"{step} are done!")
235 |
236 | with open(save_name, "w") as f:
237 | for pred_text in preds:
238 | f.write(pred_text+"\n")
239 |
240 | logger.warning(f"Test {len(preds)} samples")
241 | logger.warning(f"Edit sim: {edit_sim/len(preds)}, EM: {em/len(preds)}")
242 |
243 |
244 | def main():
245 | parser = argparse.ArgumentParser()
246 |
247 | ## Required parameters
248 | parser.add_argument("--data_dir", default=None, type=str, required=True,
249 | help="The input data path.")
250 | parser.add_argument("--langs", default=None, type=str, required=True,
251 | help="Languages to train, if all, train all languages in data_dir")
252 | parser.add_argument("--output_dir", default=None, type=str, required=True,
253 | help="The output directory where the model predictions and checkpoints will be written.")
254 |
255 | parser.add_argument("--load_file_name", default=None, type=str,
256 | help="search corpus to load")
257 | parser.add_argument("--search_res", default=None, type=str,
258 | help="file that saves search results")
259 | parser.add_argument("--save_name", default=None, type=str,
260 | help="file to save model predictions")
261 |
262 | ## Other parameters
263 | parser.add_argument("--model_type", default="gpt2", type=str,
264 | help="The model architecture to be fine-tuned.")
265 | parser.add_argument("--pretrain_dir", default="", type=str,
266 | help="The output directory where the model predictions and checkpoints will be written.")
267 | parser.add_argument("--config_dir", type=str,
268 | help="config name. Required when training from scratch")
269 | parser.add_argument("--tokenizer_dir", type=str,
270 | help="Pre-trained tokenizer dir. Required when training from scratch")
271 | parser.add_argument("--lit_file", type=str,
272 | help="literals json file")
273 |
274 | parser.add_argument("--block_size", default=1024, type=int,
275 | help="Optional input sequence length after tokenization."
276 | "The training dataset will be truncated in block of this size for training."
277 | "Default to the model max input length for single sentence inputs (take into account special tokens).")
278 | parser.add_argument("--eval_line", action='store_true',
279 | help="Whether to run eval on the dev set.")
280 |
281 | parser.add_argument("--per_gpu_eval_batch_size", default=12, type=int,
282 | help="Batch size per GPU/CPU for evaluation.")
283 |
284 | parser.add_argument('--logging_steps', type=int, default=1000,
285 | help="Log every X updates steps.")
286 | parser.add_argument("--no_cuda", action='store_true',
287 | help="Avoid using CUDA when available")
288 | parser.add_argument('--seed', type=int, default=42,
289 | help="random seed for initialization")
290 |
291 | pool = None
292 | args = parser.parse_args()
293 |
294 | # args.output_dir = os.path.join(args.output_dir, args.dataset)
295 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
296 | datefmt='%m/%d/%Y %H:%M:%S',
297 | level=logging.INFO)
298 |
299 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
300 | args.n_gpu = torch.cuda.device_count()
301 | args.device = device
302 | # args.batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
303 |
304 | # Set seed
305 | set_seed(args)
306 |
307 | # get special tokens
308 | special_tokens = get_special_tokens(args.lit_file)
309 |
310 | # Load pre-trained model
311 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
312 | pretrained = args.pretrain_dir
313 | if pretrained:
314 | tokenizer = tokenizer_class.from_pretrained(pretrained, sep_token='', bos_token='', eos_token='', pad_token='', unk_token='<|UNKNOWN|>', additional_special_tokens=special_tokens)
315 | logger.warning(f"Loading model from {pretrained}")
316 | model = model_class.from_pretrained(pretrained)
317 | model.resize_token_embeddings(len(tokenizer))
318 | else:
319 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_dir, sep_token='', bos_token='', eos_token='', pad_token='', unk_token='<|UNKNOWN|>', additional_special_tokens=special_tokens)
320 | args.vocab_size = len(tokenizer)
321 | config = config_class.from_pretrained(args.config_dir)
322 | model = model_class(config)
323 | model.resize_token_embeddings(len(tokenizer))
324 |
325 | model_parameters = model.parameters()
326 | num_params = sum([np.prod(p.size()) for p in model_parameters])
327 | logger.warning(f"Model has a total of {num_params} trainable parameters")
328 |
329 | logger.warning("Training/evaluation parameters %s", args)
330 |
331 | # Only works on single GPU
332 | if args.eval_line:
333 | eval_line_completion(
334 | args,
335 | model,
336 | tokenizer,
337 | file_type="test",
338 | load_file=args.load_file_name,
339 | res_file=args.search_res,
340 | save_name=args.save_name
341 | )
342 |
343 |
344 | if __name__ == "__main__":
345 | main()
346 |
--------------------------------------------------------------------------------
/parser/DFG.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT license.
3 |
4 | from tree_sitter import Language, Parser
5 | from .utils import (remove_comments_and_docstrings,
6 | tree_to_token_index,
7 | index_to_code_token,
8 | tree_to_variable_index)
9 |
10 |
11 | def DFG_python(root_node,index_to_code,states):
12 | assignment=['assignment','augmented_assignment','for_in_clause']
13 | if_statement=['if_statement']
14 | for_statement=['for_statement']
15 | while_statement=['while_statement']
16 | do_first_statement=['for_in_clause']
17 | def_statement=['default_parameter']
18 | states=states.copy()
19 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment':
20 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)]
21 | if root_node.type==code:
22 | return [],states
23 | elif code in states:
24 | return [(code,idx,'comesFrom',[code],states[code].copy())],states
25 | else:
26 | if root_node.type=='identifier':
27 | states[code]=[idx]
28 | return [(code,idx,'comesFrom',[],[])],states
29 | elif root_node.type in def_statement:
30 | name=root_node.child_by_field_name('name')
31 | value=root_node.child_by_field_name('value')
32 | DFG=[]
33 | if value is None:
34 | indexs=tree_to_variable_index(name,index_to_code)
35 | for index in indexs:
36 | idx,code=index_to_code[index]
37 | DFG.append((code,idx,'comesFrom',[],[]))
38 | states[code]=[idx]
39 | return sorted(DFG,key=lambda x:x[1]),states
40 | else:
41 | name_indexs=tree_to_variable_index(name,index_to_code)
42 | value_indexs=tree_to_variable_index(value,index_to_code)
43 | temp,states=DFG_python(value,index_to_code,states)
44 | DFG+=temp
45 | for index1 in name_indexs:
46 | idx1,code1=index_to_code[index1]
47 | for index2 in value_indexs:
48 | idx2,code2=index_to_code[index2]
49 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2]))
50 | states[code1]=[idx1]
51 | return sorted(DFG,key=lambda x:x[1]),states
52 | elif root_node.type in assignment:
53 | if root_node.type=='for_in_clause':
54 | right_nodes=[root_node.children[-1]]
55 | left_nodes=[root_node.child_by_field_name('left')]
56 | else:
57 | if root_node.child_by_field_name('right') is None:
58 | return [],states
59 | left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=',']
60 | right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=',']
61 | if len(right_nodes)!=len(left_nodes):
62 | left_nodes=[root_node.child_by_field_name('left')]
63 | right_nodes=[root_node.child_by_field_name('right')]
64 | if len(left_nodes)==0:
65 | left_nodes=[root_node.child_by_field_name('left')]
66 | if len(right_nodes)==0:
67 | right_nodes=[root_node.child_by_field_name('right')]
68 | DFG=[]
69 | for node in right_nodes:
70 | temp,states=DFG_python(node,index_to_code,states)
71 | DFG+=temp
72 |
73 | for left_node,right_node in zip(left_nodes,right_nodes):
74 | left_tokens_index=tree_to_variable_index(left_node,index_to_code)
75 | right_tokens_index=tree_to_variable_index(right_node,index_to_code)
76 | temp=[]
77 | for token1_index in left_tokens_index:
78 | idx1,code1=index_to_code[token1_index]
79 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index],
80 | [index_to_code[x][0] for x in right_tokens_index]))
81 | states[code1]=[idx1]
82 | DFG+=temp
83 | return sorted(DFG,key=lambda x:x[1]),states
84 | elif root_node.type in if_statement:
85 | DFG=[]
86 | current_states=states.copy()
87 | others_states=[]
88 | tag=False
89 | if 'else' in root_node.type:
90 | tag=True
91 | for child in root_node.children:
92 | if 'else' in child.type:
93 | tag=True
94 | if child.type not in ['elif_clause','else_clause']:
95 | temp,current_states=DFG_python(child,index_to_code,current_states)
96 | DFG+=temp
97 | else:
98 | temp,new_states=DFG_python(child,index_to_code,states)
99 | DFG+=temp
100 | others_states.append(new_states)
101 | others_states.append(current_states)
102 | if tag is False:
103 | others_states.append(states)
104 | new_states={}
105 | for dic in others_states:
106 | for key in dic:
107 | if key not in new_states:
108 | new_states[key]=dic[key].copy()
109 | else:
110 | new_states[key]+=dic[key]
111 | for key in new_states:
112 | new_states[key]=sorted(list(set(new_states[key])))
113 | return sorted(DFG,key=lambda x:x[1]),new_states
114 | elif root_node.type in for_statement:
115 | DFG=[]
116 | for i in range(2):
117 | right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=',']
118 | left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=',']
119 | if len(right_nodes)!=len(left_nodes):
120 | left_nodes=[root_node.child_by_field_name('left')]
121 | right_nodes=[root_node.child_by_field_name('right')]
122 | if len(left_nodes)==0:
123 | left_nodes=[root_node.child_by_field_name('left')]
124 | if len(right_nodes)==0:
125 | right_nodes=[root_node.child_by_field_name('right')]
126 | for node in right_nodes:
127 | temp,states=DFG_python(node,index_to_code,states)
128 | DFG+=temp
129 | for left_node,right_node in zip(left_nodes,right_nodes):
130 | left_tokens_index=tree_to_variable_index(left_node,index_to_code)
131 | right_tokens_index=tree_to_variable_index(right_node,index_to_code)
132 | temp=[]
133 | for token1_index in left_tokens_index:
134 | idx1,code1=index_to_code[token1_index]
135 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index],
136 | [index_to_code[x][0] for x in right_tokens_index]))
137 | states[code1]=[idx1]
138 | DFG+=temp
139 | if root_node.children[-1].type=="block":
140 | temp,states=DFG_python(root_node.children[-1],index_to_code,states)
141 | DFG+=temp
142 | dic={}
143 | for x in DFG:
144 | if (x[0],x[1],x[2]) not in dic:
145 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
146 | else:
147 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
148 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
149 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
150 | return sorted(DFG,key=lambda x:x[1]),states
151 | elif root_node.type in while_statement:
152 | DFG=[]
153 | for i in range(2):
154 | for child in root_node.children:
155 | temp,states=DFG_python(child,index_to_code,states)
156 | DFG+=temp
157 | dic={}
158 | for x in DFG:
159 | if (x[0],x[1],x[2]) not in dic:
160 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
161 | else:
162 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
163 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
164 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
165 | return sorted(DFG,key=lambda x:x[1]),states
166 | else:
167 | DFG=[]
168 | for child in root_node.children:
169 | if child.type in do_first_statement:
170 | temp,states=DFG_python(child,index_to_code,states)
171 | DFG+=temp
172 | for child in root_node.children:
173 | if child.type not in do_first_statement:
174 | temp,states=DFG_python(child,index_to_code,states)
175 | DFG+=temp
176 |
177 | return sorted(DFG,key=lambda x:x[1]),states
178 |
179 |
180 | def DFG_java(root_node,index_to_code,states):
181 | assignment=['assignment_expression']
182 | def_statement=['variable_declarator']
183 | increment_statement=['update_expression']
184 | if_statement=['if_statement','else']
185 | for_statement=['for_statement']
186 | enhanced_for_statement=['enhanced_for_statement']
187 | while_statement=['while_statement']
188 | do_first_statement=[]
189 | states=states.copy()
190 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment':
191 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)]
192 | if root_node.type==code:
193 | return [],states
194 | elif code in states:
195 | return [(code,idx,'comesFrom',[code],states[code].copy())],states
196 | else:
197 | if root_node.type=='identifier':
198 | states[code]=[idx]
199 | return [(code,idx,'comesFrom',[],[])],states
200 | elif root_node.type in def_statement:
201 | name=root_node.child_by_field_name('name')
202 | value=root_node.child_by_field_name('value')
203 | DFG=[]
204 | if value is None:
205 | indexs=tree_to_variable_index(name,index_to_code)
206 | for index in indexs:
207 | idx,code=index_to_code[index]
208 | DFG.append((code,idx,'comesFrom',[],[]))
209 | states[code]=[idx]
210 | return sorted(DFG,key=lambda x:x[1]),states
211 | else:
212 | name_indexs=tree_to_variable_index(name,index_to_code)
213 | value_indexs=tree_to_variable_index(value,index_to_code)
214 | temp,states=DFG_java(value,index_to_code,states)
215 | DFG+=temp
216 | for index1 in name_indexs:
217 | idx1,code1=index_to_code[index1]
218 | for index2 in value_indexs:
219 | idx2,code2=index_to_code[index2]
220 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2]))
221 | states[code1]=[idx1]
222 | return sorted(DFG,key=lambda x:x[1]),states
223 | elif root_node.type in assignment:
224 | left_nodes=root_node.child_by_field_name('left')
225 | right_nodes=root_node.child_by_field_name('right')
226 | DFG=[]
227 | temp,states=DFG_java(right_nodes,index_to_code,states)
228 | DFG+=temp
229 | name_indexs=tree_to_variable_index(left_nodes,index_to_code)
230 | value_indexs=tree_to_variable_index(right_nodes,index_to_code)
231 | for index1 in name_indexs:
232 | idx1,code1=index_to_code[index1]
233 | for index2 in value_indexs:
234 | idx2,code2=index_to_code[index2]
235 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
236 | states[code1]=[idx1]
237 | return sorted(DFG,key=lambda x:x[1]),states
238 | elif root_node.type in increment_statement:
239 | DFG=[]
240 | indexs=tree_to_variable_index(root_node,index_to_code)
241 | for index1 in indexs:
242 | idx1,code1=index_to_code[index1]
243 | for index2 in indexs:
244 | idx2,code2=index_to_code[index2]
245 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
246 | states[code1]=[idx1]
247 | return sorted(DFG,key=lambda x:x[1]),states
248 | elif root_node.type in if_statement:
249 | DFG=[]
250 | current_states=states.copy()
251 | others_states=[]
252 | flag=False
253 | tag=False
254 | if 'else' in root_node.type:
255 | tag=True
256 | for child in root_node.children:
257 | if 'else' in child.type:
258 | tag=True
259 | if child.type not in if_statement and flag is False:
260 | temp,current_states=DFG_java(child,index_to_code,current_states)
261 | DFG+=temp
262 | else:
263 | flag=True
264 | temp,new_states=DFG_java(child,index_to_code,states)
265 | DFG+=temp
266 | others_states.append(new_states)
267 | others_states.append(current_states)
268 | if tag is False:
269 | others_states.append(states)
270 | new_states={}
271 | for dic in others_states:
272 | for key in dic:
273 | if key not in new_states:
274 | new_states[key]=dic[key].copy()
275 | else:
276 | new_states[key]+=dic[key]
277 | for key in new_states:
278 | new_states[key]=sorted(list(set(new_states[key])))
279 | return sorted(DFG,key=lambda x:x[1]),new_states
280 | elif root_node.type in for_statement:
281 | DFG=[]
282 | for child in root_node.children:
283 | temp,states=DFG_java(child,index_to_code,states)
284 | DFG+=temp
285 | flag=False
286 | for child in root_node.children:
287 | if flag:
288 | temp,states=DFG_java(child,index_to_code,states)
289 | DFG+=temp
290 | elif child.type=="local_variable_declaration":
291 | flag=True
292 | dic={}
293 | for x in DFG:
294 | if (x[0],x[1],x[2]) not in dic:
295 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
296 | else:
297 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
298 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
299 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
300 | return sorted(DFG,key=lambda x:x[1]),states
301 | elif root_node.type in enhanced_for_statement:
302 | name=root_node.child_by_field_name('name')
303 | value=root_node.child_by_field_name('value')
304 | body=root_node.child_by_field_name('body')
305 | DFG=[]
306 | for i in range(2):
307 | temp,states=DFG_java(value,index_to_code,states)
308 | DFG+=temp
309 | name_indexs=tree_to_variable_index(name,index_to_code)
310 | value_indexs=tree_to_variable_index(value,index_to_code)
311 | for index1 in name_indexs:
312 | idx1,code1=index_to_code[index1]
313 | for index2 in value_indexs:
314 | idx2,code2=index_to_code[index2]
315 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
316 | states[code1]=[idx1]
317 | temp,states=DFG_java(body,index_to_code,states)
318 | DFG+=temp
319 | dic={}
320 | for x in DFG:
321 | if (x[0],x[1],x[2]) not in dic:
322 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
323 | else:
324 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
325 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
326 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
327 | return sorted(DFG,key=lambda x:x[1]),states
328 | elif root_node.type in while_statement:
329 | DFG=[]
330 | for i in range(2):
331 | for child in root_node.children:
332 | temp,states=DFG_java(child,index_to_code,states)
333 | DFG+=temp
334 | dic={}
335 | for x in DFG:
336 | if (x[0],x[1],x[2]) not in dic:
337 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
338 | else:
339 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
340 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
341 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
342 | return sorted(DFG,key=lambda x:x[1]),states
343 | else:
344 | DFG=[]
345 | for child in root_node.children:
346 | if child.type in do_first_statement:
347 | temp,states=DFG_java(child,index_to_code,states)
348 | DFG+=temp
349 | for child in root_node.children:
350 | if child.type not in do_first_statement:
351 | temp,states=DFG_java(child,index_to_code,states)
352 | DFG+=temp
353 |
354 | return sorted(DFG,key=lambda x:x[1]),states
355 |
356 | def DFG_csharp(root_node,index_to_code,states):
357 | assignment=['assignment_expression']
358 | def_statement=['variable_declarator']
359 | increment_statement=['postfix_unary_expression']
360 | if_statement=['if_statement','else']
361 | for_statement=['for_statement']
362 | enhanced_for_statement=['for_each_statement']
363 | while_statement=['while_statement']
364 | do_first_statement=[]
365 | states=states.copy()
366 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment':
367 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)]
368 | if root_node.type==code:
369 | return [],states
370 | elif code in states:
371 | return [(code,idx,'comesFrom',[code],states[code].copy())],states
372 | else:
373 | if root_node.type=='identifier':
374 | states[code]=[idx]
375 | return [(code,idx,'comesFrom',[],[])],states
376 | elif root_node.type in def_statement:
377 | if len(root_node.children)==2:
378 | name=root_node.children[0]
379 | value=root_node.children[1]
380 | else:
381 | name=root_node.children[0]
382 | value=None
383 | DFG=[]
384 | if value is None:
385 | indexs=tree_to_variable_index(name,index_to_code)
386 | for index in indexs:
387 | idx,code=index_to_code[index]
388 | DFG.append((code,idx,'comesFrom',[],[]))
389 | states[code]=[idx]
390 | return sorted(DFG,key=lambda x:x[1]),states
391 | else:
392 | name_indexs=tree_to_variable_index(name,index_to_code)
393 | value_indexs=tree_to_variable_index(value,index_to_code)
394 | temp,states=DFG_csharp(value,index_to_code,states)
395 | DFG+=temp
396 | for index1 in name_indexs:
397 | idx1,code1=index_to_code[index1]
398 | for index2 in value_indexs:
399 | idx2,code2=index_to_code[index2]
400 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2]))
401 | states[code1]=[idx1]
402 | return sorted(DFG,key=lambda x:x[1]),states
403 | elif root_node.type in assignment:
404 | left_nodes=root_node.child_by_field_name('left')
405 | right_nodes=root_node.child_by_field_name('right')
406 | DFG=[]
407 | temp,states=DFG_csharp(right_nodes,index_to_code,states)
408 | DFG+=temp
409 | name_indexs=tree_to_variable_index(left_nodes,index_to_code)
410 | value_indexs=tree_to_variable_index(right_nodes,index_to_code)
411 | for index1 in name_indexs:
412 | idx1,code1=index_to_code[index1]
413 | for index2 in value_indexs:
414 | idx2,code2=index_to_code[index2]
415 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
416 | states[code1]=[idx1]
417 | return sorted(DFG,key=lambda x:x[1]),states
418 | elif root_node.type in increment_statement:
419 | DFG=[]
420 | indexs=tree_to_variable_index(root_node,index_to_code)
421 | for index1 in indexs:
422 | idx1,code1=index_to_code[index1]
423 | for index2 in indexs:
424 | idx2,code2=index_to_code[index2]
425 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
426 | states[code1]=[idx1]
427 | return sorted(DFG,key=lambda x:x[1]),states
428 | elif root_node.type in if_statement:
429 | DFG=[]
430 | current_states=states.copy()
431 | others_states=[]
432 | flag=False
433 | tag=False
434 | if 'else' in root_node.type:
435 | tag=True
436 | for child in root_node.children:
437 | if 'else' in child.type:
438 | tag=True
439 | if child.type not in if_statement and flag is False:
440 | temp,current_states=DFG_csharp(child,index_to_code,current_states)
441 | DFG+=temp
442 | else:
443 | flag=True
444 | temp,new_states=DFG_csharp(child,index_to_code,states)
445 | DFG+=temp
446 | others_states.append(new_states)
447 | others_states.append(current_states)
448 | if tag is False:
449 | others_states.append(states)
450 | new_states={}
451 | for dic in others_states:
452 | for key in dic:
453 | if key not in new_states:
454 | new_states[key]=dic[key].copy()
455 | else:
456 | new_states[key]+=dic[key]
457 | for key in new_states:
458 | new_states[key]=sorted(list(set(new_states[key])))
459 | return sorted(DFG,key=lambda x:x[1]),new_states
460 | elif root_node.type in for_statement:
461 | DFG=[]
462 | for child in root_node.children:
463 | temp,states=DFG_csharp(child,index_to_code,states)
464 | DFG+=temp
465 | flag=False
466 | for child in root_node.children:
467 | if flag:
468 | temp,states=DFG_csharp(child,index_to_code,states)
469 | DFG+=temp
470 | elif child.type=="local_variable_declaration":
471 | flag=True
472 | dic={}
473 | for x in DFG:
474 | if (x[0],x[1],x[2]) not in dic:
475 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
476 | else:
477 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
478 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
479 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
480 | return sorted(DFG,key=lambda x:x[1]),states
481 | elif root_node.type in enhanced_for_statement:
482 | name=root_node.child_by_field_name('left')
483 | value=root_node.child_by_field_name('right')
484 | body=root_node.child_by_field_name('body')
485 | DFG=[]
486 | for i in range(2):
487 | temp,states=DFG_csharp(value,index_to_code,states)
488 | DFG+=temp
489 | name_indexs=tree_to_variable_index(name,index_to_code)
490 | value_indexs=tree_to_variable_index(value,index_to_code)
491 | for index1 in name_indexs:
492 | idx1,code1=index_to_code[index1]
493 | for index2 in value_indexs:
494 | idx2,code2=index_to_code[index2]
495 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
496 | states[code1]=[idx1]
497 | temp,states=DFG_csharp(body,index_to_code,states)
498 | DFG+=temp
499 | dic={}
500 | for x in DFG:
501 | if (x[0],x[1],x[2]) not in dic:
502 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
503 | else:
504 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
505 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
506 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
507 | return sorted(DFG,key=lambda x:x[1]),states
508 | elif root_node.type in while_statement:
509 | DFG=[]
510 | for i in range(2):
511 | for child in root_node.children:
512 | temp,states=DFG_csharp(child,index_to_code,states)
513 | DFG+=temp
514 | dic={}
515 | for x in DFG:
516 | if (x[0],x[1],x[2]) not in dic:
517 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
518 | else:
519 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
520 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
521 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
522 | return sorted(DFG,key=lambda x:x[1]),states
523 | else:
524 | DFG=[]
525 | for child in root_node.children:
526 | if child.type in do_first_statement:
527 | temp,states=DFG_csharp(child,index_to_code,states)
528 | DFG+=temp
529 | for child in root_node.children:
530 | if child.type not in do_first_statement:
531 | temp,states=DFG_csharp(child,index_to_code,states)
532 | DFG+=temp
533 |
534 | return sorted(DFG,key=lambda x:x[1]),states
535 |
536 |
537 |
538 |
539 | def DFG_ruby(root_node,index_to_code,states):
540 | assignment=['assignment','operator_assignment']
541 | if_statement=['if','elsif','else','unless','when']
542 | for_statement=['for']
543 | while_statement=['while_modifier','until']
544 | do_first_statement=[]
545 | def_statement=['keyword_parameter']
546 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment':
547 | states=states.copy()
548 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)]
549 | if root_node.type==code:
550 | return [],states
551 | elif code in states:
552 | return [(code,idx,'comesFrom',[code],states[code].copy())],states
553 | else:
554 | if root_node.type=='identifier':
555 | states[code]=[idx]
556 | return [(code,idx,'comesFrom',[],[])],states
557 | elif root_node.type in def_statement:
558 | name=root_node.child_by_field_name('name')
559 | value=root_node.child_by_field_name('value')
560 | DFG=[]
561 | if value is None:
562 | indexs=tree_to_variable_index(name,index_to_code)
563 | for index in indexs:
564 | idx,code=index_to_code[index]
565 | DFG.append((code,idx,'comesFrom',[],[]))
566 | states[code]=[idx]
567 | return sorted(DFG,key=lambda x:x[1]),states
568 | else:
569 | name_indexs=tree_to_variable_index(name,index_to_code)
570 | value_indexs=tree_to_variable_index(value,index_to_code)
571 | temp,states=DFG_ruby(value,index_to_code,states)
572 | DFG+=temp
573 | for index1 in name_indexs:
574 | idx1,code1=index_to_code[index1]
575 | for index2 in value_indexs:
576 | idx2,code2=index_to_code[index2]
577 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2]))
578 | states[code1]=[idx1]
579 | return sorted(DFG,key=lambda x:x[1]),states
580 | elif root_node.type in assignment:
581 | left_nodes=[x for x in root_node.child_by_field_name('left').children if x.type!=',']
582 | right_nodes=[x for x in root_node.child_by_field_name('right').children if x.type!=',']
583 | if len(right_nodes)!=len(left_nodes):
584 | left_nodes=[root_node.child_by_field_name('left')]
585 | right_nodes=[root_node.child_by_field_name('right')]
586 | if len(left_nodes)==0:
587 | left_nodes=[root_node.child_by_field_name('left')]
588 | if len(right_nodes)==0:
589 | right_nodes=[root_node.child_by_field_name('right')]
590 | if root_node.type=="operator_assignment":
591 | left_nodes=[root_node.children[0]]
592 | right_nodes=[root_node.children[-1]]
593 |
594 | DFG=[]
595 | for node in right_nodes:
596 | temp,states=DFG_ruby(node,index_to_code,states)
597 | DFG+=temp
598 |
599 | for left_node,right_node in zip(left_nodes,right_nodes):
600 | left_tokens_index=tree_to_variable_index(left_node,index_to_code)
601 | right_tokens_index=tree_to_variable_index(right_node,index_to_code)
602 | temp=[]
603 | for token1_index in left_tokens_index:
604 | idx1,code1=index_to_code[token1_index]
605 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index],
606 | [index_to_code[x][0] for x in right_tokens_index]))
607 | states[code1]=[idx1]
608 | DFG+=temp
609 | return sorted(DFG,key=lambda x:x[1]),states
610 | elif root_node.type in if_statement:
611 | DFG=[]
612 | current_states=states.copy()
613 | others_states=[]
614 | tag=False
615 | if 'else' in root_node.type:
616 | tag=True
617 | for child in root_node.children:
618 | if 'else' in child.type:
619 | tag=True
620 | if child.type not in if_statement:
621 | temp,current_states=DFG_ruby(child,index_to_code,current_states)
622 | DFG+=temp
623 | else:
624 | temp,new_states=DFG_ruby(child,index_to_code,states)
625 | DFG+=temp
626 | others_states.append(new_states)
627 | others_states.append(current_states)
628 | if tag is False:
629 | others_states.append(states)
630 | new_states={}
631 | for dic in others_states:
632 | for key in dic:
633 | if key not in new_states:
634 | new_states[key]=dic[key].copy()
635 | else:
636 | new_states[key]+=dic[key]
637 | for key in new_states:
638 | new_states[key]=sorted(list(set(new_states[key])))
639 | return sorted(DFG,key=lambda x:x[1]),new_states
640 | elif root_node.type in for_statement:
641 | DFG=[]
642 | for i in range(2):
643 | left_nodes=[root_node.child_by_field_name('pattern')]
644 | right_nodes=[root_node.child_by_field_name('value')]
645 | assert len(right_nodes)==len(left_nodes)
646 | for node in right_nodes:
647 | temp,states=DFG_ruby(node,index_to_code,states)
648 | DFG+=temp
649 | for left_node,right_node in zip(left_nodes,right_nodes):
650 | left_tokens_index=tree_to_variable_index(left_node,index_to_code)
651 | right_tokens_index=tree_to_variable_index(right_node,index_to_code)
652 | temp=[]
653 | for token1_index in left_tokens_index:
654 | idx1,code1=index_to_code[token1_index]
655 | temp.append((code1,idx1,'computedFrom',[index_to_code[x][1] for x in right_tokens_index],
656 | [index_to_code[x][0] for x in right_tokens_index]))
657 | states[code1]=[idx1]
658 | DFG+=temp
659 | temp,states=DFG_ruby(root_node.child_by_field_name('body'),index_to_code,states)
660 | DFG+=temp
661 | dic={}
662 | for x in DFG:
663 | if (x[0],x[1],x[2]) not in dic:
664 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
665 | else:
666 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
667 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
668 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
669 | return sorted(DFG,key=lambda x:x[1]),states
670 | elif root_node.type in while_statement:
671 | DFG=[]
672 | for i in range(2):
673 | for child in root_node.children:
674 | temp,states=DFG_ruby(child,index_to_code,states)
675 | DFG+=temp
676 | dic={}
677 | for x in DFG:
678 | if (x[0],x[1],x[2]) not in dic:
679 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
680 | else:
681 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
682 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
683 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
684 | return sorted(DFG,key=lambda x:x[1]),states
685 | else:
686 | DFG=[]
687 | for child in root_node.children:
688 | if child.type in do_first_statement:
689 | temp,states=DFG_ruby(child,index_to_code,states)
690 | DFG+=temp
691 | for child in root_node.children:
692 | if child.type not in do_first_statement:
693 | temp,states=DFG_ruby(child,index_to_code,states)
694 | DFG+=temp
695 |
696 | return sorted(DFG,key=lambda x:x[1]),states
697 |
698 | def DFG_go(root_node,index_to_code,states):
699 | assignment=['assignment_statement',]
700 | def_statement=['var_spec']
701 | increment_statement=['inc_statement']
702 | if_statement=['if_statement','else']
703 | for_statement=['for_statement']
704 | enhanced_for_statement=[]
705 | while_statement=[]
706 | do_first_statement=[]
707 | states=states.copy()
708 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment':
709 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)]
710 | if root_node.type==code:
711 | return [],states
712 | elif code in states:
713 | return [(code,idx,'comesFrom',[code],states[code].copy())],states
714 | else:
715 | if root_node.type=='identifier':
716 | states[code]=[idx]
717 | return [(code,idx,'comesFrom',[],[])],states
718 | elif root_node.type in def_statement:
719 | name=root_node.child_by_field_name('name')
720 | value=root_node.child_by_field_name('value')
721 | DFG=[]
722 | if value is None:
723 | indexs=tree_to_variable_index(name,index_to_code)
724 | for index in indexs:
725 | idx,code=index_to_code[index]
726 | DFG.append((code,idx,'comesFrom',[],[]))
727 | states[code]=[idx]
728 | return sorted(DFG,key=lambda x:x[1]),states
729 | else:
730 | name_indexs=tree_to_variable_index(name,index_to_code)
731 | value_indexs=tree_to_variable_index(value,index_to_code)
732 | temp,states=DFG_go(value,index_to_code,states)
733 | DFG+=temp
734 | for index1 in name_indexs:
735 | idx1,code1=index_to_code[index1]
736 | for index2 in value_indexs:
737 | idx2,code2=index_to_code[index2]
738 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2]))
739 | states[code1]=[idx1]
740 | return sorted(DFG,key=lambda x:x[1]),states
741 | elif root_node.type in assignment:
742 | left_nodes=root_node.child_by_field_name('left')
743 | right_nodes=root_node.child_by_field_name('right')
744 | DFG=[]
745 | temp,states=DFG_go(right_nodes,index_to_code,states)
746 | DFG+=temp
747 | name_indexs=tree_to_variable_index(left_nodes,index_to_code)
748 | value_indexs=tree_to_variable_index(right_nodes,index_to_code)
749 | for index1 in name_indexs:
750 | idx1,code1=index_to_code[index1]
751 | for index2 in value_indexs:
752 | idx2,code2=index_to_code[index2]
753 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
754 | states[code1]=[idx1]
755 | return sorted(DFG,key=lambda x:x[1]),states
756 | elif root_node.type in increment_statement:
757 | DFG=[]
758 | indexs=tree_to_variable_index(root_node,index_to_code)
759 | for index1 in indexs:
760 | idx1,code1=index_to_code[index1]
761 | for index2 in indexs:
762 | idx2,code2=index_to_code[index2]
763 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
764 | states[code1]=[idx1]
765 | return sorted(DFG,key=lambda x:x[1]),states
766 | elif root_node.type in if_statement:
767 | DFG=[]
768 | current_states=states.copy()
769 | others_states=[]
770 | flag=False
771 | tag=False
772 | if 'else' in root_node.type:
773 | tag=True
774 | for child in root_node.children:
775 | if 'else' in child.type:
776 | tag=True
777 | if child.type not in if_statement and flag is False:
778 | temp,current_states=DFG_go(child,index_to_code,current_states)
779 | DFG+=temp
780 | else:
781 | flag=True
782 | temp,new_states=DFG_go(child,index_to_code,states)
783 | DFG+=temp
784 | others_states.append(new_states)
785 | others_states.append(current_states)
786 | if tag is False:
787 | others_states.append(states)
788 | new_states={}
789 | for dic in others_states:
790 | for key in dic:
791 | if key not in new_states:
792 | new_states[key]=dic[key].copy()
793 | else:
794 | new_states[key]+=dic[key]
795 | for key in states:
796 | if key not in new_states:
797 | new_states[key]=states[key]
798 | else:
799 | new_states[key]+=states[key]
800 | for key in new_states:
801 | new_states[key]=sorted(list(set(new_states[key])))
802 | return sorted(DFG,key=lambda x:x[1]),new_states
803 | elif root_node.type in for_statement:
804 | DFG=[]
805 | for child in root_node.children:
806 | temp,states=DFG_go(child,index_to_code,states)
807 | DFG+=temp
808 | flag=False
809 | for child in root_node.children:
810 | if flag:
811 | temp,states=DFG_go(child,index_to_code,states)
812 | DFG+=temp
813 | elif child.type=="for_clause":
814 | if child.child_by_field_name('update') is not None:
815 | temp,states=DFG_go(child.child_by_field_name('update'),index_to_code,states)
816 | DFG+=temp
817 | flag=True
818 | dic={}
819 | for x in DFG:
820 | if (x[0],x[1],x[2]) not in dic:
821 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
822 | else:
823 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
824 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
825 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
826 | return sorted(DFG,key=lambda x:x[1]),states
827 | else:
828 | DFG=[]
829 | for child in root_node.children:
830 | if child.type in do_first_statement:
831 | temp,states=DFG_go(child,index_to_code,states)
832 | DFG+=temp
833 | for child in root_node.children:
834 | if child.type not in do_first_statement:
835 | temp,states=DFG_go(child,index_to_code,states)
836 | DFG+=temp
837 |
838 | return sorted(DFG,key=lambda x:x[1]),states
839 |
840 |
841 |
842 |
843 | def DFG_php(root_node,index_to_code,states):
844 | assignment=['assignment_expression','augmented_assignment_expression']
845 | def_statement=['simple_parameter']
846 | increment_statement=['update_expression']
847 | if_statement=['if_statement','else_clause']
848 | for_statement=['for_statement']
849 | enhanced_for_statement=['foreach_statement']
850 | while_statement=['while_statement']
851 | do_first_statement=[]
852 | states=states.copy()
853 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment':
854 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)]
855 | if root_node.type==code:
856 | return [],states
857 | elif code in states:
858 | return [(code,idx,'comesFrom',[code],states[code].copy())],states
859 | else:
860 | if root_node.type=='identifier':
861 | states[code]=[idx]
862 | return [(code,idx,'comesFrom',[],[])],states
863 | elif root_node.type in def_statement:
864 | name=root_node.child_by_field_name('name')
865 | value=root_node.child_by_field_name('default_value')
866 | DFG=[]
867 | if value is None:
868 | indexs=tree_to_variable_index(name,index_to_code)
869 | for index in indexs:
870 | idx,code=index_to_code[index]
871 | DFG.append((code,idx,'comesFrom',[],[]))
872 | states[code]=[idx]
873 | return sorted(DFG,key=lambda x:x[1]),states
874 | else:
875 | name_indexs=tree_to_variable_index(name,index_to_code)
876 | value_indexs=tree_to_variable_index(value,index_to_code)
877 | temp,states=DFG_php(value,index_to_code,states)
878 | DFG+=temp
879 | for index1 in name_indexs:
880 | idx1,code1=index_to_code[index1]
881 | for index2 in value_indexs:
882 | idx2,code2=index_to_code[index2]
883 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2]))
884 | states[code1]=[idx1]
885 | return sorted(DFG,key=lambda x:x[1]),states
886 | elif root_node.type in assignment:
887 | left_nodes=root_node.child_by_field_name('left')
888 | right_nodes=root_node.child_by_field_name('right')
889 | DFG=[]
890 | temp,states=DFG_php(right_nodes,index_to_code,states)
891 | DFG+=temp
892 | name_indexs=tree_to_variable_index(left_nodes,index_to_code)
893 | value_indexs=tree_to_variable_index(right_nodes,index_to_code)
894 | for index1 in name_indexs:
895 | idx1,code1=index_to_code[index1]
896 | for index2 in value_indexs:
897 | idx2,code2=index_to_code[index2]
898 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
899 | states[code1]=[idx1]
900 | return sorted(DFG,key=lambda x:x[1]),states
901 | elif root_node.type in increment_statement:
902 | DFG=[]
903 | indexs=tree_to_variable_index(root_node,index_to_code)
904 | for index1 in indexs:
905 | idx1,code1=index_to_code[index1]
906 | for index2 in indexs:
907 | idx2,code2=index_to_code[index2]
908 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
909 | states[code1]=[idx1]
910 | return sorted(DFG,key=lambda x:x[1]),states
911 | elif root_node.type in if_statement:
912 | DFG=[]
913 | current_states=states.copy()
914 | others_states=[]
915 | flag=False
916 | tag=False
917 | if 'else' in root_node.type:
918 | tag=True
919 | for child in root_node.children:
920 | if 'else' in child.type:
921 | tag=True
922 | if child.type not in if_statement and flag is False:
923 | temp,current_states=DFG_php(child,index_to_code,current_states)
924 | DFG+=temp
925 | else:
926 | flag=True
927 | temp,new_states=DFG_php(child,index_to_code,states)
928 | DFG+=temp
929 | others_states.append(new_states)
930 | others_states.append(current_states)
931 | new_states={}
932 | for dic in others_states:
933 | for key in dic:
934 | if key not in new_states:
935 | new_states[key]=dic[key].copy()
936 | else:
937 | new_states[key]+=dic[key]
938 | for key in states:
939 | if key not in new_states:
940 | new_states[key]=states[key]
941 | else:
942 | new_states[key]+=states[key]
943 | for key in new_states:
944 | new_states[key]=sorted(list(set(new_states[key])))
945 | return sorted(DFG,key=lambda x:x[1]),new_states
946 | elif root_node.type in for_statement:
947 | DFG=[]
948 | for child in root_node.children:
949 | temp,states=DFG_php(child,index_to_code,states)
950 | DFG+=temp
951 | flag=False
952 | for child in root_node.children:
953 | if flag:
954 | temp,states=DFG_php(child,index_to_code,states)
955 | DFG+=temp
956 | elif child.type=="assignment_expression":
957 | flag=True
958 | dic={}
959 | for x in DFG:
960 | if (x[0],x[1],x[2]) not in dic:
961 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
962 | else:
963 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
964 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
965 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
966 | return sorted(DFG,key=lambda x:x[1]),states
967 | elif root_node.type in enhanced_for_statement:
968 | name=None
969 | value=None
970 | for child in root_node.children:
971 | if child.type=='variable_name' and value is None:
972 | value=child
973 | elif child.type=='variable_name' and name is None:
974 | name=child
975 | break
976 | body=root_node.child_by_field_name('body')
977 | DFG=[]
978 | for i in range(2):
979 | temp,states=DFG_php(value,index_to_code,states)
980 | DFG+=temp
981 | name_indexs=tree_to_variable_index(name,index_to_code)
982 | value_indexs=tree_to_variable_index(value,index_to_code)
983 | for index1 in name_indexs:
984 | idx1,code1=index_to_code[index1]
985 | for index2 in value_indexs:
986 | idx2,code2=index_to_code[index2]
987 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
988 | states[code1]=[idx1]
989 | temp,states=DFG_php(body,index_to_code,states)
990 | DFG+=temp
991 | dic={}
992 | for x in DFG:
993 | if (x[0],x[1],x[2]) not in dic:
994 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
995 | else:
996 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
997 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
998 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
999 | return sorted(DFG,key=lambda x:x[1]),states
1000 | elif root_node.type in while_statement:
1001 | DFG=[]
1002 | for i in range(2):
1003 | for child in root_node.children:
1004 | temp,states=DFG_php(child,index_to_code,states)
1005 | DFG+=temp
1006 | dic={}
1007 | for x in DFG:
1008 | if (x[0],x[1],x[2]) not in dic:
1009 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
1010 | else:
1011 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
1012 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
1013 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
1014 | return sorted(DFG,key=lambda x:x[1]),states
1015 | else:
1016 | DFG=[]
1017 | for child in root_node.children:
1018 | if child.type in do_first_statement:
1019 | temp,states=DFG_php(child,index_to_code,states)
1020 | DFG+=temp
1021 | for child in root_node.children:
1022 | if child.type not in do_first_statement:
1023 | temp,states=DFG_php(child,index_to_code,states)
1024 | DFG+=temp
1025 |
1026 | return sorted(DFG,key=lambda x:x[1]),states
1027 |
1028 |
1029 | def DFG_javascript(root_node,index_to_code,states):
1030 | assignment=['assignment_pattern','augmented_assignment_expression']
1031 | def_statement=['variable_declarator']
1032 | increment_statement=['update_expression']
1033 | if_statement=['if_statement','else']
1034 | for_statement=['for_statement']
1035 | enhanced_for_statement=[]
1036 | while_statement=['while_statement']
1037 | do_first_statement=[]
1038 | states=states.copy()
1039 | if (len(root_node.children)==0 or root_node.type=='string') and root_node.type!='comment':
1040 | idx,code=index_to_code[(root_node.start_point,root_node.end_point)]
1041 | if root_node.type==code:
1042 | return [],states
1043 | elif code in states:
1044 | return [(code,idx,'comesFrom',[code],states[code].copy())],states
1045 | else:
1046 | if root_node.type=='identifier':
1047 | states[code]=[idx]
1048 | return [(code,idx,'comesFrom',[],[])],states
1049 | elif root_node.type in def_statement:
1050 | name=root_node.child_by_field_name('name')
1051 | value=root_node.child_by_field_name('value')
1052 | DFG=[]
1053 | if value is None:
1054 | indexs=tree_to_variable_index(name,index_to_code)
1055 | for index in indexs:
1056 | idx,code=index_to_code[index]
1057 | DFG.append((code,idx,'comesFrom',[],[]))
1058 | states[code]=[idx]
1059 | return sorted(DFG,key=lambda x:x[1]),states
1060 | else:
1061 | name_indexs=tree_to_variable_index(name,index_to_code)
1062 | value_indexs=tree_to_variable_index(value,index_to_code)
1063 | temp,states=DFG_javascript(value,index_to_code,states)
1064 | DFG+=temp
1065 | for index1 in name_indexs:
1066 | idx1,code1=index_to_code[index1]
1067 | for index2 in value_indexs:
1068 | idx2,code2=index_to_code[index2]
1069 | DFG.append((code1,idx1,'comesFrom',[code2],[idx2]))
1070 | states[code1]=[idx1]
1071 | return sorted(DFG,key=lambda x:x[1]),states
1072 | elif root_node.type in assignment:
1073 | left_nodes=root_node.child_by_field_name('left')
1074 | right_nodes=root_node.child_by_field_name('right')
1075 | DFG=[]
1076 | temp,states=DFG_javascript(right_nodes,index_to_code,states)
1077 | DFG+=temp
1078 | name_indexs=tree_to_variable_index(left_nodes,index_to_code)
1079 | value_indexs=tree_to_variable_index(right_nodes,index_to_code)
1080 | for index1 in name_indexs:
1081 | idx1,code1=index_to_code[index1]
1082 | for index2 in value_indexs:
1083 | idx2,code2=index_to_code[index2]
1084 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
1085 | states[code1]=[idx1]
1086 | return sorted(DFG,key=lambda x:x[1]),states
1087 | elif root_node.type in increment_statement:
1088 | DFG=[]
1089 | indexs=tree_to_variable_index(root_node,index_to_code)
1090 | for index1 in indexs:
1091 | idx1,code1=index_to_code[index1]
1092 | for index2 in indexs:
1093 | idx2,code2=index_to_code[index2]
1094 | DFG.append((code1,idx1,'computedFrom',[code2],[idx2]))
1095 | states[code1]=[idx1]
1096 | return sorted(DFG,key=lambda x:x[1]),states
1097 | elif root_node.type in if_statement:
1098 | DFG=[]
1099 | current_states=states.copy()
1100 | others_states=[]
1101 | flag=False
1102 | tag=False
1103 | if 'else' in root_node.type:
1104 | tag=True
1105 | for child in root_node.children:
1106 | if 'else' in child.type:
1107 | tag=True
1108 | if child.type not in if_statement and flag is False:
1109 | temp,current_states=DFG_javascript(child,index_to_code,current_states)
1110 | DFG+=temp
1111 | else:
1112 | flag=True
1113 | temp,new_states=DFG_javascript(child,index_to_code,states)
1114 | DFG+=temp
1115 | others_states.append(new_states)
1116 | others_states.append(current_states)
1117 | if tag is False:
1118 | others_states.append(states)
1119 | new_states={}
1120 | for dic in others_states:
1121 | for key in dic:
1122 | if key not in new_states:
1123 | new_states[key]=dic[key].copy()
1124 | else:
1125 | new_states[key]+=dic[key]
1126 | for key in states:
1127 | if key not in new_states:
1128 | new_states[key]=states[key]
1129 | else:
1130 | new_states[key]+=states[key]
1131 | for key in new_states:
1132 | new_states[key]=sorted(list(set(new_states[key])))
1133 | return sorted(DFG,key=lambda x:x[1]),new_states
1134 | elif root_node.type in for_statement:
1135 | DFG=[]
1136 | for child in root_node.children:
1137 | temp,states=DFG_javascript(child,index_to_code,states)
1138 | DFG+=temp
1139 | flag=False
1140 | for child in root_node.children:
1141 | if flag:
1142 | temp,states=DFG_javascript(child,index_to_code,states)
1143 | DFG+=temp
1144 | elif child.type=="variable_declaration":
1145 | flag=True
1146 | dic={}
1147 | for x in DFG:
1148 | if (x[0],x[1],x[2]) not in dic:
1149 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
1150 | else:
1151 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
1152 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
1153 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
1154 | return sorted(DFG,key=lambda x:x[1]),states
1155 | elif root_node.type in while_statement:
1156 | DFG=[]
1157 | for i in range(2):
1158 | for child in root_node.children:
1159 | temp,states=DFG_javascript(child,index_to_code,states)
1160 | DFG+=temp
1161 | dic={}
1162 | for x in DFG:
1163 | if (x[0],x[1],x[2]) not in dic:
1164 | dic[(x[0],x[1],x[2])]=[x[3],x[4]]
1165 | else:
1166 | dic[(x[0],x[1],x[2])][0]=list(set(dic[(x[0],x[1],x[2])][0]+x[3]))
1167 | dic[(x[0],x[1],x[2])][1]=sorted(list(set(dic[(x[0],x[1],x[2])][1]+x[4])))
1168 | DFG=[(x[0],x[1],x[2],y[0],y[1]) for x,y in sorted(dic.items(),key=lambda t:t[0][1])]
1169 | return sorted(DFG,key=lambda x:x[1]),states
1170 | else:
1171 | DFG=[]
1172 | for child in root_node.children:
1173 | if child.type in do_first_statement:
1174 | temp,states=DFG_javascript(child,index_to_code,states)
1175 | DFG+=temp
1176 | for child in root_node.children:
1177 | if child.type not in do_first_statement:
1178 | temp,states=DFG_javascript(child,index_to_code,states)
1179 | DFG+=temp
1180 |
1181 | return sorted(DFG,key=lambda x:x[1]),states
1182 |
1183 |
1184 |
1185 |
--------------------------------------------------------------------------------