├── LICENSE
├── README.md
├── analyze_results.py
├── att_model.py
├── block_diagram.png
├── check_duplication.py
├── context.py
├── create_hole_data.py
├── data_utils.py
├── generate_completions.py
├── generate_rule_representations.py
├── get_info_from_hole_predictions.py
├── model_preprocessed_data.py
├── parse_tree.py
├── preprocessed_data.py
├── projects.txt
├── rearrange_data.py
├── rule_classifier_preprocessed_data.py
├── rule_config.py
├── rule_inference_preprocessed_data.py
├── rule_representation_data.py
├── script_analyze_results.py
├── script_completions.py
├── script_gen_and_preprocess_data.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Disha Shrivastava
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Repository-Level Prompt Generation for Large Language Models of Code
2 | Disha Shrivastava, Hugo Larochelle, Daniel Tarlow
3 |
4 | This repository contains implementation and data for our work [Repository-Level Prompt Generation for Large Language Models of Code](https://arxiv.org/abs/2206.12839). A block diagram of our approach can be found below. For more details, refer to the paper.
5 |
6 |
7 |
8 |
9 |
10 | ## Dependencies
11 | * Access to OpenAI Codex API: https://openai.com/blog/openai-codex/. The key should be placed in a file named `openai_api_key`
12 | * pytorch: https://pytorch.org/
13 | * Huggingface's transformers: https://pypi.org/project/transformers/
14 | * tree-sitter-java: https://github.com/tree-sitter/tree-sitter-java
15 | * tqdm
16 | * tensorboard
17 |
18 | ## Updates
19 | * We have released the trained checkpoints for RLPG-H and RLPG-R models. You can download them [here](https://drive.google.com/file/d/1txmObhsA_Cs8paj1x8IGsoUqX7oHNxbw/view?usp=share_link). While loading the state_dict, please set the `strict` parameter to `False`. Example usage can be found in `rule_inference_preprocessed_data.py`
20 | * We are also releasing the prediction probabilities of the RLPG-H and RLPG-R models for each hole in our validation and test data. You can download them [here](https://drive.google.com/file/d/1WSPf4p0tfWs2nLgbpk53Kh5qJG3_33-b/view?usp=share_link). Each file contains the probabilities for all prompt proposals as given by the corresponding trained RLPG models. Example usage can be found in `get_info_from_hole_predictions.py`
21 |
22 | ## Code
23 | ### Data preprocessing
24 | The web URLs for all the repositories used in our work is provided in projects.txt. Download and store them in a folder called gcode-data. Then run `script_gen_and_preprocess_data.py`. This script will produce an output file called commands_gen_and_preprocess. Running it will execute three scripts:
25 | - `create_sample_data.py`: creates the hole completion data by choosing the midpoint of each line as hole position.
26 | - `parse_tree.py`: creates a parse tree for each file as well as store repo-level meta-info needed to get rule-context.
27 | - `check_duplication.py`: check for duplicates within a repo.
28 | Running this will create a new folder called rule_classifier_data that has train, val and test subfolders. Inside each folder, we will have a folder for a repository that will contain the following:
29 |
30 | * The repository with .java files and preserving the directory structure.
31 | * hole_data
32 | * file_class_data
33 | * parsed_data
34 | * duplicates
35 |
36 | ### Generating completions using Codex, i.e., obtaining the ground-truth for training the rule classifier.
37 | `script_completions.py`
38 | Generates a file commands_completion. Running this will create a new folder called results that has train, val and test subfolders. Inside each folder, we will have the ten folders corresponding to rule context locations. Each folder contains .json files corresponding to rule context types. Each row of the file contains data about the application of that particular rule to a hole. It stores the target hole, predicted hole, the prompt and the validity of the rule.
39 |
40 | ### Generating the oracle
41 | `script_analyze_results.py`
42 | Generates a file commands_analyze_results. Running this file will create a file called oracle inside each repo in rule_classifier_data. This file contains the collated information about the success of each rule for a particular target hole.
43 |
44 | ### Generating the rule context representations for the rule classifier
45 | `generate_rule_representations.py`
46 |
47 | Example usage: `python generate_rule_representations.py --data_split=val --repo=jata4test --emb_model_type=codebert`
48 |
49 | This will lead to creation of codebert_mod folder inside the path rule_classifier_data/val/jata4test. Each file in this folder contains the rule context representation obtained from codebert for each hole.
50 |
51 | ### Capping the number of holes
52 | `rearrange_data.py`
53 | This script will cap the maximum contribution from a repo to 10000 holes. After this, each repo folder will contain files capped_holes_10000, capped_codebert_mod and capped_oracle_10000.
54 |
55 | ### Training the rule classiifer
56 | `rule_classifier_preprocessed_data.py`
57 | This needs the capped_codebert_mod folder (rule context representations) to be present inside each repo folder as well as capped_oracle_10000 file.
58 | The best model is stored in models directory along with the tensorboard logs. The output from each epoch is stored in the outputs folder.
59 |
60 | ### Inference with the rule classifier
61 | `rule_inference_preprocessed_data.py`
62 | This needs the capped_codebert_mod folder (rule context representations) to be present inside each repo folder as well as capped_oracle_10000 file.
63 | This produces a file inside the outputs folder that contains the prediction of the classifier for each hole.
64 |
65 | ### Getting results for variation with k
66 | `get_info_from_predictions.py`
67 | This needs a hole_stats_file as input (generated from the previous step) and a value of k.
68 |
--------------------------------------------------------------------------------
/analyze_results.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | from utils import *
5 | import argparse
6 | from rule_config import rule_hyperparams
7 |
8 | def get_hole_identities(hole_filename, duplicate_filename):
9 | hole_data = pickle.load(open(hole_filename, 'rb'))
10 | duplicate_files = open(duplicate_filename, 'r').readlines()
11 | duplicate_files = [ x.strip() for x in duplicate_files]
12 | hole_identities = []
13 | for (k,v) in hole_data.items():
14 | if k not in duplicate_files and not k.startswith('rule_classifier_data/train/rsbotownversion/trunk/scripts/'):
15 | for hole in v:
16 | h_id = k + '_' + str(hole[0]) + '_' + str(hole[1])
17 | hole_identities.append(h_id)
18 | return hole_identities, duplicate_files
19 |
20 | def obtain_full_rule_results(result_file, codex_results, hole_rule_mapping):
21 | rule_results = read_result_file(result_file, hole_rule_mapping)
22 | j = 0
23 | all_results = []
24 | for i in range(len(codex_results)):
25 | codex_res = codex_results[i]
26 | hid = codex_res['hole_identity']
27 | if j < len(rule_results):
28 | rule_res_hid = rule_results[j]['hole_identity']
29 | if hid == rule_res_hid:
30 | all_results.append(rule_results[j])
31 | j+=1
32 | else:
33 | all_results.append(codex_res)
34 | else:
35 | all_results.append(codex_res)
36 | return all_results
37 |
38 | def is_match(res_line):
39 | res_line = json.loads(res_line)
40 | prediction = res_line['prediction']
41 | hole = res_line['ground-truth hole']
42 | pred = prediction.rstrip()
43 | hole = hole.rstrip()
44 | # there is an exact match corresponding to this hole id
45 | if pred == hole:
46 | return True
47 | else:
48 | return False
49 |
50 | def update_hole_rule_mapping(res_line, hid, hole_rule_mapping, rule_parts):
51 | if is_match(res_line):
52 | if hid in hole_rule_mapping:
53 | hole_rule_mapping[hid].append(rule_parts)
54 | else:
55 | hole_rule_mapping[hid] = [rule_parts]
56 | return hole_rule_mapping
57 |
58 | def modify_results(result_lines, duplicate_files):
59 | if not duplicate_files:
60 | return result_lines
61 | else:
62 | #print(len(duplicate_files), len(result_lines))
63 | mod_result_lines = []
64 | for i in range(len(result_lines)):
65 | res_hid = json.loads(result_lines[i])['hole_identity']
66 | if is_valid_hole(res_hid, duplicate_files):
67 | mod_result_lines.append(result_lines[i])
68 | return mod_result_lines
69 |
70 | def read_result_file(rule_result_file, codex_file, hole_identities, hole_rule_mapping, rule_parts, duplicate_files):
71 | #print(rule_result_file)
72 | rule_lines = open(rule_result_file, 'r').readlines()
73 | codex_lines = open(codex_file, 'r').readlines()
74 | rule_lines = modify_results(rule_lines, duplicate_files)
75 | codex_lines = modify_results(codex_lines, duplicate_files)
76 | j = 0
77 | for i in range(len(hole_identities)):
78 | hid = hole_identities[i]
79 | codex_hid = json.loads(codex_lines[i])['hole_identity']
80 | codex_hid = alter_hid(codex_hid, hid)
81 |
82 | if j < len(rule_lines):
83 | try:
84 | rule_hid = json.loads(rule_lines[j])['hole_identity']
85 | except:
86 | print(rule_result_file)
87 | rule_hid = alter_hid(rule_hid, hid)
88 | # use rule result
89 | if hid == rule_hid:
90 | hole_rule_mapping = update_hole_rule_mapping(rule_lines[j], hid, hole_rule_mapping, rule_parts)
91 | j+=1
92 | else:
93 | # use codex result
94 | hole_rule_mapping = update_hole_rule_mapping(codex_lines[i], hid, hole_rule_mapping, rule_parts)
95 | else:
96 | # use codex result
97 | hole_rule_mapping = update_hole_rule_mapping(codex_lines[i], hid, hole_rule_mapping, rule_parts)
98 |
99 | return hole_rule_mapping
100 |
101 | def get_results(base_result_dir, context_location, exclude_codex=True):
102 | context_result_dir = os.path.join(base_result_dir, context_location)
103 | result_files = next(os.walk(context_result_dir), (None, None, []))[2] # [] if no file
104 | if result_files and exclude_codex and 'codex_4072.json' in result_files:
105 | result_files.remove('codex_4072.json')
106 | mod_result_files = [os.path.join(context_result_dir, result_file) for result_file in result_files if result_file]
107 | result_files = [f for f in mod_result_files if os.path.getsize(f)>0]
108 | return result_files
109 |
110 | def check_validity_by_rule_parts(rule):
111 | valid = False
112 | if 'codex_4072' in rule:
113 | context_location = 'codex'
114 | context_type = 'codex'
115 | context_ratio = 0.5
116 | valid = True
117 | else:
118 | context_location = rule.split("/")[-2]
119 | rule_parts = rule.split("/")[-1].split("_")
120 | i = 5
121 | ct = '_'.join(rule_parts[1:5])
122 | # keep removing the parts joined by _ till it matches a valid context_type
123 | while(ct not in context_types_to_index):
124 | i-=1
125 | ct = '_'.join(rule_parts[1:i])
126 | context_type = ct
127 |
128 | mod_rule_parts = rule_parts[i:]
129 | try:
130 | context_ratio = float(mod_rule_parts[3])/4072
131 | if check_rule_validity(context_type, mod_rule_parts):
132 | valid =True
133 | except:
134 | valid = False
135 | if valid:
136 | return context_location, context_type, context_ratio
137 | else:
138 | return '', '', ''
139 |
140 | def get_all_hole_rule_mapping(base_result_dir, hole_identities, duplicate_files):
141 | in_file_files = get_results(base_result_dir, 'in_file', exclude_codex=False)
142 | parent_class_files = get_results(base_result_dir, 'parent_class_file')
143 | import_files = get_results(base_result_dir, 'import_file')
144 | sibling_files = get_results(base_result_dir, 'sibling_file')
145 | similar_name_files = get_results(base_result_dir, 'similar_name_file')
146 | child_class_files = get_results(base_result_dir, 'child_class_file')
147 | import_of_similar_name_files = get_results(base_result_dir, 'import_of_similar_name_file')
148 | import_of_sibling_files = get_results(base_result_dir, 'import_of_sibling_file')
149 | import_of_parent_class_files = get_results(base_result_dir, 'import_of_parent_class_file')
150 | import_of_child_class_files = get_results(base_result_dir, 'import_of_child_class_file')
151 | codex_file = os.path.join(base_result_dir, 'in_file', 'codex_4072.json')
152 |
153 | result_files = in_file_files + parent_class_files + import_files + sibling_files + similar_name_files \
154 | + child_class_files + import_of_similar_name_files + import_of_sibling_files + import_of_child_class_files + import_of_parent_class_files
155 |
156 | # print(len(in_file_files), len(parent_class_files), len(import_files), len(sibling_files), len(similar_name_files), \
157 | # len(child_class_files) , len(import_of_sibling_files), len(import_of_similar_name_files), len(import_of_child_class_files),\
158 | # len(import_of_parent_class_files), len(result_files))
159 |
160 | hole_rule_mapping = {}
161 | for result_file in result_files:
162 | context_location, context_type, context_ratio = check_validity_by_rule_parts(result_file)
163 | if context_location:
164 | hole_rule_mapping = read_result_file(result_file, codex_file, hole_identities, hole_rule_mapping, \
165 | (context_location, context_type, context_ratio), duplicate_files)
166 | return hole_rule_mapping
167 |
168 | def get_failed_holes(successful_holes, hole_identities):
169 | failed_holes = []
170 | for hole_identity in hole_identities:
171 | if hole_identity not in successful_holes:
172 | failed_holes.append(hole_identity)
173 | return failed_holes
174 |
175 | def find_rule_pattern(rule_pattern, rules):
176 | found = False
177 | for rule in rules:
178 | if rule_pattern in rule:
179 | found = True
180 | break
181 | return found
182 |
183 | def find_rule_specific_success(successful_holes, query_file_pattern=''):
184 | count = 0
185 | for h_id, rules in successful_holes.items():
186 | if find_rule_pattern(query_file_pattern, rules):
187 | count +=1
188 | return count
189 |
190 | def find_complementary_rules(successful_holes):
191 | other_rules = []
192 | not_lines_not_iden_codex = []
193 | not_lines_iden = []
194 | lines = []
195 | for h_id, rules in successful_holes.items():
196 | if not find_rule_pattern('lines', rules):
197 | if not find_rule_pattern('identifiers', rules):
198 | if not find_rule_pattern('codex', rules):
199 | other_rules.append((h_id, rules))
200 | else:
201 | not_lines_not_iden_codex.append((h_id, rules))
202 | else:
203 | not_lines_iden.append((h_id, rules))
204 | else:
205 | lines.append((h_id, rules))
206 | return lines, not_lines_iden, not_lines_not_iden_codex, other_rules
207 |
208 | def check_rule_validity(context_type, rule_parts):
209 | valid = False
210 | valid_hyperparams = rule_hyperparams[context_type]
211 | try:
212 | rule_context_ratio = float(rule_parts[3])/4072
213 | except:
214 | return False
215 | rule_prompt_separator = rule_parts[-1]
216 | rule_rule_context_formatting = '_'.join(rule_parts[4:-1])
217 | if rule_context_ratio in valid_hyperparams['context_ratio']:
218 | if rule_prompt_separator in valid_hyperparams['prompt_separator']:
219 | if rule_rule_context_formatting in valid_hyperparams['rule_context_formatting']:
220 | valid = True
221 | return valid
222 |
223 | def get_rule_templated_version(oracle):
224 | mod_oracle = {}
225 | for hid, rules in oracle.items():
226 | context_locations = []
227 | context_types = []
228 | combined = []
229 | for rule in rules:
230 | context_location, context_type, context_ratio = rule
231 | context_locations.append(context_location)
232 | context_types.append(context_type)
233 | if context_location != 'codex':
234 | combined.append(context_location + '#' + context_type + '#' + str(context_ratio))
235 | else:
236 | combined.append('codex')
237 | context_location = get_multi_hot_vector(context_locations, 'cl')
238 | context_type = get_multi_hot_vector(context_types, 'ct')
239 | comb = get_multi_hot_vector(combined, 'com')
240 | mod_oracle[hid] = {'cl': context_location, 'ct': context_type, 'com': comb}
241 | return mod_oracle
242 |
243 |
244 | def find_rule_mapping(successful_holes):
245 | rule_mapping = {}
246 | for hid, rules in successful_holes.items():
247 | for rule in rules:
248 | if rule not in rule_mapping:
249 | rule_mapping[rule] = [hid]
250 | else:
251 | rule_mapping[rule].append(hid)
252 | return rule_mapping
253 |
254 | def find_single_best_rule_success(rule_mapping):
255 | best_single_rule_success = 0
256 | for k, v in rule_mapping.items():
257 | if len(v)> best_single_rule_success:
258 | best_rule_parts = k
259 | best_single_rule_success = len(v)
260 | best_rule = best_rule_parts[0] + '_' + best_rule_parts[1] + '_' + str(best_rule_parts[2])
261 | return best_rule, best_single_rule_success
262 |
263 |
264 | def setup_args():
265 | """
266 | Description: Takes in the command-line arguments from user
267 | """
268 | parser = argparse.ArgumentParser()
269 | parser.add_argument("--base_dir", type=str, default='rule_classifier_data', help="base directory for the data")
270 | parser.add_argument("--data_split", type=str, default='test', help="data split to store the data")
271 | parser.add_argument("--proj_name", type=str, default='dovetaildb', help="name of the input repo")
272 |
273 | return parser.parse_args()
274 |
275 | if __name__ == '__main__':
276 |
277 | args = setup_args()
278 | hole_filename = os.path.join(args.base_dir, args.data_split, args.proj_name, 'hole_data')
279 | duplicate_filename = os.path.join(args.base_dir, args.data_split, args.proj_name, 'duplicates')
280 | hole_identities, duplicate_files = get_hole_identities(hole_filename, duplicate_filename)
281 | print("Total number of holes:", len(hole_identities))
282 | base_result_dir = os.path.join('results', args.base_dir, args.data_split, args.proj_name)
283 | print(len(duplicate_files))
284 | successful_holes = get_all_hole_rule_mapping(base_result_dir, hole_identities, duplicate_files)
285 | print("Number of holes that got atleast one rule successful: ", len(successful_holes))
286 |
287 | oracle = get_rule_templated_version(successful_holes)
288 | with open(os.path.join(args.base_dir, args.data_split, args.proj_name, 'oracle'), 'wb') as f:
289 | pickle.dump(oracle, f)
290 | assert len(successful_holes) == len(oracle)
291 | rule_mapping = find_rule_mapping(successful_holes)
292 | codex_success = len(rule_mapping[('codex', 'codex', 0.5)])
293 | best_rule, best_rule_success = find_single_best_rule_success(rule_mapping)
294 | best_single_rule_success = len(rule_mapping[('in_file', 'lines', 0.75)])
295 | # print(rule_mapping)
296 | print(codex_success, best_single_rule_success, best_rule, best_rule_success)
297 | print(
298 | args.proj_name + ", " + \
299 | str(float(len(successful_holes)*100/len(hole_identities))) + ", " + \
300 | str(float(codex_success*100/len(hole_identities))) + ", " + \
301 | best_rule + ", " +\
302 | str(float(best_rule_success*100/len(hole_identities))) + ", " + \
303 | "in_file_lines_0.75" + ", " +\
304 | str(float(best_single_rule_success*100/len(hole_identities)))
305 | )
306 |
307 | # failed_holes = get_failed_holes(successful_holes, hole_identities)
308 | # print("Number of holes that got no rule successful: ", len(failed_holes))
309 | # with open(os.path.join(base_result_dir, 'failed_cases'), 'wb') as f:
310 | # pickle.dump(failed_holes, f)
311 |
312 | # post_lines_success = find_rule_specific_success(successful_holes, 'lines')
313 | # codex_success = find_rule_specific_success(successful_holes, 'codex')
314 | # identifiers_success = find_rule_specific_success(successful_holes, 'identifiers')
315 | # print("Number of post lines successes: ", post_lines_success)
316 | # print("Number of codex successes: ", codex_success)
317 | # print("Number of identifiers successes: ", identifiers_success)
318 |
319 | # lines, not_lines_iden, not_lines_not_iden_codex, other_rules = find_complementary_rules(successful_holes)
320 | # print("Post Lines: ", len(lines), end=", ")
321 | # print("Not Post Lines, Identifiers: ", len(not_lines_iden), end=", ")
322 | # print("Not Post Lines, Not Identifiers, Codex: ", len(not_lines_not_iden_codex), end=", ")
323 | # print("Other Rules: ", len(other_rules), end=", ")
324 | # print("No Rules: ", len(failed_holes), end="\n")
325 |
326 |
--------------------------------------------------------------------------------
/att_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn.functional as F
3 | from torch import nn
4 | import torch
5 | import math
6 |
7 |
8 | class ScaledDotProductAttention(nn.Module):
9 | ''' Scaled Dot-Product Attention '''
10 |
11 | def __init__(self, temperature, attn_dropout=0.0):
12 | super().__init__()
13 | self.temperature = temperature
14 | self.dropout = nn.Dropout(attn_dropout)
15 |
16 | def forward(self, q, k, mask=None):
17 |
18 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) #(bs, num_heads, max_len_q, max_len_k)
19 | #print("attn:", attn.shape)
20 |
21 | if mask is not None:
22 | attn = attn.masked_fill(mask == 0, -1e9)
23 |
24 | attn = self.dropout(F.softmax(attn, dim=-1))#(bs, num_heads, max_len_q, max_len_k)
25 | return attn
26 |
27 | class MultiHeadAttention(nn.Module):
28 | def __init__(self, n_head=4, d_model=768, d_k=32, d_v=32,dropout=0.0, include_res_ln=True, return_att_weights=False):
29 | super(MultiHeadAttention, self).__init__()
30 |
31 | self.include_res_ln = include_res_ln
32 | self.return_att_weights = return_att_weights
33 | self.n_head = n_head
34 | self.d_k = d_k
35 | self.d_v = d_v
36 | self.d_model = d_model
37 | self.w_qs = nn.Linear(d_model, n_head * d_k)
38 | self.w_ks = nn.Linear(d_model, n_head * d_k)
39 | self.w_vs = nn.Linear(d_model, n_head * d_v)
40 | self.fc = nn.Linear(n_head * d_v, d_model)
41 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5, attn_dropout=dropout)
42 | self.dropout = nn.Dropout(dropout)
43 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
44 |
45 | def forward(self, query, keys, values, mask):
46 |
47 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
48 | sz_b = query.size(0) #batch_size
49 | q = query #(bs, 1, d_model)
50 | k = keys #(bs, #rules, d_model)
51 | v = values #(bs, #rules, d_model)
52 | len_q, len_k, len_v = q.size(1), k.size(1), v.size(1) #(1, n_rules, n_rules)
53 |
54 | residual = q
55 |
56 | # Pass through the pre-attention projection: b x lq x (n*dv)
57 | # Separate different heads: b x lq x n x dv
58 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) #(bs, 1, n_head, d_k)
59 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) #(bs, n_rules, n_head, d_k)
60 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) #(bs, n_rules , n_head, d_k)
61 |
62 | # Transpose for attention dot product: b x n x lq x dv
63 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) #(bs, n_head, 1, d_k), (bs, n_head, n_rules, d_k), (bs, n_head, n_rules, d_k)
64 |
65 | if mask is not None:
66 | mask = mask.unsqueeze(1) # For head axis broadcasting.
67 |
68 | # calculate attention scores
69 | attn = self.attention(q, k, mask)#(bs, num_heads, 1, n_rules)
70 |
71 | pred = torch.matmul(attn, v) #(bs, num_heads, 1, d_model)
72 | pred = pred.transpose(1, 2).contiguous().view(sz_b, len_q, -1) #(bs, 1, num_heads*d_model)
73 | pred = self.dropout(self.fc(pred)) #(bs, 1, d_model)
74 | if self.include_res_ln:
75 | pred += residual
76 | pred = self.layer_norm(pred)
77 |
78 | pred = pred.squeeze()
79 | attn = attn.squeeze()
80 | if self.return_att_weights:
81 | return pred, attn
82 | else:
83 | return pred, None
84 |
85 | class PositionwiseFeedForward(nn.Module):
86 | ''' A two-feed-forward-layer module '''
87 |
88 | def __init__(self, d_in, d_hid, dropout=0.0):
89 | super().__init__()
90 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise
91 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise
92 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
93 | self.dropout = nn.Dropout(dropout)
94 |
95 | def forward(self, x):
96 | residual = x
97 | x = self.w_2(F.relu(self.w_1(x)))
98 | x = self.dropout(x)
99 | x += residual
100 | x = self.layer_norm(x)
101 | return x
102 |
103 | class BasicAggModel(nn.Module):
104 |
105 | def __init__(self, include_ff=True, include_res_ln=True, dropout=0.0, d_inner=2048, d_model=768, return_att_weights=False, n_head=8,
106 | d_k=96, n_rules=63, device='cpu', is_dense_bias=True):
107 |
108 | super(BasicAggModel, self).__init__()
109 |
110 | self.include_ff = include_ff
111 | self.include_res_ln = include_res_ln
112 | self.d_inner = d_inner
113 | self.d_model = d_model
114 | self.n_rules = n_rules
115 | self.device = device
116 | self.return_att_weights = return_att_weights
117 | self.mha = MultiHeadAttention(n_head=n_head, d_k=d_k, d_v=d_k, dropout=dropout,\
118 | return_att_weights=return_att_weights, include_res_ln=include_res_ln)
119 | self.ff = PositionwiseFeedForward(self.d_model, self.d_inner, dropout=dropout)
120 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
121 | self.dropout = nn.Dropout(p=dropout)
122 |
123 | self.rule_encoder = nn.Embedding(self.n_rules, self.d_model)
124 | if is_dense_bias:
125 | self.dense_proj = nn.Linear(d_model, n_rules)
126 | else:
127 | self.dense_proj = nn.Linear(d_model, n_rules, bias=False)
128 |
129 | for name, p in self.named_parameters():
130 | if name!='dense_proj.weight' and p.dim() > 1:
131 | nn.init.xavier_uniform_(p)
132 |
133 |
134 | def forward(self, query, keys, mask):
135 |
136 | query = torch.unsqueeze(query, 1) #(bs, 1, d_model)
137 | bs = query.size(0) #bs
138 |
139 | mask = torch.sum(mask, dim=-1) #(bs, n_rules)
140 | mask = mask.unsqueeze(1) #(bs, n_rules)
141 |
142 | values = keys
143 |
144 | query = self.layer_norm(query) #(bs, 1, d_model)
145 | keys = self.layer_norm(keys) #(bs, n_rules, d_model)
146 | values = self.layer_norm(values) #(bs, n_rules, d_model)
147 |
148 | #Multi-Head Attention
149 | pred, att_weights = self.mha(query, keys, values, mask) #(bs, d_model), #(bs, num_heads, n_rules)
150 |
151 | #Positionwise FeedForward
152 | if self.include_ff:
153 | pred = self.ff(pred)#(bs, d_model)
154 |
155 | #Final projection to get logits
156 | pred = self.dense_proj(pred)#(bs, num_rules)
157 |
158 | if self.return_att_weights:
159 | return pred, att_weights
160 | else:
161 | return pred, None
162 |
163 |
164 |
--------------------------------------------------------------------------------
/block_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shrivastavadisha/repo_level_prompt_generation/3af5f3424740448d8e325b3726e61944f6eec8b6/block_diagram.png
--------------------------------------------------------------------------------
/check_duplication.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import hashlib
4 | import numpy as np
5 | import argparse
6 |
7 | comments = ['*', '/']
8 |
9 | def chunk_reader(fobj, chunk_size=1024):
10 | """Generator that reads a file in chunks of bytes"""
11 | while True:
12 | chunk = fobj.read(chunk_size)
13 | if not chunk:
14 | return
15 | yield chunk
16 |
17 |
18 | def get_hole_count(full_path):
19 | hole_count = 0
20 | file_lines = open(full_path, encoding="utf8", errors='backslashreplace').readlines()
21 | for line in file_lines:
22 | line = line.strip()
23 | # omitting comments and empty lines (heuristic: NEED TO DOUBLE CHECK FOR WIDE APPLICABILITY)
24 | if line and not (np.any([line.startswith(comment) for comment in comments])):
25 | hole_count+=1
26 | return hole_count
27 |
28 | def check_for_duplicates(paths, hash=hashlib.sha1):
29 | duplicate_files = {}
30 | file_count = 0
31 | hole_count = 0
32 | hashes = {}
33 | duplicate_file_paths = []
34 | for path in paths:
35 | for dirpath, dirnames, filenames in os.walk(path):
36 | for filename in filenames:
37 | if os.path.splitext(filename)[1] == '.java':
38 | full_path = os.path.join(dirpath, filename)
39 | hashobj = hash()
40 | for chunk in chunk_reader(open(full_path, 'rb')):
41 | hashobj.update(chunk)
42 | file_id = (hashobj.digest(), os.path.getsize(full_path))
43 | duplicate = hashes.get(file_id, None)
44 | if duplicate:
45 | if duplicate not in duplicate_files:
46 | duplicate_files[duplicate] = [full_path]
47 | else:
48 | duplicate_files[duplicate].append(full_path)
49 | # print("Duplicate found: ")
50 |
51 | else:
52 | hashes[file_id] = full_path
53 |
54 | for k,v in duplicate_files.items():
55 | num_files = len(v)
56 | file_count+= num_files
57 | ind_hole_count = get_hole_count(k)
58 | hole_count += num_files * ind_hole_count
59 | duplicate_file_paths.append(k)
60 | for path in v:
61 | duplicate_file_paths.append(path)
62 |
63 |
64 | print(paths)
65 | print(len(duplicate_file_paths))
66 | print(str(file_count) + ", " + str(hole_count))
67 | with open(os.path.join(paths[0], "duplicates"), "w") as outfile:
68 | outfile.write("\n".join(duplicate_file_paths))
69 |
70 | def setup_args():
71 | """
72 | Description: Takes in the command-line arguments from user
73 | """
74 | parser = argparse.ArgumentParser()
75 |
76 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility")
77 | parser.add_argument("--base_dir", type=str, default='rule_classifier_data/val', \
78 | help="base directory for the data")
79 | parser.add_argument("--proj_name", type=str, default='rsbotownversion', \
80 | help="name of the input repo")
81 |
82 | return parser.parse_args()
83 |
84 | if __name__ == '__main__':
85 |
86 | args = setup_args()
87 |
88 | #Fix seeds
89 | np.random.seed(args.seed)
90 | os.environ['PYTHONHASHSEED']=str(args.seed)
91 |
92 | repo_path = os.path.join(args.base_dir, args.proj_name)
93 | if os.path.isdir(repo_path):
94 | check_for_duplicates([repo_path])
95 |
--------------------------------------------------------------------------------
/context.py:
--------------------------------------------------------------------------------
1 | import os
2 | from utils import *
3 | import numpy as np
4 |
5 | class getContext():
6 |
7 | def __init__(self, context_location='in_file', tokenizer=None, file='',
8 | context_len=4072, parse_data=None,
9 | context_type='lines', context_scope='pre', top_k=-1, top_k_type='first',
10 | attention_scores=None, rule_context_formatting='space',
11 | file_lines=None):
12 |
13 | super(getContext, self).__init__()
14 |
15 | self.file = file
16 | self.context_len = context_len
17 | self.context_location = context_location
18 | self.tokenizer = tokenizer
19 | self.top_k = top_k
20 | self.top_k_type = top_k_type
21 | self.context_type = context_type
22 | self.parse_data = parse_data
23 | self.attention_scores = attention_scores
24 | self.rule_context_formatting = rule_context_formatting
25 | if file_lines !=None:
26 | self.file_lines = file_lines
27 | else:
28 | self.file_lines = open(file, encoding="utf8", errors='backslashreplace').readlines()
29 | if self.context_location == 'in_file':
30 | self.set_context_scope_and_inclusion_type(context_scope)
31 | elif self.context_location == 'parent_class_file':
32 | self.codex_completion_inclusion_type = 'back'
33 | else:
34 | self.import_overlap_type_files = {}
35 | self.codex_completion_inclusion_type = 'front'
36 |
37 | def is_non_empty(self, parse_data_attribute_str):
38 | parse_data_attribute = self.parse_data[self.file][parse_data_attribute_str]
39 | if parse_data_attribute:
40 | return True
41 | else:
42 | return False
43 |
44 | def is_out_files(self):
45 | if self.context_location == 'parent_class_file' or self.context_location == 'import_of_parent_class_file':
46 | return self.is_non_empty('parent_class_filenames')
47 | if self.context_location == 'import_file':
48 | return self.is_non_empty('imports')
49 | if self.context_location == 'sibling_file' or self.context_location == 'reverse_sibling_file' or self.context_location == 'import_of_sibling_file':
50 | return self.is_non_empty('sibling_files')
51 | if self.context_location == 'similar_name_file' or self.context_location == 'reverse_similar_name_file' or self.context_location == 'import_of_similar_name_file':
52 | return self.is_non_empty('similar_name_files')
53 | if self.context_location == 'child_class_file' or self.context_location == 'import_of_child_class_file':
54 | return self.is_non_empty('child_class_filenames')
55 |
56 | def set_context_scope_and_inclusion_type(self, new_scope):
57 | self.context_scope = new_scope
58 | if self.context_scope == 'pre' or self.context_scope =='pre_post':
59 | self.codex_completion_inclusion_type ='back'
60 | if self.context_scope == 'post':
61 | self.codex_completion_inclusion_type ='front'
62 |
63 | def set_hole_pos(self, hole_pos):
64 | self.hole_pos = hole_pos
65 |
66 | def get_context_len(self):
67 | return self.context_len
68 |
69 | def set_context_len(self, new_context_len):
70 | self.context_len = new_context_len
71 |
72 | def get_rule_context_format(self, lst):
73 | if self.rule_context_formatting == 'space':
74 | context = " ".join(lst)
75 |
76 | if self.rule_context_formatting == 'newline':
77 | context = "\n".join(lst)
78 |
79 | if self.rule_context_formatting == 'method_name'\
80 | or self.rule_context_formatting == 'class_name'\
81 | or self.rule_context_formatting == 'class_method_name':
82 | context = " ".join(lst)
83 |
84 | if self.rule_context_formatting == 'comment':
85 | context = " ".join(lst)
86 | context = "/**" + context + "*/"
87 |
88 | return context
89 |
90 | def get_nearest_attribute_index(self, attribute_names, type='class'):
91 | hole_start_line = self.hole_pos[0]
92 | min_pos_diff = 100000
93 | min_pos_diff_index = -1
94 | for i in range(len(attribute_names)):
95 | if attribute_names[i]:
96 | attribute_start_line = attribute_names[i][0][0]
97 | if type == 'class':
98 | pos_diff = hole_start_line - attribute_start_line
99 | if type == 'import':
100 | pos_diff = np.abs(hole_start_line - attribute_start_line)
101 | if pos_diff < min_pos_diff:
102 | if type == 'class':
103 | if pos_diff > 0:
104 | min_pos_diff = pos_diff
105 | min_pos_diff_index = i
106 | #print(pos_diff, min_pos_diff, min_pos_diff_index)
107 | if type == 'import':
108 | if pos_diff != 0 or (pos_diff == 0 and attribute_names[i][1][1] < self.hole_pos[1]):
109 | min_pos_diff = pos_diff
110 | min_pos_diff_index = i
111 | if type == 'class':
112 | return min_pos_diff_index
113 | if type == 'import':
114 | return pos_diff
115 |
116 | def get_relevant_import_of_att_files(self, att_type):
117 | att_import_ranking = {}
118 | att_files = self.parse_data[self.file][att_type]
119 | #att_files = list(set(att_files))
120 | for att_file, att_file_overlap in att_files:
121 | if 'small_' in self.file.split('/')[1]:
122 | att_file = '/'.join([att_file.split('/')[0]] + [self.file.split('/')[1]] + att_file.split('/')[2:])
123 | att_file_imports = list(self.parse_data[att_file]['imports'].keys())
124 | for att_file_import in att_file_imports:
125 | if att_file_import in att_import_ranking:
126 | att_import_ranking[att_file_import]+=1
127 | else:
128 | att_import_ranking[att_file_import] = 1
129 | sorted_att_import_ranking = sorted(att_import_ranking.items(), key=lambda x: x[1], reverse=True)
130 | sorted_att_import_ranking = [imp_file for imp_file, _ in sorted_att_import_ranking]
131 | return sorted_att_import_ranking
132 |
133 | def get_relevant_import_files(self):
134 | all_imports = self.parse_data[self.file]['imports']
135 | import_distances_from_hole = {}
136 | for import_file, import_identifier_loc in all_imports.items():
137 | pos_diff = self.get_nearest_attribute_index(import_identifier_loc, type='import')
138 | if pos_diff != -1:
139 | import_distances_from_hole[import_file] = pos_diff
140 |
141 | # less the position difference from the hole, higher the ranking
142 | sorted_import_distances_from_hole = sorted(import_distances_from_hole.items(), key=lambda x: x[1])
143 | sorted_import_files = [imp_file for imp_file, _ in sorted_import_distances_from_hole]
144 | return sorted_import_files
145 |
146 | def get_relevant_files(self, type_str, sort_order='descending'):
147 | all_type_files = self.parse_data[self.file][type_str]
148 | if not all_type_files:
149 | return all_type_files
150 |
151 | sorted_file_imports = self.get_relevant_import_files()
152 | overlapping_type_files = {}
153 | found = False
154 | # find type files(e.g. sibling files) with imports common with current file based on the position with the hole.
155 | # in case multiple such files exist, sort based on number of common import statements.
156 | for imp_file in sorted_file_imports:
157 | if imp_file in self.import_overlap_type_files:
158 | return self.import_overlap_type_files[imp_file]
159 | if found:
160 | break
161 | for type_file, type_file_overlap in all_type_files:
162 | if type_file_overlap > 0:
163 | if 'small_' in self.file.split('/')[1]:
164 | type_file = '/'.join([type_file.split('/')[0]] + [self.file.split('/')[1]] + type_file.split('/')[2:])
165 | type_file_import_files = list(self.parse_data[type_file]['imports'].keys())
166 | if imp_file in type_file_import_files:
167 | overlapping_type_files[type_file] = type_file_overlap
168 | found = True
169 |
170 | if not found:
171 | type_files = [x[0] for x in all_type_files]
172 | return type_files
173 |
174 | # more the overlap, higher the ranking
175 | if sort_order == 'descending':
176 | sorted_overlapping_type_files = sorted(overlapping_type_files.items(), key=lambda x: x[1], reverse=True)
177 | else:
178 | sorted_overlapping_type_files = sorted(overlapping_type_files.items(), key=lambda x: x[1])
179 |
180 | sorted_type_files = [type_file for type_file, _ in sorted_overlapping_type_files]
181 | self.import_overlap_type_files[imp_file] = sorted_type_files
182 | return sorted_type_files
183 |
184 | def get_parent_class_filename(self):
185 | """
186 | Return the parent class filename that corresponds to the immediate scope of the hole location
187 | """
188 | file_parsed_data = self.parse_data[self.file]
189 | parent_class_filenames = file_parsed_data['parent_class_filenames']
190 | parent_class_names = file_parsed_data['parent_class_names']
191 | relevant_index = self.get_nearest_attribute_index(parent_class_names)
192 | if relevant_index != -1:
193 | return parent_class_filenames[relevant_index][0], parent_class_names[relevant_index]
194 | else:
195 | return '', ''
196 |
197 | def get_method_names_and_bodies(self, method_names, method_bodies, file):
198 |
199 | if self.top_k != -1 and len(method_names) >= self.top_k:
200 | method_context_len = int(self.context_len/self.top_k)
201 | else:
202 | method_context_len = int(self.context_len/len(method_names))
203 |
204 | method_contexts = []
205 | context_len = 0
206 | for method_name in method_names:
207 | if method_name:
208 | found = False
209 | for method_body in method_bodies:
210 | # for each method name, find the corresponding method_body
211 | if method_body and method_body[0][0] == method_name[0][0]:
212 | ms, me = method_body
213 | full_ms = (ms[0], 0)
214 | full_me = me
215 | found = True
216 | break
217 |
218 | if found == False:
219 | ms, me = method_name
220 | full_ms = (ms[0], 0)
221 | full_me = (ms[0], -1)
222 |
223 | method_name_and_body = get_string(file, full_ms, full_me)
224 | method_context, method_context_len = get_codex_tokenized_string(self.tokenizer, method_name_and_body, \
225 | method_context_len)
226 | if self.rule_context_formatting == 'method_name'\
227 | or self.rule_context_formatting =='class_method_name':
228 | method_name_str = "[" + get_string(file, method_name[0], method_name[1]) + "]"
229 | method_contexts.append(method_name_str)
230 | method_contexts.append(method_context)
231 | context_len += method_context_len
232 |
233 | context = self.get_rule_context_format(method_contexts)
234 | return context, context_len
235 |
236 | def get_context_string(self, candidate_attributes):
237 |
238 | if self.top_k == -1:
239 | attributes_str = self.get_rule_context_format(candidate_attributes)
240 | else:
241 | if self.top_k_type == 'first':
242 | attributes_str = self.get_rule_context_format(candidate_attributes[:self.top_k])
243 | if self.top_k_type == 'last':
244 | attributes_str = self.get_rule_context_format(candidate_attributes[-self.top_k:])
245 |
246 | context, context_len = get_codex_tokenized_string(self.tokenizer, attributes_str, self.context_len,
247 | type=self.codex_completion_inclusion_type)
248 |
249 | return context, context_len
250 |
251 | def get_attribute_context(self, attributes, file):
252 | candidate_attributes = []
253 | for attribute in attributes:
254 | if attribute:
255 | start, end = attribute
256 |
257 | if self.context_location == 'in_file':
258 | start_line, start_char = start
259 | end_line, end_char = end
260 | hole_pos_line, hole_pos_char = self.hole_pos
261 | #assert start_line == end_line, "attribute doesn't span a single line"
262 |
263 | if self.context_scope == 'pre' or self.context_scope == 'pre_post':
264 | if end_line < hole_pos_line or (end_line == hole_pos_line and end_char < hole_pos_char):
265 | attribute_string = get_string(file, start, end)
266 | candidate_attributes.append(attribute_string.strip())
267 |
268 | if self.context_scope == 'post' or self.context_scope == 'pre_post':
269 | if start_line > hole_pos_line:
270 | attribute_string = get_string(file, start, end)
271 | candidate_attributes.append(attribute_string.strip())
272 |
273 | else:
274 | # checking for overlap with the hole is not needed here as it is a different file
275 | attribute_string = get_string(file, start, end)
276 | candidate_attributes.append(attribute_string.strip())
277 |
278 | context, context_len = self.get_context_string(candidate_attributes)
279 | return context, context_len
280 |
281 | def get_line_context(self, num_of_lines_to_exclude=0):
282 | num_of_lines_to_be_taken = self.top_k
283 | pre_context = ''
284 | post_context = ''
285 | if self.context_scope == 'pre' or self.context_scope == 'pre_post':
286 | end = self.hole_pos
287 | if num_of_lines_to_be_taken == -1:
288 | start = (0, 0)
289 | else:
290 | hole_pos_line = self.hole_pos[0]
291 | start_line = hole_pos_line - num_of_lines_to_be_taken
292 | if start_line < 0:
293 | start_line = 0
294 | start = (start_line, 0)
295 | pre_context = get_string(self.file, start, end)
296 |
297 | if self.context_scope == 'post' or self.context_scope == 'pre_post':
298 | hole_pos_line = self.hole_pos[0]
299 | start = (hole_pos_line + 1 + num_of_lines_to_exclude, 0)
300 | if num_of_lines_to_be_taken != -1:
301 | end_line = hole_pos_line + num_of_lines_to_be_taken + num_of_lines_to_exclude
302 | if end_line >= len(self.file_lines):
303 | end_line = len(self.file_lines) - 1
304 | end_char = len(self.file_lines[end_line])
305 | end = (end_line, end_char)
306 | else:
307 | end_line = len(self.file_lines)-1
308 | end_char = len(self.file_lines[end_line])
309 | end = (end_line, end_char)
310 |
311 | post_context = get_string(self.file, start, end)
312 |
313 | if self.context_scope == 'pre':
314 | context, context_len = get_codex_tokenized_string(self.tokenizer, pre_context, self.context_len,
315 | type=self.codex_completion_inclusion_type)
316 | if self.context_scope == 'post':
317 | context, context_len = get_codex_tokenized_string(self.tokenizer, post_context, self.context_len,
318 | type=self.codex_completion_inclusion_type)
319 | if self.context_scope == 'pre_post':
320 | pre_context, pre_context_len = get_codex_tokenized_string(self.tokenizer, pre_context, int(self.context_len/2),
321 | type='back')
322 | post_context, post_context_len = get_codex_tokenized_string(self.tokenizer, post_context, int(self.context_len/2),
323 | type='front')
324 |
325 | context = pre_context + "\n" + post_context
326 | context_len = pre_context_len + post_context_len
327 |
328 | return context, context_len
329 |
330 | def get_base_context(self):
331 | base_class_names = self.parse_data[self.file]['class_names']
332 | class_index = self.get_nearest_attribute_index(base_class_names)
333 | if class_index != -1:
334 | base_class_name = get_string(self.file, base_class_names[class_index][0], base_class_names[class_index][1])
335 | base_context = "[" + base_class_name + "]"
336 | else:
337 | base_context = ''
338 | return base_context
339 |
340 | def get_attribute_context_from_context_type(self, file_type):
341 | if 'small_' in self.file.split('/')[1]:
342 | file_type = '/'.join([file_type.split('/')[0]] + [self.file.split('/')[1]] + file_type.split('/')[2:])
343 | if self.context_type == 'identifiers':
344 | context, context_len = self.get_attribute_context(self.parse_data[file_type]['identifiers'], file_type)
345 |
346 | if self.context_type == 'type_identifiers':
347 | context, context_len = self.get_attribute_context(self.parse_data[file_type]['type_identifiers'], file_type)
348 |
349 | if self.context_type == 'string_literals':
350 | context, context_len = self.get_attribute_context(self.parse_data[file_type]['string_literals'], file_type)
351 |
352 | if self.context_type == 'method_names':
353 | context, context_len = self.get_attribute_context(self.parse_data[file_type]['all_method_names'], file_type)
354 |
355 | if self.context_type == 'method_names_and_bodies':
356 | method_names = self.parse_data[file_type]['all_method_names']
357 | method_bodies = self.parse_data[file_type]['all_method_bodies']
358 | if method_names:
359 | context, context_len = self.get_method_names_and_bodies(method_names, method_bodies, file_type)
360 | else:
361 | context= ''
362 | context_len = 0
363 |
364 | if self.context_type == 'field_declarations':
365 | context, context_len = self.get_attribute_context(self.parse_data[file_type]['field_declarations'], file_type)
366 |
367 | return context, context_len
368 |
369 | def get_context_from_multiple_files(self, files):
370 | total_context_len = 0
371 | total_context = ''
372 | if files:
373 | for file in files:
374 | if total_context_len < self.get_context_len():
375 | if self.rule_context_formatting == 'class_name' or self.rule_context_formatting == 'class_method_name':
376 | base_context = self.get_base_context()
377 | file_name = file.split('/')[-1].split('.')[0]
378 | file_name = "[" + file_name + "]"
379 | # get context
380 | context, context_len = self.get_attribute_context_from_context_type(file)
381 |
382 | # import contexts are added to the front based on decreasing priority
383 | if self.rule_context_formatting == 'class_name' or self.rule_context_formatting == 'class_method_name':
384 | total_context = file_name + " " + context + " " + total_context
385 | else:
386 | total_context = context + " " + total_context
387 | total_context_len += context_len
388 | else:
389 | break
390 |
391 | if self.rule_context_formatting == 'class_name' or self.rule_context_formatting == 'class_method_name':
392 | total_context = total_context + "\n" + base_context
393 | return total_context, total_context_len
394 |
395 | def get_in_file_context(self, num_of_lines_to_exclude=0):
396 | """
397 | for in_file only post lines makes sense.
398 | for others, first post is tried, if post is not successful in finding any context the pre_post is tried.
399 | """
400 | if self.context_type == 'lines':
401 | self.set_context_scope_and_inclusion_type('post')
402 | context, context_len = self.get_line_context(num_of_lines_to_exclude)
403 | # doesn't mean much to have pre_post in this setting
404 |
405 | if self.context_type == 'identifiers':
406 | self.set_context_scope_and_inclusion_type('post')
407 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['identifiers'], self.file)
408 | if not context:
409 | self.set_context_scope_and_inclusion_type('pre_post')
410 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['identifiers'], self.file)
411 |
412 | if self.context_type == 'type_identifiers':
413 | self.set_context_scope_and_inclusion_type('post')
414 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['type_identifiers'], self.file)
415 | if not context:
416 | self.set_context_scope_and_inclusion_type('pre_post')
417 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['type_identifiers'], self.file)
418 |
419 | if self.context_type == 'string_literals':
420 | self.set_context_scope_and_inclusion_type('post')
421 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['string_literals'], self.file)
422 | if not context:
423 | self.set_context_scope_and_inclusion_type('pre_post')
424 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['string_literals'], self.file)
425 |
426 | if self.context_type == 'method_names':
427 | self.set_context_scope_and_inclusion_type('post')
428 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['all_method_names'], self.file)
429 | if not context:
430 | self.set_context_scope_and_inclusion_type('pre_post')
431 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['all_method_names'], self.file)
432 |
433 | if self.context_type == 'field_declarations':
434 | self.set_context_scope_and_inclusion_type('post')
435 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['field_declarations'], self.file)
436 | if not context:
437 | self.set_context_scope_and_inclusion_type('pre_post')
438 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['field_declarations'], self.file)
439 |
440 | return context, context_len
441 |
442 | def get_parent_class_file_context(self):
443 | self.parent_class_file, self.parent_class_name = self.get_parent_class_filename()
444 | if self.parent_class_file:
445 | if self.rule_context_formatting == 'class_name' or self.rule_context_formatting == 'class_method_name':
446 | base_context = self.get_base_context()
447 | parent_class_name = get_string(self.file, self.parent_class_name[0], self.parent_class_name[1])
448 | parent_context = "[" + parent_class_name + "]"
449 | # get context
450 | context, context_len = self.get_attribute_context_from_context_type(self.parent_class_file)
451 | if self.rule_context_formatting == 'class_name' or self.rule_context_formatting == 'class_method_name':
452 | context = parent_context + " " + context + "\n" + base_context
453 | else:
454 | context = ''
455 | context_len = 0
456 |
457 | return context, context_len
458 |
459 | def get_import_file_context(self):
460 | import_files = self.get_relevant_import_files()
461 | return self.get_context_from_multiple_files(import_files)
462 |
463 | def get_sibling_file_context(self):
464 | if self.context_location.startswith('reverse'):
465 | sort_order ='ascending'
466 | else:
467 | sort_order = 'descending'
468 | sibling_files = self.get_relevant_files(type_str='sibling_files', sort_order=sort_order)
469 | return self.get_context_from_multiple_files(sibling_files)
470 |
471 | def get_similar_name_file_context(self):
472 | if self.context_location.startswith('reverse'):
473 | sort_order ='ascending'
474 | else:
475 | sort_order = 'descending'
476 | similar_name_files = self.get_relevant_files(type_str='similar_name_files', sort_order=sort_order)
477 | return self.get_context_from_multiple_files(similar_name_files)
478 |
479 | def get_child_class_file_context(self):
480 | child_class_files = self.get_relevant_files(type_str='child_class_filenames')
481 | return self.get_context_from_multiple_files(child_class_files)
482 |
483 | def get_import_of_sibling_file_context(self):
484 | imports_of_sibling_files = self.get_relevant_import_of_att_files('sibling_files')
485 | return self.get_context_from_multiple_files(imports_of_sibling_files)
486 |
487 | def get_import_of_similar_name_file_context(self):
488 | imports_of_similar_name_files = self.get_relevant_import_of_att_files('similar_name_files')
489 | return self.get_context_from_multiple_files(imports_of_similar_name_files)
490 |
491 | def get_import_of_parent_class_file_context(self):
492 | imports_of_parent_class_files = self.get_relevant_import_of_att_files('parent_class_filenames')
493 | return self.get_context_from_multiple_files(imports_of_parent_class_files)
494 |
495 | def get_import_of_child_class_file_context(self):
496 | imports_of_child_class_files = self.get_relevant_import_of_att_files('child_class_filenames')
497 | return self.get_context_from_multiple_files(imports_of_child_class_files)
498 |
--------------------------------------------------------------------------------
/create_hole_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import argparse
4 | import pickle
5 | import javac_parser
6 | import shutil
7 |
8 |
9 | '''
10 | For each line in the repo (which is not blank or comment) choose the midpoint of those lines(character wise not token wise) as hole position
11 | '''
12 |
13 | def choose_holes(project_lines, comments):
14 | data = {}
15 | count = 0
16 | repeated_holes = 0
17 | chosen_lines = []
18 |
19 | selected_lines = np.arange(0, len(project_lines))
20 |
21 | for proj_line_id in selected_lines:
22 | file, file_line_id, line = project_lines[proj_line_id]
23 | # removing leading and trailing whitespaces
24 | line = line.strip()
25 | # omitting comments and empty lines
26 | if line and not (np.any([line.startswith(comment) for comment in comments])):
27 | if proj_line_id in chosen_lines:
28 | repeated_holes+=1
29 | else:
30 | chosen_lines.append(proj_line_id)
31 | count+=1
32 | #get holes from the middle of the lines
33 | mid_point = int(len(line)/2)
34 | chosen_position = mid_point
35 |
36 | if file in data:
37 | data[file].append((file_line_id, chosen_position))
38 | else:
39 | data[file] = [(file_line_id, chosen_position)]
40 |
41 | #total number of holes, #number of repeated holes, number of relevant lines, number of files
42 | return data, len(chosen_lines), len(data)
43 |
44 | def setup_args():
45 | """
46 | Description: Takes in the command-line arguments from user
47 | """
48 | parser = argparse.ArgumentParser()
49 |
50 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility")
51 | parser.add_argument("--base_dir", type=str, default='gcode-data', \
52 | help="base directory for the data")
53 | parser.add_argument("--data_split", type=str, default='train', \
54 | help="data split to store the data")
55 | parser.add_argument("--language", type=str, default='java', help="java, cpp")
56 | parser.add_argument("--proj_name", type=str, default='javasummerframework', \
57 | help="name of the input repo")
58 |
59 | return parser.parse_args()
60 |
61 | if __name__ == '__main__':
62 |
63 | args = setup_args()
64 |
65 | #Fix seeds
66 | np.random.seed(args.seed)
67 | os.environ['PYTHONHASHSEED']=str(args.seed)
68 |
69 | if args.language == 'java':
70 | file_extensions = ['.java']
71 | comments = ['*', '/']
72 | if args.language == 'lua':
73 | file_extensions = ['.lua']
74 | comments = ['--']
75 | if args.language == 'cpp':
76 | file_extensions == ['.cc', '.cpp', '.h']
77 | comments = ['/']
78 |
79 | source_data_path = os.path.join(args.base_dir, args.proj_name)
80 | os.makedirs(os.path.join('rule_classifier_data', args.data_split), exist_ok=True)
81 | destination_data_path = os.path.join('rule_classifier_data', args.data_split, args.proj_name)
82 | shutil.move(source_data_path, destination_data_path)
83 |
84 | files = []
85 | for dp, dn, filenames in os.walk(destination_data_path):
86 | for f in filenames:
87 | for file_ext in file_extensions:
88 | if os.path.splitext(f)[1] == file_ext:
89 | files.append(os.path.join(dp, f))
90 |
91 | project_lines = []
92 | for file in files:
93 | file_lines = open(file, encoding="utf8", errors='backslashreplace').readlines()
94 | parsed_file_lines = []
95 | for l in range(len(file_lines)):
96 | line = file_lines[l]
97 | parsed_file_lines.append((file, l, line))
98 | project_lines.extend(parsed_file_lines)
99 | num_of_lines_in_proj = len(project_lines) # number of lines in the project
100 | data, num_of_holes, num_of_files = choose_holes(project_lines, comments=comments)
101 |
102 | with open(os.path.join(destination_data_path, 'hole_data'), 'wb') as f:
103 | pickle.dump(data, f)
104 |
105 | print(args.proj_name + ", " + str(num_of_files) + ", " \
106 | + str(num_of_lines_in_proj) + ", " + str(num_of_holes))
107 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | from context import *
2 | from utils import *
3 | from rule_config import *
4 | from transformers import GPT2TokenizerFast
5 |
6 | context_location_conversion = {
7 | 'in_file':'in_file', \
8 | 'parent_class_file':'parent_class_file', \
9 | 'import_file':'import_file',\
10 | 'sibling_file':'sibling_files', \
11 | 'similar_name_file':'similar_name_files', \
12 | 'child_class_file':'child_class_filenames', \
13 | 'import_of_sibling_file':'sibling_files', \
14 | 'import_of_similar_name_file':'similar_name_files', \
15 | 'import_of_parent_class_file':'parent_class_filenames', \
16 | 'import_of_child_class_file':'child_class_filenames'
17 | }
18 |
19 |
20 | class RuleDatasetUtils():
21 | def __init__(self, file, parse_datas, hole_pos, tokenizer):
22 | super(RuleDatasetUtils, self).__init__()
23 | #mod_file = '/'. join(['data', 'gcode-data'] + file.split('/')[2:])
24 | self.file = file
25 | self.parse_datas = parse_datas
26 | self.hole_pos = hole_pos
27 | self.tokenizer = tokenizer
28 | #self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
29 |
30 | def get_relevant_files(self, context_location):
31 | if context_location == 'in_file':
32 | return [self.file]
33 | else:
34 | rule_context_obj = getContext(context_location = context_location, file=self.file, parse_data = self.parse_datas)
35 | rule_context_obj.set_hole_pos(self.hole_pos)
36 | if rule_context_obj.is_out_files():
37 | if context_location == 'parent_class_file':
38 | relevant_file, _ = rule_context_obj.get_parent_class_filename()
39 | relevant_files = []
40 | elif context_location == 'import_file':
41 | relevant_files = rule_context_obj.get_relevant_import_files()
42 | elif 'import_of_' in context_location:
43 | relevant_files = rule_context_obj.get_relevant_import_of_att_files(context_location_conversion[context_location])
44 | else:
45 | relevant_files = rule_context_obj.get_relevant_files(context_location_conversion[context_location])
46 | return relevant_files
47 | else:
48 | return []
49 |
50 | def get_usages_from_context_location(self, hole_attributes, context_location):
51 | cl_files = self.get_relevant_files(context_location)
52 | for hole_att in hole_attributes:
53 | for cl_file in cl_files:
54 | print(cl_file)
55 | file_attributes = self.parse_datas[cl_file]['identifiers']
56 | att_usages = find_usages(hole_att, self.file, file_attributes, cl_file)
57 | if not att_usages:
58 | continue
59 | else:
60 | return (cl_file, att_usages)
61 |
62 | def get_all_usages(self, hole_attributes):
63 | usages = {}
64 | for context_location in context_location_conversion.keys():
65 | print(context_location)
66 | usages[context_location] = self.get_usages_from_context_location(hole_attributes, context_location)
67 | return usages
68 |
69 | def get_default_prompt(self, context_len):
70 |
71 | default_context_obj = getContext(context_location='in_file',
72 | tokenizer=self.tokenizer,
73 | file=self.file,
74 | context_len=context_len,
75 | context_scope='pre',\
76 | context_type='lines',\
77 | top_k=-1)
78 |
79 | default_context_obj.set_hole_pos(self.hole_pos)
80 | default_prompt, default_prompt_len = default_context_obj.get_line_context()
81 | return default_prompt, default_prompt_len
82 |
83 |
84 | def get_all_rules_context(self, num_of_lines_to_exclude=0):
85 | rule_prompts = []
86 | rule_indexes = []
87 | total_context_len = self.tokenizer.model_max_length
88 |
89 |
90 | for key, val in combined_to_index.items():
91 | rule_prompt = ''
92 | if key == 'codex':
93 | rule_prompt, rule_prompt_len = self.get_default_prompt(context_len=total_context_len)
94 | if key != 'codex':
95 | context_location, context_type, context_division_ratio = key.split('#')
96 | rule_context_formatting = rule_hyperparams[context_type]['rule_context_formatting'][0]
97 | rule_context_obj = getContext(context_location = context_location, file=self.file, parse_data = self.parse_datas, \
98 | context_type=context_type, rule_context_formatting=rule_context_formatting, \
99 | tokenizer = self.tokenizer, context_len = total_context_len)
100 |
101 | allocated_rule_context_len = int(rule_context_obj.get_context_len()*float(context_division_ratio))
102 | rule_context_obj.set_context_len(allocated_rule_context_len)
103 | rule_context_obj.set_hole_pos(self.hole_pos)
104 |
105 | if context_location == 'in_file':
106 | rule_prompt, rule_prompt_len = rule_context_obj.get_in_file_context(num_of_lines_to_exclude)
107 |
108 | # there are files for this context location except in_file context location
109 | is_out_files = rule_context_obj.is_out_files()
110 | if is_out_files:
111 | if context_location == 'parent_class_file':
112 | rule_prompt, rule_prompt_len = rule_context_obj.get_parent_class_file_context()
113 | if context_location == 'import_file':
114 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_file_context()
115 | if context_location == 'sibling_file':
116 | rule_prompt, rule_prompt_len = rule_context_obj.get_sibling_file_context()
117 | if context_location == 'similar_name_file':
118 | rule_prompt, rule_prompt_len = rule_context_obj.get_similar_name_file_context()
119 | if context_location == 'child_class_file':
120 | rule_prompt, rule_prompt_len = rule_context_obj.get_child_class_file_context()
121 | if context_location == 'import_of_similar_name_file':
122 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_similar_name_file_context()
123 | if context_location == 'import_of_parent_class_file':
124 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_parent_class_file_context()
125 | if context_location == 'import_of_child_class_file':
126 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_child_class_file_context()
127 | if context_location == 'import_of_sibling_file':
128 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_sibling_file_context()
129 |
130 | if rule_prompt:
131 | rule_prompts.append(rule_prompt)
132 | rule_indexes.append(val)
133 |
134 | return rule_prompts, rule_indexes
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/generate_completions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | import os
4 | import copy
5 | import pickle
6 | import openai
7 | import argparse
8 | import json
9 | import random
10 | from utils import *
11 | from context import *
12 | from transformers import GPT2TokenizerFast
13 |
14 | def setup_args():
15 | """
16 | Description: Takes in the command-line arguments from user
17 | """
18 | parser = argparse.ArgumentParser()
19 |
20 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility")
21 | parser.add_argument("--base_dir", type=str, default='rule_classifier_data/val', help="base directory for the data")
22 | parser.add_argument("--repo_name", type=str, default='ircrpgbot', help="name of the repo")
23 |
24 | # completion-related hyperparams
25 | parser.add_argument("--mode", type=str, default='rule', help="codex, rule")
26 |
27 | parser.add_argument("--batch_size", type=int, default=20, \
28 | help="batch size of the prompts to be given to Codex")
29 | parser.add_argument("--completion_len", type=int, default=24, \
30 | help=" length of the 24 completions, so total size will be 4096")
31 | parser.add_argument("--is_run_rule_full", default=False, action='store_true', \
32 | help="whether to run rule based method for the full cases or only rule_triggered cases")
33 |
34 | # context related hyperparams
35 | parser.add_argument("--total_context_len", type=int, default=100, \
36 | help="total size of the context: 24 for completions, so total size will be 4096")
37 | parser.add_argument("--context_division_ratio", type=float, default=0.5, \
38 | help="ratio in which the in-file and out-file context are divided")
39 |
40 | parser.add_argument("--context_location", type=str, default='in_file', \
41 | help="where to take context from\
42 | NOTE that this is always in addition to the previous prompt, \
43 | i.e., in addition to the default prompt for a Codex model")
44 |
45 | parser.add_argument("--context_type", type=str, default='field_declarations',\
46 | help="the type of context to be taken. \
47 | For possible values for each context_locations see rules file")
48 |
49 | # rule-related hyperparameters
50 | parser.add_argument("--top_k", type=int, default=-1,\
51 | help="k value. A value of -1 indicates taking full context of context_type")
52 | parser.add_argument("--top_k_type", type=str, default='first', \
53 | help="first, last")
54 | parser.add_argument("--prompt_separator", type=str, default='space', \
55 | help="space, newline")
56 | parser.add_argument("--rule_context_formatting", type=str, default='space', \
57 | help="space, newline, method_name, class_name, comment, class_method_name")
58 | return parser.parse_args()
59 |
60 | def generate_prediction(prompt):
61 | '''
62 | generate predictions using Codex
63 | '''
64 | try:
65 | response = openai.Completion.create(engine='code-davinci-001',\
66 | prompt=prompt,stop='\n',\
67 | temperature=0.0)
68 |
69 | except:
70 | print ("Waiting")
71 | response = None
72 | return response
73 |
74 | def check_hole_scope(hole_pos, class_spans):
75 | '''
76 | return the class span of the base class where the cursor is present. If there is no base class, return None
77 | '''
78 | for class_span in class_spans:
79 | cs = int(class_span.split('_')[0])
80 | ce = int(class_span.split('_')[1])
81 | l, c = hole_pos
82 | if l == cs or cs == -1:
83 | return None
84 | if cs < l <= ce:
85 | return class_span
86 |
87 | def get_default_prompt(hole_pos=(0,0), context_len=0, tokenizer=None, file=''):
88 |
89 | default_context_obj = getContext(context_location='in_file',
90 | tokenizer=tokenizer,
91 | file=file,
92 | context_len=context_len,
93 | context_scope='pre',\
94 | context_type='lines',\
95 | top_k=-1)
96 |
97 | default_context_obj.set_hole_pos(hole_pos)
98 | default_prompt, default_prompt_len = default_context_obj.get_line_context()
99 | return default_prompt, default_prompt_len
100 |
101 | def get_prompt(rule_context_obj=None, context_location='in_file',
102 | total_context_len=4072, rule_triggered=False, parent_class_filename='',
103 | context_division_ratio=0.5, num_of_lines_to_exclude=0):
104 |
105 | # start by assigning half of the total_context_len to the rule prompt
106 | rule_context_obj.set_context_len(total_context_len)
107 | allocated_rule_context_len = int(rule_context_obj.get_context_len()*context_division_ratio)
108 | rule_context_obj.set_context_len(allocated_rule_context_len)
109 |
110 | if context_location == 'in_file':
111 | rule_prompt, rule_prompt_len = rule_context_obj.get_in_file_context(num_of_lines_to_exclude)
112 | if context_location == 'parent_class_file':
113 | rule_prompt, rule_prompt_len = rule_context_obj.get_parent_class_file_context()
114 | if context_location == 'import_file':
115 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_file_context()
116 | if context_location == 'sibling_file' or context_location == 'reverse_sibling_file':
117 | rule_prompt, rule_prompt_len = rule_context_obj.get_sibling_file_context()
118 | if context_location == 'similar_name_file' or context_location == 'reverse_similar_name_file':
119 | rule_prompt, rule_prompt_len = rule_context_obj.get_similar_name_file_context()
120 | if context_location == 'child_class_file':
121 | rule_prompt, rule_prompt_len = rule_context_obj.get_child_class_file_context()
122 | if context_location == 'import_of_similar_name_file':
123 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_similar_name_file_context()
124 | if context_location == 'import_of_parent_class_file':
125 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_parent_class_file_context()
126 | if context_location == 'import_of_child_class_file':
127 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_child_class_file_context()
128 | if context_location == 'import_of_sibling_file':
129 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_sibling_file_context()
130 |
131 | # if the rule_prompt_len is shorter than the allocated space, use the extra space for the default_prompt
132 | if rule_prompt_len < allocated_rule_context_len:
133 | default_context_len = total_context_len - rule_prompt_len
134 | else:
135 | default_context_len = total_context_len - allocated_rule_context_len
136 | # if something is returned by the rule, it means that the rule is triggered
137 | if rule_prompt_len > 0:
138 | rule_triggered = True
139 | default_prompt, default_prompt_len = get_default_prompt(
140 | hole_pos=getattr(rule_context_obj, 'hole_pos'),
141 | context_len=default_context_len,
142 | tokenizer=getattr(rule_context_obj, 'tokenizer'),
143 | file=getattr(rule_context_obj, 'file')
144 | )
145 | return rule_prompt, default_prompt, rule_triggered
146 |
147 | if __name__ == '__main__':
148 |
149 | args = setup_args()
150 |
151 | #Fix seeds
152 | np.random.seed(args.seed)
153 | os.environ['PYTHONHASHSEED'] = str(args.seed)
154 |
155 | os.environ["OPENAI_API_KEY"] = open('openai_api_key', 'r').read().strip()
156 | openai.api_key = os.getenv("OPENAI_API_KEY")
157 |
158 | #directory for storing results
159 | input_data_dir = os.path.join(args.base_dir, args.repo_name)
160 | result_path = os.path.join('results', args.base_dir, args.repo_name)
161 | os.makedirs(result_path, exist_ok=True)
162 |
163 |
164 | # get tokenizer
165 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
166 |
167 | #get stored parsed data
168 | parsed_data_filename = os.path.join(args.base_dir, args.repo_name, 'parsed_data')
169 | parse_data = pickle.load(open(parsed_data_filename, 'rb'))
170 | #get the holes
171 | hole_filename = os.path.join(args.base_dir, args.repo_name, 'hole_data')
172 | hole_data = pickle.load(open(hole_filename, 'rb'))
173 |
174 | # get all relevant files (in raw form)
175 | files = [os.path.join(dp, f) \
176 | for dp, dn, filenames in os.walk(input_data_dir) \
177 | for f in filenames \
178 | if os.path.splitext(f)[1] == '.java']
179 |
180 | # Create the result file for writing the predictions
181 | result_dir = os.path.join(result_path, args.context_location)
182 | os.makedirs(result_dir, exist_ok=True)
183 |
184 | if args.mode == 'rule':
185 | result_filename = args.mode + '_' + args.context_type + '_' + str(args.top_k_type)\
186 | + '_' + str(args.top_k) \
187 | + '_' + str(int(args.total_context_len * args.context_division_ratio)) \
188 | + '_' + args.rule_context_formatting \
189 | + '_' + args.prompt_separator
190 |
191 | if args.mode =='codex':
192 | result_filename = args.mode + '_' + str(args.total_context_len)
193 |
194 | full_result_filename = os.path.join(result_dir, result_filename + '.json')
195 | print(full_result_filename)
196 | if os.path.isfile(full_result_filename) and os.path.getsize(full_result_filename) > 0:
197 | print(full_result_filename, " : result file already exists")
198 | exit()
199 |
200 | f = open(full_result_filename, 'w')
201 |
202 | all_holes = []
203 | rule_prompts = []
204 | default_prompts = []
205 | all_hole_identities = []
206 | all_rules_triggered = []
207 | total_count = 0
208 | rule_triggered_count = 0
209 |
210 | # get the prompts for all files
211 | for file in files:
212 | if file in hole_data:
213 | file_lines = open(file, encoding="utf8", errors='backslashreplace').readlines()
214 |
215 | # define the rule context object. Depends on the file
216 | rule_context_obj = getContext(context_location = args.context_location, \
217 | tokenizer=tokenizer, file=file,
218 | parse_data = parse_data,
219 | context_type=args.context_type,
220 | top_k=args.top_k,top_k_type=args.top_k_type,
221 | rule_context_formatting=args.rule_context_formatting,
222 | file_lines=file_lines)
223 |
224 | is_out_file = rule_context_obj.is_out_files()
225 |
226 | # go through the holes in the file
227 | for (l,c) in hole_data[file]: # l = line no, c = character offset within line l
228 | if total_count%1000 == 0:
229 | print("Total Count:", total_count)
230 | hole = file_lines[l][c:]
231 | hole_identity = file + '_' + str(l) + '_' + str(c)
232 | hole_pos = (l, c)
233 |
234 | # if mode is codex or we have no parent_class_files or import_files,
235 | # then get the default prompt directly
236 | if args.mode == 'codex' or \
237 | (args.mode == 'rule' and args.context_location != 'in_file' and not is_out_file):
238 | default_prompt, default_prompt_len = get_default_prompt(hole_pos, args.total_context_len,
239 | tokenizer, file)
240 | rule_triggered = False
241 | rule_prompt = ''
242 |
243 | else:
244 | rule_context_obj.set_hole_pos(hole_pos)
245 | rule_prompt, default_prompt, rule_triggered = \
246 | get_prompt(
247 | rule_context_obj=rule_context_obj, \
248 | context_location=args.context_location, \
249 | total_context_len=args.total_context_len, \
250 | context_division_ratio=args.context_division_ratio)
251 |
252 | #print("RP: ", rule_prompt)
253 | #print("DP: ", default_prompt)
254 | if rule_triggered == True:
255 | rule_triggered_count+=1
256 | rule_prompts.append(rule_prompt)
257 | default_prompts.append(default_prompt)
258 | all_holes.append(hole)
259 | all_hole_identities.append(hole_identity)
260 | all_rules_triggered.append(rule_triggered)
261 | #all_parent_class_filenames.append(parent_class_filename)
262 |
263 | total_count += 1
264 |
265 | print(total_count, rule_triggered_count)
266 | print(len(all_holes))
267 |
268 |
269 | # create prompts only for the cases where the rules are triggered.
270 | # other cases will be the same as codex, so they can be directly copied from the pre results
271 | prompts = []
272 | for i in range(len(all_holes)):
273 | if (args.mode != 'codex' and args.is_run_rule_full) \
274 | or (args.mode != 'codex' and not args.is_run_rule_full and all_rules_triggered[i])\
275 | or (args.mode == 'codex'):
276 | rule_p = rule_prompts[i]
277 | def_p = default_prompts[i]
278 | prompt_separator = args.prompt_separator
279 | # if rule is empty
280 | if not rule_p and prompt_separator == 'newline':
281 | prompt_separator == 'space'
282 | prompt = rule_p + promptseparator2str[prompt_separator] + def_p
283 |
284 | # make sure that the length of the prompt is less than or equal to the total_context_len
285 | codex_tokens = tokenizer(prompt)['input_ids']
286 | if len(codex_tokens) > args.total_context_len:
287 | codex_tokens = codex_tokens[-args.total_context_len:]
288 | prompt = tokenizer.decode(codex_tokens)
289 | if prompt:
290 | assert len(codex_tokens) <= args.total_context_len, 'prompt length exceeds the maximum length'
291 | #print("Hole:", all_holes[i])
292 | #print("Prompt:", prompt)
293 | prompts.append((i, prompt))
294 |
295 | print(len(prompts))
296 |
297 | # prompt the codex model in batches to generate completions with the prompts created before
298 | count = 0
299 | i = 0
300 | while (i < len(prompts)):
301 | print(i)
302 | batch_prompts = prompts[i:i+args.batch_size]
303 | batch_prompt_texts = [x[1] for x in batch_prompts]
304 | #print(batch_prompt_post_texts)
305 | batch_prompt_indexes = [x[0] for x in batch_prompts] # index within the all_* arrays
306 | batch_responses = generate_prediction(batch_prompt_texts)
307 | if batch_responses != None:
308 | for j in range(len(batch_prompts)):
309 | response = batch_responses.choices[j]
310 | prediction = response.text
311 | #prediction_tokens = response.logprobs.tokens
312 | #prediction_token_logprobs = response.logprobs.token_logprobs
313 | hole = all_holes[batch_prompt_indexes[j]]
314 | hole_identity = all_hole_identities[batch_prompt_indexes[j]]
315 | rule_triggered = all_rules_triggered[batch_prompt_indexes[j]]
316 | if rule_triggered:
317 | count+=1
318 | batch_suffix = ''
319 | result = {
320 | 'hole_identity': hole_identity, \
321 | 'prediction': prediction, \
322 | 'ground-truth hole': hole, \
323 | 'prompt': batch_prompt_texts[j], \
324 | 'post_prompt': batch_suffix, \
325 | 'rule_triggered': rule_triggered, \
326 | 'index': batch_prompt_indexes[j] + 1 # this index corresponds to the global index
327 | }
328 | f.write(json.dumps(result))
329 | f.write("\n")
330 | f.flush()
331 | i = i + args.batch_size
332 | else:
333 | # wait for 60s before calling the API again
334 | time.sleep(60)
335 |
336 | f.close()
337 | print(i, j, count, len(prompts))
338 |
--------------------------------------------------------------------------------
/generate_rule_representations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import argparse
5 | import random
6 | from torch.utils.data import DataLoader
7 | from torch import nn
8 | from tqdm import tqdm
9 | from rule_representation_data import *
10 | from torch import FloatTensor
11 |
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 |
15 | def setup_args():
16 | """
17 | Description: Takes in the command-line arguments from user
18 | """
19 | parser = argparse.ArgumentParser()
20 |
21 | # data related hyperparameters
22 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility")
23 | parser.add_argument("--input_data_dir", type=str, default='rule_classifier_data', help="base directory for the data")
24 | parser.add_argument("--data_split", type=str, default='val', help="train, val, test")
25 |
26 | # model related hyperparameters
27 | parser.add_argument("--emb_model_type", type=str, default='codebert', help="model to obtain embedding from")
28 | parser.add_argument("--repo", type=str, default='jata4test', help="model to obtain embedding from")
29 | return parser.parse_args()
30 |
31 | if __name__ == '__main__':
32 |
33 | args = setup_args()
34 |
35 | #Fix seeds
36 | np.random.seed(args.seed)
37 | os.environ['PYTHONHASHSEED'] = str(args.seed)
38 | torch.manual_seed(args.seed)
39 | random.seed(args.seed)
40 |
41 |
42 | # Define dataloaders
43 | kwargs = {'num_workers': 8, 'pin_memory': True} if device=='cuda' else {}
44 | tokenizer = set_tokenizer(args.emb_model_type)
45 | base_dir = os.path.join(args.input_data_dir, args.data_split)
46 | dataset = RuleReprDataset(base_dir, emb_model_type = args.emb_model_type, tokenizer=tokenizer)
47 | #for repo in os.listdir(base_dir):
48 | start, end = dataset.get_start_index(args.repo, start_offset=0, interval=0)
49 | print(args.repo, start, end)
50 | for batch, (rule_context, hole, repo_name) in enumerate(dataset):
51 | if batch > end:
52 | break
53 | if repo_name == args.repo:
54 | save_dir = os.path.join(base_dir, repo_name, args.emb_model_type +'_mod')
55 | os.makedirs(save_dir, exist_ok=True)
56 | rule_representation = {hole: rule_context}
57 | with open(os.path.join(save_dir, str(batch)) , 'wb') as f:
58 | pickle.dump(rule_representation, f)
59 |
--------------------------------------------------------------------------------
/get_info_from_hole_predictions.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | import io
4 | import torch
5 | import argparse
6 |
7 | device = 'cpu'
8 |
9 | # This order was obtained based on the decreasing order of success rate on the validation set
10 | rule_order = [5, 7, 6, 1, 20, 22, 2, 0, 25, 3, 23, 24, 28, 26, 4, 21, 62, 31, 27, 29, 30, 8, 34, 32, \
11 | 10, 33, 9, 35, 13, 11, 12, 36, 46, 44, 16, 14, 49, 45, 48, 40, 38, 19, 15, 18, 39, 43, 47,\
12 | 17, 42, 41, 58, 56, 57, 61, 59, 60, 37, 52, 50, 53, 55, 54, 51]
13 |
14 | projects = { 'train': [
15 | ('largemail' , 1653),
16 | ('ftpserverremoteadmin' , 7323),
17 | ('myt5lib' , 838),
18 | ('seamlets' , 4890),
19 | ('gloodb' , 10000),
20 | ('jjskit' , 9043),
21 | ('mobileexpensetracker' , 2298),
22 | ('gfsfa' , 10000),
23 | ('swe574-group3' , 2029),
24 | ('strudem-sicsa' , 6131),
25 | ('soap-dtc' , 1370),
26 | ('openprocesslogger' , 7191),
27 | ('tapestry-sesame', 397),
28 | ('exogdx' , 735),
29 | ('designpatternjavapedro' , 1069),
30 | ('quidsee' , 3020),
31 | ('healpix-rangeset' , 4734),
32 | ('sol-agent-platform' , 10000),
33 | ('rsbotownversion' , 10000),
34 |
35 | ],
36 |
37 | 'val': [
38 | ('tyrond', 721),
39 | ('math-mech-eshop', 2225),
40 | ('infinispan-storage-service', 373),
41 | ('teammates-shakthi', 7665),
42 | ('javasummerframework', 10000),
43 | ('tinwiki', 10000),
44 | ('jloogle', 3145),
45 | ('jcontenedor', 5464),
46 | ('sohocms', 772),
47 | ('affinity_propagation_java', 1466),
48 | ('jata4test', 1921),
49 | ('swinagile', 2595),
50 | ('navigablep2p', 1322),
51 | ('springlime', 879),
52 | ],
53 |
54 | 'test': [
55 | ('dovetaildb', 10000),
56 | ('project-pt-diaoc', 10000),
57 | ('realtimegc', 2513),
58 | ('fswuniceubtemplates', 2070),
59 | ('qwikioffice-java', 1138),
60 | ('glperaudsimon', 1766),
61 | ('xiaonei-java-api', 839),
62 | ('ircrpgbot', 6591),
63 | ('robotsimulator2009w', 7514),
64 | ('gwt-plugindetect', 73),
65 | ('apiitfriends', 1385),
66 | ('wicketbits', 754),
67 | ('hucourses', 590),
68 | ('xfuze', 3055),
69 | ]
70 | }
71 |
72 | def setup_args():
73 | """
74 | Description: Takes in the command-line arguments from user
75 | """
76 | parser = argparse.ArgumentParser()
77 |
78 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility")
79 | parser.add_argument("--hole_stats_file", type=str, default='hole', help="name of the prediction file to consider")
80 | parser.add_argument("--data_split", type=str, default='val', help="data_split")
81 | parser.add_argument("--base_dir", type=str, default='outputs', help="base dir")
82 | parser.add_argument("--k", type=int, default=1, help="how many rules to draw")
83 | return parser.parse_args()
84 |
85 | class CPU_Unpickler(pickle.Unpickler):
86 | """To load a pickle file stored in torch GPU setting to CPU."""
87 | def find_class(self, module, name):
88 | if module == 'torch.storage' and name == '_load_from_bytes':
89 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
90 | else:
91 | return super().find_class(module, name)
92 |
93 | def get_repo_name(hid):
94 | return hid.split('/')[2]
95 |
96 | def update_dict(dic, data_type):
97 | if 'small_' in data_type:
98 | return dic
99 | else:
100 | mod_dic = {}
101 | for k,v in dic.items():
102 | mod_k = '/'. join(['rule_classifier_data', data_type] + k.split('/')[2:])
103 | mod_dic[mod_k] = v
104 | return mod_dic
105 |
106 | def get_top_k_acc(hole_pred, hole_gt, k=1):
107 | top_preds, top_pred_indices = torch.topk(hole_pred, k)
108 | for top_pred_idx in top_pred_indices:
109 | if hole_gt[top_pred_idx] == 1:
110 | return 1.0
111 | return 0.0
112 |
113 | def get_rule_wise_nums(oracle):
114 | rule_success = {}
115 | for hid, entry in oracle.items():
116 | hole_gt= entry['com']
117 | for i in range(len(hole_gt)):
118 | if hole_gt[i] == 1:
119 | if i in rule_success:
120 | rule_success[i]+=1
121 | else:
122 | rule_success[i] = 1
123 | return rule_success
124 |
125 | def get_single_rule_acc(hole_gt, k):
126 | rules = rule_order[:k]
127 | for rule in rules:
128 | if hole_gt[rule] == 1:
129 | return 1.0
130 | return 0.0
131 |
132 | if __name__ == '__main__':
133 |
134 | args = setup_args()
135 |
136 | #Fix seeds
137 | os.environ['PYTHONHASHSEED'] = str(args.seed)
138 | torch.manual_seed(args.seed)
139 |
140 | k = args.k
141 | repo_stats={}
142 | single_rule_stats = {}
143 | # get rlpg predictions
144 | data = CPU_Unpickler(open(os.path.join(args.base_dir, args.data_split, args.hole_stats_file), 'rb')).load()
145 | oracle_dir = 'rule_classifier_data/' + args.data_split
146 | for repo, repo_count in projects[args.data_split ]:
147 | oracle = pickle.load(open(os.path.join(oracle_dir, repo, 'capped_oracle_10000'), 'rb'))
148 | oracle = update_dict(oracle, args.data_split )
149 | for hid, entry in oracle.items():
150 | em = get_top_k_acc(data[hid][1], oracle[hid]['com'], k)
151 | single_rule_em = get_single_rule_acc(oracle[hid]['com'], k)
152 | if repo in repo_stats:
153 | repo_stats[repo]+= em
154 | else:
155 | repo_stats[repo] = em
156 | if repo in single_rule_stats:
157 | single_rule_stats[repo]+= single_rule_em
158 | else:
159 | single_rule_stats[repo] = single_rule_em
160 |
161 | repo_success = 0.0
162 | single_rule_repo_success = 0.0
163 | for repo, repo_count in projects[args.data_split]:
164 | repo_success += repo_stats[repo]*100/repo_count
165 | single_rule_repo_success += single_rule_stats[repo]*100/repo_count
166 |
167 |
168 | total_count = 0
169 | total_success = 0.0
170 | total_single_rule_success = 0.0
171 | for repo, repo_count in projects[args.data_split ]:
172 | total_count+= repo_count
173 | total_success += repo_stats[repo]
174 | total_single_rule_success += single_rule_stats[repo]
175 |
176 | print(args.hole_stats_file + "," + str(k) + "," + str(repo_success/len(projects[args.data_split ])) \
177 | + "," + str(single_rule_repo_success/len(projects[args.data_split ]))\
178 | + "," + str(total_success*100/total_count) + "," + str(total_single_rule_success*100/total_count))
179 |
180 |
--------------------------------------------------------------------------------
/model_preprocessed_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 | from utils import *
6 | from transformers import AutoModel
7 | from transformers import GPT2Model
8 | from att_model import BasicAggModel
9 |
10 |
11 | class RuleModel(nn.Module):
12 | def __init__(self, emb_model_type, repr_size=768, device='cpu', n_head=1, d_k=128, d_proj=512,
13 | mode='rlpg-h', dropout=0.0):
14 | super(RuleModel, self).__init__()
15 |
16 |
17 | self.n_rules = len(combined_to_index)
18 | self.set_embedding_model(emb_model_type)
19 | self.repr_size = repr_size
20 | self.device = device
21 | self.mode = mode
22 |
23 | if self.mode == 'rlpg-h':
24 | self.hole_dense1 = nn.Linear(self.repr_size, d_proj)
25 | self.hole_dense2 = nn.Linear(d_proj, self.n_rules)
26 |
27 |
28 | if self.mode == 'rlpg-r':
29 | self.att_model = BasicAggModel(d_model=repr_size, n_head=n_head, d_k=d_k, n_rules=self.n_rules, device=self.device, \
30 | dropout=dropout)
31 |
32 |
33 | def get_representation(self, inputs, mask):
34 | outputs = self.emb_model(inputs, attention_mask=mask)
35 | try:
36 | representation = outputs.pooler_output
37 | except:
38 | representation = outputs.last_hidden_state[:, 0]
39 | return representation
40 |
41 | def get_context_embedding(self, context, attn_mask):
42 | context_embedding = self.get_representation(context, attn_mask)
43 | return context_embedding
44 |
45 | def forward(self, info):
46 |
47 | hole_inputs, hole_mask, rule_context_repr = info
48 | batch_size = hole_inputs.shape[0]
49 | # get hole window representation
50 | hole_window_repr = self.get_context_embedding(hole_inputs, hole_mask)
51 |
52 | # mask for invalid rules. It is 0 at positions where the rule is not valid.
53 | valid_rules_mask = (rule_context_repr != 0) #(bs, n_rules, repr_size)
54 |
55 | #get prediction from hole window
56 | if self.mode == 'rlpg-h':
57 | hole_pred = self.hole_dense2(F.relu(self.hole_dense1(hole_window_repr)))
58 | if len(hole_pred.shape)==1:
59 | hole_pred = torch.unsqueeze(hole_pred, dim=0)
60 | hole_pred = torch.sigmoid(hole_pred)
61 | return hole_pred, valid_rules_mask
62 |
63 | if self.mode == 'rlpg-r':
64 | rule_pred, att_weights = self.att_model(hole_window_repr, rule_context_repr, valid_rules_mask)
65 | if len(rule_pred.shape)==1:
66 | rule_pred = torch.unsqueeze(rule_pred, dim=0)
67 | rule_pred = torch.sigmoid(rule_pred)
68 | return rule_pred, valid_rules_mask
69 |
70 | def set_embedding_model(self, emb_model_type):
71 | if emb_model_type == 'gpt-2':
72 | self.emb_model = GPT2Model.from_pretrained("gpt2")
73 |
74 | # CodeBERT
75 | if emb_model_type == 'codebert':
76 | self.emb_model = AutoModel.from_pretrained("microsoft/codebert-base")
77 |
78 | # GraphCodeBERT
79 | if emb_model_type == 'graphcodebert':
80 | self.emb_model = AutoModel.from_pretrained("microsoft/graphcodebert-base")
81 |
82 | # freeze the parameters of the pretrained emb_model
83 | for param in self.emb_model.parameters():
84 | param.requires_grad = False
--------------------------------------------------------------------------------
/parse_tree.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import argparse
4 | from tree_sitter import Language, Parser
5 | from utils import *
6 | import copy
7 |
8 | """
9 | Obtain the parse tree for individual files and collate data at repo-level for rules.
10 | """
11 |
12 | Language.build_library('build/my-languages.so', ['tree-sitter-java'])
13 |
14 | JAVA_LANGUAGE = Language('build/my-languages.so', 'java')
15 |
16 | parser = Parser()
17 | parser.set_language(JAVA_LANGUAGE)
18 |
19 |
20 | def get_sibling_files(file, all_files):
21 | file_parts = file.split('/')
22 | root_dir = '/'.join(file_parts[:-1])
23 | sibling_files = []
24 | for f in os.listdir(root_dir):
25 | if os.path.splitext(f)[1] == '.java' and f != file_parts[-1]:
26 | sib_file = os.path.join(root_dir, f)
27 | sibling_files.append(sib_file)
28 | return sibling_files
29 |
30 | def camel_case_split(str):
31 | start_idx = [i for i, e in enumerate(str)
32 | if e.isupper()] + [len(str)]
33 |
34 | start_idx = [0] + start_idx
35 | return [str[x: y] for x, y in zip(start_idx, start_idx[1:])]
36 |
37 | def match_by_parts(file1, file2, split_type):
38 | # omit .java in the end
39 | f1 = file1.split('.')[0]
40 | f2 = file2.split('.')[0]
41 |
42 | if split_type == 'camel-case':
43 | f1_parts = camel_case_split(f1)
44 | f2_parts = camel_case_split(f2)
45 |
46 | if split_type == 'underscore':
47 | f1_parts = f1.split('_')
48 | f2_parts = f2.split('_')
49 |
50 | for p1 in f1_parts:
51 | if p1 and p1 in f2_parts:
52 | #print(split_type, file1, file2, p1, f1_parts, f2_parts)
53 | return True
54 | return False
55 |
56 |
57 | def match_similar_filenames(file1, file2):
58 | # exactly same name
59 | if file1 == file2:
60 | return True
61 |
62 | #camelcase split similar parts
63 | return match_by_parts(file1, file2, 'camel-case')
64 |
65 | #underscore split similar parts
66 | return match_by_parts(file1, file2, 'underscore')
67 |
68 |
69 | def get_similar_name_files(file, all_files):
70 | filename = file.split('/')[-1]
71 | similar_name_files = []
72 | for f in all_files:
73 | if f != file and match_similar_filenames(f.split('/')[-1], filename):
74 | similar_name_files.append(f)
75 | return similar_name_files
76 |
77 | def get_tree(filename):
78 | """
79 | obtain parse tree for a file
80 | """
81 | file_str = open(filename, encoding="utf8", errors='backslashreplace').read()
82 | tree = parser.parse(bytes(file_str, "utf-8"))
83 | root_node = tree.root_node
84 | return root_node
85 |
86 | def parse_captures(captures, filename):
87 | text_spans = []
88 | for capture in captures:
89 | #capture[1] = property_name
90 | start, end = capture[0].start_point, capture[0].end_point
91 | #text = get_string(filename, start, end)
92 | text_spans.append((start, end))
93 | return text_spans
94 |
95 | def get_query(attribute_type):
96 |
97 | if attribute_type == 'class_name':
98 | query = JAVA_LANGUAGE.query("""(class_declaration
99 | name: (identifier) @class_name)""")
100 |
101 | if attribute_type == 'class_body':
102 | query = JAVA_LANGUAGE.query("""(class_declaration
103 | body: (class_body) @class_body)""")
104 |
105 | if attribute_type == 'parent_class_name':
106 | query = JAVA_LANGUAGE.query("""(class_declaration
107 | name: (identifier)
108 | superclass: (superclass (type_identifier) @superclass_name))""")
109 |
110 | if attribute_type == 'all_method_name':
111 | query = JAVA_LANGUAGE.query("""(method_declaration
112 | name: (identifier) @all_method_name)""")
113 |
114 | if attribute_type == 'all_method_body':
115 | query = JAVA_LANGUAGE.query("""(method_declaration body: (block) @all_method_block)""")
116 |
117 | if attribute_type == 'import_statement':
118 | query = JAVA_LANGUAGE.query("""(import_declaration (
119 | scoped_identifier
120 | name: (identifier)) @import_statement)""")
121 |
122 | if attribute_type == 'all_field_declaration':
123 | query = JAVA_LANGUAGE.query("""(field_declaration) @field_declaration""")
124 |
125 | if attribute_type == 'all_string_literal':
126 | query = JAVA_LANGUAGE.query("""(string_literal) @string_literal""")
127 |
128 | if attribute_type == 'all_identifier':
129 | query = JAVA_LANGUAGE.query("""(identifier) @identifier""")
130 |
131 | if attribute_type == 'all_type_identifier':
132 | query = JAVA_LANGUAGE.query("""(type_identifier) @type_identifier""")
133 |
134 | return query
135 |
136 | def get_attribute(root_node, filename, attribute_type):
137 |
138 | query = get_query(attribute_type)
139 | captures = query.captures(root_node)
140 | if captures:
141 | attributes = parse_captures(captures, filename)
142 | else:
143 | attributes = [((-1, -1), (-1, -1))]
144 | return attributes
145 |
146 | def get_import_path(import_stat, file):
147 | import_stat_str = get_string(file, import_stat[0], import_stat[1])
148 | #print(import_stat_str, file)
149 | import_path_parts = import_stat_str.split(".")
150 | absolute_import_path = []
151 | import_path_part = import_path_parts[0]
152 | if import_path_part != 'java':
153 | file_path_parts = file.split("/")
154 | try:
155 | index_pos = len(file_path_parts) - file_path_parts[::-1].index(import_path_part) - 1
156 | absolute_import_path = file_path_parts[:index_pos] + import_path_parts
157 | except ValueError as e:
158 | print('')
159 | #print(absolute_import_path)
160 | if absolute_import_path:
161 | import_path = '/'.join(absolute_import_path)
162 | import_path = import_path + '.java'
163 | return import_path
164 | else:
165 | return ''
166 |
167 | def get_parent_class_filename(parent_class_name, file_class_info, file):
168 | parent_class_filename = ''
169 | if parent_class_name:
170 | parent_class_name_text = get_string(file, parent_class_name[0], parent_class_name[1])
171 | # we don't want the current file to be the parent class file
172 | copy_file_class_info = copy.deepcopy(file_class_info)
173 | del copy_file_class_info[file]
174 |
175 | if parent_class_name_text:
176 | # search for the parent class name in all files
177 | found = False
178 | for (k,v) in copy_file_class_info.items():
179 | for val in v:
180 | if val==parent_class_name_text:
181 | parent_class_filename = k
182 | found = True
183 | break
184 | return parent_class_filename
185 |
186 | def find_relevant_file_identifier(import_identifier, file_identifiers, file):
187 | candidate_file_identifiers = []
188 | for file_identifier in file_identifiers:
189 | if file_identifier:
190 | file_identifier_str = get_string(file, file_identifier[0], file_identifier[1])
191 | if file_identifier_str == import_identifier:
192 | candidate_file_identifiers.append(file_identifier)
193 | return candidate_file_identifiers[1:]
194 |
195 | def get_imports(import_statements, file, all_identifiers, all_type_identifiers):
196 | imports = {}
197 | file_identifiers = all_identifiers
198 | file_identifiers.extend(all_type_identifiers)
199 | for import_stat in import_statements:
200 | import_file_path = get_import_path(import_stat, file)
201 | if import_file_path and os.path.isfile(import_file_path):
202 | import_identifier = import_file_path.split('/')[-1].split('.')[0]
203 | candidate_file_identifiers = find_relevant_file_identifier(import_identifier, file_identifiers, file)
204 | if candidate_file_identifiers:
205 | imports[import_file_path] = candidate_file_identifiers
206 | return imports
207 |
208 | def check_empty_attribute(attribute):
209 | if len(attribute) == 1 and attribute[0][0][0] == -1:
210 | attribute = []
211 | return attribute
212 |
213 | def update_attribute(parse_data, att_type, files):
214 | count = 0
215 | for file in files:
216 | current_file_imports = list(parse_data[file]['imports'].keys())
217 | att_files = parse_data[file][att_type]
218 | att_info = []
219 | for att_file in att_files:
220 | if att_file:
221 | att_file_imports = list(parse_data[att_file]['imports'].keys())
222 | overlapping_imports = find_similar_intersection(att_file_imports, current_file_imports)
223 | #if len(overlapping_imports) > 0:
224 | att_info.append((att_file, len(overlapping_imports)))
225 | #print(file, att_file, overlapping_imports)
226 | parse_data[file][att_type] = att_info
227 | if att_info:
228 | count+=1
229 | #print(file, parse_data[file][att_type])
230 | #print(count)
231 | return parse_data
232 |
233 | def update_child_class_info(parse_data, child_class_info):
234 | for file, file_parse_data in parse_data.items():
235 | if file in child_class_info:
236 | parse_data[file]['child_class_filenames'] = child_class_info[file]
237 | else:
238 | parse_data[file]['child_class_filenames'] = []
239 | return parse_data
240 |
241 | def setup_args():
242 | """
243 | Description: Takes in the command-line arguments from user
244 | """
245 | parser = argparse.ArgumentParser()
246 |
247 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility")
248 | parser.add_argument("--base_dir", type=str, default='rule_classifier_data/val', \
249 | help="base directory for the data")
250 | parser.add_argument("--proj_name", type=str, default='rsbotownversion', \
251 | help="name of the input repo")
252 |
253 | return parser.parse_args()
254 |
255 | if __name__ == '__main__':
256 |
257 | args = setup_args()
258 |
259 | #Fix seeds
260 | np.random.seed(args.seed)
261 | os.environ['PYTHONHASHSEED']=str(args.seed)
262 |
263 | input_data_path = os.path.join(args.base_dir, args.proj_name)
264 | os.makedirs(input_data_path, exist_ok=True)
265 |
266 | files = [os.path.join(dp, f) \
267 | for dp, dn, filenames in os.walk(input_data_path) \
268 | for f in filenames \
269 | if os.path.splitext(f)[1] == '.java']
270 |
271 | file_class_info = {}
272 | for file in files:
273 | root_node = get_tree(file)
274 | class_names = get_attribute(root_node, file, 'class_name')
275 | file_class_names = []
276 | for cn in class_names:
277 | start, end = cn
278 | class_name = get_string(file, start, end)
279 | file_class_names.append(class_name)
280 | file_class_info[file] = file_class_names
281 | #print(file_class_info)
282 |
283 | with open(os.path.join(input_data_path, 'file_class_data'), 'wb') as f:
284 | pickle.dump(file_class_info, f)
285 |
286 | parse_data = {}
287 | child_class_info = {}
288 |
289 | similar_count = 0
290 | sibling_count = 0
291 |
292 | for file in files:
293 | root_node = get_tree(file)
294 | sibling_files = get_sibling_files(file, files)
295 | similar_name_files = get_similar_name_files(file, files)
296 | if len(similar_name_files) > 0:
297 | similar_count +=1
298 | if len(sibling_files) > 0:
299 | sibling_count +=1
300 |
301 | class_names = get_attribute(root_node, file, 'class_name')
302 | class_bodies = get_attribute(root_node, file, 'class_body')
303 | parent_class_names = get_attribute(root_node, file, 'parent_class_name')
304 | all_field_declarations = get_attribute(root_node, file, 'all_field_declaration')
305 | all_string_literals = get_attribute(root_node, file, 'all_string_literal')
306 | all_identifiers = get_attribute(root_node, file, 'all_identifier')
307 | all_type_identifiers = get_attribute(root_node, file, 'all_type_identifier')
308 | all_method_names = get_attribute(root_node, file, 'all_method_name')
309 | all_method_bodies = get_attribute(root_node, file, 'all_method_body')
310 | import_statements = get_attribute(root_node, file, 'import_statement')
311 |
312 | class_names = check_empty_attribute(class_names)
313 | class_bodies = check_empty_attribute(class_bodies)
314 | parent_class_names = check_empty_attribute(parent_class_names)
315 | all_field_declarations = check_empty_attribute(all_field_declarations)
316 | all_identifiers = check_empty_attribute(all_identifiers)
317 | all_type_identifiers = check_empty_attribute(all_type_identifiers)
318 | all_string_literals = check_empty_attribute(all_string_literals)
319 | all_method_names = check_empty_attribute(all_method_names)
320 | all_method_bodies = check_empty_attribute(all_method_bodies)
321 | import_statements = check_empty_attribute(import_statements)
322 |
323 | # get imports
324 | imports = get_imports(import_statements, file, all_identifiers, all_type_identifiers)
325 |
326 | parent_class_filenames = []
327 | mod_parent_class_names = []
328 | for parent_class_name in parent_class_names:
329 | parent_class_filename = get_parent_class_filename(parent_class_name, file_class_info, file)
330 | if parent_class_filename:
331 | mod_parent_class_names.append(parent_class_name)
332 | if parent_class_filename in child_class_info:
333 | child_class_info[parent_class_filename].append(file)
334 | else:
335 | child_class_info[parent_class_filename] = [file]
336 | parent_class_filenames.append(parent_class_filename)
337 |
338 | #print(parent_class_names, parent_class_filenames)
339 | assert len(mod_parent_class_names) == len(parent_class_filenames)
340 |
341 | #store the data in dict form
342 | parse_data[file] = {
343 | 'class_names': class_names,\
344 | 'class_bodies': class_bodies, \
345 | 'parent_class_names': mod_parent_class_names, \
346 | 'parent_class_filenames': parent_class_filenames, \
347 | 'imports': imports, \
348 | 'field_declarations': all_field_declarations, \
349 | 'string_literals': all_string_literals, \
350 | 'identifiers': all_identifiers, \
351 | 'type_identifiers': all_type_identifiers, \
352 | 'all_method_names': all_method_names, \
353 | 'all_method_bodies': all_method_bodies, \
354 | 'sibling_files': sibling_files, \
355 | 'similar_name_files': similar_name_files}
356 |
357 | print(len(files), sibling_count, similar_count)
358 | print("updating sibling files")
359 | parse_data = update_attribute(parse_data, 'sibling_files', files)
360 | print("updating similar_name_files")
361 | parse_data = update_attribute(parse_data, 'similar_name_files', files)
362 | print("updating child class filenames")
363 | parse_data = update_child_class_info(parse_data, child_class_info)
364 | parse_data = update_attribute(parse_data, 'child_class_filenames', files)
365 | print("updating parent class filenames")
366 | parse_data = update_attribute(parse_data, 'parent_class_filenames', files)
367 |
368 | print("Writing parse data...")
369 | with open(os.path.join(input_data_path, 'parsed_data'), 'wb') as f:
370 | pickle.dump(parse_data, f)
371 |
--------------------------------------------------------------------------------
/preprocessed_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import pickle
5 | from torch.utils.data import Dataset
6 | from transformers import GPT2TokenizerFast, AutoTokenizer
7 | from transformers import DataCollatorWithPadding
8 | from utils import *
9 | import re
10 |
11 |
12 | class RuleDataset(Dataset):
13 |
14 | def __init__(self, input_data_dir, tokenizer=None, emb_model_type='codebert'):
15 |
16 | self.input_data_dir = input_data_dir
17 | self.tokenizer = tokenizer
18 | data_type = input_data_dir.split('/')[-1]
19 | oracles = {}
20 | self.data_files = []
21 | for dp, dn, filenames in os.walk(input_data_dir):
22 | for f in filenames:
23 | if f == 'capped_oracle_10000':
24 | oracle = pickle.load(open(os.path.join(dp, f), 'rb'))
25 | oracle = self.update_dict(oracle, data_type)
26 | oracles = {**oracles, **oracle}
27 |
28 | for dp, dn, filenames in os.walk(input_data_dir):
29 | if dp.split('/')[-1] == 'capped_'+ emb_model_type + '_mod':
30 | for f in filenames:
31 | self.data_files.append(os.path.join(dp, f))
32 |
33 | self.oracles = oracles
34 | #print(self.oracles.keys())
35 | print("oracle")
36 | print(data_type, len(self.oracles))
37 | print("data_files")
38 | print(data_type, len(self.data_files))
39 | self.data_type = data_type
40 | self.num_combined = len(combined_to_index)
41 |
42 | def __len__(self):
43 | return len(self.data_files)
44 |
45 | def __getitem__(self, idx):
46 | return self.generate_data(self.data_files[idx])
47 |
48 | def update_dict(self, dic, data_type):
49 | if 'small_' in data_type:
50 | return dic
51 | else:
52 | mod_dic = {}
53 | for k,v in dic.items():
54 | mod_k = '/'. join(['rule_classifier_data', data_type] + k.split('/')[2:])
55 | mod_dic[mod_k] = v
56 | return mod_dic
57 |
58 | def generate_data(self, data_file):
59 | data = pickle.load(open(data_file, 'rb'))
60 | data = self.update_dict(data, self.data_type)
61 | for hole, rule_context in data.items():
62 | hole_context = self.get_hole_context(hole)
63 | if hole in self.oracles:
64 | combined = self.oracles[hole]['com']
65 | failure_flag = 1
66 | else:
67 | combined = np.zeros(self.num_combined)
68 | failure_flag = 0
69 | return hole_context, rule_context, combined, hole, failure_flag
70 |
71 | def get_hole_context(self, hole, num_of_prev_lines=2, num_of_post_lines=2):
72 | '''
73 | return the pre_context_len tokens from the current file based on codex tokenization from the position of the cursor
74 | '''
75 | hole_parts = hole.split('/')[-1].split('_')
76 | repo_name = hole.split('/')[2]
77 | if len(hole_parts) > 3:
78 | new_hole_parts = hole_parts[:-2]
79 | filename = '_'.join(new_hole_parts)
80 | filename = [filename]
81 | else:
82 | filename = [hole_parts[0]]
83 | file = '/'.join(hole.split('/')[:-1] + filename)
84 | pos = (int(hole_parts[-2]), int(hole_parts[-1]))
85 |
86 | pre_end = pos
87 | pre_start_line = pos[0] - num_of_prev_lines
88 | if pre_start_line < 0:
89 | pre_start_line = 0
90 | pre_start = (pre_start_line, 0)
91 | pre_hole_context = get_string(file, pre_start, pre_end)
92 |
93 | post_hole_context = ""
94 | if num_of_post_lines > 0:
95 | file_lines = open(file, encoding="utf8", errors='backslashreplace').readlines()
96 | post_start_line = pos[0] + 1
97 | if post_start_line < len(file_lines):
98 | post_end_line = pos[0] + num_of_post_lines
99 | if post_end_line >= len(file_lines):
100 | post_end_line = len(file_lines) - 1
101 | post_start = (post_start_line, 0)
102 | post_end = (post_end_line, len(file_lines[post_end_line]))
103 | post_hole_context = get_string(file, post_start, post_end)
104 | hole_context = post_hole_context + " " + pre_hole_context
105 | hole_context = self.tokenizer(hole_context, truncation=True)
106 | return hole_context
107 |
108 | def collate_fn(data):
109 | hole_context, rule_contexts, gt_com, hole_id, failure_flag = zip(*data)
110 | hole_context = data_collator(hole_context)
111 | rule_contexts = torch.stack(rule_contexts, dim=0)
112 | #print("rule_contexts:", torch.sum(rule_contexts, dim=-1))
113 | gt_com = torch.FloatTensor(gt_com)
114 | failure_flag = torch.IntTensor(failure_flag)
115 | return hole_context['input_ids'], hole_context['attention_mask'], \
116 | rule_contexts, \
117 | gt_com, \
118 | hole_id, \
119 | failure_flag
120 |
121 | def set_tokenizer(emb_model_type):
122 | global data_collator
123 | if emb_model_type == 'codebert':
124 | tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
125 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
126 | if emb_model_type == 'graphcodebert':
127 | tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
128 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
129 | if emb_model_type == 'gpt-2':
130 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
131 | tokenizer.pad_token = tokenizer.eos_token
132 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
133 | return tokenizer
134 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/projects.txt:
--------------------------------------------------------------------------------
1 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/rsbotownversion/source-archive.zip
2 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/sol-agent-platform/source-archive.zip
3 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/project-pt-diaoc/source-archive.zip
4 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/gloodb/source-archive.zip
5 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/dovetaildb/source-archive.zip
6 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/tinwiki/source-archive.zip
7 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/jjskit/source-archive.zip
8 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/openprocesslogger/source-archive.zip
9 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/teammates-shakthi/source-archive.zip
10 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/ftpserverremoteadmin/source-archive.zip
11 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/ircrpgbot/source-archive.zip
12 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/strudem-sicsa/source-archive.zip
13 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/seamlets/source-archive.zip
14 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/robotsimulator2009w/source-archive.zip
15 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/healpix-rangeset/source-archive.zip
16 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/jcontenedor/source-archive.zip
17 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/qwikioffice-java/source-archive.zip
18 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/jloogle/source-archive.zip
19 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/xiaonei-java-api/source-archive.zip
20 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/xfuze/source-archive.zip
21 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/realtimegc/source-archive.zip
22 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/swinagile/source-archive.zip
23 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/swe574-group3/source-archive.zip
24 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/math-mech-eshop/source-archive.zip
25 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/quidsee/source-archive.zip
26 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/glperaudsimon/source-archive.zip
27 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/fswuniceubtemplates/source-archive.zip
28 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/mobileexpensetracker/source-archive.zip
29 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/jata4test/source-archive.zip
30 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/largemail/source-archive.zip
31 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/apiitfriends/source-archive.zip
32 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/navigablep2p/source-archive.zip
33 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/wicketbits/source-archive.zip
34 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/soap-dtc/source-archive.zip
35 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/designpatternjavapedro/source-archive.zip
36 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/exogdx/source-archive.zip
37 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/tyrond/source-archive.zip
38 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/springlime/source-archive.zip
39 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/hucourses/source-archive.zip
40 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/sohocms/source-archive.zip
41 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/infinispan-storage-service/source-archive.zip
42 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/myt5lib/source-archive.zip
43 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/tapestry-sesame/source-archive.zip
44 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/gwt-plugindetect/source-archive.zip
45 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/javasummerframework/source-archive.zip
46 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/affinity_propagation_java/source-archive.zip
47 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/gfsfa/source-archive.zip
48 |
--------------------------------------------------------------------------------
/rearrange_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import pickle
4 | import random
5 | import numpy as np
6 |
7 | seed = 9
8 |
9 | #Fix seeds
10 | os.environ['PYTHONHASHSEED'] = str(seed)
11 | random.seed(seed)
12 | np.random.seed(seed)
13 |
14 | projects = { 'train': [
15 | 'gfsfa',
16 | 'sol-agent-platform',
17 | 'gloodb',
18 | 'rsbotownversion',
19 | 'jjskit',
20 | 'ftpserverremoteadmin',
21 | 'openprocesslogger',
22 | 'strudem-sicsa',
23 | 'seamlets',
24 | 'healpix-rangeset',
25 | 'quidsee',
26 | 'mobileexpensetracker',
27 | 'swe574-group3',
28 | 'largemail',
29 | 'soap-dtc',
30 | 'designpatternjavapedro',
31 | 'myt5lib',
32 | 'exogdx',
33 | 'tapestry-sesame'
34 | ],
35 |
36 | 'val': [
37 | 'javasummerframework',
38 | 'tinwiki',
39 | 'teammates-shakthi',
40 | 'jcontenedor',
41 | 'jloogle',
42 | 'swinagile',
43 | 'math-mech-eshop',
44 | 'jata4test',
45 | 'affinity_propagation_java',
46 | 'navigablep2p',
47 | 'springlime',
48 | 'sohocms',
49 | 'tyrond',
50 | 'infinispan-storage-service',
51 | ],
52 |
53 | 'test': [
54 | 'project-pt-diaoc',
55 | 'dovetaildb',
56 | 'robotsimulator2009w',
57 | 'ircrpgbot',
58 | 'xfuze',
59 | 'realtimegc',
60 | 'fswuniceubtemplates',
61 | 'glperaudsimon',
62 | 'apiitfriends',
63 | 'qwikioffice-java',
64 | 'xiaonei-java-api',
65 | 'wicketbits',
66 | 'hucourses',
67 | 'gwt-plugindetect'
68 | ]
69 | }
70 |
71 |
72 |
73 | repo_split_map = {}
74 | for split, repos in projects.items():
75 | for repo in repos:
76 | repo_split_map[repo] = split
77 |
78 | max_holes = 10000
79 |
80 | def is_move(base_dir, split, repo):
81 | new_split = repo_split_map[repo]
82 | if new_split != split:
83 | shutil.move(os.path.join(base_dir, split, repo), os.path.join(base_dir, new_split, repo))
84 |
85 | def find_single_best_rule_success(rule_mapping):
86 | best_single_rule_success = 0
87 | for k, v in rule_mapping.items():
88 | if len(v)> best_single_rule_success:
89 | best_rule = k
90 | best_single_rule_success = len(v)
91 | return best_rule, best_single_rule_success
92 |
93 | def find_rule_mapping(oracle):
94 | rule_mapping = {}
95 | for hid, entry in oracle.items():
96 | rules = entry['com']
97 | success_rule_positions = np.where(rules == 1)[0]
98 | for s_r_p in success_rule_positions:
99 | if s_r_p not in rule_mapping:
100 | rule_mapping[s_r_p] = [hid]
101 | else:
102 | rule_mapping[s_r_p].append(hid)
103 | return rule_mapping
104 |
105 | def get_new_oracle_numbers(capped_oracle, repo, total_holes):
106 | rule_mapping = find_rule_mapping(capped_oracle)
107 | codex_success = len(rule_mapping[62])
108 | best_rule, best_rule_success = find_single_best_rule_success(rule_mapping)
109 | best_single_rule_success = len(rule_mapping[7])
110 | print(
111 | repo + ", " + \
112 | str(total_holes) + ", " + \
113 | str(float(len(capped_oracle)*100/total_holes)) + ", " + \
114 | str(float(codex_success*100/total_holes)) + ", " + \
115 | str(best_rule) + ", " +\
116 | str(float(best_rule_success*100/total_holes)) + ", " + \
117 | "in_file_lines_0.75" + ", " +\
118 | str(float(best_single_rule_success*100/total_holes))
119 | )
120 |
121 | def rewrite_rule_context_data(repo_path, capped_holes, emb_model_type):
122 | all_files = os.listdir(os.path.join(repo_path, emb_model_type))
123 | os.makedirs(os.path.join(repo_path, 'capped_'+ emb_model_type), exist_ok=True)
124 | for file in all_files:
125 | hole_path = os.path.join(repo_path, emb_model_type, file)
126 | data = pickle.load(open(hole_path, 'rb'))
127 | hole = list(data.keys())[0]
128 | if hole in capped_holes:
129 | dest_hole_path = os.path.join(repo_path, 'capped_'+ emb_model_type, file)
130 | shutil.copy(hole_path, dest_hole_path)
131 |
132 | def rearrange_data(base_dir, split):
133 | print(split)
134 | all_dirs = os.listdir(os.path.join(base_dir, split))
135 | for repo in all_dirs:
136 | if repo in repo_split_map:
137 | print(base_dir, split, repo)
138 | repo_holes = []
139 | hole_data = pickle.load(open(os.path.join(base_dir, split, repo, 'hole_data'), 'rb'))
140 | oracle = pickle.load(open(os.path.join(base_dir, split, repo, 'oracle'), 'rb'))
141 | duplicate_files = open(os.path.join(base_dir, split, repo, 'duplicates'), 'r').readlines()
142 | all_duplicate_files = [x.strip() for x in duplicate_files]
143 | for file, holes in hole_data.items():
144 | if file not in all_duplicate_files and not file.startswith('rule_classifier_data/val/rsbotownversion/trunk/scripts/'):
145 | hids = [file + '_' + str(h[0]) + '_' + str(h[1]) for h in holes]
146 | repo_holes.extend(hids)
147 | #print(len(repo_holes))
148 | if len(repo_holes) < max_holes:
149 | capped_holes = repo_holes
150 | capped_oracle = oracle
151 | total_holes = len(repo_holes)
152 | else:
153 | capped_holes = random.sample(repo_holes, max_holes)
154 | capped_oracle = {}
155 | for hid, entry in oracle.items():
156 | if hid in capped_holes:
157 | capped_oracle[hid] = entry
158 | total_holes = len(capped_holes)
159 |
160 | get_new_oracle_numbers(capped_oracle, repo, total_holes)
161 | with open(os.path.join(base_dir, split, repo, 'capped_oracle_'+ str(max_holes)), 'wb') as f:
162 | pickle.dump(capped_oracle, f)
163 |
164 | with open(os.path.join(base_dir, split, repo, 'capped_holes_'+ str(max_holes)), 'w') as f:
165 | for item in capped_holes:
166 | f.write("%s\n" %(item,))
167 | capped_holes = open(os.path.join(base_dir, split, repo, 'capped_holes_10000'), 'r').readlines()
168 | capped_holes = [x.strip() for x in capped_holes]
169 |
170 | rewrite_rule_context_data(os.path.join(base_dir, split, repo), capped_holes, 'codebert_mod')
171 |
172 | is_move(base_dir, split, repo)
173 |
174 | rearrange_data('rule_classifier_data', 'train')
175 | rearrange_data('rule_classifier_data', 'val')
176 | rearrange_data('rule_classifier_data', 'test')
177 |
178 |
179 |
--------------------------------------------------------------------------------
/rule_classifier_preprocessed_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import argparse
5 | import random
6 | from torch.utils.data import DataLoader
7 | from torch.utils.tensorboard import SummaryWriter
8 | from torch.autograd import Variable
9 | from torch import nn
10 | from tqdm import tqdm
11 | from preprocessed_data import *
12 | from model_preprocessed_data import RuleModel
13 | from torch import FloatTensor
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 |
17 |
18 | def setup_args():
19 | """
20 | Description: Takes in the command-line arguments from user
21 | """
22 | parser = argparse.ArgumentParser()
23 |
24 | #data hyperparams
25 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility")
26 | parser.add_argument("--input_data_dir", type=str, default='rule_classifier_data', help="base directory for the data")
27 | parser.add_argument("--model_output_path", type=str, default='models/', help="base directory for storing the models")
28 | parser.add_argument("--output_path", type=str, default='outputs/', help="base directory for storing the models")
29 | parser.add_argument("--batch_size", type=int, default=64, help="batch size for training the classifier")
30 | parser.add_argument("--emb_model_type", type=str, default='codebert', help="model to obtain embedding from")
31 | #training hyperparams
32 | parser.add_argument("--num_epochs", type=int, default=5000, help="number of epochs for training the classifier")
33 | parser.add_argument("--optimizer", type=str, default='adam', help="optimizer to use for training")
34 | parser.add_argument("--lr_scheduler", type=str, default='none', help="optimizer to use for training")
35 | parser.add_argument("--learning_rate", type=float, default=3e-4, help="learning rate for training the classifier")
36 | parser.add_argument("--patience", type=int, default=5000, help="patience for early-stop")
37 | parser.add_argument('--load_from_checkpoint', default=False, action='store_true')
38 | #model hyperparams
39 | parser.add_argument("--mode", type=str, default='rlpg-r', help="rule classifier variant: rlpg-h, rlpg-r")
40 | parser.add_argument("--n_head", type=int, default=4, help="number of heads")
41 | parser.add_argument("--d_k", type=int, default=32, help="depth of projection")
42 | parser.add_argument("--dropout", type=float, default=0.25, help="depth of projection")
43 |
44 |
45 | return parser.parse_args()
46 |
47 | def save(model, optimizer, epoch, save_dir):
48 | to_save = model.module if hasattr(model, "module") else model
49 | # pytorch_total_params = sum(p.numel() for p in to_save.parameters())
50 | # pytorch_trainable_params = sum(p.numel() for p in to_save.parameters() if p.requires_grad)
51 | # print(pytorch_trainable_params, pytorch_total_params)
52 | torch.save(to_save.state_dict(), os.path.join(save_dir, "best_model.th"))
53 | torch.save({"optimizer": optimizer.state_dict(), "last_epoch": epoch}, os.path.join(save_dir, "optim.th"))
54 |
55 | def get_accuracy(pred, gold, mask):
56 | pred = pred.masked_fill(mask==0, 0)
57 | max_idx = torch.argmax(pred, dim=1, keepdim=True)
58 | rounded_pred = torch.round(pred)
59 | max_idx_gold_vals = torch.gather(gold, 1, max_idx)
60 | mean_highest_success_correct = (max_idx_gold_vals == 1).to(dtype=torch.float).mean()
61 | return mean_highest_success_correct
62 |
63 | def get_prediction(rule_model, info):
64 | pred, mask = rule_model(info)
65 | mask = torch.sum(mask, dim=-1) #(bs, #rules)
66 | return pred, mask
67 |
68 | def calculate_loss(rule_model, criterion, info, gt):
69 |
70 | pred, mask = get_prediction(rule_model, info)
71 | n_valid_entries = torch.sum(mask.view(-1)!=0)
72 | loss = criterion(pred, gt)
73 | loss = loss.masked_fill(mask==0, 0)
74 | loss = torch.sum(loss)/n_valid_entries
75 | mean_highest_success_correct = get_accuracy(pred, gt, mask)
76 | masked_gt = torch.sum(gt.masked_fill(mask==0, 0), dim=-1)
77 | mean_oracle_success = masked_gt.masked_fill(masked_gt!=0, 1.0).mean()
78 |
79 | return {'loss': loss, \
80 | 'mean_highest_success_correct': mean_highest_success_correct}, \
81 | mean_oracle_success
82 |
83 |
84 | if __name__ == '__main__':
85 |
86 | args = setup_args()
87 |
88 | #Fix seeds
89 | np.random.seed(args.seed)
90 | os.environ['PYTHONHASHSEED'] = str(args.seed)
91 | torch.manual_seed(args.seed)
92 | random.seed(args.seed)
93 |
94 | #Define paths for storing tensorboard logs
95 | dir_name = 'optimizer#' + args.optimizer + '#learning_rate#' + str(args.learning_rate) + '#lr_scheduler#' + args.lr_scheduler \
96 | + '#emb_model_type#' + args.emb_model_type + '#n_head#' + str(args.n_head) + '#d_k#' + str(args.d_k) \
97 | + '#mode#' + args.mode + '#dropout#' + str(args.dropout)
98 |
99 | save_dir = os.path.join(args.model_output_path, dir_name)
100 | os.makedirs(save_dir, exist_ok=True)
101 | tb_writer = SummaryWriter(os.path.join(save_dir, "logs"))
102 | os.makedirs(args.output_path, exist_ok=True)
103 | f_out = open(os.path.join(args.output_path, dir_name), 'w')
104 |
105 | # Define train and val dataloaders
106 | kwargs = {'num_workers': 8, 'pin_memory': True} if device=='cuda' else {}
107 | tokenizer = set_tokenizer(args.emb_model_type)
108 | #print(tokenizer)
109 | train_dataset = RuleDataset(os.path.join(args.input_data_dir, 'train'), tokenizer=tokenizer, emb_model_type=args.emb_model_type)
110 | train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, **kwargs)
111 |
112 | val_dataset = RuleDataset(os.path.join(args.input_data_dir, 'val'), tokenizer=tokenizer, emb_model_type=args.emb_model_type)
113 | val_data_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, **kwargs)
114 |
115 | # Define the model
116 | rule_model = RuleModel(emb_model_type=args.emb_model_type, device=device, n_head=args.n_head, d_k=args.d_k, \
117 | mode = args.mode, dropout=args.dropout)
118 | rule_model.to(device)
119 |
120 | #Define optimizer and loss
121 | if args.lr_scheduler == 'none':
122 | if args.optimizer == 'adam':
123 | optimizer = torch.optim.Adam(rule_model.parameters(), lr=args.learning_rate)
124 | if args.optimizer == 'sgd':
125 | optimizer = torch.optim.SGD(rule_model.parameters(), lr=args.learning_rate)
126 |
127 | if args.lr_scheduler == 'cosine':
128 | if args.optimizer == 'adam':
129 | optimizer = torch.optim.Adam(rule_model.parameters(), lr=args.learning_rate)
130 | if args.optimizer == 'sgd':
131 | optimizer = torch.optim.SGD(rule_model.parameters(), lr=args.learning_rate)
132 | lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
133 |
134 | if args.lr_scheduler == 'cosinewarm':
135 | if args.optimizer == 'adam':
136 | optimizer = torch.optim.Adam(rule_model.parameters(), lr=args.learning_rate)
137 | if args.optimizer == 'sgd':
138 | optimizer = torch.optim.SGD(rule_model.parameters(), lr=args.learning_rate)
139 | lr_sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10)
140 |
141 | if args.lr_scheduler == 'tr2':
142 | if args.optimizer == 'adam':
143 | optimizer = torch.optim.Adam(rule_model.parameters(), lr=args.learning_rate)
144 | if args.att_lr_optimizer == 'sgd':
145 | optimizer = torch.optim.SGD(rule_model.parameters(), lr=args.learning_rate)
146 | lr_sched = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=args.learning_rate/2.0, max_lr=0.1,
147 | step_size_up=5, mode="triangular2", cycle_momentum=False)
148 |
149 | if args.lr_scheduler == 'reduceonplateau':
150 | if args.optimizer == 'adam':
151 | optimizer = torch.optim.Adam(rule_model.parameters(), lr=args.learning_rate)
152 | if args.optimizer == 'sgd':
153 | optimizer = torch.optim.SGD(rule_model.parameters(), lr=args.learning_rate)
154 | lr_sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max')
155 |
156 |
157 | if args.load_from_checkpoint:
158 | print("=> loading checkpoint '{}'".format(save_dir))
159 | model_path = os.path.join(save_dir, 'best_model.th')
160 | opt_path = os.path.join(save_dir, 'optim.th')
161 | status_dict = torch.load(opt_path, map_location=torch.device('cpu'))
162 | rule_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
163 | optimizer.load_state_dict(status_dict['optimizer'])
164 | print("=> loaded checkpoint '{}' (epoch {})".format(save_dir, status_dict['last_epoch']))
165 |
166 | criterion = nn.BCELoss(reduction='none')
167 |
168 |
169 | best_val_acc = 0
170 |
171 | for epoch in range(args.num_epochs):
172 | print("Epoch %d" % epoch)
173 | f_out.write("Epoch: " + str(epoch)+"\n")
174 |
175 | ########################Training Loop#############################################
176 | total_highest_success_correct, total_loss = 0.0, 0.0
177 |
178 | total_batches = 0
179 | total_oracle_success = 0.0
180 |
181 | rule_model.train()
182 | train_count = 0
183 |
184 | for batch in tqdm(train_data_loader):
185 |
186 | hole_context = Variable(batch[0]).to(device)
187 | hole_attention_mask = Variable(batch[1]).to(device)
188 | rule_context = Variable(batch[2]).to(device)
189 | gt = Variable(batch[3]).to(device)
190 | failure_flag = Variable(batch[5]).to(device)
191 |
192 | train_count+= torch.sum(failure_flag)
193 |
194 | optimizer.zero_grad()
195 |
196 | batch_metrices, oracle_success = calculate_loss(rule_model, \
197 | criterion, \
198 | (hole_context, hole_attention_mask, rule_context), \
199 | gt)
200 |
201 | batch_loss = batch_metrices['loss']
202 | batch_loss.backward()
203 | optimizer.step()
204 |
205 | total_highest_success_correct += batch_metrices['mean_highest_success_correct']
206 | total_oracle_success += oracle_success
207 | total_loss += batch_loss.item()
208 | total_batches += 1
209 |
210 | avg_train_loss = total_loss/ total_batches
211 | avg_highest_success_accuracy = total_highest_success_correct*100/ total_batches
212 | avg_oracle_success_accuracy = total_oracle_success*100/ total_batches
213 |
214 | tb_writer.add_scalar("metrics/train_loss", avg_train_loss, epoch)
215 | tb_writer.add_scalar("metrics/train_highest_success_accuracy", avg_highest_success_accuracy, epoch)
216 |
217 | print("Train loss: Total %f" % avg_train_loss)
218 | f_out.write("Train loss: " + str(avg_train_loss) + "\n")
219 | print("Train oracle success accuracy: %f" % avg_oracle_success_accuracy)
220 | f_out.write("Train oracle success accuracy: " + str(avg_oracle_success_accuracy) + "\n")
221 | print("Train highest success accuracy: %f" % avg_highest_success_accuracy)
222 | f_out.write("Train highest success accuracy: " + str(avg_highest_success_accuracy) + "\n")
223 |
224 |
225 | ######################################Evaluation Loop############################################
226 | rule_model.eval()
227 |
228 | with torch.no_grad():
229 |
230 | total_highest_success_correct, total_loss = 0.0, 0.0
231 | total_batches = 0
232 | total_oracle_success = 0.0
233 | val_count = 0
234 |
235 | for batch in tqdm(val_data_loader):
236 |
237 | hole_context = Variable(batch[0]).to(device)
238 | hole_attention_mask = Variable(batch[1]).to(device)
239 | rule_context = Variable(batch[2]).to(device)
240 | gt = Variable(batch[3]).to(device)
241 | failure_flag = Variable(batch[5]).to(device)
242 |
243 | val_count+= torch.sum(failure_flag)
244 |
245 |
246 | batch_metrices, oracle_success = calculate_loss(rule_model, \
247 | criterion, \
248 | (hole_context, hole_attention_mask, rule_context), \
249 | gt)
250 |
251 |
252 | batch_loss = batch_metrices['loss']
253 | total_highest_success_correct += batch_metrices['mean_highest_success_correct']
254 | total_oracle_success+= oracle_success
255 | total_loss += batch_loss.item()
256 | total_batches += 1
257 |
258 | avg_val_loss = total_loss/ total_batches
259 | avg_highest_success_accuracy = total_highest_success_correct*100/ total_batches
260 | avg_oracle_success_accuracy = total_oracle_success*100/total_batches
261 |
262 | tb_writer.add_scalar("metrics/val_loss", avg_val_loss, epoch)
263 | tb_writer.add_scalar("metrics/val_highest_success_accuracy", avg_highest_success_accuracy, epoch)
264 |
265 | print("Val loss: Total %f" % avg_val_loss)
266 | f_out.write("Val loss: " + str(avg_val_loss) + "\n")
267 | print("Val oracle success accuracy: %f" % avg_oracle_success_accuracy)
268 | f_out.write("Val oracle success accuracy: " + str(avg_oracle_success_accuracy) + "\n")
269 | print("Val highest success accuracy: %f" % avg_highest_success_accuracy)
270 | f_out.write("Val highest success accuracy: " + str(avg_highest_success_accuracy) + "\n")
271 |
272 | if args.lr_scheduler =='reduceonplateau':
273 | lr_sched.step(avg_highest_success_accuracy)
274 | elif args.lr_scheduler !='none':
275 | lr_sched.step()
276 |
277 | if avg_highest_success_accuracy > best_val_acc:
278 | print("Found new best model")
279 | f_out.write("Found new best model\n")
280 | best_val_acc = avg_highest_success_accuracy
281 | save(rule_model, optimizer, epoch, save_dir)
282 | patience_ctr = 0
283 | else:
284 | patience_ctr += 1
285 | if patience_ctr == args.patience:
286 | print("Ran out of patience. Stopping training early...")
287 | f_out.write("Ran out of patience. Stopping training early...\n")
288 | print("Best Val Acc: ", best_val_acc)
289 | f_out.write("Best Val Acc: " + str(best_val_acc))
290 | break
291 | f_out.write("\n\n")
292 | f_out.flush()
293 | print("Best Val Acc: ", best_val_acc)
294 | f_out.write("Best Val Acc: " + str(best_val_acc))
295 | f_out.close()
296 |
297 |
--------------------------------------------------------------------------------
/rule_config.py:
--------------------------------------------------------------------------------
1 |
2 | # context_location and context_type define a rule
3 | context_location = [
4 | 'in_file', \
5 | 'parent_class_file', \
6 | 'import_file',\
7 | 'sibling_file', \
8 | 'similar_name_file', \
9 | 'child_class_file', \
10 | 'import_of_sibling_file', \
11 | 'import_of_similar_name_file', \
12 | 'import_of_parent_class_file', \
13 | 'import_of_child_class_file'
14 | ]
15 |
16 |
17 | all_context_types = [
18 | 'method_names_and_bodies',\
19 | 'method_names',\
20 | 'identifiers', \
21 | 'type_identifiers',\
22 | 'string_literals',\
23 | 'field_declarations', \
24 | 'lines'
25 | ]
26 |
27 | context_type_dict = {}
28 | for con_loc in context_location:
29 | if con_loc == 'in_file':
30 | context_types = all_context_types[1:]
31 | else:
32 | context_types = all_context_types[:-1]
33 | context_type_dict[con_loc] = context_types
34 |
35 | # rule-specific hyperparams to run. Make changes here to run different configurations
36 | rule_hyperparams = {
37 | 'lines':
38 | {
39 | 'context_ratio': [0.5, 0.25, 0.75],
40 | 'top_k': [-1],
41 | 'prompt_separator': ['space'],
42 | 'top_k_type':['first'],
43 | 'rule_context_formatting':['space']
44 | },
45 |
46 | 'identifiers':
47 | {
48 | 'context_ratio': [0.5],
49 | 'top_k': [-1],
50 | 'prompt_separator': ['newline'],
51 | 'top_k_type':['first'],
52 | 'rule_context_formatting':['class_name']
53 | },
54 |
55 | 'type_identifiers':
56 | {
57 | 'context_ratio': [0.5],
58 | 'top_k': [-1],
59 | 'prompt_separator': ['newline'],
60 | 'top_k_type':['first'],
61 | 'rule_context_formatting':['class_name']
62 | },
63 |
64 | 'string_literals':
65 | {
66 | 'context_ratio': [0.5],
67 | 'top_k': [-1],
68 | 'prompt_separator': ['newline'],
69 | 'top_k_type':['first'],
70 | 'rule_context_formatting':['class_name']
71 | },
72 |
73 | 'method_names':
74 | {
75 | 'context_ratio': [0.5],
76 | 'top_k': [-1],
77 | 'prompt_separator': ['newline'],
78 | 'top_k_type':['first'],
79 | 'rule_context_formatting':['class_name']
80 | },
81 |
82 | 'field_declarations':
83 | {
84 | 'context_ratio': [0.5],
85 | 'top_k': [-1],
86 | 'prompt_separator': ['newline'],
87 | 'top_k_type':['first'],
88 | 'rule_context_formatting':['class_name']
89 | },
90 |
91 | 'method_names_and_bodies':
92 | {
93 | 'context_ratio': [0.5],
94 | 'top_k': [-1],
95 | 'prompt_separator': ['newline'],
96 | 'top_k_type':['first'],
97 | 'rule_context_formatting':['class_method_name']
98 | }
99 |
100 |
101 | }
102 |
--------------------------------------------------------------------------------
/rule_inference_preprocessed_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import sys
4 | import torch
5 | import pickle
6 | import argparse
7 | import random
8 | from torch.utils.data import DataLoader
9 | from torch.utils.tensorboard import SummaryWriter
10 | from torch.autograd import Variable
11 | from torch import nn
12 | from tqdm import tqdm
13 | from preprocessed_data import *
14 | from model_preprocessed_data import RuleModel
15 | from torch import FloatTensor
16 |
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 |
19 |
20 | def setup_args():
21 | """
22 | Description: Takes in the command-line arguments from user
23 | """
24 | parser = argparse.ArgumentParser()
25 |
26 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility")
27 | parser.add_argument("--input_data_dir", type=str, default='rule_classifier_data', help="base directory for the data")
28 | parser.add_argument("--data_split", type=str, default='val', help="data_split")
29 | parser.add_argument("--model_path", type=str, default='models/rlpg-h', help="base directory for storing the models")
30 | parser.add_argument("--batch_size", type=int, default=32, help="batch size for training the classifier")
31 | return parser.parse_args()
32 |
33 | def get_accuracy(pred, gold, mask):
34 | pred = pred.masked_fill(mask==0, 0)
35 | max_idx = torch.argmax(pred, dim=1, keepdim=True)
36 | rounded_pred = torch.round(pred)
37 | max_idx_gold_vals = torch.gather(gold, 1, max_idx)
38 | mean_highest_success_correct = (max_idx_gold_vals == 1).to(dtype=torch.float).mean()
39 | return mean_highest_success_correct, pred
40 |
41 | def get_prediction(rule_model, info):
42 | pred, mask = rule_model(info)
43 | mask = torch.sum(mask, dim=-1) #(bs, #rules)
44 | return pred, mask
45 |
46 | def calculate_loss(rule_model, criterion, info, gt, hole_ids, hole_stats):
47 |
48 | pred, mask = get_prediction(rule_model, info)
49 | n_valid_entries = torch.sum(mask.view(-1)!=0)
50 | loss = criterion(pred, gt)
51 | loss = loss.masked_fill(mask==0, 0)
52 | mean_highest_success_correct, pred = get_accuracy(pred, gt, mask)
53 | masked_gt = torch.sum(gt.masked_fill(mask==0, 0), dim=-1)
54 | mean_oracle_success = masked_gt.masked_fill(masked_gt!=0, 1.0).mean()
55 |
56 | for i in range(len(hole_ids)):
57 | hid = hole_ids[i]
58 | hole_loss = torch.sum(loss[i])
59 | n_valid_hole_rules = torch.sum(loss[i]!=0)
60 | hole_loss = hole_loss/n_valid_hole_rules
61 | hole_prediction = pred[i]
62 | hole_stats[hid] = (hole_loss, hole_prediction)
63 |
64 | return {'loss': torch.sum(loss)/n_valid_entries, \
65 | 'mean_highest_success_correct': mean_highest_success_correct}, \
66 | mean_oracle_success, \
67 | hole_stats
68 |
69 | if __name__ == '__main__':
70 |
71 | args = setup_args()
72 |
73 | #Fix seeds
74 | np.random.seed(args.seed)
75 | os.environ['PYTHONHASHSEED'] = str(args.seed)
76 | torch.manual_seed(args.seed)
77 | random.seed(args.seed)
78 |
79 | os.makedirs(os.path.join('outputs', args.data_split), exist_ok=True)
80 | f_out = open(os.path.join('outputs', args.data_split + '_inference'), 'a')
81 |
82 | model_path = args.model_path
83 | mode = model_path.split('/')[-1]
84 |
85 | # Define the model
86 | if mode == 'rlpg-h':
87 | emb_model_type = 'codebert'
88 | rule_model = RuleModel(emb_model_type=emb_model_type, device=device, mode=mode)
89 | if mode == 'rlpg-r':
90 | emb_model_type = 'codebert'
91 | rule_model = RuleModel(emb_model_type=emb_model_type, device=device, mode=mode, n_head=4, d_k=32)
92 |
93 | # Define train and val dataloaders
94 | kwargs = {'num_workers': 8, 'pin_memory': True} if device=='cuda' else {}
95 | tokenizer = set_tokenizer(emb_model_type)
96 | dataset = RuleDataset(os.path.join(args.input_data_dir, args.data_split), tokenizer=tokenizer, emb_model_type=emb_model_type)
97 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, **kwargs)
98 |
99 | print("=> loading checkpoint '{}'".format(model_path))
100 | best_model_path = os.path.join(model_path, 'best_model.th')
101 | rule_model.load_state_dict(torch.load(best_model_path, map_location=torch.device('cpu')), strict=False)
102 | print("=> loaded checkpoint '{}'".format(model_path))
103 | rule_model.to(device)
104 |
105 | rule_model.eval()
106 | criterion = nn.BCELoss(reduction='none')
107 |
108 | with torch.no_grad():
109 |
110 | total_highest_success_correct, total_loss = 0.0, 0.0
111 | total_batches = 0
112 | total_oracle_success = 0.0
113 | hole_stats = {}
114 | count = 0
115 |
116 | for batch in tqdm(data_loader):
117 | hole_context = Variable(batch[0]).to(device)
118 | hole_attention_mask = Variable(batch[1]).to(device)
119 | rule_context = Variable(batch[2]).to(device)
120 | gt = Variable(batch[3]).to(device)
121 | hole_id = batch[4]
122 | failure_flag = Variable(batch[5]).to(device)
123 |
124 | count+= torch.sum(failure_flag)
125 |
126 | batch_metrices, oracle_success, hole_stats = calculate_loss(rule_model, \
127 | criterion, \
128 | (hole_context, hole_attention_mask, rule_context), \
129 | gt, \
130 | hole_id, \
131 | hole_stats)
132 |
133 | batch_loss = batch_metrices['loss']
134 | total_highest_success_correct += batch_metrices['mean_highest_success_correct']
135 | total_oracle_success+= oracle_success
136 | total_loss += batch_loss.item()
137 | total_batches += 1
138 |
139 | avg_loss = total_loss/ total_batches
140 | avg_highest_success_accuracy = total_highest_success_correct*100/ total_batches
141 | avg_oracle_success_accuracy = total_oracle_success*100/total_batches
142 |
143 | f_out.write("\n********************************\n")
144 | f_out.write(model_path + "\n")
145 | print("Loss: Total %f" % avg_loss)
146 | f_out.write("Loss: " + str(avg_loss) + "\n")
147 | print("Oracle success accuracy: %f" % avg_oracle_success_accuracy)
148 | f_out.write("Oracle success accuracy: " + str(avg_oracle_success_accuracy) + "\n")
149 | print("Highest success accuracy: %f" % avg_highest_success_accuracy)
150 | f_out.write("Highest success accuracy: " + str(avg_highest_success_accuracy) + "\n")
151 | f_out.write("\n********************************\n")
152 | f_out.flush()
153 |
154 | with open(os.path.join('outputs', args.data_split, '/'.join(model_path.split('/')[1:])) , 'wb') as f:
155 | pickle.dump(hole_stats, f)
--------------------------------------------------------------------------------
/rule_representation_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import pickle
5 | from torch.utils.data import Dataset
6 | from transformers import AutoModel
7 | from transformers import GPT2TokenizerFast, AutoTokenizer
8 | from transformers import pipeline
9 | from utils import *
10 | from data_utils import RuleDatasetUtils
11 |
12 | class RuleReprDataset(Dataset):
13 |
14 | def __init__(self, input_data_dir, emb_model_type, tokenizer):
15 | # get all relevant files (in raw form)
16 | files = []
17 | oracles = {}
18 | hole_datas = {}
19 | parse_datas = {}
20 | all_duplicate_files = []
21 | data_type = input_data_dir.split('/')[-1]
22 | for dp, dn, filenames in os.walk(input_data_dir):
23 | for f in filenames:
24 | if f == 'hole_data':
25 | hole_data = pickle.load(open(os.path.join(dp, f), 'rb'))
26 | hole_data = self.update_dict(hole_data, data_type)
27 | hole_datas = {**hole_datas, **hole_data}
28 | if f == 'parsed_data':
29 | parse_data = pickle.load(open(os.path.join(dp, f), 'rb'))
30 | parse_data = self.update_dict(parse_data, data_type)
31 | parse_datas = {**parse_datas, **parse_data}
32 | if f == 'duplicates':
33 | duplicate_files = open(os.path.join(dp, f), 'r').readlines()
34 | all_duplicate_files.extend([x.strip() for x in duplicate_files])
35 | if os.path.splitext(f)[1] == '.java':
36 | files.append(os.path.join(dp, f))
37 | print(len(all_duplicate_files))
38 | self.holes = []
39 | for file in files:
40 | if file in hole_datas and \
41 | file not in all_duplicate_files and \
42 | not file.startswith('rule_classifier_data/train/rsbotownversion/trunk/scripts/'):
43 | for (l,c) in hole_datas[file]:
44 | hole_identity = file + '_' + str(l) + '_' + str(c)
45 | self.holes.append(hole_identity)
46 |
47 | print(len(self.holes))
48 | self.num_rules = len(combined_to_index)
49 | self.tokenizer = tokenizer
50 |
51 | self.parse_datas = parse_datas
52 | self.model_max_length = self.tokenizer.model_max_length
53 | self.rule_repr_cache = {}
54 | self.emb_model_type = emb_model_type
55 | self.set_embedding_model()
56 | self.repr_size = 768
57 | self.start = 0
58 | self.end = 500000
59 |
60 | def update_dict(self, dic, data_type):
61 | mod_dic = {}
62 | for k,v in dic.items():
63 | mod_k = '/'. join(['rule_classifier_data', data_type] + k.split('/')[2:])
64 | mod_dic[mod_k] = v
65 | return mod_dic
66 |
67 | def __len__(self):
68 | return len(self.holes)
69 |
70 | def __getitem__(self, idx):
71 | if idx >=self.start and idx <= self.end:
72 | return self.generate_data(self.holes[idx])
73 | else:
74 | return None, None, None
75 |
76 | def get_start_index(self, repo, start_offset=0, interval=0):
77 | count=0
78 | for i in range(len(self.holes)):
79 | hole = self.holes[i]
80 | repo_name = hole.split('/')[2]
81 | if repo_name == repo:
82 | count+=1
83 | repo_end_idx = i
84 |
85 | self.start = repo_end_idx - count + 1
86 | self.start = self.start + start_offset
87 | if interval!=0 :
88 | self.end = self.start + interval
89 | else:
90 | self.end = repo_end_idx
91 | return self.start, self.end
92 |
93 | def is_clear_cache(self):
94 | if len(self.rule_repr_cache) < 30:
95 | self.clear_cache = False
96 | else:
97 | self.clear_cache = True
98 | self.rule_repr_cache = {}
99 |
100 | def get_representation(self, inputs, mask):
101 | outputs = self.emb_model(inputs, attention_mask=mask)
102 | try:
103 | representation = outputs.pooler_output
104 | except:
105 | representation = outputs.last_hidden_state[:, 0]
106 | #print(representation.shape)
107 | return representation
108 |
109 | def get_context_embedding(self, context, attn_mask):
110 | context_embedding = self.get_representation(context, attn_mask)
111 | return context_embedding
112 |
113 | def get_rule_context(self, file, hole_pos):
114 | self.is_clear_cache()
115 | rule_dataset_util = RuleDatasetUtils(file, self.parse_datas, hole_pos, self.tokenizer)
116 | rule_prompts, rule_indexes = rule_dataset_util.get_all_rules_context()
117 | rule_contexts = self.tokenizer(rule_prompts, truncation=True, padding='max_length')
118 | rule_inputs = torch.tensor(rule_contexts['input_ids'])
119 | rule_masks = torch.tensor(rule_contexts['attention_mask'])
120 | rule_indexes = torch.tensor(rule_indexes)
121 |
122 | # remove rules that are already cached
123 | rule_prompts = self.tokenizer.batch_decode(rule_inputs)
124 | filtered_rule_context = []
125 | filtered_rule_mask = []
126 | filtered_rule_prompts = []
127 | filtered_rule_indexes = []
128 | for i in range(len(rule_prompts)):
129 | rule_prompt = rule_prompts[i]
130 | if rule_prompt not in self.rule_repr_cache:
131 | filtered_rule_indexes.append(rule_indexes[i])
132 | filtered_rule_context.append(rule_inputs[i])
133 | filtered_rule_mask.append(rule_masks[i])
134 | filtered_rule_prompts.append(rule_prompt)
135 |
136 | if filtered_rule_context:
137 | filtered_rule_context = torch.stack(filtered_rule_context)
138 | filtered_rule_mask = torch.stack(filtered_rule_mask)
139 |
140 | # get rule representations
141 | filtered_representations = self.get_context_embedding(filtered_rule_context, filtered_rule_mask)
142 | # cache the representations
143 | for i in range(len(filtered_representations)):
144 | f_repr = filtered_representations[i]
145 | rule_prompt = filtered_rule_prompts[i]
146 | self.rule_repr_cache[rule_prompt] = f_repr
147 |
148 | # obtain full representations
149 | keys = []
150 | j = 0
151 | for ind in range(self.num_rules):
152 | if ind in rule_indexes:
153 | prompt = rule_prompts[j]
154 | j+=1
155 | if prompt in self.rule_repr_cache:
156 | keys.append(self.rule_repr_cache[prompt])
157 | else:
158 | keys.append(torch.zeros(self.repr_size))
159 | else:
160 | keys.append(torch.zeros(self.repr_size))
161 |
162 | keys = torch.stack(keys)
163 | return keys
164 |
165 | def generate_data(self, hole):
166 |
167 | hole_parts = hole.split('/')[-1].split('_')
168 | repo_name = hole.split('/')[2]
169 | if len(hole_parts) > 3:
170 | new_hole_parts = hole_parts[:-2]
171 | filename = '_'.join(new_hole_parts)
172 | filename = [filename]
173 | else:
174 | filename = [hole_parts[0]]
175 | file = '/'.join(hole.split('/')[:-1] + filename)
176 | hole_pos = (int(hole_parts[-2]), int(hole_parts[-1]))
177 | rule_contexts = self.get_rule_context(file, hole_pos)
178 | return rule_contexts, hole, repo_name
179 |
180 | def set_tokenizer(self):
181 | if self.emb_model_type == 'codebert':
182 | self.tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
183 | if self.emb_model_type == 'graphcodebert':
184 | self.tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
185 | if self.emb_model_type == 'gpt-2':
186 | self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
187 | self.tokenizer.pad_token = self.tokenizer.eos_token
188 |
189 | def set_embedding_model(self):
190 | # CodeBERT
191 | if self.emb_model_type == 'codebert':
192 | self.emb_model = AutoModel.from_pretrained("microsoft/codebert-base")
193 | # GraphCodeBERT
194 | if self.emb_model_type == 'graphcodebert':
195 | self.emb_model = AutoModel.from_pretrained("microsoft/graphcodebert-base")
196 |
197 | def set_tokenizer(emb_model_type):
198 |
199 | if emb_model_type == 'codebert':
200 | tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
201 | if emb_model_type == 'graphcodebert':
202 | tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
203 | if emb_model_type == 'gpt-2':
204 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
205 | tokenizer.pad_token = tokenizer.eos_token
206 | return tokenizer
207 |
208 |
209 |
--------------------------------------------------------------------------------
/script_analyze_results.py:
--------------------------------------------------------------------------------
1 | import os
2 | base_dir = 'rule_classifier_data'
3 |
4 | projects = { 'train': [
5 | 'gfsfa',
6 | 'sol-agent-platform',
7 | 'gloodb',
8 | 'rsbotownversion',
9 | 'jjskit',
10 | 'ftpserverremoteadmin',
11 | 'openprocesslogger',
12 | 'strudem-sicsa',
13 | 'seamlets',
14 | 'healpix-rangeset',
15 | 'quidsee',
16 | 'mobileexpensetracker',
17 | 'swe574-group3',
18 | 'largemail',
19 | 'soap-dtc',
20 | 'designpatternjavapedro',
21 | 'myt5lib',
22 | 'exogdx',
23 | 'tapestry-sesame'
24 | ],
25 |
26 | 'val': [
27 | 'javasummerframework',
28 | 'tinwiki',
29 | 'teammates-shakthi',
30 | 'jcontenedor',
31 | 'jloogle',
32 | 'swinagile',
33 | 'math-mech-eshop',
34 | 'jata4test',
35 | 'affinity_propagation_java',
36 | 'navigablep2p',
37 | 'springlime',
38 | 'sohocms',
39 | 'tyrond',
40 | 'infinispan-storage-service',
41 | ],
42 |
43 | 'test': [
44 | 'project-pt-diaoc',
45 | 'dovetaildb',
46 | 'robotsimulator2009w',
47 | 'ircrpgbot',
48 | 'xfuze',
49 | 'realtimegc',
50 | 'fswuniceubtemplates',
51 | 'glperaudsimon',
52 | 'apiitfriends',
53 | 'qwikioffice-java',
54 | 'xiaonei-java-api',
55 | 'wicketbits',
56 | 'hucourses',
57 | 'gwt-plugindetect'
58 | ]
59 | }
60 |
61 | commands = []
62 | for data_split, data_split_repos in projects.items():
63 | for proj in data_split_repos:
64 | proj_name = proj.strip()
65 | command = "python analyze_results.py --proj_name " + proj_name \
66 | + " --base_dir " + base_dir + " --data_split " + data_split
67 | commands.append(command)
68 |
69 | with open("commands_analyze_results", 'w') as f:
70 | f.writelines("%s\n" % command for command in commands)
71 | f.close()
72 |
73 |
--------------------------------------------------------------------------------
/script_completions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import glob
4 | import numpy as np
5 | from rule_config import *
6 |
7 | base_dirs = os.listdir('rule_classifier_data')
8 |
9 | modes = ['codex', 'rule']
10 | context_locations = [
11 | 'in_file', \
12 | 'parent_class_file', \
13 | 'import_file',\
14 | 'sibling_file', \
15 | 'similar_name_file', \
16 | 'child_class_file', \
17 | 'import_of_sibling_file', \
18 | 'import_of_similar_name_file', \
19 | 'import_of_parent_class_file', \
20 | 'import_of_child_class_file'
21 | ]
22 |
23 |
24 | batch_size = 20
25 | total_context_length = 4072
26 |
27 | def main():
28 | commands = []
29 | for base_repo in base_dirs:
30 | base_dir = os.path.join('rule_classifier_data', base_repo)
31 | for repo in os.listdir(base_dir):
32 | for mode in modes:
33 | if mode == 'codex':
34 | command = "python generate_completions.py --mode " + mode \
35 | + " --total_context_len " + str(total_context_length)\
36 | + " --base_dir " + base_dir\
37 | + " --repo_name " + repo\
38 | + " --batch_size " + str(batch_size)
39 | commands.append(command)
40 |
41 | if mode == 'rule':
42 | for context_location in context_locations:
43 | context_types = context_type_dict[context_location]
44 | for context_type in context_types:
45 | rule_specific_hyperparams = rule_hyperparams[context_type]
46 | for context_ratio in rule_specific_hyperparams['context_ratio']:
47 | for prompt_separator in rule_specific_hyperparams['prompt_separator']:
48 | for top_k in rule_specific_hyperparams['top_k']:
49 | for rule_context_format in rule_specific_hyperparams['rule_context_formatting']:
50 | if top_k == -1:
51 | command = "python generate_completions.py --mode " + mode\
52 | + " --context_location " + context_location\
53 | + " --context_type " + context_type\
54 | + " --context_division_ratio " + str(context_ratio) \
55 | + " --prompt_separator " + prompt_separator \
56 | + " --top_k " + str(top_k)\
57 | + " --total_context_len " + str(total_context_length)\
58 | + " --base_dir " + base_dir\
59 | + " --repo_name " + repo\
60 | + " --batch_size " + str(batch_size)\
61 | + " --rule_context_formatting " + rule_context_format\
62 |
63 | commands.append(command)
64 |
65 | else:
66 | for top_k_type in rule_specific_hyperparams['top_k_type']:
67 | final_command = command + " --top_k_type " + top_k_type
68 | commands.append(final_command)
69 |
70 |
71 | with open("commands_completion", 'w') as f:
72 | f.writelines("%s\n" % command for command in commands)
73 | f.close()
74 |
75 | if __name__ == '__main__':
76 | main()
--------------------------------------------------------------------------------
/script_gen_and_preprocess_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | base_data_dir = 'gcode-data'
3 |
4 | projects = { 'train': [
5 | 'gfsfa',
6 | 'sol-agent-platform',
7 | 'gloodb',
8 | 'rsbotownversion',
9 | 'jjskit',
10 | 'ftpserverremoteadmin',
11 | 'openprocesslogger',
12 | 'strudem-sicsa',
13 | 'seamlets',
14 | 'healpix-rangeset',
15 | 'quidsee',
16 | 'mobileexpensetracker',
17 | 'swe574-group3',
18 | 'largemail',
19 | 'soap-dtc',
20 | 'designpatternjavapedro',
21 | 'myt5lib',
22 | 'exogdx',
23 | 'tapestry-sesame'
24 | ],
25 |
26 | 'val': [
27 | 'javasummerframework',
28 | 'tinwiki',
29 | 'teammates-shakthi',
30 | 'jcontenedor',
31 | 'jloogle',
32 | 'swinagile',
33 | 'math-mech-eshop',
34 | 'jata4test',
35 | 'affinity_propagation_java',
36 | 'navigablep2p',
37 | 'springlime',
38 | 'sohocms',
39 | 'tyrond',
40 | 'infinispan-storage-service',
41 | ],
42 |
43 | 'test': [
44 | 'project-pt-diaoc',
45 | 'dovetaildb',
46 | 'robotsimulator2009w',
47 | 'ircrpgbot',
48 | 'xfuze',
49 | 'realtimegc',
50 | 'fswuniceubtemplates',
51 | 'glperaudsimon',
52 | 'apiitfriends',
53 | 'qwikioffice-java',
54 | 'xiaonei-java-api',
55 | 'wicketbits',
56 | 'hucourses',
57 | 'gwt-plugindetect'
58 | ]
59 | }
60 |
61 | commands = []
62 | for data_split, data_split_repos in projects.items():
63 | for proj in data_split_repos:
64 | proj_name = proj.strip()
65 | command = "python create_hole_data.py --proj_name " + proj_name \
66 | + " --base_dir " + base_data_dir + " --data_split " + data_split
67 | commands.append(command)
68 | command = "python parse_tree.py --proj_name " + proj_name \
69 | + " --base_dir " + os.path.join('rule_classifier_data', data_split)
70 | commands.append(command)
71 | command = "python check_duplication.py --proj_name " + proj_name \
72 | + " --base_dir " + os.path.join('rule_classifier_data', data_split)
73 | commands.append(command)
74 |
75 | with open("commands_gen_and_preprocess", 'w') as f:
76 | f.writelines("%s\n" % command for command in commands)
77 | f.close()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from itertools import product
3 | from rule_config import *
4 |
5 | promptseparator2str = {'space': " ", \
6 | 'newline': "\n", \
7 | 'class_names': "class_names",\
8 | 'class_method_names': "class_method_names",\
9 | 'method_names': "method_names"}
10 |
11 | context_location_to_index = {
12 | 'in_file':0, \
13 | 'parent_class_file':1, \
14 | 'import_file':2,\
15 | 'sibling_file':3, \
16 | 'similar_name_file':4, \
17 | 'child_class_file':5, \
18 | 'import_of_sibling_file':6, \
19 | 'import_of_similar_name_file':7, \
20 | 'import_of_parent_class_file':8, \
21 | 'import_of_child_class_file':9, \
22 | 'codex': 10 #codex
23 | }
24 |
25 |
26 | context_types_to_index = {
27 | 'method_names_and_bodies':0,\
28 | 'method_names':1,\
29 | 'identifiers':2, \
30 | 'type_identifiers':3,\
31 | 'string_literals':4,\
32 | 'field_declarations':5, \
33 | 'lines':6, \
34 | 'codex': 7 #codex
35 | }
36 |
37 | count = 0
38 | combined_to_index = {}
39 | cl_keys = list(context_location_to_index.keys())
40 | ct_keys = list(context_types_to_index.keys())
41 | cl_keys.remove('codex')
42 | ct_keys.remove('codex')
43 | for (k1, k2) in product(cl_keys, ct_keys):
44 | if (k1 != 'in_file' and k2 in ct_keys[:-1]) or (k1 == 'in_file' and k2 in ct_keys[1:]):
45 | cr_keys = rule_hyperparams[k2]['context_ratio']
46 | for k3 in cr_keys:
47 | key = k1 + '#' + k2 + '#' + str(k3)
48 | combined_to_index[key] = count
49 | count +=1
50 | combined_to_index['codex'] = count
51 | #print(combined_to_index)
52 |
53 | def get_multi_hot_vector(lst, type):
54 | set_lst = list(set(lst))
55 | if type == 'cl':
56 | index_dict = context_location_to_index
57 | if type == 'ct':
58 | index_dict = context_types_to_index
59 | if type == 'com':
60 | index_dict = combined_to_index
61 | vector_size = len(index_dict)
62 | multi_hot_vector = np.zeros(vector_size)
63 | for entry in lst:
64 | multi_hot_vector[index_dict[entry]] = 1
65 | return multi_hot_vector
66 |
67 | def is_valid_hole(hole, duplicate_files):
68 | hole_parts = hole.split('/')[-1].split('_')
69 | if len(hole_parts) > 3:
70 | new_hole_parts = hole_parts[:-2]
71 | filename = '_'.join(new_hole_parts)
72 | filename = [filename]
73 | else:
74 | filename = [hole_parts[0]]
75 | file = '/'.join(hole.split('/')[:-1] + filename)
76 | if file in duplicate_files:
77 | return False
78 | else:
79 | return True
80 |
81 | def find_intersection(lst1, lst2):
82 | set_lst1 = set(lst1)
83 | set_lst2 = set(lst2)
84 | return set_lst1.intersection(set_lst2)
85 |
86 | def alter_hid(orig_hid, hid):
87 | data_split = hid.split('/')[1]
88 | if 'gcode-data' in orig_hid:
89 | new_id = orig_hid.replace('data/gcode-data', 'rule_classifier_data/' + data_split)
90 | return new_id
91 | elif 'java-other' in orig_hid:
92 | new_id = orig_hid.replace('data/java-other', 'rule_classifier_data/' + data_split)
93 | return new_id
94 | else:
95 | return orig_hid
96 |
97 | def find_usages(query_att, query_file, lst_key_att, key_file):
98 | usages = []
99 | query_str = get_string(query_file, query_att[0], query_att[1])
100 | for key_att in lst_key_att:
101 | key_str = get_string(key_file, key_att[0], key_att[1])
102 | if key_str == query_str:
103 | usages.append(key_att)
104 | return usages
105 |
106 | def update_list(src_lst, tgt_lst, f, return_type='str'):
107 | for elem in src_lst:
108 | elem_str = get_string(f, elem[0], elem[1])
109 | if elem_str not in tgt_lst:
110 | if return_type == 'pos':
111 | tgt_lst.append(elem)
112 | else:
113 | tgt_lst.append(elem_str)
114 | return tgt_lst
115 |
116 | def find_similar_intersection(file1, file2):
117 | lst1 = [x.split('/')[-1] for x in file1]
118 | lst2 = [x.split('/')[-1] for x in file2]
119 | #print(lst1, lst2)
120 | return find_intersection(lst1, lst2)
121 |
122 | def get_codex_tokenized_string(tokenizer, input_str, context_len, type='back'):
123 | '''
124 | get the codex tokenized string
125 | '''
126 | if input_str:
127 | codex_tokens = tokenizer(input_str)['input_ids']
128 | if type == 'front':
129 | truncated_codex_tokens = codex_tokens[:context_len]
130 | else:
131 | truncated_codex_tokens = codex_tokens[-context_len:]
132 | out_str = tokenizer.decode(truncated_codex_tokens)
133 | return out_str, len(truncated_codex_tokens)
134 | else:
135 | return '', 0
136 |
137 | def join_lines(lst):
138 | return ''.join(lst)
139 |
140 | # take start line as the first non-commented and non-empty line
141 | def modified_start_line(lines):
142 | for i in range(len(lines)):
143 | line = lines[i]
144 | if line and not (line.startswith('/') or line.startswith('*')): # not part of the license text or empty line
145 | return i
146 |
147 | def get_string(filename, start, end):
148 | '''
149 | get the string corresponding to the start and end positions in the parse tree
150 | '''
151 | lines = open(filename, encoding="utf8", errors='backslashreplace').readlines()
152 | start_line, start_char = start
153 | span_str = ''
154 | if start_line == 0:
155 | start_line = modified_start_line(lines)
156 | end_line, end_char = end
157 | if start_line <= end_line and start_line < len(lines) and start_line!= -1:
158 | if start_line == end_line:
159 | if end_char == -1:
160 | span_str = lines[start_line]
161 | else:
162 | span_str = lines[start_line][start_char:end_char]
163 | else:
164 | if start_line + 1 < len(lines):
165 | span_str = lines[start_line][start_char:] + \
166 | join_lines(lines[start_line+1: end_line]) + \
167 | lines[end_line][:end_char]
168 | return span_str
169 |
170 |
171 |
172 | def get_context_from_java_github(out_context_len):
173 | dataset_filename = os.path.join('preprocessed_data/java_github', 'holes_1.val')
174 | data = pickle.load(open(dataset_filename, 'rb'))
175 | out_context_prompts = []
176 | for i in range(len(data)):
177 | for j in range(len(data[i])):
178 | for k in range(len(data[i][j])):
179 | file_data = data[i][j][k]
180 | if file_data[0]:
181 | file_token_str = file_data[0]
182 | out_context_prompts.append(get_codex_tokenized_string(tokenizer, file_token_str, out_context_len))
183 | return out_context_prompts
--------------------------------------------------------------------------------