├── LICENSE ├── README.md ├── analyze_results.py ├── att_model.py ├── block_diagram.png ├── check_duplication.py ├── context.py ├── create_hole_data.py ├── data_utils.py ├── generate_completions.py ├── generate_rule_representations.py ├── get_info_from_hole_predictions.py ├── model_preprocessed_data.py ├── parse_tree.py ├── preprocessed_data.py ├── projects.txt ├── rearrange_data.py ├── rule_classifier_preprocessed_data.py ├── rule_config.py ├── rule_inference_preprocessed_data.py ├── rule_representation_data.py ├── script_analyze_results.py ├── script_completions.py ├── script_gen_and_preprocess_data.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Disha Shrivastava 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Repository-Level Prompt Generation for Large Language Models of Code 2 | Disha Shrivastava, Hugo Larochelle, Daniel Tarlow 3 | 4 | This repository contains implementation and data for our work [Repository-Level Prompt Generation for Large Language Models of Code](https://arxiv.org/abs/2206.12839). A block diagram of our approach can be found below. For more details, refer to the paper. 5 | 6 |

7 | 8 |

9 | 10 | ## Dependencies 11 | * Access to OpenAI Codex API: https://openai.com/blog/openai-codex/. The key should be placed in a file named `openai_api_key` 12 | * pytorch: https://pytorch.org/ 13 | * Huggingface's transformers: https://pypi.org/project/transformers/ 14 | * tree-sitter-java: https://github.com/tree-sitter/tree-sitter-java 15 | * tqdm 16 | * tensorboard 17 | 18 | ## Updates 19 | * We have released the trained checkpoints for RLPG-H and RLPG-R models. You can download them [here](https://drive.google.com/file/d/1txmObhsA_Cs8paj1x8IGsoUqX7oHNxbw/view?usp=share_link). While loading the state_dict, please set the `strict` parameter to `False`. Example usage can be found in `rule_inference_preprocessed_data.py` 20 | * We are also releasing the prediction probabilities of the RLPG-H and RLPG-R models for each hole in our validation and test data. You can download them [here](https://drive.google.com/file/d/1WSPf4p0tfWs2nLgbpk53Kh5qJG3_33-b/view?usp=share_link). Each file contains the probabilities for all prompt proposals as given by the corresponding trained RLPG models. Example usage can be found in `get_info_from_hole_predictions.py` 21 | 22 | ## Code 23 | ### Data preprocessing 24 | The web URLs for all the repositories used in our work is provided in projects.txt. Download and store them in a folder called gcode-data. Then run `script_gen_and_preprocess_data.py`. This script will produce an output file called commands_gen_and_preprocess. Running it will execute three scripts: 25 | - `create_sample_data.py`: creates the hole completion data by choosing the midpoint of each line as hole position. 26 | - `parse_tree.py`: creates a parse tree for each file as well as store repo-level meta-info needed to get rule-context. 27 | - `check_duplication.py`: check for duplicates within a repo. 28 | Running this will create a new folder called rule_classifier_data that has train, val and test subfolders. Inside each folder, we will have a folder for a repository that will contain the following: 29 | 30 | * The repository with .java files and preserving the directory structure. 31 | * hole_data 32 | * file_class_data 33 | * parsed_data 34 | * duplicates 35 | 36 | ### Generating completions using Codex, i.e., obtaining the ground-truth for training the rule classifier. 37 | `script_completions.py` 38 | Generates a file commands_completion. Running this will create a new folder called results that has train, val and test subfolders. Inside each folder, we will have the ten folders corresponding to rule context locations. Each folder contains .json files corresponding to rule context types. Each row of the file contains data about the application of that particular rule to a hole. It stores the target hole, predicted hole, the prompt and the validity of the rule. 39 | 40 | ### Generating the oracle 41 | `script_analyze_results.py` 42 | Generates a file commands_analyze_results. Running this file will create a file called oracle inside each repo in rule_classifier_data. This file contains the collated information about the success of each rule for a particular target hole. 43 | 44 | ### Generating the rule context representations for the rule classifier 45 | `generate_rule_representations.py` 46 | 47 | Example usage: `python generate_rule_representations.py --data_split=val --repo=jata4test --emb_model_type=codebert` 48 | 49 | This will lead to creation of codebert_mod folder inside the path rule_classifier_data/val/jata4test. Each file in this folder contains the rule context representation obtained from codebert for each hole. 50 | 51 | ### Capping the number of holes 52 | `rearrange_data.py` 53 | This script will cap the maximum contribution from a repo to 10000 holes. After this, each repo folder will contain files capped_holes_10000, capped_codebert_mod and capped_oracle_10000. 54 | 55 | ### Training the rule classiifer 56 | `rule_classifier_preprocessed_data.py` 57 | This needs the capped_codebert_mod folder (rule context representations) to be present inside each repo folder as well as capped_oracle_10000 file. 58 | The best model is stored in models directory along with the tensorboard logs. The output from each epoch is stored in the outputs folder. 59 | 60 | ### Inference with the rule classifier 61 | `rule_inference_preprocessed_data.py` 62 | This needs the capped_codebert_mod folder (rule context representations) to be present inside each repo folder as well as capped_oracle_10000 file. 63 | This produces a file inside the outputs folder that contains the prediction of the classifier for each hole. 64 | 65 | ### Getting results for variation with k 66 | `get_info_from_predictions.py` 67 | This needs a hole_stats_file as input (generated from the previous step) and a value of k. 68 | -------------------------------------------------------------------------------- /analyze_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | from utils import * 5 | import argparse 6 | from rule_config import rule_hyperparams 7 | 8 | def get_hole_identities(hole_filename, duplicate_filename): 9 | hole_data = pickle.load(open(hole_filename, 'rb')) 10 | duplicate_files = open(duplicate_filename, 'r').readlines() 11 | duplicate_files = [ x.strip() for x in duplicate_files] 12 | hole_identities = [] 13 | for (k,v) in hole_data.items(): 14 | if k not in duplicate_files and not k.startswith('rule_classifier_data/train/rsbotownversion/trunk/scripts/'): 15 | for hole in v: 16 | h_id = k + '_' + str(hole[0]) + '_' + str(hole[1]) 17 | hole_identities.append(h_id) 18 | return hole_identities, duplicate_files 19 | 20 | def obtain_full_rule_results(result_file, codex_results, hole_rule_mapping): 21 | rule_results = read_result_file(result_file, hole_rule_mapping) 22 | j = 0 23 | all_results = [] 24 | for i in range(len(codex_results)): 25 | codex_res = codex_results[i] 26 | hid = codex_res['hole_identity'] 27 | if j < len(rule_results): 28 | rule_res_hid = rule_results[j]['hole_identity'] 29 | if hid == rule_res_hid: 30 | all_results.append(rule_results[j]) 31 | j+=1 32 | else: 33 | all_results.append(codex_res) 34 | else: 35 | all_results.append(codex_res) 36 | return all_results 37 | 38 | def is_match(res_line): 39 | res_line = json.loads(res_line) 40 | prediction = res_line['prediction'] 41 | hole = res_line['ground-truth hole'] 42 | pred = prediction.rstrip() 43 | hole = hole.rstrip() 44 | # there is an exact match corresponding to this hole id 45 | if pred == hole: 46 | return True 47 | else: 48 | return False 49 | 50 | def update_hole_rule_mapping(res_line, hid, hole_rule_mapping, rule_parts): 51 | if is_match(res_line): 52 | if hid in hole_rule_mapping: 53 | hole_rule_mapping[hid].append(rule_parts) 54 | else: 55 | hole_rule_mapping[hid] = [rule_parts] 56 | return hole_rule_mapping 57 | 58 | def modify_results(result_lines, duplicate_files): 59 | if not duplicate_files: 60 | return result_lines 61 | else: 62 | #print(len(duplicate_files), len(result_lines)) 63 | mod_result_lines = [] 64 | for i in range(len(result_lines)): 65 | res_hid = json.loads(result_lines[i])['hole_identity'] 66 | if is_valid_hole(res_hid, duplicate_files): 67 | mod_result_lines.append(result_lines[i]) 68 | return mod_result_lines 69 | 70 | def read_result_file(rule_result_file, codex_file, hole_identities, hole_rule_mapping, rule_parts, duplicate_files): 71 | #print(rule_result_file) 72 | rule_lines = open(rule_result_file, 'r').readlines() 73 | codex_lines = open(codex_file, 'r').readlines() 74 | rule_lines = modify_results(rule_lines, duplicate_files) 75 | codex_lines = modify_results(codex_lines, duplicate_files) 76 | j = 0 77 | for i in range(len(hole_identities)): 78 | hid = hole_identities[i] 79 | codex_hid = json.loads(codex_lines[i])['hole_identity'] 80 | codex_hid = alter_hid(codex_hid, hid) 81 | 82 | if j < len(rule_lines): 83 | try: 84 | rule_hid = json.loads(rule_lines[j])['hole_identity'] 85 | except: 86 | print(rule_result_file) 87 | rule_hid = alter_hid(rule_hid, hid) 88 | # use rule result 89 | if hid == rule_hid: 90 | hole_rule_mapping = update_hole_rule_mapping(rule_lines[j], hid, hole_rule_mapping, rule_parts) 91 | j+=1 92 | else: 93 | # use codex result 94 | hole_rule_mapping = update_hole_rule_mapping(codex_lines[i], hid, hole_rule_mapping, rule_parts) 95 | else: 96 | # use codex result 97 | hole_rule_mapping = update_hole_rule_mapping(codex_lines[i], hid, hole_rule_mapping, rule_parts) 98 | 99 | return hole_rule_mapping 100 | 101 | def get_results(base_result_dir, context_location, exclude_codex=True): 102 | context_result_dir = os.path.join(base_result_dir, context_location) 103 | result_files = next(os.walk(context_result_dir), (None, None, []))[2] # [] if no file 104 | if result_files and exclude_codex and 'codex_4072.json' in result_files: 105 | result_files.remove('codex_4072.json') 106 | mod_result_files = [os.path.join(context_result_dir, result_file) for result_file in result_files if result_file] 107 | result_files = [f for f in mod_result_files if os.path.getsize(f)>0] 108 | return result_files 109 | 110 | def check_validity_by_rule_parts(rule): 111 | valid = False 112 | if 'codex_4072' in rule: 113 | context_location = 'codex' 114 | context_type = 'codex' 115 | context_ratio = 0.5 116 | valid = True 117 | else: 118 | context_location = rule.split("/")[-2] 119 | rule_parts = rule.split("/")[-1].split("_") 120 | i = 5 121 | ct = '_'.join(rule_parts[1:5]) 122 | # keep removing the parts joined by _ till it matches a valid context_type 123 | while(ct not in context_types_to_index): 124 | i-=1 125 | ct = '_'.join(rule_parts[1:i]) 126 | context_type = ct 127 | 128 | mod_rule_parts = rule_parts[i:] 129 | try: 130 | context_ratio = float(mod_rule_parts[3])/4072 131 | if check_rule_validity(context_type, mod_rule_parts): 132 | valid =True 133 | except: 134 | valid = False 135 | if valid: 136 | return context_location, context_type, context_ratio 137 | else: 138 | return '', '', '' 139 | 140 | def get_all_hole_rule_mapping(base_result_dir, hole_identities, duplicate_files): 141 | in_file_files = get_results(base_result_dir, 'in_file', exclude_codex=False) 142 | parent_class_files = get_results(base_result_dir, 'parent_class_file') 143 | import_files = get_results(base_result_dir, 'import_file') 144 | sibling_files = get_results(base_result_dir, 'sibling_file') 145 | similar_name_files = get_results(base_result_dir, 'similar_name_file') 146 | child_class_files = get_results(base_result_dir, 'child_class_file') 147 | import_of_similar_name_files = get_results(base_result_dir, 'import_of_similar_name_file') 148 | import_of_sibling_files = get_results(base_result_dir, 'import_of_sibling_file') 149 | import_of_parent_class_files = get_results(base_result_dir, 'import_of_parent_class_file') 150 | import_of_child_class_files = get_results(base_result_dir, 'import_of_child_class_file') 151 | codex_file = os.path.join(base_result_dir, 'in_file', 'codex_4072.json') 152 | 153 | result_files = in_file_files + parent_class_files + import_files + sibling_files + similar_name_files \ 154 | + child_class_files + import_of_similar_name_files + import_of_sibling_files + import_of_child_class_files + import_of_parent_class_files 155 | 156 | # print(len(in_file_files), len(parent_class_files), len(import_files), len(sibling_files), len(similar_name_files), \ 157 | # len(child_class_files) , len(import_of_sibling_files), len(import_of_similar_name_files), len(import_of_child_class_files),\ 158 | # len(import_of_parent_class_files), len(result_files)) 159 | 160 | hole_rule_mapping = {} 161 | for result_file in result_files: 162 | context_location, context_type, context_ratio = check_validity_by_rule_parts(result_file) 163 | if context_location: 164 | hole_rule_mapping = read_result_file(result_file, codex_file, hole_identities, hole_rule_mapping, \ 165 | (context_location, context_type, context_ratio), duplicate_files) 166 | return hole_rule_mapping 167 | 168 | def get_failed_holes(successful_holes, hole_identities): 169 | failed_holes = [] 170 | for hole_identity in hole_identities: 171 | if hole_identity not in successful_holes: 172 | failed_holes.append(hole_identity) 173 | return failed_holes 174 | 175 | def find_rule_pattern(rule_pattern, rules): 176 | found = False 177 | for rule in rules: 178 | if rule_pattern in rule: 179 | found = True 180 | break 181 | return found 182 | 183 | def find_rule_specific_success(successful_holes, query_file_pattern=''): 184 | count = 0 185 | for h_id, rules in successful_holes.items(): 186 | if find_rule_pattern(query_file_pattern, rules): 187 | count +=1 188 | return count 189 | 190 | def find_complementary_rules(successful_holes): 191 | other_rules = [] 192 | not_lines_not_iden_codex = [] 193 | not_lines_iden = [] 194 | lines = [] 195 | for h_id, rules in successful_holes.items(): 196 | if not find_rule_pattern('lines', rules): 197 | if not find_rule_pattern('identifiers', rules): 198 | if not find_rule_pattern('codex', rules): 199 | other_rules.append((h_id, rules)) 200 | else: 201 | not_lines_not_iden_codex.append((h_id, rules)) 202 | else: 203 | not_lines_iden.append((h_id, rules)) 204 | else: 205 | lines.append((h_id, rules)) 206 | return lines, not_lines_iden, not_lines_not_iden_codex, other_rules 207 | 208 | def check_rule_validity(context_type, rule_parts): 209 | valid = False 210 | valid_hyperparams = rule_hyperparams[context_type] 211 | try: 212 | rule_context_ratio = float(rule_parts[3])/4072 213 | except: 214 | return False 215 | rule_prompt_separator = rule_parts[-1] 216 | rule_rule_context_formatting = '_'.join(rule_parts[4:-1]) 217 | if rule_context_ratio in valid_hyperparams['context_ratio']: 218 | if rule_prompt_separator in valid_hyperparams['prompt_separator']: 219 | if rule_rule_context_formatting in valid_hyperparams['rule_context_formatting']: 220 | valid = True 221 | return valid 222 | 223 | def get_rule_templated_version(oracle): 224 | mod_oracle = {} 225 | for hid, rules in oracle.items(): 226 | context_locations = [] 227 | context_types = [] 228 | combined = [] 229 | for rule in rules: 230 | context_location, context_type, context_ratio = rule 231 | context_locations.append(context_location) 232 | context_types.append(context_type) 233 | if context_location != 'codex': 234 | combined.append(context_location + '#' + context_type + '#' + str(context_ratio)) 235 | else: 236 | combined.append('codex') 237 | context_location = get_multi_hot_vector(context_locations, 'cl') 238 | context_type = get_multi_hot_vector(context_types, 'ct') 239 | comb = get_multi_hot_vector(combined, 'com') 240 | mod_oracle[hid] = {'cl': context_location, 'ct': context_type, 'com': comb} 241 | return mod_oracle 242 | 243 | 244 | def find_rule_mapping(successful_holes): 245 | rule_mapping = {} 246 | for hid, rules in successful_holes.items(): 247 | for rule in rules: 248 | if rule not in rule_mapping: 249 | rule_mapping[rule] = [hid] 250 | else: 251 | rule_mapping[rule].append(hid) 252 | return rule_mapping 253 | 254 | def find_single_best_rule_success(rule_mapping): 255 | best_single_rule_success = 0 256 | for k, v in rule_mapping.items(): 257 | if len(v)> best_single_rule_success: 258 | best_rule_parts = k 259 | best_single_rule_success = len(v) 260 | best_rule = best_rule_parts[0] + '_' + best_rule_parts[1] + '_' + str(best_rule_parts[2]) 261 | return best_rule, best_single_rule_success 262 | 263 | 264 | def setup_args(): 265 | """ 266 | Description: Takes in the command-line arguments from user 267 | """ 268 | parser = argparse.ArgumentParser() 269 | parser.add_argument("--base_dir", type=str, default='rule_classifier_data', help="base directory for the data") 270 | parser.add_argument("--data_split", type=str, default='test', help="data split to store the data") 271 | parser.add_argument("--proj_name", type=str, default='dovetaildb', help="name of the input repo") 272 | 273 | return parser.parse_args() 274 | 275 | if __name__ == '__main__': 276 | 277 | args = setup_args() 278 | hole_filename = os.path.join(args.base_dir, args.data_split, args.proj_name, 'hole_data') 279 | duplicate_filename = os.path.join(args.base_dir, args.data_split, args.proj_name, 'duplicates') 280 | hole_identities, duplicate_files = get_hole_identities(hole_filename, duplicate_filename) 281 | print("Total number of holes:", len(hole_identities)) 282 | base_result_dir = os.path.join('results', args.base_dir, args.data_split, args.proj_name) 283 | print(len(duplicate_files)) 284 | successful_holes = get_all_hole_rule_mapping(base_result_dir, hole_identities, duplicate_files) 285 | print("Number of holes that got atleast one rule successful: ", len(successful_holes)) 286 | 287 | oracle = get_rule_templated_version(successful_holes) 288 | with open(os.path.join(args.base_dir, args.data_split, args.proj_name, 'oracle'), 'wb') as f: 289 | pickle.dump(oracle, f) 290 | assert len(successful_holes) == len(oracle) 291 | rule_mapping = find_rule_mapping(successful_holes) 292 | codex_success = len(rule_mapping[('codex', 'codex', 0.5)]) 293 | best_rule, best_rule_success = find_single_best_rule_success(rule_mapping) 294 | best_single_rule_success = len(rule_mapping[('in_file', 'lines', 0.75)]) 295 | # print(rule_mapping) 296 | print(codex_success, best_single_rule_success, best_rule, best_rule_success) 297 | print( 298 | args.proj_name + ", " + \ 299 | str(float(len(successful_holes)*100/len(hole_identities))) + ", " + \ 300 | str(float(codex_success*100/len(hole_identities))) + ", " + \ 301 | best_rule + ", " +\ 302 | str(float(best_rule_success*100/len(hole_identities))) + ", " + \ 303 | "in_file_lines_0.75" + ", " +\ 304 | str(float(best_single_rule_success*100/len(hole_identities))) 305 | ) 306 | 307 | # failed_holes = get_failed_holes(successful_holes, hole_identities) 308 | # print("Number of holes that got no rule successful: ", len(failed_holes)) 309 | # with open(os.path.join(base_result_dir, 'failed_cases'), 'wb') as f: 310 | # pickle.dump(failed_holes, f) 311 | 312 | # post_lines_success = find_rule_specific_success(successful_holes, 'lines') 313 | # codex_success = find_rule_specific_success(successful_holes, 'codex') 314 | # identifiers_success = find_rule_specific_success(successful_holes, 'identifiers') 315 | # print("Number of post lines successes: ", post_lines_success) 316 | # print("Number of codex successes: ", codex_success) 317 | # print("Number of identifiers successes: ", identifiers_success) 318 | 319 | # lines, not_lines_iden, not_lines_not_iden_codex, other_rules = find_complementary_rules(successful_holes) 320 | # print("Post Lines: ", len(lines), end=", ") 321 | # print("Not Post Lines, Identifiers: ", len(not_lines_iden), end=", ") 322 | # print("Not Post Lines, Not Identifiers, Codex: ", len(not_lines_not_iden_codex), end=", ") 323 | # print("Other Rules: ", len(other_rules), end=", ") 324 | # print("No Rules: ", len(failed_holes), end="\n") 325 | 326 | -------------------------------------------------------------------------------- /att_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import torch 5 | import math 6 | 7 | 8 | class ScaledDotProductAttention(nn.Module): 9 | ''' Scaled Dot-Product Attention ''' 10 | 11 | def __init__(self, temperature, attn_dropout=0.0): 12 | super().__init__() 13 | self.temperature = temperature 14 | self.dropout = nn.Dropout(attn_dropout) 15 | 16 | def forward(self, q, k, mask=None): 17 | 18 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) #(bs, num_heads, max_len_q, max_len_k) 19 | #print("attn:", attn.shape) 20 | 21 | if mask is not None: 22 | attn = attn.masked_fill(mask == 0, -1e9) 23 | 24 | attn = self.dropout(F.softmax(attn, dim=-1))#(bs, num_heads, max_len_q, max_len_k) 25 | return attn 26 | 27 | class MultiHeadAttention(nn.Module): 28 | def __init__(self, n_head=4, d_model=768, d_k=32, d_v=32,dropout=0.0, include_res_ln=True, return_att_weights=False): 29 | super(MultiHeadAttention, self).__init__() 30 | 31 | self.include_res_ln = include_res_ln 32 | self.return_att_weights = return_att_weights 33 | self.n_head = n_head 34 | self.d_k = d_k 35 | self.d_v = d_v 36 | self.d_model = d_model 37 | self.w_qs = nn.Linear(d_model, n_head * d_k) 38 | self.w_ks = nn.Linear(d_model, n_head * d_k) 39 | self.w_vs = nn.Linear(d_model, n_head * d_v) 40 | self.fc = nn.Linear(n_head * d_v, d_model) 41 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5, attn_dropout=dropout) 42 | self.dropout = nn.Dropout(dropout) 43 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 44 | 45 | def forward(self, query, keys, values, mask): 46 | 47 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 48 | sz_b = query.size(0) #batch_size 49 | q = query #(bs, 1, d_model) 50 | k = keys #(bs, #rules, d_model) 51 | v = values #(bs, #rules, d_model) 52 | len_q, len_k, len_v = q.size(1), k.size(1), v.size(1) #(1, n_rules, n_rules) 53 | 54 | residual = q 55 | 56 | # Pass through the pre-attention projection: b x lq x (n*dv) 57 | # Separate different heads: b x lq x n x dv 58 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) #(bs, 1, n_head, d_k) 59 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) #(bs, n_rules, n_head, d_k) 60 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) #(bs, n_rules , n_head, d_k) 61 | 62 | # Transpose for attention dot product: b x n x lq x dv 63 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) #(bs, n_head, 1, d_k), (bs, n_head, n_rules, d_k), (bs, n_head, n_rules, d_k) 64 | 65 | if mask is not None: 66 | mask = mask.unsqueeze(1) # For head axis broadcasting. 67 | 68 | # calculate attention scores 69 | attn = self.attention(q, k, mask)#(bs, num_heads, 1, n_rules) 70 | 71 | pred = torch.matmul(attn, v) #(bs, num_heads, 1, d_model) 72 | pred = pred.transpose(1, 2).contiguous().view(sz_b, len_q, -1) #(bs, 1, num_heads*d_model) 73 | pred = self.dropout(self.fc(pred)) #(bs, 1, d_model) 74 | if self.include_res_ln: 75 | pred += residual 76 | pred = self.layer_norm(pred) 77 | 78 | pred = pred.squeeze() 79 | attn = attn.squeeze() 80 | if self.return_att_weights: 81 | return pred, attn 82 | else: 83 | return pred, None 84 | 85 | class PositionwiseFeedForward(nn.Module): 86 | ''' A two-feed-forward-layer module ''' 87 | 88 | def __init__(self, d_in, d_hid, dropout=0.0): 89 | super().__init__() 90 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 91 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 92 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 93 | self.dropout = nn.Dropout(dropout) 94 | 95 | def forward(self, x): 96 | residual = x 97 | x = self.w_2(F.relu(self.w_1(x))) 98 | x = self.dropout(x) 99 | x += residual 100 | x = self.layer_norm(x) 101 | return x 102 | 103 | class BasicAggModel(nn.Module): 104 | 105 | def __init__(self, include_ff=True, include_res_ln=True, dropout=0.0, d_inner=2048, d_model=768, return_att_weights=False, n_head=8, 106 | d_k=96, n_rules=63, device='cpu', is_dense_bias=True): 107 | 108 | super(BasicAggModel, self).__init__() 109 | 110 | self.include_ff = include_ff 111 | self.include_res_ln = include_res_ln 112 | self.d_inner = d_inner 113 | self.d_model = d_model 114 | self.n_rules = n_rules 115 | self.device = device 116 | self.return_att_weights = return_att_weights 117 | self.mha = MultiHeadAttention(n_head=n_head, d_k=d_k, d_v=d_k, dropout=dropout,\ 118 | return_att_weights=return_att_weights, include_res_ln=include_res_ln) 119 | self.ff = PositionwiseFeedForward(self.d_model, self.d_inner, dropout=dropout) 120 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 121 | self.dropout = nn.Dropout(p=dropout) 122 | 123 | self.rule_encoder = nn.Embedding(self.n_rules, self.d_model) 124 | if is_dense_bias: 125 | self.dense_proj = nn.Linear(d_model, n_rules) 126 | else: 127 | self.dense_proj = nn.Linear(d_model, n_rules, bias=False) 128 | 129 | for name, p in self.named_parameters(): 130 | if name!='dense_proj.weight' and p.dim() > 1: 131 | nn.init.xavier_uniform_(p) 132 | 133 | 134 | def forward(self, query, keys, mask): 135 | 136 | query = torch.unsqueeze(query, 1) #(bs, 1, d_model) 137 | bs = query.size(0) #bs 138 | 139 | mask = torch.sum(mask, dim=-1) #(bs, n_rules) 140 | mask = mask.unsqueeze(1) #(bs, n_rules) 141 | 142 | values = keys 143 | 144 | query = self.layer_norm(query) #(bs, 1, d_model) 145 | keys = self.layer_norm(keys) #(bs, n_rules, d_model) 146 | values = self.layer_norm(values) #(bs, n_rules, d_model) 147 | 148 | #Multi-Head Attention 149 | pred, att_weights = self.mha(query, keys, values, mask) #(bs, d_model), #(bs, num_heads, n_rules) 150 | 151 | #Positionwise FeedForward 152 | if self.include_ff: 153 | pred = self.ff(pred)#(bs, d_model) 154 | 155 | #Final projection to get logits 156 | pred = self.dense_proj(pred)#(bs, num_rules) 157 | 158 | if self.return_att_weights: 159 | return pred, att_weights 160 | else: 161 | return pred, None 162 | 163 | 164 | -------------------------------------------------------------------------------- /block_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shrivastavadisha/repo_level_prompt_generation/3af5f3424740448d8e325b3726e61944f6eec8b6/block_diagram.png -------------------------------------------------------------------------------- /check_duplication.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import hashlib 4 | import numpy as np 5 | import argparse 6 | 7 | comments = ['*', '/'] 8 | 9 | def chunk_reader(fobj, chunk_size=1024): 10 | """Generator that reads a file in chunks of bytes""" 11 | while True: 12 | chunk = fobj.read(chunk_size) 13 | if not chunk: 14 | return 15 | yield chunk 16 | 17 | 18 | def get_hole_count(full_path): 19 | hole_count = 0 20 | file_lines = open(full_path, encoding="utf8", errors='backslashreplace').readlines() 21 | for line in file_lines: 22 | line = line.strip() 23 | # omitting comments and empty lines (heuristic: NEED TO DOUBLE CHECK FOR WIDE APPLICABILITY) 24 | if line and not (np.any([line.startswith(comment) for comment in comments])): 25 | hole_count+=1 26 | return hole_count 27 | 28 | def check_for_duplicates(paths, hash=hashlib.sha1): 29 | duplicate_files = {} 30 | file_count = 0 31 | hole_count = 0 32 | hashes = {} 33 | duplicate_file_paths = [] 34 | for path in paths: 35 | for dirpath, dirnames, filenames in os.walk(path): 36 | for filename in filenames: 37 | if os.path.splitext(filename)[1] == '.java': 38 | full_path = os.path.join(dirpath, filename) 39 | hashobj = hash() 40 | for chunk in chunk_reader(open(full_path, 'rb')): 41 | hashobj.update(chunk) 42 | file_id = (hashobj.digest(), os.path.getsize(full_path)) 43 | duplicate = hashes.get(file_id, None) 44 | if duplicate: 45 | if duplicate not in duplicate_files: 46 | duplicate_files[duplicate] = [full_path] 47 | else: 48 | duplicate_files[duplicate].append(full_path) 49 | # print("Duplicate found: ") 50 | 51 | else: 52 | hashes[file_id] = full_path 53 | 54 | for k,v in duplicate_files.items(): 55 | num_files = len(v) 56 | file_count+= num_files 57 | ind_hole_count = get_hole_count(k) 58 | hole_count += num_files * ind_hole_count 59 | duplicate_file_paths.append(k) 60 | for path in v: 61 | duplicate_file_paths.append(path) 62 | 63 | 64 | print(paths) 65 | print(len(duplicate_file_paths)) 66 | print(str(file_count) + ", " + str(hole_count)) 67 | with open(os.path.join(paths[0], "duplicates"), "w") as outfile: 68 | outfile.write("\n".join(duplicate_file_paths)) 69 | 70 | def setup_args(): 71 | """ 72 | Description: Takes in the command-line arguments from user 73 | """ 74 | parser = argparse.ArgumentParser() 75 | 76 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility") 77 | parser.add_argument("--base_dir", type=str, default='rule_classifier_data/val', \ 78 | help="base directory for the data") 79 | parser.add_argument("--proj_name", type=str, default='rsbotownversion', \ 80 | help="name of the input repo") 81 | 82 | return parser.parse_args() 83 | 84 | if __name__ == '__main__': 85 | 86 | args = setup_args() 87 | 88 | #Fix seeds 89 | np.random.seed(args.seed) 90 | os.environ['PYTHONHASHSEED']=str(args.seed) 91 | 92 | repo_path = os.path.join(args.base_dir, args.proj_name) 93 | if os.path.isdir(repo_path): 94 | check_for_duplicates([repo_path]) 95 | -------------------------------------------------------------------------------- /context.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils import * 3 | import numpy as np 4 | 5 | class getContext(): 6 | 7 | def __init__(self, context_location='in_file', tokenizer=None, file='', 8 | context_len=4072, parse_data=None, 9 | context_type='lines', context_scope='pre', top_k=-1, top_k_type='first', 10 | attention_scores=None, rule_context_formatting='space', 11 | file_lines=None): 12 | 13 | super(getContext, self).__init__() 14 | 15 | self.file = file 16 | self.context_len = context_len 17 | self.context_location = context_location 18 | self.tokenizer = tokenizer 19 | self.top_k = top_k 20 | self.top_k_type = top_k_type 21 | self.context_type = context_type 22 | self.parse_data = parse_data 23 | self.attention_scores = attention_scores 24 | self.rule_context_formatting = rule_context_formatting 25 | if file_lines !=None: 26 | self.file_lines = file_lines 27 | else: 28 | self.file_lines = open(file, encoding="utf8", errors='backslashreplace').readlines() 29 | if self.context_location == 'in_file': 30 | self.set_context_scope_and_inclusion_type(context_scope) 31 | elif self.context_location == 'parent_class_file': 32 | self.codex_completion_inclusion_type = 'back' 33 | else: 34 | self.import_overlap_type_files = {} 35 | self.codex_completion_inclusion_type = 'front' 36 | 37 | def is_non_empty(self, parse_data_attribute_str): 38 | parse_data_attribute = self.parse_data[self.file][parse_data_attribute_str] 39 | if parse_data_attribute: 40 | return True 41 | else: 42 | return False 43 | 44 | def is_out_files(self): 45 | if self.context_location == 'parent_class_file' or self.context_location == 'import_of_parent_class_file': 46 | return self.is_non_empty('parent_class_filenames') 47 | if self.context_location == 'import_file': 48 | return self.is_non_empty('imports') 49 | if self.context_location == 'sibling_file' or self.context_location == 'reverse_sibling_file' or self.context_location == 'import_of_sibling_file': 50 | return self.is_non_empty('sibling_files') 51 | if self.context_location == 'similar_name_file' or self.context_location == 'reverse_similar_name_file' or self.context_location == 'import_of_similar_name_file': 52 | return self.is_non_empty('similar_name_files') 53 | if self.context_location == 'child_class_file' or self.context_location == 'import_of_child_class_file': 54 | return self.is_non_empty('child_class_filenames') 55 | 56 | def set_context_scope_and_inclusion_type(self, new_scope): 57 | self.context_scope = new_scope 58 | if self.context_scope == 'pre' or self.context_scope =='pre_post': 59 | self.codex_completion_inclusion_type ='back' 60 | if self.context_scope == 'post': 61 | self.codex_completion_inclusion_type ='front' 62 | 63 | def set_hole_pos(self, hole_pos): 64 | self.hole_pos = hole_pos 65 | 66 | def get_context_len(self): 67 | return self.context_len 68 | 69 | def set_context_len(self, new_context_len): 70 | self.context_len = new_context_len 71 | 72 | def get_rule_context_format(self, lst): 73 | if self.rule_context_formatting == 'space': 74 | context = " ".join(lst) 75 | 76 | if self.rule_context_formatting == 'newline': 77 | context = "\n".join(lst) 78 | 79 | if self.rule_context_formatting == 'method_name'\ 80 | or self.rule_context_formatting == 'class_name'\ 81 | or self.rule_context_formatting == 'class_method_name': 82 | context = " ".join(lst) 83 | 84 | if self.rule_context_formatting == 'comment': 85 | context = " ".join(lst) 86 | context = "/**" + context + "*/" 87 | 88 | return context 89 | 90 | def get_nearest_attribute_index(self, attribute_names, type='class'): 91 | hole_start_line = self.hole_pos[0] 92 | min_pos_diff = 100000 93 | min_pos_diff_index = -1 94 | for i in range(len(attribute_names)): 95 | if attribute_names[i]: 96 | attribute_start_line = attribute_names[i][0][0] 97 | if type == 'class': 98 | pos_diff = hole_start_line - attribute_start_line 99 | if type == 'import': 100 | pos_diff = np.abs(hole_start_line - attribute_start_line) 101 | if pos_diff < min_pos_diff: 102 | if type == 'class': 103 | if pos_diff > 0: 104 | min_pos_diff = pos_diff 105 | min_pos_diff_index = i 106 | #print(pos_diff, min_pos_diff, min_pos_diff_index) 107 | if type == 'import': 108 | if pos_diff != 0 or (pos_diff == 0 and attribute_names[i][1][1] < self.hole_pos[1]): 109 | min_pos_diff = pos_diff 110 | min_pos_diff_index = i 111 | if type == 'class': 112 | return min_pos_diff_index 113 | if type == 'import': 114 | return pos_diff 115 | 116 | def get_relevant_import_of_att_files(self, att_type): 117 | att_import_ranking = {} 118 | att_files = self.parse_data[self.file][att_type] 119 | #att_files = list(set(att_files)) 120 | for att_file, att_file_overlap in att_files: 121 | if 'small_' in self.file.split('/')[1]: 122 | att_file = '/'.join([att_file.split('/')[0]] + [self.file.split('/')[1]] + att_file.split('/')[2:]) 123 | att_file_imports = list(self.parse_data[att_file]['imports'].keys()) 124 | for att_file_import in att_file_imports: 125 | if att_file_import in att_import_ranking: 126 | att_import_ranking[att_file_import]+=1 127 | else: 128 | att_import_ranking[att_file_import] = 1 129 | sorted_att_import_ranking = sorted(att_import_ranking.items(), key=lambda x: x[1], reverse=True) 130 | sorted_att_import_ranking = [imp_file for imp_file, _ in sorted_att_import_ranking] 131 | return sorted_att_import_ranking 132 | 133 | def get_relevant_import_files(self): 134 | all_imports = self.parse_data[self.file]['imports'] 135 | import_distances_from_hole = {} 136 | for import_file, import_identifier_loc in all_imports.items(): 137 | pos_diff = self.get_nearest_attribute_index(import_identifier_loc, type='import') 138 | if pos_diff != -1: 139 | import_distances_from_hole[import_file] = pos_diff 140 | 141 | # less the position difference from the hole, higher the ranking 142 | sorted_import_distances_from_hole = sorted(import_distances_from_hole.items(), key=lambda x: x[1]) 143 | sorted_import_files = [imp_file for imp_file, _ in sorted_import_distances_from_hole] 144 | return sorted_import_files 145 | 146 | def get_relevant_files(self, type_str, sort_order='descending'): 147 | all_type_files = self.parse_data[self.file][type_str] 148 | if not all_type_files: 149 | return all_type_files 150 | 151 | sorted_file_imports = self.get_relevant_import_files() 152 | overlapping_type_files = {} 153 | found = False 154 | # find type files(e.g. sibling files) with imports common with current file based on the position with the hole. 155 | # in case multiple such files exist, sort based on number of common import statements. 156 | for imp_file in sorted_file_imports: 157 | if imp_file in self.import_overlap_type_files: 158 | return self.import_overlap_type_files[imp_file] 159 | if found: 160 | break 161 | for type_file, type_file_overlap in all_type_files: 162 | if type_file_overlap > 0: 163 | if 'small_' in self.file.split('/')[1]: 164 | type_file = '/'.join([type_file.split('/')[0]] + [self.file.split('/')[1]] + type_file.split('/')[2:]) 165 | type_file_import_files = list(self.parse_data[type_file]['imports'].keys()) 166 | if imp_file in type_file_import_files: 167 | overlapping_type_files[type_file] = type_file_overlap 168 | found = True 169 | 170 | if not found: 171 | type_files = [x[0] for x in all_type_files] 172 | return type_files 173 | 174 | # more the overlap, higher the ranking 175 | if sort_order == 'descending': 176 | sorted_overlapping_type_files = sorted(overlapping_type_files.items(), key=lambda x: x[1], reverse=True) 177 | else: 178 | sorted_overlapping_type_files = sorted(overlapping_type_files.items(), key=lambda x: x[1]) 179 | 180 | sorted_type_files = [type_file for type_file, _ in sorted_overlapping_type_files] 181 | self.import_overlap_type_files[imp_file] = sorted_type_files 182 | return sorted_type_files 183 | 184 | def get_parent_class_filename(self): 185 | """ 186 | Return the parent class filename that corresponds to the immediate scope of the hole location 187 | """ 188 | file_parsed_data = self.parse_data[self.file] 189 | parent_class_filenames = file_parsed_data['parent_class_filenames'] 190 | parent_class_names = file_parsed_data['parent_class_names'] 191 | relevant_index = self.get_nearest_attribute_index(parent_class_names) 192 | if relevant_index != -1: 193 | return parent_class_filenames[relevant_index][0], parent_class_names[relevant_index] 194 | else: 195 | return '', '' 196 | 197 | def get_method_names_and_bodies(self, method_names, method_bodies, file): 198 | 199 | if self.top_k != -1 and len(method_names) >= self.top_k: 200 | method_context_len = int(self.context_len/self.top_k) 201 | else: 202 | method_context_len = int(self.context_len/len(method_names)) 203 | 204 | method_contexts = [] 205 | context_len = 0 206 | for method_name in method_names: 207 | if method_name: 208 | found = False 209 | for method_body in method_bodies: 210 | # for each method name, find the corresponding method_body 211 | if method_body and method_body[0][0] == method_name[0][0]: 212 | ms, me = method_body 213 | full_ms = (ms[0], 0) 214 | full_me = me 215 | found = True 216 | break 217 | 218 | if found == False: 219 | ms, me = method_name 220 | full_ms = (ms[0], 0) 221 | full_me = (ms[0], -1) 222 | 223 | method_name_and_body = get_string(file, full_ms, full_me) 224 | method_context, method_context_len = get_codex_tokenized_string(self.tokenizer, method_name_and_body, \ 225 | method_context_len) 226 | if self.rule_context_formatting == 'method_name'\ 227 | or self.rule_context_formatting =='class_method_name': 228 | method_name_str = "[" + get_string(file, method_name[0], method_name[1]) + "]" 229 | method_contexts.append(method_name_str) 230 | method_contexts.append(method_context) 231 | context_len += method_context_len 232 | 233 | context = self.get_rule_context_format(method_contexts) 234 | return context, context_len 235 | 236 | def get_context_string(self, candidate_attributes): 237 | 238 | if self.top_k == -1: 239 | attributes_str = self.get_rule_context_format(candidate_attributes) 240 | else: 241 | if self.top_k_type == 'first': 242 | attributes_str = self.get_rule_context_format(candidate_attributes[:self.top_k]) 243 | if self.top_k_type == 'last': 244 | attributes_str = self.get_rule_context_format(candidate_attributes[-self.top_k:]) 245 | 246 | context, context_len = get_codex_tokenized_string(self.tokenizer, attributes_str, self.context_len, 247 | type=self.codex_completion_inclusion_type) 248 | 249 | return context, context_len 250 | 251 | def get_attribute_context(self, attributes, file): 252 | candidate_attributes = [] 253 | for attribute in attributes: 254 | if attribute: 255 | start, end = attribute 256 | 257 | if self.context_location == 'in_file': 258 | start_line, start_char = start 259 | end_line, end_char = end 260 | hole_pos_line, hole_pos_char = self.hole_pos 261 | #assert start_line == end_line, "attribute doesn't span a single line" 262 | 263 | if self.context_scope == 'pre' or self.context_scope == 'pre_post': 264 | if end_line < hole_pos_line or (end_line == hole_pos_line and end_char < hole_pos_char): 265 | attribute_string = get_string(file, start, end) 266 | candidate_attributes.append(attribute_string.strip()) 267 | 268 | if self.context_scope == 'post' or self.context_scope == 'pre_post': 269 | if start_line > hole_pos_line: 270 | attribute_string = get_string(file, start, end) 271 | candidate_attributes.append(attribute_string.strip()) 272 | 273 | else: 274 | # checking for overlap with the hole is not needed here as it is a different file 275 | attribute_string = get_string(file, start, end) 276 | candidate_attributes.append(attribute_string.strip()) 277 | 278 | context, context_len = self.get_context_string(candidate_attributes) 279 | return context, context_len 280 | 281 | def get_line_context(self, num_of_lines_to_exclude=0): 282 | num_of_lines_to_be_taken = self.top_k 283 | pre_context = '' 284 | post_context = '' 285 | if self.context_scope == 'pre' or self.context_scope == 'pre_post': 286 | end = self.hole_pos 287 | if num_of_lines_to_be_taken == -1: 288 | start = (0, 0) 289 | else: 290 | hole_pos_line = self.hole_pos[0] 291 | start_line = hole_pos_line - num_of_lines_to_be_taken 292 | if start_line < 0: 293 | start_line = 0 294 | start = (start_line, 0) 295 | pre_context = get_string(self.file, start, end) 296 | 297 | if self.context_scope == 'post' or self.context_scope == 'pre_post': 298 | hole_pos_line = self.hole_pos[0] 299 | start = (hole_pos_line + 1 + num_of_lines_to_exclude, 0) 300 | if num_of_lines_to_be_taken != -1: 301 | end_line = hole_pos_line + num_of_lines_to_be_taken + num_of_lines_to_exclude 302 | if end_line >= len(self.file_lines): 303 | end_line = len(self.file_lines) - 1 304 | end_char = len(self.file_lines[end_line]) 305 | end = (end_line, end_char) 306 | else: 307 | end_line = len(self.file_lines)-1 308 | end_char = len(self.file_lines[end_line]) 309 | end = (end_line, end_char) 310 | 311 | post_context = get_string(self.file, start, end) 312 | 313 | if self.context_scope == 'pre': 314 | context, context_len = get_codex_tokenized_string(self.tokenizer, pre_context, self.context_len, 315 | type=self.codex_completion_inclusion_type) 316 | if self.context_scope == 'post': 317 | context, context_len = get_codex_tokenized_string(self.tokenizer, post_context, self.context_len, 318 | type=self.codex_completion_inclusion_type) 319 | if self.context_scope == 'pre_post': 320 | pre_context, pre_context_len = get_codex_tokenized_string(self.tokenizer, pre_context, int(self.context_len/2), 321 | type='back') 322 | post_context, post_context_len = get_codex_tokenized_string(self.tokenizer, post_context, int(self.context_len/2), 323 | type='front') 324 | 325 | context = pre_context + "\n" + post_context 326 | context_len = pre_context_len + post_context_len 327 | 328 | return context, context_len 329 | 330 | def get_base_context(self): 331 | base_class_names = self.parse_data[self.file]['class_names'] 332 | class_index = self.get_nearest_attribute_index(base_class_names) 333 | if class_index != -1: 334 | base_class_name = get_string(self.file, base_class_names[class_index][0], base_class_names[class_index][1]) 335 | base_context = "[" + base_class_name + "]" 336 | else: 337 | base_context = '' 338 | return base_context 339 | 340 | def get_attribute_context_from_context_type(self, file_type): 341 | if 'small_' in self.file.split('/')[1]: 342 | file_type = '/'.join([file_type.split('/')[0]] + [self.file.split('/')[1]] + file_type.split('/')[2:]) 343 | if self.context_type == 'identifiers': 344 | context, context_len = self.get_attribute_context(self.parse_data[file_type]['identifiers'], file_type) 345 | 346 | if self.context_type == 'type_identifiers': 347 | context, context_len = self.get_attribute_context(self.parse_data[file_type]['type_identifiers'], file_type) 348 | 349 | if self.context_type == 'string_literals': 350 | context, context_len = self.get_attribute_context(self.parse_data[file_type]['string_literals'], file_type) 351 | 352 | if self.context_type == 'method_names': 353 | context, context_len = self.get_attribute_context(self.parse_data[file_type]['all_method_names'], file_type) 354 | 355 | if self.context_type == 'method_names_and_bodies': 356 | method_names = self.parse_data[file_type]['all_method_names'] 357 | method_bodies = self.parse_data[file_type]['all_method_bodies'] 358 | if method_names: 359 | context, context_len = self.get_method_names_and_bodies(method_names, method_bodies, file_type) 360 | else: 361 | context= '' 362 | context_len = 0 363 | 364 | if self.context_type == 'field_declarations': 365 | context, context_len = self.get_attribute_context(self.parse_data[file_type]['field_declarations'], file_type) 366 | 367 | return context, context_len 368 | 369 | def get_context_from_multiple_files(self, files): 370 | total_context_len = 0 371 | total_context = '' 372 | if files: 373 | for file in files: 374 | if total_context_len < self.get_context_len(): 375 | if self.rule_context_formatting == 'class_name' or self.rule_context_formatting == 'class_method_name': 376 | base_context = self.get_base_context() 377 | file_name = file.split('/')[-1].split('.')[0] 378 | file_name = "[" + file_name + "]" 379 | # get context 380 | context, context_len = self.get_attribute_context_from_context_type(file) 381 | 382 | # import contexts are added to the front based on decreasing priority 383 | if self.rule_context_formatting == 'class_name' or self.rule_context_formatting == 'class_method_name': 384 | total_context = file_name + " " + context + " " + total_context 385 | else: 386 | total_context = context + " " + total_context 387 | total_context_len += context_len 388 | else: 389 | break 390 | 391 | if self.rule_context_formatting == 'class_name' or self.rule_context_formatting == 'class_method_name': 392 | total_context = total_context + "\n" + base_context 393 | return total_context, total_context_len 394 | 395 | def get_in_file_context(self, num_of_lines_to_exclude=0): 396 | """ 397 | for in_file only post lines makes sense. 398 | for others, first post is tried, if post is not successful in finding any context the pre_post is tried. 399 | """ 400 | if self.context_type == 'lines': 401 | self.set_context_scope_and_inclusion_type('post') 402 | context, context_len = self.get_line_context(num_of_lines_to_exclude) 403 | # doesn't mean much to have pre_post in this setting 404 | 405 | if self.context_type == 'identifiers': 406 | self.set_context_scope_and_inclusion_type('post') 407 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['identifiers'], self.file) 408 | if not context: 409 | self.set_context_scope_and_inclusion_type('pre_post') 410 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['identifiers'], self.file) 411 | 412 | if self.context_type == 'type_identifiers': 413 | self.set_context_scope_and_inclusion_type('post') 414 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['type_identifiers'], self.file) 415 | if not context: 416 | self.set_context_scope_and_inclusion_type('pre_post') 417 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['type_identifiers'], self.file) 418 | 419 | if self.context_type == 'string_literals': 420 | self.set_context_scope_and_inclusion_type('post') 421 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['string_literals'], self.file) 422 | if not context: 423 | self.set_context_scope_and_inclusion_type('pre_post') 424 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['string_literals'], self.file) 425 | 426 | if self.context_type == 'method_names': 427 | self.set_context_scope_and_inclusion_type('post') 428 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['all_method_names'], self.file) 429 | if not context: 430 | self.set_context_scope_and_inclusion_type('pre_post') 431 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['all_method_names'], self.file) 432 | 433 | if self.context_type == 'field_declarations': 434 | self.set_context_scope_and_inclusion_type('post') 435 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['field_declarations'], self.file) 436 | if not context: 437 | self.set_context_scope_and_inclusion_type('pre_post') 438 | context, context_len = self.get_attribute_context(self.parse_data[self.file]['field_declarations'], self.file) 439 | 440 | return context, context_len 441 | 442 | def get_parent_class_file_context(self): 443 | self.parent_class_file, self.parent_class_name = self.get_parent_class_filename() 444 | if self.parent_class_file: 445 | if self.rule_context_formatting == 'class_name' or self.rule_context_formatting == 'class_method_name': 446 | base_context = self.get_base_context() 447 | parent_class_name = get_string(self.file, self.parent_class_name[0], self.parent_class_name[1]) 448 | parent_context = "[" + parent_class_name + "]" 449 | # get context 450 | context, context_len = self.get_attribute_context_from_context_type(self.parent_class_file) 451 | if self.rule_context_formatting == 'class_name' or self.rule_context_formatting == 'class_method_name': 452 | context = parent_context + " " + context + "\n" + base_context 453 | else: 454 | context = '' 455 | context_len = 0 456 | 457 | return context, context_len 458 | 459 | def get_import_file_context(self): 460 | import_files = self.get_relevant_import_files() 461 | return self.get_context_from_multiple_files(import_files) 462 | 463 | def get_sibling_file_context(self): 464 | if self.context_location.startswith('reverse'): 465 | sort_order ='ascending' 466 | else: 467 | sort_order = 'descending' 468 | sibling_files = self.get_relevant_files(type_str='sibling_files', sort_order=sort_order) 469 | return self.get_context_from_multiple_files(sibling_files) 470 | 471 | def get_similar_name_file_context(self): 472 | if self.context_location.startswith('reverse'): 473 | sort_order ='ascending' 474 | else: 475 | sort_order = 'descending' 476 | similar_name_files = self.get_relevant_files(type_str='similar_name_files', sort_order=sort_order) 477 | return self.get_context_from_multiple_files(similar_name_files) 478 | 479 | def get_child_class_file_context(self): 480 | child_class_files = self.get_relevant_files(type_str='child_class_filenames') 481 | return self.get_context_from_multiple_files(child_class_files) 482 | 483 | def get_import_of_sibling_file_context(self): 484 | imports_of_sibling_files = self.get_relevant_import_of_att_files('sibling_files') 485 | return self.get_context_from_multiple_files(imports_of_sibling_files) 486 | 487 | def get_import_of_similar_name_file_context(self): 488 | imports_of_similar_name_files = self.get_relevant_import_of_att_files('similar_name_files') 489 | return self.get_context_from_multiple_files(imports_of_similar_name_files) 490 | 491 | def get_import_of_parent_class_file_context(self): 492 | imports_of_parent_class_files = self.get_relevant_import_of_att_files('parent_class_filenames') 493 | return self.get_context_from_multiple_files(imports_of_parent_class_files) 494 | 495 | def get_import_of_child_class_file_context(self): 496 | imports_of_child_class_files = self.get_relevant_import_of_att_files('child_class_filenames') 497 | return self.get_context_from_multiple_files(imports_of_child_class_files) 498 | -------------------------------------------------------------------------------- /create_hole_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | import pickle 5 | import javac_parser 6 | import shutil 7 | 8 | 9 | ''' 10 | For each line in the repo (which is not blank or comment) choose the midpoint of those lines(character wise not token wise) as hole position 11 | ''' 12 | 13 | def choose_holes(project_lines, comments): 14 | data = {} 15 | count = 0 16 | repeated_holes = 0 17 | chosen_lines = [] 18 | 19 | selected_lines = np.arange(0, len(project_lines)) 20 | 21 | for proj_line_id in selected_lines: 22 | file, file_line_id, line = project_lines[proj_line_id] 23 | # removing leading and trailing whitespaces 24 | line = line.strip() 25 | # omitting comments and empty lines 26 | if line and not (np.any([line.startswith(comment) for comment in comments])): 27 | if proj_line_id in chosen_lines: 28 | repeated_holes+=1 29 | else: 30 | chosen_lines.append(proj_line_id) 31 | count+=1 32 | #get holes from the middle of the lines 33 | mid_point = int(len(line)/2) 34 | chosen_position = mid_point 35 | 36 | if file in data: 37 | data[file].append((file_line_id, chosen_position)) 38 | else: 39 | data[file] = [(file_line_id, chosen_position)] 40 | 41 | #total number of holes, #number of repeated holes, number of relevant lines, number of files 42 | return data, len(chosen_lines), len(data) 43 | 44 | def setup_args(): 45 | """ 46 | Description: Takes in the command-line arguments from user 47 | """ 48 | parser = argparse.ArgumentParser() 49 | 50 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility") 51 | parser.add_argument("--base_dir", type=str, default='gcode-data', \ 52 | help="base directory for the data") 53 | parser.add_argument("--data_split", type=str, default='train', \ 54 | help="data split to store the data") 55 | parser.add_argument("--language", type=str, default='java', help="java, cpp") 56 | parser.add_argument("--proj_name", type=str, default='javasummerframework', \ 57 | help="name of the input repo") 58 | 59 | return parser.parse_args() 60 | 61 | if __name__ == '__main__': 62 | 63 | args = setup_args() 64 | 65 | #Fix seeds 66 | np.random.seed(args.seed) 67 | os.environ['PYTHONHASHSEED']=str(args.seed) 68 | 69 | if args.language == 'java': 70 | file_extensions = ['.java'] 71 | comments = ['*', '/'] 72 | if args.language == 'lua': 73 | file_extensions = ['.lua'] 74 | comments = ['--'] 75 | if args.language == 'cpp': 76 | file_extensions == ['.cc', '.cpp', '.h'] 77 | comments = ['/'] 78 | 79 | source_data_path = os.path.join(args.base_dir, args.proj_name) 80 | os.makedirs(os.path.join('rule_classifier_data', args.data_split), exist_ok=True) 81 | destination_data_path = os.path.join('rule_classifier_data', args.data_split, args.proj_name) 82 | shutil.move(source_data_path, destination_data_path) 83 | 84 | files = [] 85 | for dp, dn, filenames in os.walk(destination_data_path): 86 | for f in filenames: 87 | for file_ext in file_extensions: 88 | if os.path.splitext(f)[1] == file_ext: 89 | files.append(os.path.join(dp, f)) 90 | 91 | project_lines = [] 92 | for file in files: 93 | file_lines = open(file, encoding="utf8", errors='backslashreplace').readlines() 94 | parsed_file_lines = [] 95 | for l in range(len(file_lines)): 96 | line = file_lines[l] 97 | parsed_file_lines.append((file, l, line)) 98 | project_lines.extend(parsed_file_lines) 99 | num_of_lines_in_proj = len(project_lines) # number of lines in the project 100 | data, num_of_holes, num_of_files = choose_holes(project_lines, comments=comments) 101 | 102 | with open(os.path.join(destination_data_path, 'hole_data'), 'wb') as f: 103 | pickle.dump(data, f) 104 | 105 | print(args.proj_name + ", " + str(num_of_files) + ", " \ 106 | + str(num_of_lines_in_proj) + ", " + str(num_of_holes)) 107 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | from context import * 2 | from utils import * 3 | from rule_config import * 4 | from transformers import GPT2TokenizerFast 5 | 6 | context_location_conversion = { 7 | 'in_file':'in_file', \ 8 | 'parent_class_file':'parent_class_file', \ 9 | 'import_file':'import_file',\ 10 | 'sibling_file':'sibling_files', \ 11 | 'similar_name_file':'similar_name_files', \ 12 | 'child_class_file':'child_class_filenames', \ 13 | 'import_of_sibling_file':'sibling_files', \ 14 | 'import_of_similar_name_file':'similar_name_files', \ 15 | 'import_of_parent_class_file':'parent_class_filenames', \ 16 | 'import_of_child_class_file':'child_class_filenames' 17 | } 18 | 19 | 20 | class RuleDatasetUtils(): 21 | def __init__(self, file, parse_datas, hole_pos, tokenizer): 22 | super(RuleDatasetUtils, self).__init__() 23 | #mod_file = '/'. join(['data', 'gcode-data'] + file.split('/')[2:]) 24 | self.file = file 25 | self.parse_datas = parse_datas 26 | self.hole_pos = hole_pos 27 | self.tokenizer = tokenizer 28 | #self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 29 | 30 | def get_relevant_files(self, context_location): 31 | if context_location == 'in_file': 32 | return [self.file] 33 | else: 34 | rule_context_obj = getContext(context_location = context_location, file=self.file, parse_data = self.parse_datas) 35 | rule_context_obj.set_hole_pos(self.hole_pos) 36 | if rule_context_obj.is_out_files(): 37 | if context_location == 'parent_class_file': 38 | relevant_file, _ = rule_context_obj.get_parent_class_filename() 39 | relevant_files = [] 40 | elif context_location == 'import_file': 41 | relevant_files = rule_context_obj.get_relevant_import_files() 42 | elif 'import_of_' in context_location: 43 | relevant_files = rule_context_obj.get_relevant_import_of_att_files(context_location_conversion[context_location]) 44 | else: 45 | relevant_files = rule_context_obj.get_relevant_files(context_location_conversion[context_location]) 46 | return relevant_files 47 | else: 48 | return [] 49 | 50 | def get_usages_from_context_location(self, hole_attributes, context_location): 51 | cl_files = self.get_relevant_files(context_location) 52 | for hole_att in hole_attributes: 53 | for cl_file in cl_files: 54 | print(cl_file) 55 | file_attributes = self.parse_datas[cl_file]['identifiers'] 56 | att_usages = find_usages(hole_att, self.file, file_attributes, cl_file) 57 | if not att_usages: 58 | continue 59 | else: 60 | return (cl_file, att_usages) 61 | 62 | def get_all_usages(self, hole_attributes): 63 | usages = {} 64 | for context_location in context_location_conversion.keys(): 65 | print(context_location) 66 | usages[context_location] = self.get_usages_from_context_location(hole_attributes, context_location) 67 | return usages 68 | 69 | def get_default_prompt(self, context_len): 70 | 71 | default_context_obj = getContext(context_location='in_file', 72 | tokenizer=self.tokenizer, 73 | file=self.file, 74 | context_len=context_len, 75 | context_scope='pre',\ 76 | context_type='lines',\ 77 | top_k=-1) 78 | 79 | default_context_obj.set_hole_pos(self.hole_pos) 80 | default_prompt, default_prompt_len = default_context_obj.get_line_context() 81 | return default_prompt, default_prompt_len 82 | 83 | 84 | def get_all_rules_context(self, num_of_lines_to_exclude=0): 85 | rule_prompts = [] 86 | rule_indexes = [] 87 | total_context_len = self.tokenizer.model_max_length 88 | 89 | 90 | for key, val in combined_to_index.items(): 91 | rule_prompt = '' 92 | if key == 'codex': 93 | rule_prompt, rule_prompt_len = self.get_default_prompt(context_len=total_context_len) 94 | if key != 'codex': 95 | context_location, context_type, context_division_ratio = key.split('#') 96 | rule_context_formatting = rule_hyperparams[context_type]['rule_context_formatting'][0] 97 | rule_context_obj = getContext(context_location = context_location, file=self.file, parse_data = self.parse_datas, \ 98 | context_type=context_type, rule_context_formatting=rule_context_formatting, \ 99 | tokenizer = self.tokenizer, context_len = total_context_len) 100 | 101 | allocated_rule_context_len = int(rule_context_obj.get_context_len()*float(context_division_ratio)) 102 | rule_context_obj.set_context_len(allocated_rule_context_len) 103 | rule_context_obj.set_hole_pos(self.hole_pos) 104 | 105 | if context_location == 'in_file': 106 | rule_prompt, rule_prompt_len = rule_context_obj.get_in_file_context(num_of_lines_to_exclude) 107 | 108 | # there are files for this context location except in_file context location 109 | is_out_files = rule_context_obj.is_out_files() 110 | if is_out_files: 111 | if context_location == 'parent_class_file': 112 | rule_prompt, rule_prompt_len = rule_context_obj.get_parent_class_file_context() 113 | if context_location == 'import_file': 114 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_file_context() 115 | if context_location == 'sibling_file': 116 | rule_prompt, rule_prompt_len = rule_context_obj.get_sibling_file_context() 117 | if context_location == 'similar_name_file': 118 | rule_prompt, rule_prompt_len = rule_context_obj.get_similar_name_file_context() 119 | if context_location == 'child_class_file': 120 | rule_prompt, rule_prompt_len = rule_context_obj.get_child_class_file_context() 121 | if context_location == 'import_of_similar_name_file': 122 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_similar_name_file_context() 123 | if context_location == 'import_of_parent_class_file': 124 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_parent_class_file_context() 125 | if context_location == 'import_of_child_class_file': 126 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_child_class_file_context() 127 | if context_location == 'import_of_sibling_file': 128 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_sibling_file_context() 129 | 130 | if rule_prompt: 131 | rule_prompts.append(rule_prompt) 132 | rule_indexes.append(val) 133 | 134 | return rule_prompts, rule_indexes 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /generate_completions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import os 4 | import copy 5 | import pickle 6 | import openai 7 | import argparse 8 | import json 9 | import random 10 | from utils import * 11 | from context import * 12 | from transformers import GPT2TokenizerFast 13 | 14 | def setup_args(): 15 | """ 16 | Description: Takes in the command-line arguments from user 17 | """ 18 | parser = argparse.ArgumentParser() 19 | 20 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility") 21 | parser.add_argument("--base_dir", type=str, default='rule_classifier_data/val', help="base directory for the data") 22 | parser.add_argument("--repo_name", type=str, default='ircrpgbot', help="name of the repo") 23 | 24 | # completion-related hyperparams 25 | parser.add_argument("--mode", type=str, default='rule', help="codex, rule") 26 | 27 | parser.add_argument("--batch_size", type=int, default=20, \ 28 | help="batch size of the prompts to be given to Codex") 29 | parser.add_argument("--completion_len", type=int, default=24, \ 30 | help=" length of the 24 completions, so total size will be 4096") 31 | parser.add_argument("--is_run_rule_full", default=False, action='store_true', \ 32 | help="whether to run rule based method for the full cases or only rule_triggered cases") 33 | 34 | # context related hyperparams 35 | parser.add_argument("--total_context_len", type=int, default=100, \ 36 | help="total size of the context: 24 for completions, so total size will be 4096") 37 | parser.add_argument("--context_division_ratio", type=float, default=0.5, \ 38 | help="ratio in which the in-file and out-file context are divided") 39 | 40 | parser.add_argument("--context_location", type=str, default='in_file', \ 41 | help="where to take context from\ 42 | NOTE that this is always in addition to the previous prompt, \ 43 | i.e., in addition to the default prompt for a Codex model") 44 | 45 | parser.add_argument("--context_type", type=str, default='field_declarations',\ 46 | help="the type of context to be taken. \ 47 | For possible values for each context_locations see rules file") 48 | 49 | # rule-related hyperparameters 50 | parser.add_argument("--top_k", type=int, default=-1,\ 51 | help="k value. A value of -1 indicates taking full context of context_type") 52 | parser.add_argument("--top_k_type", type=str, default='first', \ 53 | help="first, last") 54 | parser.add_argument("--prompt_separator", type=str, default='space', \ 55 | help="space, newline") 56 | parser.add_argument("--rule_context_formatting", type=str, default='space', \ 57 | help="space, newline, method_name, class_name, comment, class_method_name") 58 | return parser.parse_args() 59 | 60 | def generate_prediction(prompt): 61 | ''' 62 | generate predictions using Codex 63 | ''' 64 | try: 65 | response = openai.Completion.create(engine='code-davinci-001',\ 66 | prompt=prompt,stop='\n',\ 67 | temperature=0.0) 68 | 69 | except: 70 | print ("Waiting") 71 | response = None 72 | return response 73 | 74 | def check_hole_scope(hole_pos, class_spans): 75 | ''' 76 | return the class span of the base class where the cursor is present. If there is no base class, return None 77 | ''' 78 | for class_span in class_spans: 79 | cs = int(class_span.split('_')[0]) 80 | ce = int(class_span.split('_')[1]) 81 | l, c = hole_pos 82 | if l == cs or cs == -1: 83 | return None 84 | if cs < l <= ce: 85 | return class_span 86 | 87 | def get_default_prompt(hole_pos=(0,0), context_len=0, tokenizer=None, file=''): 88 | 89 | default_context_obj = getContext(context_location='in_file', 90 | tokenizer=tokenizer, 91 | file=file, 92 | context_len=context_len, 93 | context_scope='pre',\ 94 | context_type='lines',\ 95 | top_k=-1) 96 | 97 | default_context_obj.set_hole_pos(hole_pos) 98 | default_prompt, default_prompt_len = default_context_obj.get_line_context() 99 | return default_prompt, default_prompt_len 100 | 101 | def get_prompt(rule_context_obj=None, context_location='in_file', 102 | total_context_len=4072, rule_triggered=False, parent_class_filename='', 103 | context_division_ratio=0.5, num_of_lines_to_exclude=0): 104 | 105 | # start by assigning half of the total_context_len to the rule prompt 106 | rule_context_obj.set_context_len(total_context_len) 107 | allocated_rule_context_len = int(rule_context_obj.get_context_len()*context_division_ratio) 108 | rule_context_obj.set_context_len(allocated_rule_context_len) 109 | 110 | if context_location == 'in_file': 111 | rule_prompt, rule_prompt_len = rule_context_obj.get_in_file_context(num_of_lines_to_exclude) 112 | if context_location == 'parent_class_file': 113 | rule_prompt, rule_prompt_len = rule_context_obj.get_parent_class_file_context() 114 | if context_location == 'import_file': 115 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_file_context() 116 | if context_location == 'sibling_file' or context_location == 'reverse_sibling_file': 117 | rule_prompt, rule_prompt_len = rule_context_obj.get_sibling_file_context() 118 | if context_location == 'similar_name_file' or context_location == 'reverse_similar_name_file': 119 | rule_prompt, rule_prompt_len = rule_context_obj.get_similar_name_file_context() 120 | if context_location == 'child_class_file': 121 | rule_prompt, rule_prompt_len = rule_context_obj.get_child_class_file_context() 122 | if context_location == 'import_of_similar_name_file': 123 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_similar_name_file_context() 124 | if context_location == 'import_of_parent_class_file': 125 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_parent_class_file_context() 126 | if context_location == 'import_of_child_class_file': 127 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_child_class_file_context() 128 | if context_location == 'import_of_sibling_file': 129 | rule_prompt, rule_prompt_len = rule_context_obj.get_import_of_sibling_file_context() 130 | 131 | # if the rule_prompt_len is shorter than the allocated space, use the extra space for the default_prompt 132 | if rule_prompt_len < allocated_rule_context_len: 133 | default_context_len = total_context_len - rule_prompt_len 134 | else: 135 | default_context_len = total_context_len - allocated_rule_context_len 136 | # if something is returned by the rule, it means that the rule is triggered 137 | if rule_prompt_len > 0: 138 | rule_triggered = True 139 | default_prompt, default_prompt_len = get_default_prompt( 140 | hole_pos=getattr(rule_context_obj, 'hole_pos'), 141 | context_len=default_context_len, 142 | tokenizer=getattr(rule_context_obj, 'tokenizer'), 143 | file=getattr(rule_context_obj, 'file') 144 | ) 145 | return rule_prompt, default_prompt, rule_triggered 146 | 147 | if __name__ == '__main__': 148 | 149 | args = setup_args() 150 | 151 | #Fix seeds 152 | np.random.seed(args.seed) 153 | os.environ['PYTHONHASHSEED'] = str(args.seed) 154 | 155 | os.environ["OPENAI_API_KEY"] = open('openai_api_key', 'r').read().strip() 156 | openai.api_key = os.getenv("OPENAI_API_KEY") 157 | 158 | #directory for storing results 159 | input_data_dir = os.path.join(args.base_dir, args.repo_name) 160 | result_path = os.path.join('results', args.base_dir, args.repo_name) 161 | os.makedirs(result_path, exist_ok=True) 162 | 163 | 164 | # get tokenizer 165 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 166 | 167 | #get stored parsed data 168 | parsed_data_filename = os.path.join(args.base_dir, args.repo_name, 'parsed_data') 169 | parse_data = pickle.load(open(parsed_data_filename, 'rb')) 170 | #get the holes 171 | hole_filename = os.path.join(args.base_dir, args.repo_name, 'hole_data') 172 | hole_data = pickle.load(open(hole_filename, 'rb')) 173 | 174 | # get all relevant files (in raw form) 175 | files = [os.path.join(dp, f) \ 176 | for dp, dn, filenames in os.walk(input_data_dir) \ 177 | for f in filenames \ 178 | if os.path.splitext(f)[1] == '.java'] 179 | 180 | # Create the result file for writing the predictions 181 | result_dir = os.path.join(result_path, args.context_location) 182 | os.makedirs(result_dir, exist_ok=True) 183 | 184 | if args.mode == 'rule': 185 | result_filename = args.mode + '_' + args.context_type + '_' + str(args.top_k_type)\ 186 | + '_' + str(args.top_k) \ 187 | + '_' + str(int(args.total_context_len * args.context_division_ratio)) \ 188 | + '_' + args.rule_context_formatting \ 189 | + '_' + args.prompt_separator 190 | 191 | if args.mode =='codex': 192 | result_filename = args.mode + '_' + str(args.total_context_len) 193 | 194 | full_result_filename = os.path.join(result_dir, result_filename + '.json') 195 | print(full_result_filename) 196 | if os.path.isfile(full_result_filename) and os.path.getsize(full_result_filename) > 0: 197 | print(full_result_filename, " : result file already exists") 198 | exit() 199 | 200 | f = open(full_result_filename, 'w') 201 | 202 | all_holes = [] 203 | rule_prompts = [] 204 | default_prompts = [] 205 | all_hole_identities = [] 206 | all_rules_triggered = [] 207 | total_count = 0 208 | rule_triggered_count = 0 209 | 210 | # get the prompts for all files 211 | for file in files: 212 | if file in hole_data: 213 | file_lines = open(file, encoding="utf8", errors='backslashreplace').readlines() 214 | 215 | # define the rule context object. Depends on the file 216 | rule_context_obj = getContext(context_location = args.context_location, \ 217 | tokenizer=tokenizer, file=file, 218 | parse_data = parse_data, 219 | context_type=args.context_type, 220 | top_k=args.top_k,top_k_type=args.top_k_type, 221 | rule_context_formatting=args.rule_context_formatting, 222 | file_lines=file_lines) 223 | 224 | is_out_file = rule_context_obj.is_out_files() 225 | 226 | # go through the holes in the file 227 | for (l,c) in hole_data[file]: # l = line no, c = character offset within line l 228 | if total_count%1000 == 0: 229 | print("Total Count:", total_count) 230 | hole = file_lines[l][c:] 231 | hole_identity = file + '_' + str(l) + '_' + str(c) 232 | hole_pos = (l, c) 233 | 234 | # if mode is codex or we have no parent_class_files or import_files, 235 | # then get the default prompt directly 236 | if args.mode == 'codex' or \ 237 | (args.mode == 'rule' and args.context_location != 'in_file' and not is_out_file): 238 | default_prompt, default_prompt_len = get_default_prompt(hole_pos, args.total_context_len, 239 | tokenizer, file) 240 | rule_triggered = False 241 | rule_prompt = '' 242 | 243 | else: 244 | rule_context_obj.set_hole_pos(hole_pos) 245 | rule_prompt, default_prompt, rule_triggered = \ 246 | get_prompt( 247 | rule_context_obj=rule_context_obj, \ 248 | context_location=args.context_location, \ 249 | total_context_len=args.total_context_len, \ 250 | context_division_ratio=args.context_division_ratio) 251 | 252 | #print("RP: ", rule_prompt) 253 | #print("DP: ", default_prompt) 254 | if rule_triggered == True: 255 | rule_triggered_count+=1 256 | rule_prompts.append(rule_prompt) 257 | default_prompts.append(default_prompt) 258 | all_holes.append(hole) 259 | all_hole_identities.append(hole_identity) 260 | all_rules_triggered.append(rule_triggered) 261 | #all_parent_class_filenames.append(parent_class_filename) 262 | 263 | total_count += 1 264 | 265 | print(total_count, rule_triggered_count) 266 | print(len(all_holes)) 267 | 268 | 269 | # create prompts only for the cases where the rules are triggered. 270 | # other cases will be the same as codex, so they can be directly copied from the pre results 271 | prompts = [] 272 | for i in range(len(all_holes)): 273 | if (args.mode != 'codex' and args.is_run_rule_full) \ 274 | or (args.mode != 'codex' and not args.is_run_rule_full and all_rules_triggered[i])\ 275 | or (args.mode == 'codex'): 276 | rule_p = rule_prompts[i] 277 | def_p = default_prompts[i] 278 | prompt_separator = args.prompt_separator 279 | # if rule is empty 280 | if not rule_p and prompt_separator == 'newline': 281 | prompt_separator == 'space' 282 | prompt = rule_p + promptseparator2str[prompt_separator] + def_p 283 | 284 | # make sure that the length of the prompt is less than or equal to the total_context_len 285 | codex_tokens = tokenizer(prompt)['input_ids'] 286 | if len(codex_tokens) > args.total_context_len: 287 | codex_tokens = codex_tokens[-args.total_context_len:] 288 | prompt = tokenizer.decode(codex_tokens) 289 | if prompt: 290 | assert len(codex_tokens) <= args.total_context_len, 'prompt length exceeds the maximum length' 291 | #print("Hole:", all_holes[i]) 292 | #print("Prompt:", prompt) 293 | prompts.append((i, prompt)) 294 | 295 | print(len(prompts)) 296 | 297 | # prompt the codex model in batches to generate completions with the prompts created before 298 | count = 0 299 | i = 0 300 | while (i < len(prompts)): 301 | print(i) 302 | batch_prompts = prompts[i:i+args.batch_size] 303 | batch_prompt_texts = [x[1] for x in batch_prompts] 304 | #print(batch_prompt_post_texts) 305 | batch_prompt_indexes = [x[0] for x in batch_prompts] # index within the all_* arrays 306 | batch_responses = generate_prediction(batch_prompt_texts) 307 | if batch_responses != None: 308 | for j in range(len(batch_prompts)): 309 | response = batch_responses.choices[j] 310 | prediction = response.text 311 | #prediction_tokens = response.logprobs.tokens 312 | #prediction_token_logprobs = response.logprobs.token_logprobs 313 | hole = all_holes[batch_prompt_indexes[j]] 314 | hole_identity = all_hole_identities[batch_prompt_indexes[j]] 315 | rule_triggered = all_rules_triggered[batch_prompt_indexes[j]] 316 | if rule_triggered: 317 | count+=1 318 | batch_suffix = '' 319 | result = { 320 | 'hole_identity': hole_identity, \ 321 | 'prediction': prediction, \ 322 | 'ground-truth hole': hole, \ 323 | 'prompt': batch_prompt_texts[j], \ 324 | 'post_prompt': batch_suffix, \ 325 | 'rule_triggered': rule_triggered, \ 326 | 'index': batch_prompt_indexes[j] + 1 # this index corresponds to the global index 327 | } 328 | f.write(json.dumps(result)) 329 | f.write("\n") 330 | f.flush() 331 | i = i + args.batch_size 332 | else: 333 | # wait for 60s before calling the API again 334 | time.sleep(60) 335 | 336 | f.close() 337 | print(i, j, count, len(prompts)) 338 | -------------------------------------------------------------------------------- /generate_rule_representations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import argparse 5 | import random 6 | from torch.utils.data import DataLoader 7 | from torch import nn 8 | from tqdm import tqdm 9 | from rule_representation_data import * 10 | from torch import FloatTensor 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | 15 | def setup_args(): 16 | """ 17 | Description: Takes in the command-line arguments from user 18 | """ 19 | parser = argparse.ArgumentParser() 20 | 21 | # data related hyperparameters 22 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility") 23 | parser.add_argument("--input_data_dir", type=str, default='rule_classifier_data', help="base directory for the data") 24 | parser.add_argument("--data_split", type=str, default='val', help="train, val, test") 25 | 26 | # model related hyperparameters 27 | parser.add_argument("--emb_model_type", type=str, default='codebert', help="model to obtain embedding from") 28 | parser.add_argument("--repo", type=str, default='jata4test', help="model to obtain embedding from") 29 | return parser.parse_args() 30 | 31 | if __name__ == '__main__': 32 | 33 | args = setup_args() 34 | 35 | #Fix seeds 36 | np.random.seed(args.seed) 37 | os.environ['PYTHONHASHSEED'] = str(args.seed) 38 | torch.manual_seed(args.seed) 39 | random.seed(args.seed) 40 | 41 | 42 | # Define dataloaders 43 | kwargs = {'num_workers': 8, 'pin_memory': True} if device=='cuda' else {} 44 | tokenizer = set_tokenizer(args.emb_model_type) 45 | base_dir = os.path.join(args.input_data_dir, args.data_split) 46 | dataset = RuleReprDataset(base_dir, emb_model_type = args.emb_model_type, tokenizer=tokenizer) 47 | #for repo in os.listdir(base_dir): 48 | start, end = dataset.get_start_index(args.repo, start_offset=0, interval=0) 49 | print(args.repo, start, end) 50 | for batch, (rule_context, hole, repo_name) in enumerate(dataset): 51 | if batch > end: 52 | break 53 | if repo_name == args.repo: 54 | save_dir = os.path.join(base_dir, repo_name, args.emb_model_type +'_mod') 55 | os.makedirs(save_dir, exist_ok=True) 56 | rule_representation = {hole: rule_context} 57 | with open(os.path.join(save_dir, str(batch)) , 'wb') as f: 58 | pickle.dump(rule_representation, f) 59 | -------------------------------------------------------------------------------- /get_info_from_hole_predictions.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import io 4 | import torch 5 | import argparse 6 | 7 | device = 'cpu' 8 | 9 | # This order was obtained based on the decreasing order of success rate on the validation set 10 | rule_order = [5, 7, 6, 1, 20, 22, 2, 0, 25, 3, 23, 24, 28, 26, 4, 21, 62, 31, 27, 29, 30, 8, 34, 32, \ 11 | 10, 33, 9, 35, 13, 11, 12, 36, 46, 44, 16, 14, 49, 45, 48, 40, 38, 19, 15, 18, 39, 43, 47,\ 12 | 17, 42, 41, 58, 56, 57, 61, 59, 60, 37, 52, 50, 53, 55, 54, 51] 13 | 14 | projects = { 'train': [ 15 | ('largemail' , 1653), 16 | ('ftpserverremoteadmin' , 7323), 17 | ('myt5lib' , 838), 18 | ('seamlets' , 4890), 19 | ('gloodb' , 10000), 20 | ('jjskit' , 9043), 21 | ('mobileexpensetracker' , 2298), 22 | ('gfsfa' , 10000), 23 | ('swe574-group3' , 2029), 24 | ('strudem-sicsa' , 6131), 25 | ('soap-dtc' , 1370), 26 | ('openprocesslogger' , 7191), 27 | ('tapestry-sesame', 397), 28 | ('exogdx' , 735), 29 | ('designpatternjavapedro' , 1069), 30 | ('quidsee' , 3020), 31 | ('healpix-rangeset' , 4734), 32 | ('sol-agent-platform' , 10000), 33 | ('rsbotownversion' , 10000), 34 | 35 | ], 36 | 37 | 'val': [ 38 | ('tyrond', 721), 39 | ('math-mech-eshop', 2225), 40 | ('infinispan-storage-service', 373), 41 | ('teammates-shakthi', 7665), 42 | ('javasummerframework', 10000), 43 | ('tinwiki', 10000), 44 | ('jloogle', 3145), 45 | ('jcontenedor', 5464), 46 | ('sohocms', 772), 47 | ('affinity_propagation_java', 1466), 48 | ('jata4test', 1921), 49 | ('swinagile', 2595), 50 | ('navigablep2p', 1322), 51 | ('springlime', 879), 52 | ], 53 | 54 | 'test': [ 55 | ('dovetaildb', 10000), 56 | ('project-pt-diaoc', 10000), 57 | ('realtimegc', 2513), 58 | ('fswuniceubtemplates', 2070), 59 | ('qwikioffice-java', 1138), 60 | ('glperaudsimon', 1766), 61 | ('xiaonei-java-api', 839), 62 | ('ircrpgbot', 6591), 63 | ('robotsimulator2009w', 7514), 64 | ('gwt-plugindetect', 73), 65 | ('apiitfriends', 1385), 66 | ('wicketbits', 754), 67 | ('hucourses', 590), 68 | ('xfuze', 3055), 69 | ] 70 | } 71 | 72 | def setup_args(): 73 | """ 74 | Description: Takes in the command-line arguments from user 75 | """ 76 | parser = argparse.ArgumentParser() 77 | 78 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility") 79 | parser.add_argument("--hole_stats_file", type=str, default='hole', help="name of the prediction file to consider") 80 | parser.add_argument("--data_split", type=str, default='val', help="data_split") 81 | parser.add_argument("--base_dir", type=str, default='outputs', help="base dir") 82 | parser.add_argument("--k", type=int, default=1, help="how many rules to draw") 83 | return parser.parse_args() 84 | 85 | class CPU_Unpickler(pickle.Unpickler): 86 | """To load a pickle file stored in torch GPU setting to CPU.""" 87 | def find_class(self, module, name): 88 | if module == 'torch.storage' and name == '_load_from_bytes': 89 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu') 90 | else: 91 | return super().find_class(module, name) 92 | 93 | def get_repo_name(hid): 94 | return hid.split('/')[2] 95 | 96 | def update_dict(dic, data_type): 97 | if 'small_' in data_type: 98 | return dic 99 | else: 100 | mod_dic = {} 101 | for k,v in dic.items(): 102 | mod_k = '/'. join(['rule_classifier_data', data_type] + k.split('/')[2:]) 103 | mod_dic[mod_k] = v 104 | return mod_dic 105 | 106 | def get_top_k_acc(hole_pred, hole_gt, k=1): 107 | top_preds, top_pred_indices = torch.topk(hole_pred, k) 108 | for top_pred_idx in top_pred_indices: 109 | if hole_gt[top_pred_idx] == 1: 110 | return 1.0 111 | return 0.0 112 | 113 | def get_rule_wise_nums(oracle): 114 | rule_success = {} 115 | for hid, entry in oracle.items(): 116 | hole_gt= entry['com'] 117 | for i in range(len(hole_gt)): 118 | if hole_gt[i] == 1: 119 | if i in rule_success: 120 | rule_success[i]+=1 121 | else: 122 | rule_success[i] = 1 123 | return rule_success 124 | 125 | def get_single_rule_acc(hole_gt, k): 126 | rules = rule_order[:k] 127 | for rule in rules: 128 | if hole_gt[rule] == 1: 129 | return 1.0 130 | return 0.0 131 | 132 | if __name__ == '__main__': 133 | 134 | args = setup_args() 135 | 136 | #Fix seeds 137 | os.environ['PYTHONHASHSEED'] = str(args.seed) 138 | torch.manual_seed(args.seed) 139 | 140 | k = args.k 141 | repo_stats={} 142 | single_rule_stats = {} 143 | # get rlpg predictions 144 | data = CPU_Unpickler(open(os.path.join(args.base_dir, args.data_split, args.hole_stats_file), 'rb')).load() 145 | oracle_dir = 'rule_classifier_data/' + args.data_split 146 | for repo, repo_count in projects[args.data_split ]: 147 | oracle = pickle.load(open(os.path.join(oracle_dir, repo, 'capped_oracle_10000'), 'rb')) 148 | oracle = update_dict(oracle, args.data_split ) 149 | for hid, entry in oracle.items(): 150 | em = get_top_k_acc(data[hid][1], oracle[hid]['com'], k) 151 | single_rule_em = get_single_rule_acc(oracle[hid]['com'], k) 152 | if repo in repo_stats: 153 | repo_stats[repo]+= em 154 | else: 155 | repo_stats[repo] = em 156 | if repo in single_rule_stats: 157 | single_rule_stats[repo]+= single_rule_em 158 | else: 159 | single_rule_stats[repo] = single_rule_em 160 | 161 | repo_success = 0.0 162 | single_rule_repo_success = 0.0 163 | for repo, repo_count in projects[args.data_split]: 164 | repo_success += repo_stats[repo]*100/repo_count 165 | single_rule_repo_success += single_rule_stats[repo]*100/repo_count 166 | 167 | 168 | total_count = 0 169 | total_success = 0.0 170 | total_single_rule_success = 0.0 171 | for repo, repo_count in projects[args.data_split ]: 172 | total_count+= repo_count 173 | total_success += repo_stats[repo] 174 | total_single_rule_success += single_rule_stats[repo] 175 | 176 | print(args.hole_stats_file + "," + str(k) + "," + str(repo_success/len(projects[args.data_split ])) \ 177 | + "," + str(single_rule_repo_success/len(projects[args.data_split ]))\ 178 | + "," + str(total_success*100/total_count) + "," + str(total_single_rule_success*100/total_count)) 179 | 180 | -------------------------------------------------------------------------------- /model_preprocessed_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from utils import * 6 | from transformers import AutoModel 7 | from transformers import GPT2Model 8 | from att_model import BasicAggModel 9 | 10 | 11 | class RuleModel(nn.Module): 12 | def __init__(self, emb_model_type, repr_size=768, device='cpu', n_head=1, d_k=128, d_proj=512, 13 | mode='rlpg-h', dropout=0.0): 14 | super(RuleModel, self).__init__() 15 | 16 | 17 | self.n_rules = len(combined_to_index) 18 | self.set_embedding_model(emb_model_type) 19 | self.repr_size = repr_size 20 | self.device = device 21 | self.mode = mode 22 | 23 | if self.mode == 'rlpg-h': 24 | self.hole_dense1 = nn.Linear(self.repr_size, d_proj) 25 | self.hole_dense2 = nn.Linear(d_proj, self.n_rules) 26 | 27 | 28 | if self.mode == 'rlpg-r': 29 | self.att_model = BasicAggModel(d_model=repr_size, n_head=n_head, d_k=d_k, n_rules=self.n_rules, device=self.device, \ 30 | dropout=dropout) 31 | 32 | 33 | def get_representation(self, inputs, mask): 34 | outputs = self.emb_model(inputs, attention_mask=mask) 35 | try: 36 | representation = outputs.pooler_output 37 | except: 38 | representation = outputs.last_hidden_state[:, 0] 39 | return representation 40 | 41 | def get_context_embedding(self, context, attn_mask): 42 | context_embedding = self.get_representation(context, attn_mask) 43 | return context_embedding 44 | 45 | def forward(self, info): 46 | 47 | hole_inputs, hole_mask, rule_context_repr = info 48 | batch_size = hole_inputs.shape[0] 49 | # get hole window representation 50 | hole_window_repr = self.get_context_embedding(hole_inputs, hole_mask) 51 | 52 | # mask for invalid rules. It is 0 at positions where the rule is not valid. 53 | valid_rules_mask = (rule_context_repr != 0) #(bs, n_rules, repr_size) 54 | 55 | #get prediction from hole window 56 | if self.mode == 'rlpg-h': 57 | hole_pred = self.hole_dense2(F.relu(self.hole_dense1(hole_window_repr))) 58 | if len(hole_pred.shape)==1: 59 | hole_pred = torch.unsqueeze(hole_pred, dim=0) 60 | hole_pred = torch.sigmoid(hole_pred) 61 | return hole_pred, valid_rules_mask 62 | 63 | if self.mode == 'rlpg-r': 64 | rule_pred, att_weights = self.att_model(hole_window_repr, rule_context_repr, valid_rules_mask) 65 | if len(rule_pred.shape)==1: 66 | rule_pred = torch.unsqueeze(rule_pred, dim=0) 67 | rule_pred = torch.sigmoid(rule_pred) 68 | return rule_pred, valid_rules_mask 69 | 70 | def set_embedding_model(self, emb_model_type): 71 | if emb_model_type == 'gpt-2': 72 | self.emb_model = GPT2Model.from_pretrained("gpt2") 73 | 74 | # CodeBERT 75 | if emb_model_type == 'codebert': 76 | self.emb_model = AutoModel.from_pretrained("microsoft/codebert-base") 77 | 78 | # GraphCodeBERT 79 | if emb_model_type == 'graphcodebert': 80 | self.emb_model = AutoModel.from_pretrained("microsoft/graphcodebert-base") 81 | 82 | # freeze the parameters of the pretrained emb_model 83 | for param in self.emb_model.parameters(): 84 | param.requires_grad = False -------------------------------------------------------------------------------- /parse_tree.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import argparse 4 | from tree_sitter import Language, Parser 5 | from utils import * 6 | import copy 7 | 8 | """ 9 | Obtain the parse tree for individual files and collate data at repo-level for rules. 10 | """ 11 | 12 | Language.build_library('build/my-languages.so', ['tree-sitter-java']) 13 | 14 | JAVA_LANGUAGE = Language('build/my-languages.so', 'java') 15 | 16 | parser = Parser() 17 | parser.set_language(JAVA_LANGUAGE) 18 | 19 | 20 | def get_sibling_files(file, all_files): 21 | file_parts = file.split('/') 22 | root_dir = '/'.join(file_parts[:-1]) 23 | sibling_files = [] 24 | for f in os.listdir(root_dir): 25 | if os.path.splitext(f)[1] == '.java' and f != file_parts[-1]: 26 | sib_file = os.path.join(root_dir, f) 27 | sibling_files.append(sib_file) 28 | return sibling_files 29 | 30 | def camel_case_split(str): 31 | start_idx = [i for i, e in enumerate(str) 32 | if e.isupper()] + [len(str)] 33 | 34 | start_idx = [0] + start_idx 35 | return [str[x: y] for x, y in zip(start_idx, start_idx[1:])] 36 | 37 | def match_by_parts(file1, file2, split_type): 38 | # omit .java in the end 39 | f1 = file1.split('.')[0] 40 | f2 = file2.split('.')[0] 41 | 42 | if split_type == 'camel-case': 43 | f1_parts = camel_case_split(f1) 44 | f2_parts = camel_case_split(f2) 45 | 46 | if split_type == 'underscore': 47 | f1_parts = f1.split('_') 48 | f2_parts = f2.split('_') 49 | 50 | for p1 in f1_parts: 51 | if p1 and p1 in f2_parts: 52 | #print(split_type, file1, file2, p1, f1_parts, f2_parts) 53 | return True 54 | return False 55 | 56 | 57 | def match_similar_filenames(file1, file2): 58 | # exactly same name 59 | if file1 == file2: 60 | return True 61 | 62 | #camelcase split similar parts 63 | return match_by_parts(file1, file2, 'camel-case') 64 | 65 | #underscore split similar parts 66 | return match_by_parts(file1, file2, 'underscore') 67 | 68 | 69 | def get_similar_name_files(file, all_files): 70 | filename = file.split('/')[-1] 71 | similar_name_files = [] 72 | for f in all_files: 73 | if f != file and match_similar_filenames(f.split('/')[-1], filename): 74 | similar_name_files.append(f) 75 | return similar_name_files 76 | 77 | def get_tree(filename): 78 | """ 79 | obtain parse tree for a file 80 | """ 81 | file_str = open(filename, encoding="utf8", errors='backslashreplace').read() 82 | tree = parser.parse(bytes(file_str, "utf-8")) 83 | root_node = tree.root_node 84 | return root_node 85 | 86 | def parse_captures(captures, filename): 87 | text_spans = [] 88 | for capture in captures: 89 | #capture[1] = property_name 90 | start, end = capture[0].start_point, capture[0].end_point 91 | #text = get_string(filename, start, end) 92 | text_spans.append((start, end)) 93 | return text_spans 94 | 95 | def get_query(attribute_type): 96 | 97 | if attribute_type == 'class_name': 98 | query = JAVA_LANGUAGE.query("""(class_declaration 99 | name: (identifier) @class_name)""") 100 | 101 | if attribute_type == 'class_body': 102 | query = JAVA_LANGUAGE.query("""(class_declaration 103 | body: (class_body) @class_body)""") 104 | 105 | if attribute_type == 'parent_class_name': 106 | query = JAVA_LANGUAGE.query("""(class_declaration 107 | name: (identifier) 108 | superclass: (superclass (type_identifier) @superclass_name))""") 109 | 110 | if attribute_type == 'all_method_name': 111 | query = JAVA_LANGUAGE.query("""(method_declaration 112 | name: (identifier) @all_method_name)""") 113 | 114 | if attribute_type == 'all_method_body': 115 | query = JAVA_LANGUAGE.query("""(method_declaration body: (block) @all_method_block)""") 116 | 117 | if attribute_type == 'import_statement': 118 | query = JAVA_LANGUAGE.query("""(import_declaration ( 119 | scoped_identifier 120 | name: (identifier)) @import_statement)""") 121 | 122 | if attribute_type == 'all_field_declaration': 123 | query = JAVA_LANGUAGE.query("""(field_declaration) @field_declaration""") 124 | 125 | if attribute_type == 'all_string_literal': 126 | query = JAVA_LANGUAGE.query("""(string_literal) @string_literal""") 127 | 128 | if attribute_type == 'all_identifier': 129 | query = JAVA_LANGUAGE.query("""(identifier) @identifier""") 130 | 131 | if attribute_type == 'all_type_identifier': 132 | query = JAVA_LANGUAGE.query("""(type_identifier) @type_identifier""") 133 | 134 | return query 135 | 136 | def get_attribute(root_node, filename, attribute_type): 137 | 138 | query = get_query(attribute_type) 139 | captures = query.captures(root_node) 140 | if captures: 141 | attributes = parse_captures(captures, filename) 142 | else: 143 | attributes = [((-1, -1), (-1, -1))] 144 | return attributes 145 | 146 | def get_import_path(import_stat, file): 147 | import_stat_str = get_string(file, import_stat[0], import_stat[1]) 148 | #print(import_stat_str, file) 149 | import_path_parts = import_stat_str.split(".") 150 | absolute_import_path = [] 151 | import_path_part = import_path_parts[0] 152 | if import_path_part != 'java': 153 | file_path_parts = file.split("/") 154 | try: 155 | index_pos = len(file_path_parts) - file_path_parts[::-1].index(import_path_part) - 1 156 | absolute_import_path = file_path_parts[:index_pos] + import_path_parts 157 | except ValueError as e: 158 | print('') 159 | #print(absolute_import_path) 160 | if absolute_import_path: 161 | import_path = '/'.join(absolute_import_path) 162 | import_path = import_path + '.java' 163 | return import_path 164 | else: 165 | return '' 166 | 167 | def get_parent_class_filename(parent_class_name, file_class_info, file): 168 | parent_class_filename = '' 169 | if parent_class_name: 170 | parent_class_name_text = get_string(file, parent_class_name[0], parent_class_name[1]) 171 | # we don't want the current file to be the parent class file 172 | copy_file_class_info = copy.deepcopy(file_class_info) 173 | del copy_file_class_info[file] 174 | 175 | if parent_class_name_text: 176 | # search for the parent class name in all files 177 | found = False 178 | for (k,v) in copy_file_class_info.items(): 179 | for val in v: 180 | if val==parent_class_name_text: 181 | parent_class_filename = k 182 | found = True 183 | break 184 | return parent_class_filename 185 | 186 | def find_relevant_file_identifier(import_identifier, file_identifiers, file): 187 | candidate_file_identifiers = [] 188 | for file_identifier in file_identifiers: 189 | if file_identifier: 190 | file_identifier_str = get_string(file, file_identifier[0], file_identifier[1]) 191 | if file_identifier_str == import_identifier: 192 | candidate_file_identifiers.append(file_identifier) 193 | return candidate_file_identifiers[1:] 194 | 195 | def get_imports(import_statements, file, all_identifiers, all_type_identifiers): 196 | imports = {} 197 | file_identifiers = all_identifiers 198 | file_identifiers.extend(all_type_identifiers) 199 | for import_stat in import_statements: 200 | import_file_path = get_import_path(import_stat, file) 201 | if import_file_path and os.path.isfile(import_file_path): 202 | import_identifier = import_file_path.split('/')[-1].split('.')[0] 203 | candidate_file_identifiers = find_relevant_file_identifier(import_identifier, file_identifiers, file) 204 | if candidate_file_identifiers: 205 | imports[import_file_path] = candidate_file_identifiers 206 | return imports 207 | 208 | def check_empty_attribute(attribute): 209 | if len(attribute) == 1 and attribute[0][0][0] == -1: 210 | attribute = [] 211 | return attribute 212 | 213 | def update_attribute(parse_data, att_type, files): 214 | count = 0 215 | for file in files: 216 | current_file_imports = list(parse_data[file]['imports'].keys()) 217 | att_files = parse_data[file][att_type] 218 | att_info = [] 219 | for att_file in att_files: 220 | if att_file: 221 | att_file_imports = list(parse_data[att_file]['imports'].keys()) 222 | overlapping_imports = find_similar_intersection(att_file_imports, current_file_imports) 223 | #if len(overlapping_imports) > 0: 224 | att_info.append((att_file, len(overlapping_imports))) 225 | #print(file, att_file, overlapping_imports) 226 | parse_data[file][att_type] = att_info 227 | if att_info: 228 | count+=1 229 | #print(file, parse_data[file][att_type]) 230 | #print(count) 231 | return parse_data 232 | 233 | def update_child_class_info(parse_data, child_class_info): 234 | for file, file_parse_data in parse_data.items(): 235 | if file in child_class_info: 236 | parse_data[file]['child_class_filenames'] = child_class_info[file] 237 | else: 238 | parse_data[file]['child_class_filenames'] = [] 239 | return parse_data 240 | 241 | def setup_args(): 242 | """ 243 | Description: Takes in the command-line arguments from user 244 | """ 245 | parser = argparse.ArgumentParser() 246 | 247 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility") 248 | parser.add_argument("--base_dir", type=str, default='rule_classifier_data/val', \ 249 | help="base directory for the data") 250 | parser.add_argument("--proj_name", type=str, default='rsbotownversion', \ 251 | help="name of the input repo") 252 | 253 | return parser.parse_args() 254 | 255 | if __name__ == '__main__': 256 | 257 | args = setup_args() 258 | 259 | #Fix seeds 260 | np.random.seed(args.seed) 261 | os.environ['PYTHONHASHSEED']=str(args.seed) 262 | 263 | input_data_path = os.path.join(args.base_dir, args.proj_name) 264 | os.makedirs(input_data_path, exist_ok=True) 265 | 266 | files = [os.path.join(dp, f) \ 267 | for dp, dn, filenames in os.walk(input_data_path) \ 268 | for f in filenames \ 269 | if os.path.splitext(f)[1] == '.java'] 270 | 271 | file_class_info = {} 272 | for file in files: 273 | root_node = get_tree(file) 274 | class_names = get_attribute(root_node, file, 'class_name') 275 | file_class_names = [] 276 | for cn in class_names: 277 | start, end = cn 278 | class_name = get_string(file, start, end) 279 | file_class_names.append(class_name) 280 | file_class_info[file] = file_class_names 281 | #print(file_class_info) 282 | 283 | with open(os.path.join(input_data_path, 'file_class_data'), 'wb') as f: 284 | pickle.dump(file_class_info, f) 285 | 286 | parse_data = {} 287 | child_class_info = {} 288 | 289 | similar_count = 0 290 | sibling_count = 0 291 | 292 | for file in files: 293 | root_node = get_tree(file) 294 | sibling_files = get_sibling_files(file, files) 295 | similar_name_files = get_similar_name_files(file, files) 296 | if len(similar_name_files) > 0: 297 | similar_count +=1 298 | if len(sibling_files) > 0: 299 | sibling_count +=1 300 | 301 | class_names = get_attribute(root_node, file, 'class_name') 302 | class_bodies = get_attribute(root_node, file, 'class_body') 303 | parent_class_names = get_attribute(root_node, file, 'parent_class_name') 304 | all_field_declarations = get_attribute(root_node, file, 'all_field_declaration') 305 | all_string_literals = get_attribute(root_node, file, 'all_string_literal') 306 | all_identifiers = get_attribute(root_node, file, 'all_identifier') 307 | all_type_identifiers = get_attribute(root_node, file, 'all_type_identifier') 308 | all_method_names = get_attribute(root_node, file, 'all_method_name') 309 | all_method_bodies = get_attribute(root_node, file, 'all_method_body') 310 | import_statements = get_attribute(root_node, file, 'import_statement') 311 | 312 | class_names = check_empty_attribute(class_names) 313 | class_bodies = check_empty_attribute(class_bodies) 314 | parent_class_names = check_empty_attribute(parent_class_names) 315 | all_field_declarations = check_empty_attribute(all_field_declarations) 316 | all_identifiers = check_empty_attribute(all_identifiers) 317 | all_type_identifiers = check_empty_attribute(all_type_identifiers) 318 | all_string_literals = check_empty_attribute(all_string_literals) 319 | all_method_names = check_empty_attribute(all_method_names) 320 | all_method_bodies = check_empty_attribute(all_method_bodies) 321 | import_statements = check_empty_attribute(import_statements) 322 | 323 | # get imports 324 | imports = get_imports(import_statements, file, all_identifiers, all_type_identifiers) 325 | 326 | parent_class_filenames = [] 327 | mod_parent_class_names = [] 328 | for parent_class_name in parent_class_names: 329 | parent_class_filename = get_parent_class_filename(parent_class_name, file_class_info, file) 330 | if parent_class_filename: 331 | mod_parent_class_names.append(parent_class_name) 332 | if parent_class_filename in child_class_info: 333 | child_class_info[parent_class_filename].append(file) 334 | else: 335 | child_class_info[parent_class_filename] = [file] 336 | parent_class_filenames.append(parent_class_filename) 337 | 338 | #print(parent_class_names, parent_class_filenames) 339 | assert len(mod_parent_class_names) == len(parent_class_filenames) 340 | 341 | #store the data in dict form 342 | parse_data[file] = { 343 | 'class_names': class_names,\ 344 | 'class_bodies': class_bodies, \ 345 | 'parent_class_names': mod_parent_class_names, \ 346 | 'parent_class_filenames': parent_class_filenames, \ 347 | 'imports': imports, \ 348 | 'field_declarations': all_field_declarations, \ 349 | 'string_literals': all_string_literals, \ 350 | 'identifiers': all_identifiers, \ 351 | 'type_identifiers': all_type_identifiers, \ 352 | 'all_method_names': all_method_names, \ 353 | 'all_method_bodies': all_method_bodies, \ 354 | 'sibling_files': sibling_files, \ 355 | 'similar_name_files': similar_name_files} 356 | 357 | print(len(files), sibling_count, similar_count) 358 | print("updating sibling files") 359 | parse_data = update_attribute(parse_data, 'sibling_files', files) 360 | print("updating similar_name_files") 361 | parse_data = update_attribute(parse_data, 'similar_name_files', files) 362 | print("updating child class filenames") 363 | parse_data = update_child_class_info(parse_data, child_class_info) 364 | parse_data = update_attribute(parse_data, 'child_class_filenames', files) 365 | print("updating parent class filenames") 366 | parse_data = update_attribute(parse_data, 'parent_class_filenames', files) 367 | 368 | print("Writing parse data...") 369 | with open(os.path.join(input_data_path, 'parsed_data'), 'wb') as f: 370 | pickle.dump(parse_data, f) 371 | -------------------------------------------------------------------------------- /preprocessed_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import pickle 5 | from torch.utils.data import Dataset 6 | from transformers import GPT2TokenizerFast, AutoTokenizer 7 | from transformers import DataCollatorWithPadding 8 | from utils import * 9 | import re 10 | 11 | 12 | class RuleDataset(Dataset): 13 | 14 | def __init__(self, input_data_dir, tokenizer=None, emb_model_type='codebert'): 15 | 16 | self.input_data_dir = input_data_dir 17 | self.tokenizer = tokenizer 18 | data_type = input_data_dir.split('/')[-1] 19 | oracles = {} 20 | self.data_files = [] 21 | for dp, dn, filenames in os.walk(input_data_dir): 22 | for f in filenames: 23 | if f == 'capped_oracle_10000': 24 | oracle = pickle.load(open(os.path.join(dp, f), 'rb')) 25 | oracle = self.update_dict(oracle, data_type) 26 | oracles = {**oracles, **oracle} 27 | 28 | for dp, dn, filenames in os.walk(input_data_dir): 29 | if dp.split('/')[-1] == 'capped_'+ emb_model_type + '_mod': 30 | for f in filenames: 31 | self.data_files.append(os.path.join(dp, f)) 32 | 33 | self.oracles = oracles 34 | #print(self.oracles.keys()) 35 | print("oracle") 36 | print(data_type, len(self.oracles)) 37 | print("data_files") 38 | print(data_type, len(self.data_files)) 39 | self.data_type = data_type 40 | self.num_combined = len(combined_to_index) 41 | 42 | def __len__(self): 43 | return len(self.data_files) 44 | 45 | def __getitem__(self, idx): 46 | return self.generate_data(self.data_files[idx]) 47 | 48 | def update_dict(self, dic, data_type): 49 | if 'small_' in data_type: 50 | return dic 51 | else: 52 | mod_dic = {} 53 | for k,v in dic.items(): 54 | mod_k = '/'. join(['rule_classifier_data', data_type] + k.split('/')[2:]) 55 | mod_dic[mod_k] = v 56 | return mod_dic 57 | 58 | def generate_data(self, data_file): 59 | data = pickle.load(open(data_file, 'rb')) 60 | data = self.update_dict(data, self.data_type) 61 | for hole, rule_context in data.items(): 62 | hole_context = self.get_hole_context(hole) 63 | if hole in self.oracles: 64 | combined = self.oracles[hole]['com'] 65 | failure_flag = 1 66 | else: 67 | combined = np.zeros(self.num_combined) 68 | failure_flag = 0 69 | return hole_context, rule_context, combined, hole, failure_flag 70 | 71 | def get_hole_context(self, hole, num_of_prev_lines=2, num_of_post_lines=2): 72 | ''' 73 | return the pre_context_len tokens from the current file based on codex tokenization from the position of the cursor 74 | ''' 75 | hole_parts = hole.split('/')[-1].split('_') 76 | repo_name = hole.split('/')[2] 77 | if len(hole_parts) > 3: 78 | new_hole_parts = hole_parts[:-2] 79 | filename = '_'.join(new_hole_parts) 80 | filename = [filename] 81 | else: 82 | filename = [hole_parts[0]] 83 | file = '/'.join(hole.split('/')[:-1] + filename) 84 | pos = (int(hole_parts[-2]), int(hole_parts[-1])) 85 | 86 | pre_end = pos 87 | pre_start_line = pos[0] - num_of_prev_lines 88 | if pre_start_line < 0: 89 | pre_start_line = 0 90 | pre_start = (pre_start_line, 0) 91 | pre_hole_context = get_string(file, pre_start, pre_end) 92 | 93 | post_hole_context = "" 94 | if num_of_post_lines > 0: 95 | file_lines = open(file, encoding="utf8", errors='backslashreplace').readlines() 96 | post_start_line = pos[0] + 1 97 | if post_start_line < len(file_lines): 98 | post_end_line = pos[0] + num_of_post_lines 99 | if post_end_line >= len(file_lines): 100 | post_end_line = len(file_lines) - 1 101 | post_start = (post_start_line, 0) 102 | post_end = (post_end_line, len(file_lines[post_end_line])) 103 | post_hole_context = get_string(file, post_start, post_end) 104 | hole_context = post_hole_context + " " + pre_hole_context 105 | hole_context = self.tokenizer(hole_context, truncation=True) 106 | return hole_context 107 | 108 | def collate_fn(data): 109 | hole_context, rule_contexts, gt_com, hole_id, failure_flag = zip(*data) 110 | hole_context = data_collator(hole_context) 111 | rule_contexts = torch.stack(rule_contexts, dim=0) 112 | #print("rule_contexts:", torch.sum(rule_contexts, dim=-1)) 113 | gt_com = torch.FloatTensor(gt_com) 114 | failure_flag = torch.IntTensor(failure_flag) 115 | return hole_context['input_ids'], hole_context['attention_mask'], \ 116 | rule_contexts, \ 117 | gt_com, \ 118 | hole_id, \ 119 | failure_flag 120 | 121 | def set_tokenizer(emb_model_type): 122 | global data_collator 123 | if emb_model_type == 'codebert': 124 | tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") 125 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 126 | if emb_model_type == 'graphcodebert': 127 | tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") 128 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 129 | if emb_model_type == 'gpt-2': 130 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 131 | tokenizer.pad_token = tokenizer.eos_token 132 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 133 | return tokenizer 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /projects.txt: -------------------------------------------------------------------------------- 1 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/rsbotownversion/source-archive.zip 2 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/sol-agent-platform/source-archive.zip 3 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/project-pt-diaoc/source-archive.zip 4 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/gloodb/source-archive.zip 5 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/dovetaildb/source-archive.zip 6 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/tinwiki/source-archive.zip 7 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/jjskit/source-archive.zip 8 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/openprocesslogger/source-archive.zip 9 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/teammates-shakthi/source-archive.zip 10 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/ftpserverremoteadmin/source-archive.zip 11 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/ircrpgbot/source-archive.zip 12 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/strudem-sicsa/source-archive.zip 13 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/seamlets/source-archive.zip 14 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/robotsimulator2009w/source-archive.zip 15 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/healpix-rangeset/source-archive.zip 16 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/jcontenedor/source-archive.zip 17 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/qwikioffice-java/source-archive.zip 18 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/jloogle/source-archive.zip 19 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/xiaonei-java-api/source-archive.zip 20 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/xfuze/source-archive.zip 21 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/realtimegc/source-archive.zip 22 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/swinagile/source-archive.zip 23 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/swe574-group3/source-archive.zip 24 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/math-mech-eshop/source-archive.zip 25 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/quidsee/source-archive.zip 26 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/glperaudsimon/source-archive.zip 27 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/fswuniceubtemplates/source-archive.zip 28 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/mobileexpensetracker/source-archive.zip 29 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/jata4test/source-archive.zip 30 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/largemail/source-archive.zip 31 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/apiitfriends/source-archive.zip 32 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/navigablep2p/source-archive.zip 33 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/wicketbits/source-archive.zip 34 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/soap-dtc/source-archive.zip 35 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/designpatternjavapedro/source-archive.zip 36 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/exogdx/source-archive.zip 37 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/tyrond/source-archive.zip 38 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/springlime/source-archive.zip 39 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/hucourses/source-archive.zip 40 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/sohocms/source-archive.zip 41 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/infinispan-storage-service/source-archive.zip 42 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/myt5lib/source-archive.zip 43 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/tapestry-sesame/source-archive.zip 44 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/gwt-plugindetect/source-archive.zip 45 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/javasummerframework/source-archive.zip 46 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/affinity_propagation_java/source-archive.zip 47 | https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/gfsfa/source-archive.zip 48 | -------------------------------------------------------------------------------- /rearrange_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import pickle 4 | import random 5 | import numpy as np 6 | 7 | seed = 9 8 | 9 | #Fix seeds 10 | os.environ['PYTHONHASHSEED'] = str(seed) 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | 14 | projects = { 'train': [ 15 | 'gfsfa', 16 | 'sol-agent-platform', 17 | 'gloodb', 18 | 'rsbotownversion', 19 | 'jjskit', 20 | 'ftpserverremoteadmin', 21 | 'openprocesslogger', 22 | 'strudem-sicsa', 23 | 'seamlets', 24 | 'healpix-rangeset', 25 | 'quidsee', 26 | 'mobileexpensetracker', 27 | 'swe574-group3', 28 | 'largemail', 29 | 'soap-dtc', 30 | 'designpatternjavapedro', 31 | 'myt5lib', 32 | 'exogdx', 33 | 'tapestry-sesame' 34 | ], 35 | 36 | 'val': [ 37 | 'javasummerframework', 38 | 'tinwiki', 39 | 'teammates-shakthi', 40 | 'jcontenedor', 41 | 'jloogle', 42 | 'swinagile', 43 | 'math-mech-eshop', 44 | 'jata4test', 45 | 'affinity_propagation_java', 46 | 'navigablep2p', 47 | 'springlime', 48 | 'sohocms', 49 | 'tyrond', 50 | 'infinispan-storage-service', 51 | ], 52 | 53 | 'test': [ 54 | 'project-pt-diaoc', 55 | 'dovetaildb', 56 | 'robotsimulator2009w', 57 | 'ircrpgbot', 58 | 'xfuze', 59 | 'realtimegc', 60 | 'fswuniceubtemplates', 61 | 'glperaudsimon', 62 | 'apiitfriends', 63 | 'qwikioffice-java', 64 | 'xiaonei-java-api', 65 | 'wicketbits', 66 | 'hucourses', 67 | 'gwt-plugindetect' 68 | ] 69 | } 70 | 71 | 72 | 73 | repo_split_map = {} 74 | for split, repos in projects.items(): 75 | for repo in repos: 76 | repo_split_map[repo] = split 77 | 78 | max_holes = 10000 79 | 80 | def is_move(base_dir, split, repo): 81 | new_split = repo_split_map[repo] 82 | if new_split != split: 83 | shutil.move(os.path.join(base_dir, split, repo), os.path.join(base_dir, new_split, repo)) 84 | 85 | def find_single_best_rule_success(rule_mapping): 86 | best_single_rule_success = 0 87 | for k, v in rule_mapping.items(): 88 | if len(v)> best_single_rule_success: 89 | best_rule = k 90 | best_single_rule_success = len(v) 91 | return best_rule, best_single_rule_success 92 | 93 | def find_rule_mapping(oracle): 94 | rule_mapping = {} 95 | for hid, entry in oracle.items(): 96 | rules = entry['com'] 97 | success_rule_positions = np.where(rules == 1)[0] 98 | for s_r_p in success_rule_positions: 99 | if s_r_p not in rule_mapping: 100 | rule_mapping[s_r_p] = [hid] 101 | else: 102 | rule_mapping[s_r_p].append(hid) 103 | return rule_mapping 104 | 105 | def get_new_oracle_numbers(capped_oracle, repo, total_holes): 106 | rule_mapping = find_rule_mapping(capped_oracle) 107 | codex_success = len(rule_mapping[62]) 108 | best_rule, best_rule_success = find_single_best_rule_success(rule_mapping) 109 | best_single_rule_success = len(rule_mapping[7]) 110 | print( 111 | repo + ", " + \ 112 | str(total_holes) + ", " + \ 113 | str(float(len(capped_oracle)*100/total_holes)) + ", " + \ 114 | str(float(codex_success*100/total_holes)) + ", " + \ 115 | str(best_rule) + ", " +\ 116 | str(float(best_rule_success*100/total_holes)) + ", " + \ 117 | "in_file_lines_0.75" + ", " +\ 118 | str(float(best_single_rule_success*100/total_holes)) 119 | ) 120 | 121 | def rewrite_rule_context_data(repo_path, capped_holes, emb_model_type): 122 | all_files = os.listdir(os.path.join(repo_path, emb_model_type)) 123 | os.makedirs(os.path.join(repo_path, 'capped_'+ emb_model_type), exist_ok=True) 124 | for file in all_files: 125 | hole_path = os.path.join(repo_path, emb_model_type, file) 126 | data = pickle.load(open(hole_path, 'rb')) 127 | hole = list(data.keys())[0] 128 | if hole in capped_holes: 129 | dest_hole_path = os.path.join(repo_path, 'capped_'+ emb_model_type, file) 130 | shutil.copy(hole_path, dest_hole_path) 131 | 132 | def rearrange_data(base_dir, split): 133 | print(split) 134 | all_dirs = os.listdir(os.path.join(base_dir, split)) 135 | for repo in all_dirs: 136 | if repo in repo_split_map: 137 | print(base_dir, split, repo) 138 | repo_holes = [] 139 | hole_data = pickle.load(open(os.path.join(base_dir, split, repo, 'hole_data'), 'rb')) 140 | oracle = pickle.load(open(os.path.join(base_dir, split, repo, 'oracle'), 'rb')) 141 | duplicate_files = open(os.path.join(base_dir, split, repo, 'duplicates'), 'r').readlines() 142 | all_duplicate_files = [x.strip() for x in duplicate_files] 143 | for file, holes in hole_data.items(): 144 | if file not in all_duplicate_files and not file.startswith('rule_classifier_data/val/rsbotownversion/trunk/scripts/'): 145 | hids = [file + '_' + str(h[0]) + '_' + str(h[1]) for h in holes] 146 | repo_holes.extend(hids) 147 | #print(len(repo_holes)) 148 | if len(repo_holes) < max_holes: 149 | capped_holes = repo_holes 150 | capped_oracle = oracle 151 | total_holes = len(repo_holes) 152 | else: 153 | capped_holes = random.sample(repo_holes, max_holes) 154 | capped_oracle = {} 155 | for hid, entry in oracle.items(): 156 | if hid in capped_holes: 157 | capped_oracle[hid] = entry 158 | total_holes = len(capped_holes) 159 | 160 | get_new_oracle_numbers(capped_oracle, repo, total_holes) 161 | with open(os.path.join(base_dir, split, repo, 'capped_oracle_'+ str(max_holes)), 'wb') as f: 162 | pickle.dump(capped_oracle, f) 163 | 164 | with open(os.path.join(base_dir, split, repo, 'capped_holes_'+ str(max_holes)), 'w') as f: 165 | for item in capped_holes: 166 | f.write("%s\n" %(item,)) 167 | capped_holes = open(os.path.join(base_dir, split, repo, 'capped_holes_10000'), 'r').readlines() 168 | capped_holes = [x.strip() for x in capped_holes] 169 | 170 | rewrite_rule_context_data(os.path.join(base_dir, split, repo), capped_holes, 'codebert_mod') 171 | 172 | is_move(base_dir, split, repo) 173 | 174 | rearrange_data('rule_classifier_data', 'train') 175 | rearrange_data('rule_classifier_data', 'val') 176 | rearrange_data('rule_classifier_data', 'test') 177 | 178 | 179 | -------------------------------------------------------------------------------- /rule_classifier_preprocessed_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import argparse 5 | import random 6 | from torch.utils.data import DataLoader 7 | from torch.utils.tensorboard import SummaryWriter 8 | from torch.autograd import Variable 9 | from torch import nn 10 | from tqdm import tqdm 11 | from preprocessed_data import * 12 | from model_preprocessed_data import RuleModel 13 | from torch import FloatTensor 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | def setup_args(): 19 | """ 20 | Description: Takes in the command-line arguments from user 21 | """ 22 | parser = argparse.ArgumentParser() 23 | 24 | #data hyperparams 25 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility") 26 | parser.add_argument("--input_data_dir", type=str, default='rule_classifier_data', help="base directory for the data") 27 | parser.add_argument("--model_output_path", type=str, default='models/', help="base directory for storing the models") 28 | parser.add_argument("--output_path", type=str, default='outputs/', help="base directory for storing the models") 29 | parser.add_argument("--batch_size", type=int, default=64, help="batch size for training the classifier") 30 | parser.add_argument("--emb_model_type", type=str, default='codebert', help="model to obtain embedding from") 31 | #training hyperparams 32 | parser.add_argument("--num_epochs", type=int, default=5000, help="number of epochs for training the classifier") 33 | parser.add_argument("--optimizer", type=str, default='adam', help="optimizer to use for training") 34 | parser.add_argument("--lr_scheduler", type=str, default='none', help="optimizer to use for training") 35 | parser.add_argument("--learning_rate", type=float, default=3e-4, help="learning rate for training the classifier") 36 | parser.add_argument("--patience", type=int, default=5000, help="patience for early-stop") 37 | parser.add_argument('--load_from_checkpoint', default=False, action='store_true') 38 | #model hyperparams 39 | parser.add_argument("--mode", type=str, default='rlpg-r', help="rule classifier variant: rlpg-h, rlpg-r") 40 | parser.add_argument("--n_head", type=int, default=4, help="number of heads") 41 | parser.add_argument("--d_k", type=int, default=32, help="depth of projection") 42 | parser.add_argument("--dropout", type=float, default=0.25, help="depth of projection") 43 | 44 | 45 | return parser.parse_args() 46 | 47 | def save(model, optimizer, epoch, save_dir): 48 | to_save = model.module if hasattr(model, "module") else model 49 | # pytorch_total_params = sum(p.numel() for p in to_save.parameters()) 50 | # pytorch_trainable_params = sum(p.numel() for p in to_save.parameters() if p.requires_grad) 51 | # print(pytorch_trainable_params, pytorch_total_params) 52 | torch.save(to_save.state_dict(), os.path.join(save_dir, "best_model.th")) 53 | torch.save({"optimizer": optimizer.state_dict(), "last_epoch": epoch}, os.path.join(save_dir, "optim.th")) 54 | 55 | def get_accuracy(pred, gold, mask): 56 | pred = pred.masked_fill(mask==0, 0) 57 | max_idx = torch.argmax(pred, dim=1, keepdim=True) 58 | rounded_pred = torch.round(pred) 59 | max_idx_gold_vals = torch.gather(gold, 1, max_idx) 60 | mean_highest_success_correct = (max_idx_gold_vals == 1).to(dtype=torch.float).mean() 61 | return mean_highest_success_correct 62 | 63 | def get_prediction(rule_model, info): 64 | pred, mask = rule_model(info) 65 | mask = torch.sum(mask, dim=-1) #(bs, #rules) 66 | return pred, mask 67 | 68 | def calculate_loss(rule_model, criterion, info, gt): 69 | 70 | pred, mask = get_prediction(rule_model, info) 71 | n_valid_entries = torch.sum(mask.view(-1)!=0) 72 | loss = criterion(pred, gt) 73 | loss = loss.masked_fill(mask==0, 0) 74 | loss = torch.sum(loss)/n_valid_entries 75 | mean_highest_success_correct = get_accuracy(pred, gt, mask) 76 | masked_gt = torch.sum(gt.masked_fill(mask==0, 0), dim=-1) 77 | mean_oracle_success = masked_gt.masked_fill(masked_gt!=0, 1.0).mean() 78 | 79 | return {'loss': loss, \ 80 | 'mean_highest_success_correct': mean_highest_success_correct}, \ 81 | mean_oracle_success 82 | 83 | 84 | if __name__ == '__main__': 85 | 86 | args = setup_args() 87 | 88 | #Fix seeds 89 | np.random.seed(args.seed) 90 | os.environ['PYTHONHASHSEED'] = str(args.seed) 91 | torch.manual_seed(args.seed) 92 | random.seed(args.seed) 93 | 94 | #Define paths for storing tensorboard logs 95 | dir_name = 'optimizer#' + args.optimizer + '#learning_rate#' + str(args.learning_rate) + '#lr_scheduler#' + args.lr_scheduler \ 96 | + '#emb_model_type#' + args.emb_model_type + '#n_head#' + str(args.n_head) + '#d_k#' + str(args.d_k) \ 97 | + '#mode#' + args.mode + '#dropout#' + str(args.dropout) 98 | 99 | save_dir = os.path.join(args.model_output_path, dir_name) 100 | os.makedirs(save_dir, exist_ok=True) 101 | tb_writer = SummaryWriter(os.path.join(save_dir, "logs")) 102 | os.makedirs(args.output_path, exist_ok=True) 103 | f_out = open(os.path.join(args.output_path, dir_name), 'w') 104 | 105 | # Define train and val dataloaders 106 | kwargs = {'num_workers': 8, 'pin_memory': True} if device=='cuda' else {} 107 | tokenizer = set_tokenizer(args.emb_model_type) 108 | #print(tokenizer) 109 | train_dataset = RuleDataset(os.path.join(args.input_data_dir, 'train'), tokenizer=tokenizer, emb_model_type=args.emb_model_type) 110 | train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, **kwargs) 111 | 112 | val_dataset = RuleDataset(os.path.join(args.input_data_dir, 'val'), tokenizer=tokenizer, emb_model_type=args.emb_model_type) 113 | val_data_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, **kwargs) 114 | 115 | # Define the model 116 | rule_model = RuleModel(emb_model_type=args.emb_model_type, device=device, n_head=args.n_head, d_k=args.d_k, \ 117 | mode = args.mode, dropout=args.dropout) 118 | rule_model.to(device) 119 | 120 | #Define optimizer and loss 121 | if args.lr_scheduler == 'none': 122 | if args.optimizer == 'adam': 123 | optimizer = torch.optim.Adam(rule_model.parameters(), lr=args.learning_rate) 124 | if args.optimizer == 'sgd': 125 | optimizer = torch.optim.SGD(rule_model.parameters(), lr=args.learning_rate) 126 | 127 | if args.lr_scheduler == 'cosine': 128 | if args.optimizer == 'adam': 129 | optimizer = torch.optim.Adam(rule_model.parameters(), lr=args.learning_rate) 130 | if args.optimizer == 'sgd': 131 | optimizer = torch.optim.SGD(rule_model.parameters(), lr=args.learning_rate) 132 | lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0) 133 | 134 | if args.lr_scheduler == 'cosinewarm': 135 | if args.optimizer == 'adam': 136 | optimizer = torch.optim.Adam(rule_model.parameters(), lr=args.learning_rate) 137 | if args.optimizer == 'sgd': 138 | optimizer = torch.optim.SGD(rule_model.parameters(), lr=args.learning_rate) 139 | lr_sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10) 140 | 141 | if args.lr_scheduler == 'tr2': 142 | if args.optimizer == 'adam': 143 | optimizer = torch.optim.Adam(rule_model.parameters(), lr=args.learning_rate) 144 | if args.att_lr_optimizer == 'sgd': 145 | optimizer = torch.optim.SGD(rule_model.parameters(), lr=args.learning_rate) 146 | lr_sched = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=args.learning_rate/2.0, max_lr=0.1, 147 | step_size_up=5, mode="triangular2", cycle_momentum=False) 148 | 149 | if args.lr_scheduler == 'reduceonplateau': 150 | if args.optimizer == 'adam': 151 | optimizer = torch.optim.Adam(rule_model.parameters(), lr=args.learning_rate) 152 | if args.optimizer == 'sgd': 153 | optimizer = torch.optim.SGD(rule_model.parameters(), lr=args.learning_rate) 154 | lr_sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max') 155 | 156 | 157 | if args.load_from_checkpoint: 158 | print("=> loading checkpoint '{}'".format(save_dir)) 159 | model_path = os.path.join(save_dir, 'best_model.th') 160 | opt_path = os.path.join(save_dir, 'optim.th') 161 | status_dict = torch.load(opt_path, map_location=torch.device('cpu')) 162 | rule_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) 163 | optimizer.load_state_dict(status_dict['optimizer']) 164 | print("=> loaded checkpoint '{}' (epoch {})".format(save_dir, status_dict['last_epoch'])) 165 | 166 | criterion = nn.BCELoss(reduction='none') 167 | 168 | 169 | best_val_acc = 0 170 | 171 | for epoch in range(args.num_epochs): 172 | print("Epoch %d" % epoch) 173 | f_out.write("Epoch: " + str(epoch)+"\n") 174 | 175 | ########################Training Loop############################################# 176 | total_highest_success_correct, total_loss = 0.0, 0.0 177 | 178 | total_batches = 0 179 | total_oracle_success = 0.0 180 | 181 | rule_model.train() 182 | train_count = 0 183 | 184 | for batch in tqdm(train_data_loader): 185 | 186 | hole_context = Variable(batch[0]).to(device) 187 | hole_attention_mask = Variable(batch[1]).to(device) 188 | rule_context = Variable(batch[2]).to(device) 189 | gt = Variable(batch[3]).to(device) 190 | failure_flag = Variable(batch[5]).to(device) 191 | 192 | train_count+= torch.sum(failure_flag) 193 | 194 | optimizer.zero_grad() 195 | 196 | batch_metrices, oracle_success = calculate_loss(rule_model, \ 197 | criterion, \ 198 | (hole_context, hole_attention_mask, rule_context), \ 199 | gt) 200 | 201 | batch_loss = batch_metrices['loss'] 202 | batch_loss.backward() 203 | optimizer.step() 204 | 205 | total_highest_success_correct += batch_metrices['mean_highest_success_correct'] 206 | total_oracle_success += oracle_success 207 | total_loss += batch_loss.item() 208 | total_batches += 1 209 | 210 | avg_train_loss = total_loss/ total_batches 211 | avg_highest_success_accuracy = total_highest_success_correct*100/ total_batches 212 | avg_oracle_success_accuracy = total_oracle_success*100/ total_batches 213 | 214 | tb_writer.add_scalar("metrics/train_loss", avg_train_loss, epoch) 215 | tb_writer.add_scalar("metrics/train_highest_success_accuracy", avg_highest_success_accuracy, epoch) 216 | 217 | print("Train loss: Total %f" % avg_train_loss) 218 | f_out.write("Train loss: " + str(avg_train_loss) + "\n") 219 | print("Train oracle success accuracy: %f" % avg_oracle_success_accuracy) 220 | f_out.write("Train oracle success accuracy: " + str(avg_oracle_success_accuracy) + "\n") 221 | print("Train highest success accuracy: %f" % avg_highest_success_accuracy) 222 | f_out.write("Train highest success accuracy: " + str(avg_highest_success_accuracy) + "\n") 223 | 224 | 225 | ######################################Evaluation Loop############################################ 226 | rule_model.eval() 227 | 228 | with torch.no_grad(): 229 | 230 | total_highest_success_correct, total_loss = 0.0, 0.0 231 | total_batches = 0 232 | total_oracle_success = 0.0 233 | val_count = 0 234 | 235 | for batch in tqdm(val_data_loader): 236 | 237 | hole_context = Variable(batch[0]).to(device) 238 | hole_attention_mask = Variable(batch[1]).to(device) 239 | rule_context = Variable(batch[2]).to(device) 240 | gt = Variable(batch[3]).to(device) 241 | failure_flag = Variable(batch[5]).to(device) 242 | 243 | val_count+= torch.sum(failure_flag) 244 | 245 | 246 | batch_metrices, oracle_success = calculate_loss(rule_model, \ 247 | criterion, \ 248 | (hole_context, hole_attention_mask, rule_context), \ 249 | gt) 250 | 251 | 252 | batch_loss = batch_metrices['loss'] 253 | total_highest_success_correct += batch_metrices['mean_highest_success_correct'] 254 | total_oracle_success+= oracle_success 255 | total_loss += batch_loss.item() 256 | total_batches += 1 257 | 258 | avg_val_loss = total_loss/ total_batches 259 | avg_highest_success_accuracy = total_highest_success_correct*100/ total_batches 260 | avg_oracle_success_accuracy = total_oracle_success*100/total_batches 261 | 262 | tb_writer.add_scalar("metrics/val_loss", avg_val_loss, epoch) 263 | tb_writer.add_scalar("metrics/val_highest_success_accuracy", avg_highest_success_accuracy, epoch) 264 | 265 | print("Val loss: Total %f" % avg_val_loss) 266 | f_out.write("Val loss: " + str(avg_val_loss) + "\n") 267 | print("Val oracle success accuracy: %f" % avg_oracle_success_accuracy) 268 | f_out.write("Val oracle success accuracy: " + str(avg_oracle_success_accuracy) + "\n") 269 | print("Val highest success accuracy: %f" % avg_highest_success_accuracy) 270 | f_out.write("Val highest success accuracy: " + str(avg_highest_success_accuracy) + "\n") 271 | 272 | if args.lr_scheduler =='reduceonplateau': 273 | lr_sched.step(avg_highest_success_accuracy) 274 | elif args.lr_scheduler !='none': 275 | lr_sched.step() 276 | 277 | if avg_highest_success_accuracy > best_val_acc: 278 | print("Found new best model") 279 | f_out.write("Found new best model\n") 280 | best_val_acc = avg_highest_success_accuracy 281 | save(rule_model, optimizer, epoch, save_dir) 282 | patience_ctr = 0 283 | else: 284 | patience_ctr += 1 285 | if patience_ctr == args.patience: 286 | print("Ran out of patience. Stopping training early...") 287 | f_out.write("Ran out of patience. Stopping training early...\n") 288 | print("Best Val Acc: ", best_val_acc) 289 | f_out.write("Best Val Acc: " + str(best_val_acc)) 290 | break 291 | f_out.write("\n\n") 292 | f_out.flush() 293 | print("Best Val Acc: ", best_val_acc) 294 | f_out.write("Best Val Acc: " + str(best_val_acc)) 295 | f_out.close() 296 | 297 | -------------------------------------------------------------------------------- /rule_config.py: -------------------------------------------------------------------------------- 1 | 2 | # context_location and context_type define a rule 3 | context_location = [ 4 | 'in_file', \ 5 | 'parent_class_file', \ 6 | 'import_file',\ 7 | 'sibling_file', \ 8 | 'similar_name_file', \ 9 | 'child_class_file', \ 10 | 'import_of_sibling_file', \ 11 | 'import_of_similar_name_file', \ 12 | 'import_of_parent_class_file', \ 13 | 'import_of_child_class_file' 14 | ] 15 | 16 | 17 | all_context_types = [ 18 | 'method_names_and_bodies',\ 19 | 'method_names',\ 20 | 'identifiers', \ 21 | 'type_identifiers',\ 22 | 'string_literals',\ 23 | 'field_declarations', \ 24 | 'lines' 25 | ] 26 | 27 | context_type_dict = {} 28 | for con_loc in context_location: 29 | if con_loc == 'in_file': 30 | context_types = all_context_types[1:] 31 | else: 32 | context_types = all_context_types[:-1] 33 | context_type_dict[con_loc] = context_types 34 | 35 | # rule-specific hyperparams to run. Make changes here to run different configurations 36 | rule_hyperparams = { 37 | 'lines': 38 | { 39 | 'context_ratio': [0.5, 0.25, 0.75], 40 | 'top_k': [-1], 41 | 'prompt_separator': ['space'], 42 | 'top_k_type':['first'], 43 | 'rule_context_formatting':['space'] 44 | }, 45 | 46 | 'identifiers': 47 | { 48 | 'context_ratio': [0.5], 49 | 'top_k': [-1], 50 | 'prompt_separator': ['newline'], 51 | 'top_k_type':['first'], 52 | 'rule_context_formatting':['class_name'] 53 | }, 54 | 55 | 'type_identifiers': 56 | { 57 | 'context_ratio': [0.5], 58 | 'top_k': [-1], 59 | 'prompt_separator': ['newline'], 60 | 'top_k_type':['first'], 61 | 'rule_context_formatting':['class_name'] 62 | }, 63 | 64 | 'string_literals': 65 | { 66 | 'context_ratio': [0.5], 67 | 'top_k': [-1], 68 | 'prompt_separator': ['newline'], 69 | 'top_k_type':['first'], 70 | 'rule_context_formatting':['class_name'] 71 | }, 72 | 73 | 'method_names': 74 | { 75 | 'context_ratio': [0.5], 76 | 'top_k': [-1], 77 | 'prompt_separator': ['newline'], 78 | 'top_k_type':['first'], 79 | 'rule_context_formatting':['class_name'] 80 | }, 81 | 82 | 'field_declarations': 83 | { 84 | 'context_ratio': [0.5], 85 | 'top_k': [-1], 86 | 'prompt_separator': ['newline'], 87 | 'top_k_type':['first'], 88 | 'rule_context_formatting':['class_name'] 89 | }, 90 | 91 | 'method_names_and_bodies': 92 | { 93 | 'context_ratio': [0.5], 94 | 'top_k': [-1], 95 | 'prompt_separator': ['newline'], 96 | 'top_k_type':['first'], 97 | 'rule_context_formatting':['class_method_name'] 98 | } 99 | 100 | 101 | } 102 | -------------------------------------------------------------------------------- /rule_inference_preprocessed_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import torch 5 | import pickle 6 | import argparse 7 | import random 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | from torch.autograd import Variable 11 | from torch import nn 12 | from tqdm import tqdm 13 | from preprocessed_data import * 14 | from model_preprocessed_data import RuleModel 15 | from torch import FloatTensor 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | 20 | def setup_args(): 21 | """ 22 | Description: Takes in the command-line arguments from user 23 | """ 24 | parser = argparse.ArgumentParser() 25 | 26 | parser.add_argument("--seed", type=int, default=9, help="seed for reproducibility") 27 | parser.add_argument("--input_data_dir", type=str, default='rule_classifier_data', help="base directory for the data") 28 | parser.add_argument("--data_split", type=str, default='val', help="data_split") 29 | parser.add_argument("--model_path", type=str, default='models/rlpg-h', help="base directory for storing the models") 30 | parser.add_argument("--batch_size", type=int, default=32, help="batch size for training the classifier") 31 | return parser.parse_args() 32 | 33 | def get_accuracy(pred, gold, mask): 34 | pred = pred.masked_fill(mask==0, 0) 35 | max_idx = torch.argmax(pred, dim=1, keepdim=True) 36 | rounded_pred = torch.round(pred) 37 | max_idx_gold_vals = torch.gather(gold, 1, max_idx) 38 | mean_highest_success_correct = (max_idx_gold_vals == 1).to(dtype=torch.float).mean() 39 | return mean_highest_success_correct, pred 40 | 41 | def get_prediction(rule_model, info): 42 | pred, mask = rule_model(info) 43 | mask = torch.sum(mask, dim=-1) #(bs, #rules) 44 | return pred, mask 45 | 46 | def calculate_loss(rule_model, criterion, info, gt, hole_ids, hole_stats): 47 | 48 | pred, mask = get_prediction(rule_model, info) 49 | n_valid_entries = torch.sum(mask.view(-1)!=0) 50 | loss = criterion(pred, gt) 51 | loss = loss.masked_fill(mask==0, 0) 52 | mean_highest_success_correct, pred = get_accuracy(pred, gt, mask) 53 | masked_gt = torch.sum(gt.masked_fill(mask==0, 0), dim=-1) 54 | mean_oracle_success = masked_gt.masked_fill(masked_gt!=0, 1.0).mean() 55 | 56 | for i in range(len(hole_ids)): 57 | hid = hole_ids[i] 58 | hole_loss = torch.sum(loss[i]) 59 | n_valid_hole_rules = torch.sum(loss[i]!=0) 60 | hole_loss = hole_loss/n_valid_hole_rules 61 | hole_prediction = pred[i] 62 | hole_stats[hid] = (hole_loss, hole_prediction) 63 | 64 | return {'loss': torch.sum(loss)/n_valid_entries, \ 65 | 'mean_highest_success_correct': mean_highest_success_correct}, \ 66 | mean_oracle_success, \ 67 | hole_stats 68 | 69 | if __name__ == '__main__': 70 | 71 | args = setup_args() 72 | 73 | #Fix seeds 74 | np.random.seed(args.seed) 75 | os.environ['PYTHONHASHSEED'] = str(args.seed) 76 | torch.manual_seed(args.seed) 77 | random.seed(args.seed) 78 | 79 | os.makedirs(os.path.join('outputs', args.data_split), exist_ok=True) 80 | f_out = open(os.path.join('outputs', args.data_split + '_inference'), 'a') 81 | 82 | model_path = args.model_path 83 | mode = model_path.split('/')[-1] 84 | 85 | # Define the model 86 | if mode == 'rlpg-h': 87 | emb_model_type = 'codebert' 88 | rule_model = RuleModel(emb_model_type=emb_model_type, device=device, mode=mode) 89 | if mode == 'rlpg-r': 90 | emb_model_type = 'codebert' 91 | rule_model = RuleModel(emb_model_type=emb_model_type, device=device, mode=mode, n_head=4, d_k=32) 92 | 93 | # Define train and val dataloaders 94 | kwargs = {'num_workers': 8, 'pin_memory': True} if device=='cuda' else {} 95 | tokenizer = set_tokenizer(emb_model_type) 96 | dataset = RuleDataset(os.path.join(args.input_data_dir, args.data_split), tokenizer=tokenizer, emb_model_type=emb_model_type) 97 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, **kwargs) 98 | 99 | print("=> loading checkpoint '{}'".format(model_path)) 100 | best_model_path = os.path.join(model_path, 'best_model.th') 101 | rule_model.load_state_dict(torch.load(best_model_path, map_location=torch.device('cpu')), strict=False) 102 | print("=> loaded checkpoint '{}'".format(model_path)) 103 | rule_model.to(device) 104 | 105 | rule_model.eval() 106 | criterion = nn.BCELoss(reduction='none') 107 | 108 | with torch.no_grad(): 109 | 110 | total_highest_success_correct, total_loss = 0.0, 0.0 111 | total_batches = 0 112 | total_oracle_success = 0.0 113 | hole_stats = {} 114 | count = 0 115 | 116 | for batch in tqdm(data_loader): 117 | hole_context = Variable(batch[0]).to(device) 118 | hole_attention_mask = Variable(batch[1]).to(device) 119 | rule_context = Variable(batch[2]).to(device) 120 | gt = Variable(batch[3]).to(device) 121 | hole_id = batch[4] 122 | failure_flag = Variable(batch[5]).to(device) 123 | 124 | count+= torch.sum(failure_flag) 125 | 126 | batch_metrices, oracle_success, hole_stats = calculate_loss(rule_model, \ 127 | criterion, \ 128 | (hole_context, hole_attention_mask, rule_context), \ 129 | gt, \ 130 | hole_id, \ 131 | hole_stats) 132 | 133 | batch_loss = batch_metrices['loss'] 134 | total_highest_success_correct += batch_metrices['mean_highest_success_correct'] 135 | total_oracle_success+= oracle_success 136 | total_loss += batch_loss.item() 137 | total_batches += 1 138 | 139 | avg_loss = total_loss/ total_batches 140 | avg_highest_success_accuracy = total_highest_success_correct*100/ total_batches 141 | avg_oracle_success_accuracy = total_oracle_success*100/total_batches 142 | 143 | f_out.write("\n********************************\n") 144 | f_out.write(model_path + "\n") 145 | print("Loss: Total %f" % avg_loss) 146 | f_out.write("Loss: " + str(avg_loss) + "\n") 147 | print("Oracle success accuracy: %f" % avg_oracle_success_accuracy) 148 | f_out.write("Oracle success accuracy: " + str(avg_oracle_success_accuracy) + "\n") 149 | print("Highest success accuracy: %f" % avg_highest_success_accuracy) 150 | f_out.write("Highest success accuracy: " + str(avg_highest_success_accuracy) + "\n") 151 | f_out.write("\n********************************\n") 152 | f_out.flush() 153 | 154 | with open(os.path.join('outputs', args.data_split, '/'.join(model_path.split('/')[1:])) , 'wb') as f: 155 | pickle.dump(hole_stats, f) -------------------------------------------------------------------------------- /rule_representation_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import pickle 5 | from torch.utils.data import Dataset 6 | from transformers import AutoModel 7 | from transformers import GPT2TokenizerFast, AutoTokenizer 8 | from transformers import pipeline 9 | from utils import * 10 | from data_utils import RuleDatasetUtils 11 | 12 | class RuleReprDataset(Dataset): 13 | 14 | def __init__(self, input_data_dir, emb_model_type, tokenizer): 15 | # get all relevant files (in raw form) 16 | files = [] 17 | oracles = {} 18 | hole_datas = {} 19 | parse_datas = {} 20 | all_duplicate_files = [] 21 | data_type = input_data_dir.split('/')[-1] 22 | for dp, dn, filenames in os.walk(input_data_dir): 23 | for f in filenames: 24 | if f == 'hole_data': 25 | hole_data = pickle.load(open(os.path.join(dp, f), 'rb')) 26 | hole_data = self.update_dict(hole_data, data_type) 27 | hole_datas = {**hole_datas, **hole_data} 28 | if f == 'parsed_data': 29 | parse_data = pickle.load(open(os.path.join(dp, f), 'rb')) 30 | parse_data = self.update_dict(parse_data, data_type) 31 | parse_datas = {**parse_datas, **parse_data} 32 | if f == 'duplicates': 33 | duplicate_files = open(os.path.join(dp, f), 'r').readlines() 34 | all_duplicate_files.extend([x.strip() for x in duplicate_files]) 35 | if os.path.splitext(f)[1] == '.java': 36 | files.append(os.path.join(dp, f)) 37 | print(len(all_duplicate_files)) 38 | self.holes = [] 39 | for file in files: 40 | if file in hole_datas and \ 41 | file not in all_duplicate_files and \ 42 | not file.startswith('rule_classifier_data/train/rsbotownversion/trunk/scripts/'): 43 | for (l,c) in hole_datas[file]: 44 | hole_identity = file + '_' + str(l) + '_' + str(c) 45 | self.holes.append(hole_identity) 46 | 47 | print(len(self.holes)) 48 | self.num_rules = len(combined_to_index) 49 | self.tokenizer = tokenizer 50 | 51 | self.parse_datas = parse_datas 52 | self.model_max_length = self.tokenizer.model_max_length 53 | self.rule_repr_cache = {} 54 | self.emb_model_type = emb_model_type 55 | self.set_embedding_model() 56 | self.repr_size = 768 57 | self.start = 0 58 | self.end = 500000 59 | 60 | def update_dict(self, dic, data_type): 61 | mod_dic = {} 62 | for k,v in dic.items(): 63 | mod_k = '/'. join(['rule_classifier_data', data_type] + k.split('/')[2:]) 64 | mod_dic[mod_k] = v 65 | return mod_dic 66 | 67 | def __len__(self): 68 | return len(self.holes) 69 | 70 | def __getitem__(self, idx): 71 | if idx >=self.start and idx <= self.end: 72 | return self.generate_data(self.holes[idx]) 73 | else: 74 | return None, None, None 75 | 76 | def get_start_index(self, repo, start_offset=0, interval=0): 77 | count=0 78 | for i in range(len(self.holes)): 79 | hole = self.holes[i] 80 | repo_name = hole.split('/')[2] 81 | if repo_name == repo: 82 | count+=1 83 | repo_end_idx = i 84 | 85 | self.start = repo_end_idx - count + 1 86 | self.start = self.start + start_offset 87 | if interval!=0 : 88 | self.end = self.start + interval 89 | else: 90 | self.end = repo_end_idx 91 | return self.start, self.end 92 | 93 | def is_clear_cache(self): 94 | if len(self.rule_repr_cache) < 30: 95 | self.clear_cache = False 96 | else: 97 | self.clear_cache = True 98 | self.rule_repr_cache = {} 99 | 100 | def get_representation(self, inputs, mask): 101 | outputs = self.emb_model(inputs, attention_mask=mask) 102 | try: 103 | representation = outputs.pooler_output 104 | except: 105 | representation = outputs.last_hidden_state[:, 0] 106 | #print(representation.shape) 107 | return representation 108 | 109 | def get_context_embedding(self, context, attn_mask): 110 | context_embedding = self.get_representation(context, attn_mask) 111 | return context_embedding 112 | 113 | def get_rule_context(self, file, hole_pos): 114 | self.is_clear_cache() 115 | rule_dataset_util = RuleDatasetUtils(file, self.parse_datas, hole_pos, self.tokenizer) 116 | rule_prompts, rule_indexes = rule_dataset_util.get_all_rules_context() 117 | rule_contexts = self.tokenizer(rule_prompts, truncation=True, padding='max_length') 118 | rule_inputs = torch.tensor(rule_contexts['input_ids']) 119 | rule_masks = torch.tensor(rule_contexts['attention_mask']) 120 | rule_indexes = torch.tensor(rule_indexes) 121 | 122 | # remove rules that are already cached 123 | rule_prompts = self.tokenizer.batch_decode(rule_inputs) 124 | filtered_rule_context = [] 125 | filtered_rule_mask = [] 126 | filtered_rule_prompts = [] 127 | filtered_rule_indexes = [] 128 | for i in range(len(rule_prompts)): 129 | rule_prompt = rule_prompts[i] 130 | if rule_prompt not in self.rule_repr_cache: 131 | filtered_rule_indexes.append(rule_indexes[i]) 132 | filtered_rule_context.append(rule_inputs[i]) 133 | filtered_rule_mask.append(rule_masks[i]) 134 | filtered_rule_prompts.append(rule_prompt) 135 | 136 | if filtered_rule_context: 137 | filtered_rule_context = torch.stack(filtered_rule_context) 138 | filtered_rule_mask = torch.stack(filtered_rule_mask) 139 | 140 | # get rule representations 141 | filtered_representations = self.get_context_embedding(filtered_rule_context, filtered_rule_mask) 142 | # cache the representations 143 | for i in range(len(filtered_representations)): 144 | f_repr = filtered_representations[i] 145 | rule_prompt = filtered_rule_prompts[i] 146 | self.rule_repr_cache[rule_prompt] = f_repr 147 | 148 | # obtain full representations 149 | keys = [] 150 | j = 0 151 | for ind in range(self.num_rules): 152 | if ind in rule_indexes: 153 | prompt = rule_prompts[j] 154 | j+=1 155 | if prompt in self.rule_repr_cache: 156 | keys.append(self.rule_repr_cache[prompt]) 157 | else: 158 | keys.append(torch.zeros(self.repr_size)) 159 | else: 160 | keys.append(torch.zeros(self.repr_size)) 161 | 162 | keys = torch.stack(keys) 163 | return keys 164 | 165 | def generate_data(self, hole): 166 | 167 | hole_parts = hole.split('/')[-1].split('_') 168 | repo_name = hole.split('/')[2] 169 | if len(hole_parts) > 3: 170 | new_hole_parts = hole_parts[:-2] 171 | filename = '_'.join(new_hole_parts) 172 | filename = [filename] 173 | else: 174 | filename = [hole_parts[0]] 175 | file = '/'.join(hole.split('/')[:-1] + filename) 176 | hole_pos = (int(hole_parts[-2]), int(hole_parts[-1])) 177 | rule_contexts = self.get_rule_context(file, hole_pos) 178 | return rule_contexts, hole, repo_name 179 | 180 | def set_tokenizer(self): 181 | if self.emb_model_type == 'codebert': 182 | self.tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") 183 | if self.emb_model_type == 'graphcodebert': 184 | self.tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") 185 | if self.emb_model_type == 'gpt-2': 186 | self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 187 | self.tokenizer.pad_token = self.tokenizer.eos_token 188 | 189 | def set_embedding_model(self): 190 | # CodeBERT 191 | if self.emb_model_type == 'codebert': 192 | self.emb_model = AutoModel.from_pretrained("microsoft/codebert-base") 193 | # GraphCodeBERT 194 | if self.emb_model_type == 'graphcodebert': 195 | self.emb_model = AutoModel.from_pretrained("microsoft/graphcodebert-base") 196 | 197 | def set_tokenizer(emb_model_type): 198 | 199 | if emb_model_type == 'codebert': 200 | tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") 201 | if emb_model_type == 'graphcodebert': 202 | tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") 203 | if emb_model_type == 'gpt-2': 204 | tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 205 | tokenizer.pad_token = tokenizer.eos_token 206 | return tokenizer 207 | 208 | 209 | -------------------------------------------------------------------------------- /script_analyze_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | base_dir = 'rule_classifier_data' 3 | 4 | projects = { 'train': [ 5 | 'gfsfa', 6 | 'sol-agent-platform', 7 | 'gloodb', 8 | 'rsbotownversion', 9 | 'jjskit', 10 | 'ftpserverremoteadmin', 11 | 'openprocesslogger', 12 | 'strudem-sicsa', 13 | 'seamlets', 14 | 'healpix-rangeset', 15 | 'quidsee', 16 | 'mobileexpensetracker', 17 | 'swe574-group3', 18 | 'largemail', 19 | 'soap-dtc', 20 | 'designpatternjavapedro', 21 | 'myt5lib', 22 | 'exogdx', 23 | 'tapestry-sesame' 24 | ], 25 | 26 | 'val': [ 27 | 'javasummerframework', 28 | 'tinwiki', 29 | 'teammates-shakthi', 30 | 'jcontenedor', 31 | 'jloogle', 32 | 'swinagile', 33 | 'math-mech-eshop', 34 | 'jata4test', 35 | 'affinity_propagation_java', 36 | 'navigablep2p', 37 | 'springlime', 38 | 'sohocms', 39 | 'tyrond', 40 | 'infinispan-storage-service', 41 | ], 42 | 43 | 'test': [ 44 | 'project-pt-diaoc', 45 | 'dovetaildb', 46 | 'robotsimulator2009w', 47 | 'ircrpgbot', 48 | 'xfuze', 49 | 'realtimegc', 50 | 'fswuniceubtemplates', 51 | 'glperaudsimon', 52 | 'apiitfriends', 53 | 'qwikioffice-java', 54 | 'xiaonei-java-api', 55 | 'wicketbits', 56 | 'hucourses', 57 | 'gwt-plugindetect' 58 | ] 59 | } 60 | 61 | commands = [] 62 | for data_split, data_split_repos in projects.items(): 63 | for proj in data_split_repos: 64 | proj_name = proj.strip() 65 | command = "python analyze_results.py --proj_name " + proj_name \ 66 | + " --base_dir " + base_dir + " --data_split " + data_split 67 | commands.append(command) 68 | 69 | with open("commands_analyze_results", 'w') as f: 70 | f.writelines("%s\n" % command for command in commands) 71 | f.close() 72 | 73 | -------------------------------------------------------------------------------- /script_completions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import glob 4 | import numpy as np 5 | from rule_config import * 6 | 7 | base_dirs = os.listdir('rule_classifier_data') 8 | 9 | modes = ['codex', 'rule'] 10 | context_locations = [ 11 | 'in_file', \ 12 | 'parent_class_file', \ 13 | 'import_file',\ 14 | 'sibling_file', \ 15 | 'similar_name_file', \ 16 | 'child_class_file', \ 17 | 'import_of_sibling_file', \ 18 | 'import_of_similar_name_file', \ 19 | 'import_of_parent_class_file', \ 20 | 'import_of_child_class_file' 21 | ] 22 | 23 | 24 | batch_size = 20 25 | total_context_length = 4072 26 | 27 | def main(): 28 | commands = [] 29 | for base_repo in base_dirs: 30 | base_dir = os.path.join('rule_classifier_data', base_repo) 31 | for repo in os.listdir(base_dir): 32 | for mode in modes: 33 | if mode == 'codex': 34 | command = "python generate_completions.py --mode " + mode \ 35 | + " --total_context_len " + str(total_context_length)\ 36 | + " --base_dir " + base_dir\ 37 | + " --repo_name " + repo\ 38 | + " --batch_size " + str(batch_size) 39 | commands.append(command) 40 | 41 | if mode == 'rule': 42 | for context_location in context_locations: 43 | context_types = context_type_dict[context_location] 44 | for context_type in context_types: 45 | rule_specific_hyperparams = rule_hyperparams[context_type] 46 | for context_ratio in rule_specific_hyperparams['context_ratio']: 47 | for prompt_separator in rule_specific_hyperparams['prompt_separator']: 48 | for top_k in rule_specific_hyperparams['top_k']: 49 | for rule_context_format in rule_specific_hyperparams['rule_context_formatting']: 50 | if top_k == -1: 51 | command = "python generate_completions.py --mode " + mode\ 52 | + " --context_location " + context_location\ 53 | + " --context_type " + context_type\ 54 | + " --context_division_ratio " + str(context_ratio) \ 55 | + " --prompt_separator " + prompt_separator \ 56 | + " --top_k " + str(top_k)\ 57 | + " --total_context_len " + str(total_context_length)\ 58 | + " --base_dir " + base_dir\ 59 | + " --repo_name " + repo\ 60 | + " --batch_size " + str(batch_size)\ 61 | + " --rule_context_formatting " + rule_context_format\ 62 | 63 | commands.append(command) 64 | 65 | else: 66 | for top_k_type in rule_specific_hyperparams['top_k_type']: 67 | final_command = command + " --top_k_type " + top_k_type 68 | commands.append(final_command) 69 | 70 | 71 | with open("commands_completion", 'w') as f: 72 | f.writelines("%s\n" % command for command in commands) 73 | f.close() 74 | 75 | if __name__ == '__main__': 76 | main() -------------------------------------------------------------------------------- /script_gen_and_preprocess_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | base_data_dir = 'gcode-data' 3 | 4 | projects = { 'train': [ 5 | 'gfsfa', 6 | 'sol-agent-platform', 7 | 'gloodb', 8 | 'rsbotownversion', 9 | 'jjskit', 10 | 'ftpserverremoteadmin', 11 | 'openprocesslogger', 12 | 'strudem-sicsa', 13 | 'seamlets', 14 | 'healpix-rangeset', 15 | 'quidsee', 16 | 'mobileexpensetracker', 17 | 'swe574-group3', 18 | 'largemail', 19 | 'soap-dtc', 20 | 'designpatternjavapedro', 21 | 'myt5lib', 22 | 'exogdx', 23 | 'tapestry-sesame' 24 | ], 25 | 26 | 'val': [ 27 | 'javasummerframework', 28 | 'tinwiki', 29 | 'teammates-shakthi', 30 | 'jcontenedor', 31 | 'jloogle', 32 | 'swinagile', 33 | 'math-mech-eshop', 34 | 'jata4test', 35 | 'affinity_propagation_java', 36 | 'navigablep2p', 37 | 'springlime', 38 | 'sohocms', 39 | 'tyrond', 40 | 'infinispan-storage-service', 41 | ], 42 | 43 | 'test': [ 44 | 'project-pt-diaoc', 45 | 'dovetaildb', 46 | 'robotsimulator2009w', 47 | 'ircrpgbot', 48 | 'xfuze', 49 | 'realtimegc', 50 | 'fswuniceubtemplates', 51 | 'glperaudsimon', 52 | 'apiitfriends', 53 | 'qwikioffice-java', 54 | 'xiaonei-java-api', 55 | 'wicketbits', 56 | 'hucourses', 57 | 'gwt-plugindetect' 58 | ] 59 | } 60 | 61 | commands = [] 62 | for data_split, data_split_repos in projects.items(): 63 | for proj in data_split_repos: 64 | proj_name = proj.strip() 65 | command = "python create_hole_data.py --proj_name " + proj_name \ 66 | + " --base_dir " + base_data_dir + " --data_split " + data_split 67 | commands.append(command) 68 | command = "python parse_tree.py --proj_name " + proj_name \ 69 | + " --base_dir " + os.path.join('rule_classifier_data', data_split) 70 | commands.append(command) 71 | command = "python check_duplication.py --proj_name " + proj_name \ 72 | + " --base_dir " + os.path.join('rule_classifier_data', data_split) 73 | commands.append(command) 74 | 75 | with open("commands_gen_and_preprocess", 'w') as f: 76 | f.writelines("%s\n" % command for command in commands) 77 | f.close() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import product 3 | from rule_config import * 4 | 5 | promptseparator2str = {'space': " ", \ 6 | 'newline': "\n", \ 7 | 'class_names': "class_names",\ 8 | 'class_method_names': "class_method_names",\ 9 | 'method_names': "method_names"} 10 | 11 | context_location_to_index = { 12 | 'in_file':0, \ 13 | 'parent_class_file':1, \ 14 | 'import_file':2,\ 15 | 'sibling_file':3, \ 16 | 'similar_name_file':4, \ 17 | 'child_class_file':5, \ 18 | 'import_of_sibling_file':6, \ 19 | 'import_of_similar_name_file':7, \ 20 | 'import_of_parent_class_file':8, \ 21 | 'import_of_child_class_file':9, \ 22 | 'codex': 10 #codex 23 | } 24 | 25 | 26 | context_types_to_index = { 27 | 'method_names_and_bodies':0,\ 28 | 'method_names':1,\ 29 | 'identifiers':2, \ 30 | 'type_identifiers':3,\ 31 | 'string_literals':4,\ 32 | 'field_declarations':5, \ 33 | 'lines':6, \ 34 | 'codex': 7 #codex 35 | } 36 | 37 | count = 0 38 | combined_to_index = {} 39 | cl_keys = list(context_location_to_index.keys()) 40 | ct_keys = list(context_types_to_index.keys()) 41 | cl_keys.remove('codex') 42 | ct_keys.remove('codex') 43 | for (k1, k2) in product(cl_keys, ct_keys): 44 | if (k1 != 'in_file' and k2 in ct_keys[:-1]) or (k1 == 'in_file' and k2 in ct_keys[1:]): 45 | cr_keys = rule_hyperparams[k2]['context_ratio'] 46 | for k3 in cr_keys: 47 | key = k1 + '#' + k2 + '#' + str(k3) 48 | combined_to_index[key] = count 49 | count +=1 50 | combined_to_index['codex'] = count 51 | #print(combined_to_index) 52 | 53 | def get_multi_hot_vector(lst, type): 54 | set_lst = list(set(lst)) 55 | if type == 'cl': 56 | index_dict = context_location_to_index 57 | if type == 'ct': 58 | index_dict = context_types_to_index 59 | if type == 'com': 60 | index_dict = combined_to_index 61 | vector_size = len(index_dict) 62 | multi_hot_vector = np.zeros(vector_size) 63 | for entry in lst: 64 | multi_hot_vector[index_dict[entry]] = 1 65 | return multi_hot_vector 66 | 67 | def is_valid_hole(hole, duplicate_files): 68 | hole_parts = hole.split('/')[-1].split('_') 69 | if len(hole_parts) > 3: 70 | new_hole_parts = hole_parts[:-2] 71 | filename = '_'.join(new_hole_parts) 72 | filename = [filename] 73 | else: 74 | filename = [hole_parts[0]] 75 | file = '/'.join(hole.split('/')[:-1] + filename) 76 | if file in duplicate_files: 77 | return False 78 | else: 79 | return True 80 | 81 | def find_intersection(lst1, lst2): 82 | set_lst1 = set(lst1) 83 | set_lst2 = set(lst2) 84 | return set_lst1.intersection(set_lst2) 85 | 86 | def alter_hid(orig_hid, hid): 87 | data_split = hid.split('/')[1] 88 | if 'gcode-data' in orig_hid: 89 | new_id = orig_hid.replace('data/gcode-data', 'rule_classifier_data/' + data_split) 90 | return new_id 91 | elif 'java-other' in orig_hid: 92 | new_id = orig_hid.replace('data/java-other', 'rule_classifier_data/' + data_split) 93 | return new_id 94 | else: 95 | return orig_hid 96 | 97 | def find_usages(query_att, query_file, lst_key_att, key_file): 98 | usages = [] 99 | query_str = get_string(query_file, query_att[0], query_att[1]) 100 | for key_att in lst_key_att: 101 | key_str = get_string(key_file, key_att[0], key_att[1]) 102 | if key_str == query_str: 103 | usages.append(key_att) 104 | return usages 105 | 106 | def update_list(src_lst, tgt_lst, f, return_type='str'): 107 | for elem in src_lst: 108 | elem_str = get_string(f, elem[0], elem[1]) 109 | if elem_str not in tgt_lst: 110 | if return_type == 'pos': 111 | tgt_lst.append(elem) 112 | else: 113 | tgt_lst.append(elem_str) 114 | return tgt_lst 115 | 116 | def find_similar_intersection(file1, file2): 117 | lst1 = [x.split('/')[-1] for x in file1] 118 | lst2 = [x.split('/')[-1] for x in file2] 119 | #print(lst1, lst2) 120 | return find_intersection(lst1, lst2) 121 | 122 | def get_codex_tokenized_string(tokenizer, input_str, context_len, type='back'): 123 | ''' 124 | get the codex tokenized string 125 | ''' 126 | if input_str: 127 | codex_tokens = tokenizer(input_str)['input_ids'] 128 | if type == 'front': 129 | truncated_codex_tokens = codex_tokens[:context_len] 130 | else: 131 | truncated_codex_tokens = codex_tokens[-context_len:] 132 | out_str = tokenizer.decode(truncated_codex_tokens) 133 | return out_str, len(truncated_codex_tokens) 134 | else: 135 | return '', 0 136 | 137 | def join_lines(lst): 138 | return ''.join(lst) 139 | 140 | # take start line as the first non-commented and non-empty line 141 | def modified_start_line(lines): 142 | for i in range(len(lines)): 143 | line = lines[i] 144 | if line and not (line.startswith('/') or line.startswith('*')): # not part of the license text or empty line 145 | return i 146 | 147 | def get_string(filename, start, end): 148 | ''' 149 | get the string corresponding to the start and end positions in the parse tree 150 | ''' 151 | lines = open(filename, encoding="utf8", errors='backslashreplace').readlines() 152 | start_line, start_char = start 153 | span_str = '' 154 | if start_line == 0: 155 | start_line = modified_start_line(lines) 156 | end_line, end_char = end 157 | if start_line <= end_line and start_line < len(lines) and start_line!= -1: 158 | if start_line == end_line: 159 | if end_char == -1: 160 | span_str = lines[start_line] 161 | else: 162 | span_str = lines[start_line][start_char:end_char] 163 | else: 164 | if start_line + 1 < len(lines): 165 | span_str = lines[start_line][start_char:] + \ 166 | join_lines(lines[start_line+1: end_line]) + \ 167 | lines[end_line][:end_char] 168 | return span_str 169 | 170 | 171 | 172 | def get_context_from_java_github(out_context_len): 173 | dataset_filename = os.path.join('preprocessed_data/java_github', 'holes_1.val') 174 | data = pickle.load(open(dataset_filename, 'rb')) 175 | out_context_prompts = [] 176 | for i in range(len(data)): 177 | for j in range(len(data[i])): 178 | for k in range(len(data[i][j])): 179 | file_data = data[i][j][k] 180 | if file_data[0]: 181 | file_token_str = file_data[0] 182 | out_context_prompts.append(get_codex_tokenized_string(tokenizer, file_token_str, out_context_len)) 183 | return out_context_prompts --------------------------------------------------------------------------------