├── .gitignore ├── LICENSE ├── README.md ├── config.yml ├── data_processing ├── __init__.py ├── corpus_extractor.py ├── graph_features.py ├── graph_processing.py └── sample_inf_processing.py ├── detailed_infer.py ├── graph_pb2.py ├── infer.py ├── model ├── __init__.py └── model.py ├── saved_models ├── MethodNaming │ ├── Definition │ │ ├── checkpoint │ │ │ ├── train.ckpt.data-00000-of-00001 │ │ │ ├── train.ckpt.index │ │ │ └── train.ckpt.meta │ │ └── tokens.txt │ └── Usage │ │ ├── checkpoint │ │ ├── train.ckpt.data-00000-of-00001 │ │ ├── train.ckpt.index │ │ └── train.ckpt.meta │ │ └── tokens.txt └── VarNaming │ ├── checkpoint │ ├── train.ckpt.data-00000-of-00001 │ ├── train.ckpt.index │ └── train.ckpt.meta │ └── tokens.txt ├── train.py └── utils ├── __init__.py ├── arg_parser.py ├── utils.py └── vocabulary_extractor.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | *.DS_STORE -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Dmitry Kazhdan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Representing Programs with Graphs 2 | 3 | This project re-implements the _VarNaming_ task model described in the paper 4 | [_Learning to Represent Programs with Graphs_](https://ml4code.github.io/publications/allamanis2018learning/), 5 | which can predict the name of a variable based on it's usage. 6 | 7 | Furthermore, this project includes functionality for applying the _VarNaming_ model to the _MethodNaming_ task 8 | (predicting the name of a method from it's usage or definition). 9 | 10 | ## Citing 11 | 12 | If you use the provided implementation in your research, please cite the [_Learning to Represent Programs with Graphs_](https://ml4code.github.io/publications/allamanis2018learning/) paper, and include a link to this repository as a footnote. 13 | 14 | 15 | ## Setup 16 | ### Prerequisites 17 | 18 | Ensure you have the following packages installed 19 | (these can all be installed with pip3): 20 | 21 | - numpy 22 | - pyYAML 23 | - tensorflow-gpu (or tensorflow) 24 | - dpu_utils 25 | - protobuf 26 | 27 | 28 | ### Dataset Format 29 | 30 | The corpus pre-processing functions are designed to work with _.proto_ 31 | graph files, which can be extracted from program source code using the feature 32 | extractor available [here](https://github.com/acr31/features-javac). 33 | 34 | 35 | 36 | 37 | ### Dataset Parsing 38 | 39 | Once you have obtained a corpus of .proto graph files, it is possible 40 | to use the _corpus_extractor.py_ file located in the _data_processing_ folder. 41 | 42 | - Create empty directories for training, validation and test datasets 43 | - Specify their paths, as well as the corpus path, in the 44 | _config.yml_ file: 45 | ```python 46 | corpus_path: "path-to-corpus" 47 | train_path: "path-to-train-data-output" 48 | val_path: "path-to-val-data-output" 49 | test_path: "path-to-test-data-output" 50 | ``` 51 | - Navigate into the repository directory 52 | - Run _corpus_extractor.py_: 53 | 54 | ```python 55 | python3 ./data_processing/corpus_extractor.py 56 | ``` 57 | 58 | This will extract all samples from the corpus, randomly shuffle them, 59 | split them into train/val/test partitions, and copy these partitions into the specified 60 | train, val and test folders. 61 | 62 | 63 | 64 | ## Usage 65 | 66 | ### Training 67 | 68 | In order to train the model: 69 | 70 | - Prepare training and validation dataset directories, 71 | as described in the _Dataset Parsing_ section above 72 | - Specify their paths in the _config.yml_ file: 73 | ```python 74 | train_path: "path-to-train-data" 75 | val_path: "path-to-val-data" 76 | ``` 77 | - Specify the token file path 78 | (where the extracted token vocabulary will be saved) 79 | and the checkpoint folder path (where the model checkpoint will be saved) in the _config.yml_ file 80 | (note the fixed specification of the 'train.ckpt' file): 81 | ```python 82 | checkpoint_path: "path-to-checkpoint-folder/train.ckpt" 83 | token_path: "path-to-vocabulary-txt-file" 84 | ``` 85 | - Navigate into the repository directory 86 | - Run _train.py_: 87 | 88 | ```python 89 | python3 ./train.py 90 | ``` 91 | 92 | 93 | ### Inference 94 | 95 | In order to use the model for inference: 96 | 97 | - Prepare the test dataset directory 98 | as described in the _Dataset Parsing_ section above 99 | - Specify it's path in the _config.yml_ file: 100 | ```python 101 | test_path: "path-to-test-data" 102 | ``` 103 | - Specify the token file path 104 | (where the extracted token vocabulary will be loaded from) 105 | and the checkpoint path (where the trained model will be loaded from) in the _config.yml_ file: 106 | ```python 107 | checkpoint_path: "path-to-checkpoint-folder/train.ckpt" 108 | token_path: "path-to-vocabulary-txt-file" 109 | ``` 110 | - Navigate into the repository directory 111 | - Run _infer.py_: 112 | 113 | ```python 114 | python3 ./infer.py 115 | ``` 116 | 117 | 118 | ### Detailed inference 119 | 120 | In order to use the model for inference, 121 | as well as for computing extra sample information 122 | (including variable usage information and type information): 123 | 124 | - Prepare the test dataset directory 125 | as described in the _Dataset Parsing_ section above 126 | - Specify it's path in the _config.yml_ file: 127 | ```python 128 | test_path: "path-to-test-data" 129 | ``` 130 | - Specify the token file path 131 | (where the extracted token vocabulary will be loaded from) 132 | and the checkpoint path (where the trained model will be loaded from) in the _config.yml_ file: 133 | ```python 134 | checkpoint_path: "path-to-checkpoint-folder/train.ckpt" 135 | token_path: "path-to-vocabulary-txt-file" 136 | ``` 137 | - Navigate into the repository directory 138 | - Run _detailed_infer.py_ 139 | 140 | ```python 141 | python3 ./detailed_infer.py 142 | ``` 143 | 144 | 145 | 146 | ### MethodNaming Task Selection 147 | The type of task you want the model to run can be specified by passing 148 | appropriate input arguments as follows: 149 | 150 | - To run training/inference using the VarNaming task (computing variable usage information) 151 | no input arguments are required 152 | - To run training/inference using the MethodNaming usage task (computing method usage information) 153 | add the string "_mth_usage_" as an input argument when calling the scripts 154 | - To run training/inference using the MethodNaming definition task (computing method body information) 155 | add the string "_mth_def_" as an input argument when calling the scripts 156 | 157 | For example, in order to train the model for the MethodNaming task using 158 | definition information, the script call will be the following: 159 | 160 | ```python 161 | python3 ./train.py mth_def 162 | ``` 163 | 164 | Similarly, for running inference using the MethodNaming definition task, 165 | the script call will be the following: 166 | ```python 167 | python3 ./infer.py mth_usage 168 | ``` 169 | 170 | ### Loading Saved Models 171 | 172 | The _saved_models_ directory includes pre-trained models, which can 173 | be used to run inference directly, without any training. 174 | The paths to the saved checkpoint and vocabulary files need to be specified 175 | in the _config.yml_ file 176 | in the usual way, as described in the "Inference" section above. 177 | 178 | 179 | 180 | 181 | 182 | ## Files/Directories 183 | 184 | - data_processing: includes code for processing graph samples and corpus files 185 | - model: includes the implementation of the VarNaming model 186 | - saved_models: pre-trained models for the VarNaming and MethodNaming tasks 187 | - utils: auxiliary code implementing various functionality, such as input 188 | argument parsing and vocabulary extraction 189 | - train.py, infer.py, detailed_infer.py: files for running training and inference 190 | using the model, as described in the previous sections 191 | - config.yml: configuration file storing string properties 192 | - graph_pb2.py: used for parsing .proto sample files 193 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Environment Variables 4 | corpus_path: 5 | checkpoint_path: 6 | train_path: 7 | val_path: 8 | test_path: 9 | token_path: 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /data_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/data_processing/__init__.py -------------------------------------------------------------------------------- /data_processing/corpus_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from random import shuffle 4 | from shutil import copyfile 5 | 6 | 7 | 8 | def split_samples(): 9 | 10 | with open("config.yml", 'r') as ymlfile: 11 | cfg = yaml.load(ymlfile) 12 | 13 | corpus_path = cfg['corpus_path'] 14 | train_path = cfg['train_path'] 15 | val_path = cfg['val_path'] 16 | test_path = cfg['test_path'] 17 | 18 | f_names = [] 19 | ignore = ("Test.java.proto", "TestCase.java.proto", "Tests.java.proto") # Ignore test cases 20 | max_size_mb = 100 # maximum file size in MB 21 | min_size_mb = 0.05 22 | 23 | 24 | # Extract all filenames from corpus folders 25 | for dirpath, dirs, files in os.walk(corpus_path): 26 | for filename in files: 27 | if filename.endswith('proto') and not filename.endswith(ignore): 28 | 29 | fname = os.path.join(dirpath, filename) 30 | 31 | f_size_mb = os.path.getsize(fname) / 1000000 32 | 33 | if f_size_mb < max_size_mb and f_size_mb > min_size_mb: 34 | fname = os.path.join(dirpath, filename) 35 | f_names.append(fname) 36 | 37 | 38 | 39 | # Copy subset of samples into training/validation/testing directories 40 | n_samples = len(f_names) 41 | n_train_and_val = round(n_samples * 0.85) 42 | n_train = round(n_train_and_val * 0.85) 43 | 44 | shuffle(f_names) 45 | 46 | train_samples = f_names[:n_train] 47 | val_samples = f_names[n_train:n_train_and_val] 48 | test_samples = f_names[n_train_and_val:n_samples] 49 | 50 | copy_samples(train_samples, train_path) 51 | copy_samples(val_samples, val_path) 52 | copy_samples(test_samples, test_path) 53 | 54 | 55 | 56 | def copy_samples(sample_names, base_path): 57 | 58 | for src in sample_names: 59 | dst = os.path.join(base_path, os.path.basename(src)) 60 | copyfile(src, dst) 61 | 62 | 63 | 64 | 65 | split_samples() 66 | -------------------------------------------------------------------------------- /data_processing/graph_features.py: -------------------------------------------------------------------------------- 1 | from graph_pb2 import FeatureNode, FeatureEdge 2 | 3 | 4 | def get_used_edges_type(): 5 | 6 | used_edge_types = [FeatureEdge.NEXT_TOKEN, FeatureEdge.AST_CHILD, FeatureEdge.LAST_WRITE, 7 | FeatureEdge.LAST_USE, FeatureEdge.COMPUTED_FROM, FeatureEdge.RETURNS_TO, 8 | FeatureEdge.FORMAL_ARG_NAME, FeatureEdge.GUARDED_BY, FeatureEdge.GUARDED_BY_NEGATION, 9 | FeatureEdge.LAST_LEXICAL_USE, 10 | FeatureEdge.ASSIGNABLE_TO, FeatureEdge.ASSOCIATED_TOKEN, 11 | FeatureEdge.HAS_TYPE, FeatureEdge.ASSOCIATED_SYMBOL] 12 | 13 | return used_edge_types 14 | 15 | 16 | 17 | def get_used_nodes_type(): 18 | 19 | used_node_types = [FeatureNode.TOKEN, FeatureNode.AST_ELEMENT, FeatureNode.IDENTIFIER_TOKEN, 20 | FeatureNode.FAKE_AST, 21 | FeatureNode.SYMBOL_TYP, FeatureNode.COMMENT_LINE, 22 | FeatureNode.TYPE] 23 | 24 | return used_node_types 25 | -------------------------------------------------------------------------------- /data_processing/graph_processing.py: -------------------------------------------------------------------------------- 1 | from graph_pb2 import FeatureNode, FeatureEdge 2 | from collections import defaultdict 3 | import numpy as np 4 | from dpu_utils.codeutils import split_identifier_into_parts 5 | from data_processing.graph_features import get_used_edges_type, get_used_nodes_type 6 | 7 | 8 | ''' 9 | Extract usage information from a given graph 10 | 11 | :graph: input graph sample 12 | :max_path_len: number of GGNN timesteps (used to remove nodes/edges unreachable in this amount of timesteps from nodes) 13 | :max_usages: maximum number of method/variable tokes 14 | :node_rep_len: maximum number of subtokens in node representation 15 | :pad_token: vocabulary pad token 16 | :slot_token: vocabulary slot token 17 | :vocabulary: corpus token vocabulary 18 | :get_method_data: whether to compute variable usage data (get_method_data=False), or method usage data (get_method_data=True) 19 | ''' 20 | def get_usage_samples(graph, max_path_len, max_usages, node_rep_len, pad_token, slot_token, vocabulary, get_method_data=False): 21 | 22 | successor_table = defaultdict(set) 23 | predecessor_table = defaultdict(set) 24 | edge_table = defaultdict(list) 25 | node_table = {} 26 | sym_node_ids = [] 27 | samples = [] 28 | non_empty_sym_nodes = [] 29 | 30 | 31 | if get_method_data: 32 | parent_usage_node_type = FeatureNode.SYMBOL_MTH 33 | else: 34 | parent_usage_node_type = FeatureNode.SYMBOL_VAR 35 | 36 | 37 | for node in graph.node: 38 | 39 | node_table[node.id] = node 40 | 41 | if node.type == parent_usage_node_type: 42 | sym_node_ids.append(node.id) 43 | 44 | 45 | for edge in graph.edge: 46 | successor_table[edge.sourceId].add(edge.destinationId) 47 | predecessor_table[edge.destinationId].add(edge.sourceId) 48 | edge_table[edge.sourceId].append(edge) 49 | 50 | 51 | 52 | for sym_node_id in sym_node_ids: 53 | 54 | successor_ids = successor_table[sym_node_id] 55 | 56 | identifier_node_ids = [node_id for node_id in successor_ids 57 | if node_table[node_id].type == FeatureNode.IDENTIFIER_TOKEN] 58 | 59 | decl_id_nodes = [] 60 | 61 | # If doing method processing, need to also check for presence of the method declaration 62 | if get_method_data: 63 | ast_elem_successors = [node_id for node_id in successor_ids 64 | if node_table[node_id].type==FeatureNode.AST_ELEMENT and node_table[node_id].contents=='METHOD'] 65 | 66 | if len(ast_elem_successors) > 0: 67 | 68 | method_decl_node_id = ast_elem_successors[0] 69 | 70 | decl_id_nodes = [node_id for node_id in successor_table[method_decl_node_id] if node_table[node_id].type == FeatureNode.IDENTIFIER_TOKEN] 71 | 72 | 73 | 74 | if len(identifier_node_ids) == 0 or len(identifier_node_ids) > max_usages: 75 | continue 76 | 77 | 78 | # Compute subgraph of nodes/edges reachable from identifier nodes in the given amount of path steps 79 | reachable_node_ids = [] 80 | successor_ids = identifier_node_ids 81 | predecessor_ids = identifier_node_ids 82 | 83 | for _ in range(max_path_len): 84 | reachable_node_ids += successor_ids 85 | reachable_node_ids += predecessor_ids 86 | successor_ids = list(set([elem for n_id in successor_ids for elem in successor_table[n_id]])) 87 | predecessor_ids = list(set([elem for n_id in predecessor_ids for elem in predecessor_table[n_id]])) 88 | 89 | reachable_node_ids += successor_ids 90 | reachable_node_ids += predecessor_ids 91 | reachable_node_ids = set(reachable_node_ids) 92 | 93 | 94 | sub_nodes = [node_table[node_id] for node_id in reachable_node_ids] 95 | 96 | sub_edges = [edge for node in sub_nodes for edge in edge_table[node.id] 97 | if edge.sourceId in reachable_node_ids and edge.destinationId in reachable_node_ids] 98 | 99 | sub_graph = (sub_nodes, sub_edges) 100 | 101 | sample_data = compute_sample_data(sub_graph, identifier_node_ids, node_rep_len, pad_token, slot_token, vocabulary, decl_id_nodes) 102 | samples.append(sample_data) 103 | non_empty_sym_nodes.append(sym_node_id) 104 | 105 | return samples, non_empty_sym_nodes 106 | 107 | 108 | 109 | ''' 110 | Used to create input samples from a given graph 111 | 112 | :sub_graph: input graph 113 | :identifier_token_node_ids: usage/declaration node ids 114 | :seq_length: length of node representation 115 | :pad_token: vocabulary pad token 116 | :slot_token: vocabulary slot token 117 | :vocabulary: corpus token vocabulary 118 | :exception_node_ids: when computing method usage information, there is a chance that a method declaration node is reachable 119 | from one of the usage nodes. This node should also be masked with a token, but should not be used in consequent decoding steps 120 | (because it is not a usage node, but a declaration node), thus it is marked as an exception 121 | ''' 122 | def compute_sample_data(sub_graph, identifier_token_node_ids, seq_length, pad_token, slot_token, vocabulary, exception_node_ids = []): 123 | 124 | used_node_types = get_used_nodes_type() 125 | used_edge_types = get_used_edges_type() 126 | 127 | node_representations = [] 128 | id_to_index_map = {} 129 | ind = 0 130 | 131 | (sub_nodes, sub_edges) = sub_graph 132 | 133 | for node in sub_nodes: 134 | if node.type in used_node_types: 135 | if node.id in exception_node_ids: 136 | node_representation = [pad_token for _ in range(seq_length)] 137 | node_representation[0] = slot_token 138 | else: 139 | node_representation = vocabulary.get_id_or_unk_multiple(split_identifier_into_parts(node.contents), seq_length, pad_token) 140 | 141 | node_representations.append(node_representation) 142 | id_to_index_map[node.id] = ind 143 | ind += 1 144 | 145 | n_nodes = len(node_representations) 146 | n_types = len(used_edge_types) 147 | node_representations = np.array(node_representations) 148 | num_incoming_edges_per_type = np.zeros((n_nodes, n_types)) 149 | num_outgoing_edges_per_type = np.zeros((n_nodes, n_types)) 150 | adj_lists = defaultdict(list) 151 | 152 | for edge in sub_edges: 153 | if edge.type in used_edge_types \ 154 | and edge.sourceId in id_to_index_map \ 155 | and edge.destinationId in id_to_index_map: 156 | 157 | type_id = used_edge_types.index(edge.type) 158 | adj_lists[type_id].append([id_to_index_map[edge.sourceId], id_to_index_map[edge.destinationId]]) 159 | num_incoming_edges_per_type[id_to_index_map[edge.destinationId], type_id] += 1 160 | num_outgoing_edges_per_type[id_to_index_map[edge.sourceId], type_id] += 1 161 | 162 | final_adj_lists = {edge_type: np.array(sorted(adj_list), dtype=np.int32) 163 | for edge_type, adj_list in adj_lists.items()} 164 | 165 | # Add empty entries for types with no adjacency lists 166 | for i in range(len(used_edge_types)): 167 | if i not in final_adj_lists: 168 | final_adj_lists[i] = np.zeros((0, 2), dtype=np.int32) 169 | 170 | 171 | identifier_nodes = [id_to_index_map[node_id] for node_id in identifier_token_node_ids] 172 | 173 | return (identifier_nodes, node_representations, final_adj_lists, \ 174 | num_incoming_edges_per_type, num_outgoing_edges_per_type) 175 | 176 | 177 | 178 | 179 | ''' 180 | Extract method body information from a given graph 181 | 182 | :graph: input graph sample 183 | :node_seq_length: maximum number of subtokens in node representation 184 | :pad_token: vocabulary pad token 185 | :slot_token: vocabulary slot token 186 | :vocabulary: corpus token vocabulary 187 | ''' 188 | 189 | def get_method_body_samples(graph, node_seq_length, pad_token, slot_token, vocabulary): 190 | 191 | successor_table = defaultdict(set) 192 | predecessor_table = defaultdict(set) 193 | edge_table = defaultdict(list) 194 | node_table = {} 195 | ast_elem_node_ids = [] 196 | samples = [] 197 | non_empty_ast_nodes = [] 198 | 199 | 200 | for node in graph.node: 201 | 202 | node_table[node.id] = node 203 | 204 | if node.type==FeatureNode.AST_ELEMENT and node.contents=='METHOD': 205 | ast_elem_node_ids.append(node.id) 206 | 207 | 208 | for edge in graph.edge: 209 | successor_table[edge.sourceId].add(edge.destinationId) 210 | predecessor_table[edge.destinationId].add(edge.sourceId) 211 | edge_table[edge.sourceId].append(edge) 212 | 213 | 214 | 215 | for ast_elem_node_id in ast_elem_node_ids: 216 | 217 | successor_ids = successor_table[ast_elem_node_id] 218 | predecessor_ids = predecessor_table[ast_elem_node_id] 219 | 220 | method_name_ids = [node_id for node_id in successor_ids 221 | if node_table[node_id].type == FeatureNode.IDENTIFIER_TOKEN] 222 | 223 | 224 | sym_mth_parents = [node_id for node_id in predecessor_ids if node_table[node_id].type == FeatureNode.SYMBOL_MTH] 225 | 226 | 227 | if len(sym_mth_parents) > 0: 228 | 229 | usage_node_ids = [node_id for sym_mth_parent in sym_mth_parents 230 | for node_id in successor_table[sym_mth_parent] 231 | if node_table[node_id].type == FeatureNode.IDENTIFIER_TOKEN] 232 | 233 | else: 234 | continue 235 | 236 | method_name_ids += usage_node_ids 237 | 238 | # Compute all nodes reachable from an AST_ELEMENT METHOD node through any edges except node token edges 239 | # (hence stop computing successors when an IDENTIFIER_TOKEN/TOKEN node is reached) 240 | reachable_node_ids = [ast_elem_node_id] 241 | successor_ids = list(set([elem for elem in successor_table[ast_elem_node_id]])) 242 | 243 | while len(successor_ids) != 0: 244 | 245 | reachable_node_ids += successor_ids 246 | 247 | new_successors = [] 248 | 249 | for n_id in successor_ids: 250 | 251 | if node_table[n_id].type != FeatureNode.IDENTIFIER_TOKEN and node_table[n_id].type != FeatureNode.TOKEN: 252 | 253 | for elem in successor_table[n_id]: 254 | 255 | if elem not in reachable_node_ids: 256 | new_successors.append(elem) 257 | 258 | successor_ids = list(set(new_successors)) 259 | 260 | reachable_node_ids = list(set(reachable_node_ids)) 261 | 262 | 263 | # Compute all reachable nodes that include the method name 264 | method_name_ids = list(set(reachable_node_ids).intersection(set(method_name_ids))) 265 | 266 | if len(method_name_ids) == 0: continue 267 | 268 | sub_nodes = [node_table[node_id] for node_id in reachable_node_ids] 269 | 270 | sub_edges = [edge for node in sub_nodes for edge in edge_table[node.id] 271 | if edge.sourceId in reachable_node_ids and edge.destinationId in reachable_node_ids] 272 | 273 | sub_graph = (sub_nodes, sub_edges) 274 | 275 | sample_data = compute_sample_data(sub_graph, method_name_ids, node_seq_length, pad_token, slot_token, vocabulary) 276 | samples.append(sample_data) 277 | non_empty_ast_nodes.append(ast_elem_node_id) 278 | 279 | return samples, non_empty_ast_nodes 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | -------------------------------------------------------------------------------- /data_processing/sample_inf_processing.py: -------------------------------------------------------------------------------- 1 | from graph_pb2 import Graph 2 | from graph_pb2 import FeatureNode 3 | from collections import defaultdict 4 | from utils.utils import compute_successors_and_predecessors, compute_node_table 5 | 6 | 7 | class SampleMetaInformation(): 8 | 9 | def __init__(self, sample_fname, node_id): 10 | 11 | self.fname = sample_fname 12 | self.node_id = node_id 13 | self.predicted_correctly = None 14 | self.empty_type = "undefined" 15 | self.type = self.empty_type 16 | self.num_usages = None 17 | self.usage_rep = None 18 | self.true_label = None 19 | self.predicted_label = None 20 | self.seen_in_training = None 21 | 22 | 23 | def compute_var_type(self): 24 | 25 | if self.type != self.empty_type: return self.type 26 | 27 | 28 | with open(self.fname, "rb") as f: 29 | 30 | g = Graph() 31 | g.ParseFromString(f.read()) 32 | 33 | var_type = get_var_type(g, self.node_id, self.empty_type) 34 | 35 | self.type = var_type 36 | 37 | return var_type 38 | 39 | 40 | 41 | def compute_var_usages(self): 42 | 43 | if self.num_usages is not None: return self.num_usages 44 | 45 | with open(self.fname, "rb") as f: 46 | 47 | g = Graph() 48 | g.ParseFromString(f.read()) 49 | 50 | n_usages = get_var_usages(g, self.node_id) 51 | 52 | self.num_usages = n_usages 53 | 54 | return n_usages 55 | 56 | 57 | 58 | 59 | class CorpusMetaInformation(): 60 | 61 | def __init__(self, _sample_meta_infs): 62 | self.sample_meta_infs = _sample_meta_infs 63 | 64 | 65 | def add_sample_inf(self, sample_inf): 66 | self.sample_meta_infs.append(sample_inf) 67 | 68 | 69 | def process_sample_inf(self): 70 | 71 | incorr_usage_classes, corr_usage_classes = defaultdict(int), defaultdict(int) 72 | incorr_type_classes, corr_type_classes = defaultdict(int), defaultdict(int) 73 | 74 | 75 | # Compute and print usage and type information from entire corpus 76 | for sample_inf in self.sample_meta_infs: 77 | 78 | if sample_inf.seen_in_training: 79 | 80 | sample_inf.compute_var_usages() 81 | sample_inf.compute_var_type() 82 | 83 | if sample_inf.predicted_correctly: 84 | corr_usage_classes[sample_inf.num_usages] += 1 85 | corr_type_classes[sample_inf.type] += 1 86 | else: 87 | incorr_usage_classes[sample_inf.num_usages] += 1 88 | incorr_type_classes[sample_inf.type] += 1 89 | 90 | 91 | 92 | # Print the computed information: 93 | all_usage_keys = list(set(incorr_usage_classes.keys()).union(corr_usage_classes.keys())) 94 | 95 | for usage_key in all_usage_keys: 96 | print(str(usage_key) + " usages: ", incorr_usage_classes[usage_key], " (incorrect) ", corr_usage_classes[usage_key], " (correct)") 97 | print("") 98 | 99 | 100 | 101 | all_type_keys = list(set(incorr_type_classes.keys()).union(corr_type_classes.keys())) 102 | 103 | for type_key in all_type_keys: 104 | print(str(type_key), incorr_type_classes[type_key], " (incorrect) ", corr_type_classes[type_key], " (correct)") 105 | print("") 106 | 107 | 108 | 109 | 110 | def get_var_type(graph, sym_var_node_id, empty_type): 111 | 112 | node_table = compute_node_table(graph) 113 | successors, predecessors = compute_successors_and_predecessors(graph) 114 | 115 | id_token_nodes = [n_id for n_id in successors[sym_var_node_id] if node_table[n_id].type == FeatureNode.IDENTIFIER_TOKEN] 116 | 117 | ast_parent = -1 118 | 119 | for id_token_node in id_token_nodes: 120 | for parent_id in predecessors[id_token_node]: 121 | 122 | if node_table[parent_id].type == FeatureNode.AST_ELEMENT and node_table[parent_id].contents == "VARIABLE": 123 | ast_parent = parent_id 124 | break 125 | 126 | if ast_parent != -1: break 127 | 128 | 129 | if ast_parent == -1: return empty_type 130 | 131 | 132 | fake_ast_type_nodes = [n for n in successors[ast_parent] 133 | if node_table[n].type == FeatureNode.FAKE_AST and node_table[n].contents == "TYPE"] 134 | 135 | if len(fake_ast_type_nodes) == 0: 136 | return empty_type 137 | 138 | else: 139 | fake_ast_type_node = fake_ast_type_nodes[0] 140 | 141 | 142 | fake_ast_type_successors = list(successors[fake_ast_type_node]) 143 | 144 | if len(fake_ast_type_successors) == 0: 145 | return "empty_type" 146 | 147 | else: 148 | fake_ast_type_successor = fake_ast_type_successors[0] 149 | 150 | 151 | type_contents = [node_table[n].contents for n in successors[fake_ast_type_successor] if node_table[n].type == FeatureNode.TYPE] 152 | 153 | if len(type_contents) == 0: 154 | return empty_type 155 | 156 | else: 157 | type_content = type_contents[0] 158 | 159 | return type_content 160 | 161 | 162 | 163 | 164 | # Find all Identifier Token successors of a SYM_VAR node 165 | def get_var_usages(graph, var_id): 166 | 167 | node_table = compute_node_table(graph) 168 | 169 | usages = 0 170 | 171 | for edge in graph.edge: 172 | if edge.sourceId == var_id: 173 | dest_id = edge.destinationId 174 | child_node = node_table[dest_id] 175 | 176 | if child_node.type == FeatureNode.IDENTIFIER_TOKEN: 177 | usages += 1 178 | 179 | return usages 180 | 181 | 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /detailed_infer.py: -------------------------------------------------------------------------------- 1 | from utils import vocabulary_extractor 2 | from model.model import Model 3 | import yaml 4 | import sys 5 | from utils.arg_parser import parse_input_args 6 | 7 | 8 | def detailed_inference(task_id): 9 | 10 | with open("config.yml", 'r') as ymlfile: 11 | cfg = yaml.load(ymlfile) 12 | 13 | checkpoint_path = cfg['checkpoint_path'] 14 | train_path = cfg['train_path'] 15 | test_path = cfg['test_path'] 16 | token_path = cfg['token_path'] 17 | 18 | # Run inference 19 | vocabulary = vocabulary_extractor.load_vocabulary(token_path) 20 | m = Model(mode='infer', task_id=task_id, vocabulary=vocabulary) 21 | m.metrics_on_seen_vars(train_path, test_path, checkpoint_path=checkpoint_path) 22 | 23 | print("Inference ran successfully...") 24 | 25 | 26 | detailed_inference(0) 27 | 28 | -------------------------------------------------------------------------------- /graph_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: graph.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor.FileDescriptor( 18 | name='graph.proto', 19 | package='protobuf', 20 | syntax='proto2', 21 | serialized_options=_b('\n$uk.ac.cam.acr31.features.javac.protoB\013GraphProtos'), 22 | serialized_pb=_b('\n\x0bgraph.proto\x12\x08protobuf\"\x82\x03\n\x0b\x46\x65\x61tureNode\x12\n\n\x02id\x18\x01 \x01(\x03\x12,\n\x04type\x18\x02 \x01(\x0e\x32\x1e.protobuf.FeatureNode.NodeType\x12\x10\n\x08\x63ontents\x18\x03 \x01(\t\x12\x15\n\rstartPosition\x18\x04 \x01(\x05\x12\x13\n\x0b\x65ndPosition\x18\x05 \x01(\x05\x12\x17\n\x0fstartLineNumber\x18\x06 \x01(\x05\x12\x15\n\rendLineNumber\x18\x07 \x01(\x05\"\xca\x01\n\x08NodeType\x12\t\n\x05TOKEN\x10\x01\x12\x0f\n\x0b\x41ST_ELEMENT\x10\x02\x12\x10\n\x0c\x43OMMENT_LINE\x10\x03\x12\x11\n\rCOMMENT_BLOCK\x10\x04\x12\x13\n\x0f\x43OMMENT_JAVADOC\x10\x05\x12\x14\n\x10IDENTIFIER_TOKEN\x10\x07\x12\x0c\n\x08\x46\x41KE_AST\x10\x08\x12\n\n\x06SYMBOL\x10\t\x12\x0e\n\nSYMBOL_TYP\x10\n\x12\x0e\n\nSYMBOL_VAR\x10\x0b\x12\x0e\n\nSYMBOL_MTH\x10\x0c\x12\x08\n\x04TYPE\x10\r\"\x8a\x03\n\x0b\x46\x65\x61tureEdge\x12\x10\n\x08sourceId\x18\x01 \x01(\x03\x12\x15\n\rdestinationId\x18\x02 \x01(\x03\x12,\n\x04type\x18\x03 \x01(\x0e\x32\x1e.protobuf.FeatureEdge.EdgeType\"\xa3\x02\n\x08\x45\x64geType\x12\x14\n\x10\x41SSOCIATED_TOKEN\x10\x01\x12\x0e\n\nNEXT_TOKEN\x10\x02\x12\r\n\tAST_CHILD\x10\x03\x12\x08\n\x04NONE\x10\x04\x12\x0e\n\nLAST_WRITE\x10\x05\x12\x0c\n\x08LAST_USE\x10\x06\x12\x11\n\rCOMPUTED_FROM\x10\x07\x12\x0e\n\nRETURNS_TO\x10\x08\x12\x13\n\x0f\x46ORMAL_ARG_NAME\x10\t\x12\x0e\n\nGUARDED_BY\x10\n\x12\x17\n\x13GUARDED_BY_NEGATION\x10\x0b\x12\x14\n\x10LAST_LEXICAL_USE\x10\x0c\x12\x0b\n\x07\x43OMMENT\x10\r\x12\x15\n\x11\x41SSOCIATED_SYMBOL\x10\x0e\x12\x0c\n\x08HAS_TYPE\x10\x0f\x12\x11\n\rASSIGNABLE_TO\x10\x10\"\xba\x01\n\x05Graph\x12#\n\x04node\x18\x01 \x03(\x0b\x32\x15.protobuf.FeatureNode\x12#\n\x04\x65\x64ge\x18\x02 \x03(\x0b\x32\x15.protobuf.FeatureEdge\x12\x12\n\nsourceFile\x18\x03 \x01(\t\x12*\n\x0b\x66irst_token\x18\x04 \x01(\x0b\x32\x15.protobuf.FeatureNode\x12\'\n\x08\x61st_root\x18\x05 \x01(\x0b\x32\x15.protobuf.FeatureNodeB3\n$uk.ac.cam.acr31.features.javac.protoB\x0bGraphProtos') 23 | ) 24 | 25 | 26 | 27 | _FEATURENODE_NODETYPE = _descriptor.EnumDescriptor( 28 | name='NodeType', 29 | full_name='protobuf.FeatureNode.NodeType', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | values=[ 33 | _descriptor.EnumValueDescriptor( 34 | name='TOKEN', index=0, number=1, 35 | serialized_options=None, 36 | type=None), 37 | _descriptor.EnumValueDescriptor( 38 | name='AST_ELEMENT', index=1, number=2, 39 | serialized_options=None, 40 | type=None), 41 | _descriptor.EnumValueDescriptor( 42 | name='COMMENT_LINE', index=2, number=3, 43 | serialized_options=None, 44 | type=None), 45 | _descriptor.EnumValueDescriptor( 46 | name='COMMENT_BLOCK', index=3, number=4, 47 | serialized_options=None, 48 | type=None), 49 | _descriptor.EnumValueDescriptor( 50 | name='COMMENT_JAVADOC', index=4, number=5, 51 | serialized_options=None, 52 | type=None), 53 | _descriptor.EnumValueDescriptor( 54 | name='IDENTIFIER_TOKEN', index=5, number=7, 55 | serialized_options=None, 56 | type=None), 57 | _descriptor.EnumValueDescriptor( 58 | name='FAKE_AST', index=6, number=8, 59 | serialized_options=None, 60 | type=None), 61 | _descriptor.EnumValueDescriptor( 62 | name='SYMBOL', index=7, number=9, 63 | serialized_options=None, 64 | type=None), 65 | _descriptor.EnumValueDescriptor( 66 | name='SYMBOL_TYP', index=8, number=10, 67 | serialized_options=None, 68 | type=None), 69 | _descriptor.EnumValueDescriptor( 70 | name='SYMBOL_VAR', index=9, number=11, 71 | serialized_options=None, 72 | type=None), 73 | _descriptor.EnumValueDescriptor( 74 | name='SYMBOL_MTH', index=10, number=12, 75 | serialized_options=None, 76 | type=None), 77 | _descriptor.EnumValueDescriptor( 78 | name='TYPE', index=11, number=13, 79 | serialized_options=None, 80 | type=None), 81 | ], 82 | containing_type=None, 83 | serialized_options=None, 84 | serialized_start=210, 85 | serialized_end=412, 86 | ) 87 | _sym_db.RegisterEnumDescriptor(_FEATURENODE_NODETYPE) 88 | 89 | _FEATUREEDGE_EDGETYPE = _descriptor.EnumDescriptor( 90 | name='EdgeType', 91 | full_name='protobuf.FeatureEdge.EdgeType', 92 | filename=None, 93 | file=DESCRIPTOR, 94 | values=[ 95 | _descriptor.EnumValueDescriptor( 96 | name='ASSOCIATED_TOKEN', index=0, number=1, 97 | serialized_options=None, 98 | type=None), 99 | _descriptor.EnumValueDescriptor( 100 | name='NEXT_TOKEN', index=1, number=2, 101 | serialized_options=None, 102 | type=None), 103 | _descriptor.EnumValueDescriptor( 104 | name='AST_CHILD', index=2, number=3, 105 | serialized_options=None, 106 | type=None), 107 | _descriptor.EnumValueDescriptor( 108 | name='NONE', index=3, number=4, 109 | serialized_options=None, 110 | type=None), 111 | _descriptor.EnumValueDescriptor( 112 | name='LAST_WRITE', index=4, number=5, 113 | serialized_options=None, 114 | type=None), 115 | _descriptor.EnumValueDescriptor( 116 | name='LAST_USE', index=5, number=6, 117 | serialized_options=None, 118 | type=None), 119 | _descriptor.EnumValueDescriptor( 120 | name='COMPUTED_FROM', index=6, number=7, 121 | serialized_options=None, 122 | type=None), 123 | _descriptor.EnumValueDescriptor( 124 | name='RETURNS_TO', index=7, number=8, 125 | serialized_options=None, 126 | type=None), 127 | _descriptor.EnumValueDescriptor( 128 | name='FORMAL_ARG_NAME', index=8, number=9, 129 | serialized_options=None, 130 | type=None), 131 | _descriptor.EnumValueDescriptor( 132 | name='GUARDED_BY', index=9, number=10, 133 | serialized_options=None, 134 | type=None), 135 | _descriptor.EnumValueDescriptor( 136 | name='GUARDED_BY_NEGATION', index=10, number=11, 137 | serialized_options=None, 138 | type=None), 139 | _descriptor.EnumValueDescriptor( 140 | name='LAST_LEXICAL_USE', index=11, number=12, 141 | serialized_options=None, 142 | type=None), 143 | _descriptor.EnumValueDescriptor( 144 | name='COMMENT', index=12, number=13, 145 | serialized_options=None, 146 | type=None), 147 | _descriptor.EnumValueDescriptor( 148 | name='ASSOCIATED_SYMBOL', index=13, number=14, 149 | serialized_options=None, 150 | type=None), 151 | _descriptor.EnumValueDescriptor( 152 | name='HAS_TYPE', index=14, number=15, 153 | serialized_options=None, 154 | type=None), 155 | _descriptor.EnumValueDescriptor( 156 | name='ASSIGNABLE_TO', index=15, number=16, 157 | serialized_options=None, 158 | type=None), 159 | ], 160 | containing_type=None, 161 | serialized_options=None, 162 | serialized_start=518, 163 | serialized_end=809, 164 | ) 165 | _sym_db.RegisterEnumDescriptor(_FEATUREEDGE_EDGETYPE) 166 | 167 | 168 | _FEATURENODE = _descriptor.Descriptor( 169 | name='FeatureNode', 170 | full_name='protobuf.FeatureNode', 171 | filename=None, 172 | file=DESCRIPTOR, 173 | containing_type=None, 174 | fields=[ 175 | _descriptor.FieldDescriptor( 176 | name='id', full_name='protobuf.FeatureNode.id', index=0, 177 | number=1, type=3, cpp_type=2, label=1, 178 | has_default_value=False, default_value=0, 179 | message_type=None, enum_type=None, containing_type=None, 180 | is_extension=False, extension_scope=None, 181 | serialized_options=None, file=DESCRIPTOR), 182 | _descriptor.FieldDescriptor( 183 | name='type', full_name='protobuf.FeatureNode.type', index=1, 184 | number=2, type=14, cpp_type=8, label=1, 185 | has_default_value=False, default_value=1, 186 | message_type=None, enum_type=None, containing_type=None, 187 | is_extension=False, extension_scope=None, 188 | serialized_options=None, file=DESCRIPTOR), 189 | _descriptor.FieldDescriptor( 190 | name='contents', full_name='protobuf.FeatureNode.contents', index=2, 191 | number=3, type=9, cpp_type=9, label=1, 192 | has_default_value=False, default_value=_b("").decode('utf-8'), 193 | message_type=None, enum_type=None, containing_type=None, 194 | is_extension=False, extension_scope=None, 195 | serialized_options=None, file=DESCRIPTOR), 196 | _descriptor.FieldDescriptor( 197 | name='startPosition', full_name='protobuf.FeatureNode.startPosition', index=3, 198 | number=4, type=5, cpp_type=1, label=1, 199 | has_default_value=False, default_value=0, 200 | message_type=None, enum_type=None, containing_type=None, 201 | is_extension=False, extension_scope=None, 202 | serialized_options=None, file=DESCRIPTOR), 203 | _descriptor.FieldDescriptor( 204 | name='endPosition', full_name='protobuf.FeatureNode.endPosition', index=4, 205 | number=5, type=5, cpp_type=1, label=1, 206 | has_default_value=False, default_value=0, 207 | message_type=None, enum_type=None, containing_type=None, 208 | is_extension=False, extension_scope=None, 209 | serialized_options=None, file=DESCRIPTOR), 210 | _descriptor.FieldDescriptor( 211 | name='startLineNumber', full_name='protobuf.FeatureNode.startLineNumber', index=5, 212 | number=6, type=5, cpp_type=1, label=1, 213 | has_default_value=False, default_value=0, 214 | message_type=None, enum_type=None, containing_type=None, 215 | is_extension=False, extension_scope=None, 216 | serialized_options=None, file=DESCRIPTOR), 217 | _descriptor.FieldDescriptor( 218 | name='endLineNumber', full_name='protobuf.FeatureNode.endLineNumber', index=6, 219 | number=7, type=5, cpp_type=1, label=1, 220 | has_default_value=False, default_value=0, 221 | message_type=None, enum_type=None, containing_type=None, 222 | is_extension=False, extension_scope=None, 223 | serialized_options=None, file=DESCRIPTOR), 224 | ], 225 | extensions=[ 226 | ], 227 | nested_types=[], 228 | enum_types=[ 229 | _FEATURENODE_NODETYPE, 230 | ], 231 | serialized_options=None, 232 | is_extendable=False, 233 | syntax='proto2', 234 | extension_ranges=[], 235 | oneofs=[ 236 | ], 237 | serialized_start=26, 238 | serialized_end=412, 239 | ) 240 | 241 | 242 | _FEATUREEDGE = _descriptor.Descriptor( 243 | name='FeatureEdge', 244 | full_name='protobuf.FeatureEdge', 245 | filename=None, 246 | file=DESCRIPTOR, 247 | containing_type=None, 248 | fields=[ 249 | _descriptor.FieldDescriptor( 250 | name='sourceId', full_name='protobuf.FeatureEdge.sourceId', index=0, 251 | number=1, type=3, cpp_type=2, label=1, 252 | has_default_value=False, default_value=0, 253 | message_type=None, enum_type=None, containing_type=None, 254 | is_extension=False, extension_scope=None, 255 | serialized_options=None, file=DESCRIPTOR), 256 | _descriptor.FieldDescriptor( 257 | name='destinationId', full_name='protobuf.FeatureEdge.destinationId', index=1, 258 | number=2, type=3, cpp_type=2, label=1, 259 | has_default_value=False, default_value=0, 260 | message_type=None, enum_type=None, containing_type=None, 261 | is_extension=False, extension_scope=None, 262 | serialized_options=None, file=DESCRIPTOR), 263 | _descriptor.FieldDescriptor( 264 | name='type', full_name='protobuf.FeatureEdge.type', index=2, 265 | number=3, type=14, cpp_type=8, label=1, 266 | has_default_value=False, default_value=1, 267 | message_type=None, enum_type=None, containing_type=None, 268 | is_extension=False, extension_scope=None, 269 | serialized_options=None, file=DESCRIPTOR), 270 | ], 271 | extensions=[ 272 | ], 273 | nested_types=[], 274 | enum_types=[ 275 | _FEATUREEDGE_EDGETYPE, 276 | ], 277 | serialized_options=None, 278 | is_extendable=False, 279 | syntax='proto2', 280 | extension_ranges=[], 281 | oneofs=[ 282 | ], 283 | serialized_start=415, 284 | serialized_end=809, 285 | ) 286 | 287 | 288 | _GRAPH = _descriptor.Descriptor( 289 | name='Graph', 290 | full_name='protobuf.Graph', 291 | filename=None, 292 | file=DESCRIPTOR, 293 | containing_type=None, 294 | fields=[ 295 | _descriptor.FieldDescriptor( 296 | name='node', full_name='protobuf.Graph.node', index=0, 297 | number=1, type=11, cpp_type=10, label=3, 298 | has_default_value=False, default_value=[], 299 | message_type=None, enum_type=None, containing_type=None, 300 | is_extension=False, extension_scope=None, 301 | serialized_options=None, file=DESCRIPTOR), 302 | _descriptor.FieldDescriptor( 303 | name='edge', full_name='protobuf.Graph.edge', index=1, 304 | number=2, type=11, cpp_type=10, label=3, 305 | has_default_value=False, default_value=[], 306 | message_type=None, enum_type=None, containing_type=None, 307 | is_extension=False, extension_scope=None, 308 | serialized_options=None, file=DESCRIPTOR), 309 | _descriptor.FieldDescriptor( 310 | name='sourceFile', full_name='protobuf.Graph.sourceFile', index=2, 311 | number=3, type=9, cpp_type=9, label=1, 312 | has_default_value=False, default_value=_b("").decode('utf-8'), 313 | message_type=None, enum_type=None, containing_type=None, 314 | is_extension=False, extension_scope=None, 315 | serialized_options=None, file=DESCRIPTOR), 316 | _descriptor.FieldDescriptor( 317 | name='first_token', full_name='protobuf.Graph.first_token', index=3, 318 | number=4, type=11, cpp_type=10, label=1, 319 | has_default_value=False, default_value=None, 320 | message_type=None, enum_type=None, containing_type=None, 321 | is_extension=False, extension_scope=None, 322 | serialized_options=None, file=DESCRIPTOR), 323 | _descriptor.FieldDescriptor( 324 | name='ast_root', full_name='protobuf.Graph.ast_root', index=4, 325 | number=5, type=11, cpp_type=10, label=1, 326 | has_default_value=False, default_value=None, 327 | message_type=None, enum_type=None, containing_type=None, 328 | is_extension=False, extension_scope=None, 329 | serialized_options=None, file=DESCRIPTOR), 330 | ], 331 | extensions=[ 332 | ], 333 | nested_types=[], 334 | enum_types=[ 335 | ], 336 | serialized_options=None, 337 | is_extendable=False, 338 | syntax='proto2', 339 | extension_ranges=[], 340 | oneofs=[ 341 | ], 342 | serialized_start=812, 343 | serialized_end=998, 344 | ) 345 | 346 | _FEATURENODE.fields_by_name['type'].enum_type = _FEATURENODE_NODETYPE 347 | _FEATURENODE_NODETYPE.containing_type = _FEATURENODE 348 | _FEATUREEDGE.fields_by_name['type'].enum_type = _FEATUREEDGE_EDGETYPE 349 | _FEATUREEDGE_EDGETYPE.containing_type = _FEATUREEDGE 350 | _GRAPH.fields_by_name['node'].message_type = _FEATURENODE 351 | _GRAPH.fields_by_name['edge'].message_type = _FEATUREEDGE 352 | _GRAPH.fields_by_name['first_token'].message_type = _FEATURENODE 353 | _GRAPH.fields_by_name['ast_root'].message_type = _FEATURENODE 354 | DESCRIPTOR.message_types_by_name['FeatureNode'] = _FEATURENODE 355 | DESCRIPTOR.message_types_by_name['FeatureEdge'] = _FEATUREEDGE 356 | DESCRIPTOR.message_types_by_name['Graph'] = _GRAPH 357 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 358 | 359 | FeatureNode = _reflection.GeneratedProtocolMessageType('FeatureNode', (_message.Message,), dict( 360 | DESCRIPTOR = _FEATURENODE, 361 | __module__ = 'graph_pb2' 362 | # @@protoc_insertion_point(class_scope:protobuf.FeatureNode) 363 | )) 364 | _sym_db.RegisterMessage(FeatureNode) 365 | 366 | FeatureEdge = _reflection.GeneratedProtocolMessageType('FeatureEdge', (_message.Message,), dict( 367 | DESCRIPTOR = _FEATUREEDGE, 368 | __module__ = 'graph_pb2' 369 | # @@protoc_insertion_point(class_scope:protobuf.FeatureEdge) 370 | )) 371 | _sym_db.RegisterMessage(FeatureEdge) 372 | 373 | Graph = _reflection.GeneratedProtocolMessageType('Graph', (_message.Message,), dict( 374 | DESCRIPTOR = _GRAPH, 375 | __module__ = 'graph_pb2' 376 | # @@protoc_insertion_point(class_scope:protobuf.Graph) 377 | )) 378 | _sym_db.RegisterMessage(Graph) 379 | 380 | 381 | DESCRIPTOR._options = None 382 | # @@protoc_insertion_point(module_scope) 383 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | from utils import vocabulary_extractor 2 | from model.model import Model 3 | import yaml 4 | import sys 5 | from utils.arg_parser import parse_input_args 6 | 7 | def infer(task_id): 8 | 9 | with open("config.yml", 'r') as ymlfile: 10 | cfg = yaml.load(ymlfile) 11 | 12 | checkpoint_path = cfg['checkpoint_path'] 13 | test_path = cfg['test_path'] 14 | token_path = cfg['token_path'] 15 | 16 | 17 | vocabulary = vocabulary_extractor.load_vocabulary(token_path) 18 | m = Model(mode='infer', task_id=task_id, vocabulary=vocabulary) 19 | 20 | m.infer(corpus_path=test_path, checkpoint_path=checkpoint_path) 21 | print("Inference ran successfully...") 22 | 23 | 24 | 25 | args = sys.argv[1:] 26 | task_id = parse_input_args(args) 27 | 28 | infer(task_id) 29 | 30 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/model/__init__.py -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from graph_pb2 import Graph 3 | from dpu_utils.tfmodels import SparseGGNN 4 | from data_processing.sample_inf_processing import SampleMetaInformation, CorpusMetaInformation 5 | import numpy as np 6 | import os 7 | from data_processing import graph_processing 8 | from data_processing.graph_features import get_used_edges_type 9 | from random import shuffle 10 | from utils.utils import compute_f1_score 11 | 12 | 13 | class Model: 14 | 15 | def __init__(self, mode, task_id, vocabulary): 16 | 17 | # Initialize parameter values 18 | self.max_node_seq_len = 32 # Maximum number of node subtokens 19 | self.max_var_seq_len = 16 # Maximum number of variable subtokens 20 | self.max_slots = 64 # Maximum number of variable occurrences 21 | self.batch_size = 20000 # Number of nodes per batch sample 22 | self.enable_batching = True 23 | self.learning_rate = 0.001 24 | self.ggnn_params = self.get_gnn_params() 25 | self.vocabulary = vocabulary 26 | self.voc_size = len(vocabulary) 27 | self.slot_id = self.vocabulary.get_id_or_unk('') 28 | self.sos_token_id = self.vocabulary.get_id_or_unk('sos_token') 29 | self.pad_token_id = self.vocabulary.get_id_or_unk(self.vocabulary.get_pad()) 30 | self.embedding_size = self.ggnn_params['hidden_size'] 31 | self.ggnn_dropout = 1.0 32 | self.task_type = task_id 33 | 34 | if mode != 'train' and mode != 'infer': 35 | raise ValueError("Invalid mode. Please specify \'train\' or \'infer\'...") 36 | 37 | 38 | self.graph = tf.Graph() 39 | self.mode = mode 40 | 41 | with self.graph.as_default(): 42 | 43 | self.placeholders = {} 44 | self.make_model() 45 | self.sess = tf.Session(graph=self.graph, config=tf.ConfigProto()) 46 | 47 | if self.mode == 'train': 48 | self.make_train_step() 49 | self.sess.run(tf.global_variables_initializer()) 50 | 51 | 52 | print ("Model built successfully...") 53 | 54 | 55 | 56 | 57 | def get_gnn_params(self): 58 | 59 | gnn_params = {} 60 | gnn_params["n_edge_types"] = len(get_used_edges_type()) 61 | gnn_params["hidden_size"] = 64 62 | gnn_params["edge_features_size"] = {} 63 | gnn_params["add_backwards_edges"] = True 64 | gnn_params["message_aggregation_type"] = "sum" 65 | gnn_params["layer_timesteps"] = [8] 66 | gnn_params["use_propagation_attention"] = False 67 | gnn_params["use_edge_bias"] = False 68 | gnn_params["graph_rnn_activation"] = "relu" 69 | gnn_params["graph_rnn_cell"] = "gru" 70 | gnn_params["residual_connections"] = {} 71 | gnn_params["use_edge_msg_avg_aggregation"] = False 72 | 73 | return gnn_params 74 | 75 | 76 | 77 | def make_inputs(self): 78 | 79 | # Node token sequences 80 | self.placeholders['unique_node_labels'] = tf.placeholder(name='unique_labels', shape=[None, self.max_node_seq_len], dtype=tf.int32 ) 81 | self.placeholders['unique_node_labels_mask'] = tf.placeholder(name='unique_node_labels_mask', shape=[None, self.max_node_seq_len], dtype=tf.float32) 82 | self.placeholders['node_label_indices'] = tf.placeholder(name='node_label_indices', shape=[None], dtype=tf.int32) 83 | 84 | # Graph edge matrices 85 | self.placeholders['adjacency_lists'] = [tf.placeholder(tf.int32, [None, 2]) for _ in range(self.ggnn_params['n_edge_types'])] 86 | self.placeholders['num_incoming_edges_per_type'] = tf.placeholder(tf.float32, [None, self.ggnn_params['n_edge_types']]) 87 | self.placeholders['num_outgoing_edges_per_type'] = tf.placeholder(tf.float32, [None, self.ggnn_params['n_edge_types']]) 88 | 89 | # Decoder sequence placeholders 90 | self.placeholders['decoder_targets'] = tf.placeholder(dtype=tf.int32, shape=(None, self.max_var_seq_len), name='dec_targets') 91 | self.placeholders['decoder_inputs'] = tf.placeholder(shape=(self.max_var_seq_len, self.placeholders['decoder_targets'].shape[0]), dtype=tf.int32, name='dec_inputs') 92 | self.placeholders['target_mask'] = tf.placeholder(tf.float32, [self.placeholders['decoder_targets'].shape[0], self.max_var_seq_len], name='target_mask') 93 | self.placeholders['sos_tokens'] = tf.placeholder(shape=(self.placeholders['decoder_targets'].shape[0]), dtype=tf.int32, name='sos_tokens') 94 | self.placeholders['decoder_targets_length'] = tf.placeholder(shape=(self.placeholders['decoder_targets'].shape[0]), dtype=tf.int32) 95 | 96 | # Node identifiers of all graph nodes of the target variable 97 | self.placeholders['slot_ids'] = tf.placeholder(tf.int32, [self.placeholders['decoder_targets'].shape[0], self.max_slots], name='slot_ids') 98 | self.placeholders['slot_ids_mask'] = tf.placeholder(tf.float32, [self.placeholders['decoder_targets'].shape[0], self.max_slots], name='slot_mask') 99 | 100 | # Record number of graph samples in given batch (used during loss computation) 101 | self.placeholders['num_samples_in_batch'] = tf.placeholder(dtype=tf.float32, shape=(1), name='num_samples_in_batch') 102 | 103 | 104 | 105 | def get_initial_node_representation(self): 106 | 107 | # Compute the embedding of input node sub-tokens 108 | self.embedding_encoder = tf.get_variable('embedding_encoder', [self.voc_size, self.embedding_size]) 109 | 110 | subtoken_embedding = tf.nn.embedding_lookup(params=self.embedding_encoder, ids=self.placeholders['unique_node_labels']) 111 | 112 | subtoken_ids_mask = tf.reshape(self.placeholders['unique_node_labels_mask'], [-1, self.max_node_seq_len, 1]) 113 | 114 | subtoken_embedding = subtoken_ids_mask * subtoken_embedding 115 | 116 | unique_label_representations = tf.reduce_sum(subtoken_embedding, axis=1) 117 | 118 | num_subtokens = tf.reduce_sum(subtoken_ids_mask, axis=1) 119 | 120 | unique_label_representations /= num_subtokens 121 | 122 | self.node_label_representations = tf.gather(params=unique_label_representations, 123 | indices=self.placeholders['node_label_indices']) 124 | 125 | 126 | 127 | 128 | def make_model(self): 129 | 130 | # Create inputs and compute initial node representations 131 | self.make_inputs() 132 | self.get_initial_node_representation() 133 | 134 | # Run graph through GGNN layer 135 | self.gnn_model = SparseGGNN(self.ggnn_params) 136 | self.gnn_representation = self.gnn_model.sparse_gnn_layer(self.ggnn_dropout, 137 | self.node_label_representations, 138 | self.placeholders['adjacency_lists'], 139 | self.placeholders['num_incoming_edges_per_type'], 140 | self.placeholders['num_outgoing_edges_per_type'], 141 | {}) 142 | 143 | 144 | # Compute average of usage representations 145 | self.avg_representation = tf.gather(self.gnn_representation, self.placeholders['slot_ids']) 146 | slot_mask = tf.reshape(self.placeholders['slot_ids_mask'], [-1, self.max_slots, 1]) 147 | slot_embedding = slot_mask * self.avg_representation 148 | self.avg_representation = tf.reduce_sum(slot_embedding, axis=1) 149 | num_slots = tf.reduce_sum(slot_mask, axis=1) 150 | self.avg_representation /= num_slots 151 | 152 | 153 | # Obtain output sequence by passing through a single GRU layer 154 | self.embedding_decoder = tf.get_variable('embedding_decoder', [self.voc_size, self.embedding_size]) 155 | self.decoder_cell = tf.nn.rnn_cell.GRUCell(self.embedding_size) 156 | decoder_initial_state = self.avg_representation 157 | self.projection_layer = tf.layers.Dense(self.voc_size, use_bias=False) 158 | 159 | 160 | 161 | # Training 162 | decoder_embedding_inputs = tf.nn.embedding_lookup(self.embedding_decoder, self.placeholders['decoder_inputs']) 163 | 164 | self.train_helper = tf.contrib.seq2seq.TrainingHelper(decoder_embedding_inputs, 165 | self.placeholders['decoder_targets_length'], 166 | time_major=True) 167 | 168 | self.train_decoder = tf.contrib.seq2seq.BasicDecoder(self.decoder_cell, self.train_helper, 169 | initial_state=decoder_initial_state, 170 | output_layer=self.projection_layer) 171 | 172 | decoder_outputs_train, _, _ = tf.contrib.seq2seq.dynamic_decode(self.train_decoder) 173 | 174 | self.decoder_logits_train = decoder_outputs_train.rnn_output 175 | 176 | 177 | # Inference 178 | end_token = self.pad_token_id 179 | max_iterations = self.max_var_seq_len 180 | 181 | self.inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.embedding_decoder, 182 | start_tokens=self.placeholders['sos_tokens'], 183 | end_token=end_token) 184 | 185 | 186 | self.inference_decoder = tf.contrib.seq2seq.BasicDecoder(self.decoder_cell, self.inference_helper, 187 | initial_state=decoder_initial_state, 188 | output_layer=self.projection_layer) 189 | 190 | outputs_inference, _, _ = tf.contrib.seq2seq.dynamic_decode(self.inference_decoder, 191 | maximum_iterations=max_iterations) 192 | 193 | self.predictions = outputs_inference.sample_id 194 | 195 | 196 | 197 | 198 | def make_train_step(self): 199 | 200 | max_batch_seq_len = tf.reduce_max(self.placeholders['decoder_targets_length']) 201 | 202 | self.crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.placeholders['decoder_targets'][:, :max_batch_seq_len], 203 | logits=self.decoder_logits_train) 204 | 205 | self.train_loss = tf.reduce_sum(self.crossent * self.placeholders['target_mask'][:, :max_batch_seq_len]) / self.placeholders['num_samples_in_batch'] 206 | 207 | # Calculate and clip gradients 208 | self.train_vars = tf.trainable_variables() 209 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate) 210 | 211 | 212 | grads_and_vars = self.optimizer.compute_gradients(self.train_loss, var_list=self.train_vars) 213 | 214 | clipped_grads = [] 215 | 216 | for grad, var in grads_and_vars: 217 | if grad is not None: 218 | clipped_grads.append((tf.clip_by_norm(grad, 5.0), var)) 219 | else: 220 | clipped_grads.append((grad, var)) 221 | 222 | self.train_step = self.optimizer.apply_gradients(clipped_grads) 223 | 224 | 225 | 226 | 227 | 228 | # Set placeholder values using given graph input 229 | def create_sample(self, slot_row_id_list, node_representation, adj_lists, incoming_edges, outgoing_edges): 230 | 231 | # Retrieve variable token sequence 232 | target_token_seq = node_representation[slot_row_id_list[0]][:self.max_var_seq_len] 233 | 234 | # Set all occurrences of variable to 235 | slotted_node_representation = node_representation.copy() 236 | slotted_node_representation[slot_row_id_list, :] = self.pad_token_id 237 | slotted_node_representation[slot_row_id_list, 0] = self.slot_id 238 | 239 | node_rep_mask = (slotted_node_representation != self.pad_token_id).astype(int) 240 | 241 | slot_row_ids = np.zeros((1, self.max_slots)) 242 | slot_mask = np.zeros((1, self.max_slots)) 243 | slot_row_ids[0, 0:len(slot_row_id_list)] = slot_row_id_list 244 | slot_mask[0, 0:len(slot_row_id_list)] = 1 245 | 246 | decoder_inputs = np.zeros((self.max_var_seq_len, 1)) 247 | decoder_targets = np.zeros((1, self.max_var_seq_len)) 248 | target_mask = np.zeros((1, self.max_var_seq_len)) 249 | start_tokens = np.ones((1)) * self.sos_token_id 250 | 251 | if self.mode == 'train': 252 | 253 | # Set decoder inputs and targets 254 | decoder_inputs = target_token_seq.copy() 255 | decoder_inputs = np.insert(decoder_inputs, 0, self.sos_token_id)[:-1] 256 | decoder_inputs = decoder_inputs.reshape(self.max_var_seq_len, 1) 257 | 258 | decoder_targets = target_token_seq.copy() 259 | decoder_targets = decoder_targets.reshape(1, self.max_var_seq_len) 260 | 261 | num_non_pads = np.sum(decoder_targets != self.pad_token_id) + 1 262 | target_mask[0, 0:num_non_pads] = 1 263 | 264 | 265 | 266 | # If batching is enabled, delay creation of the node representations until batch creation 267 | if self.enable_batching: 268 | unique_label_subtokens, unique_label_indices = None, None 269 | unique_label_inverse_indices = slotted_node_representation 270 | else: 271 | unique_label_subtokens, unique_label_indices, unique_label_inverse_indices = \ 272 | np.unique(slotted_node_representation, return_index=True, return_inverse=True, axis=0) 273 | 274 | 275 | # Create the sample graph 276 | graph_sample = { 277 | self.placeholders['unique_node_labels']: unique_label_subtokens, 278 | self.placeholders['unique_node_labels_mask']: node_rep_mask[unique_label_indices], 279 | self.placeholders['node_label_indices']: unique_label_inverse_indices, 280 | self.placeholders['slot_ids']: slot_row_ids, 281 | self.placeholders['slot_ids_mask']: slot_mask, 282 | self.placeholders['num_incoming_edges_per_type']: incoming_edges, 283 | self.placeholders['num_outgoing_edges_per_type']: outgoing_edges, 284 | self.placeholders['decoder_targets']: decoder_targets, 285 | self.placeholders['decoder_inputs']: decoder_inputs, 286 | self.placeholders['decoder_targets_length']: np.ones((1)) * np.sum(target_mask), 287 | self.placeholders['sos_tokens']: start_tokens, 288 | self.placeholders['target_mask']: target_mask, 289 | self.placeholders['num_samples_in_batch']: np.ones((1)) 290 | } 291 | 292 | for i in range(self.ggnn_params['n_edge_types']): 293 | graph_sample[self.placeholders['adjacency_lists'][i]] = adj_lists[i] 294 | 295 | target_name = [self.vocabulary.get_name_for_id(token_id) 296 | for token_id in target_token_seq if token_id != self.pad_token_id] 297 | 298 | return graph_sample, target_name 299 | 300 | 301 | # Extract samples from given file 302 | def create_samples(self, filepath): 303 | 304 | with open(filepath, "rb") as f: 305 | 306 | g = Graph() 307 | g.ParseFromString(f.read()) 308 | 309 | max_path_len = 8 310 | 311 | 312 | # Select sample parsing strategy depending on the specified model task 313 | if self.task_type == 0: 314 | graph_samples, slot_node_ids = graph_processing.get_usage_samples(g, max_path_len, self.max_slots, 315 | self.max_node_seq_len, self.pad_token_id, 316 | self.slot_id, self.vocabulary) 317 | 318 | elif self.task_type == 1: 319 | graph_samples, slot_node_ids = graph_processing.get_usage_samples(g, max_path_len, self.max_slots, 320 | self.max_node_seq_len, 321 | self.pad_token_id, 322 | self.slot_id, self.vocabulary, True) 323 | 324 | elif self.task_type == 2: 325 | graph_samples, slot_node_ids = graph_processing.get_method_body_samples(g, 326 | self.max_node_seq_len, 327 | self.pad_token_id, 328 | self.slot_id, self.vocabulary) 329 | 330 | else: 331 | raise ValueError("Invalid task id...") 332 | 333 | 334 | samples, labels = [], [] 335 | 336 | for sample in graph_samples: 337 | new_sample, new_label = self.create_sample(*sample) 338 | samples.append(new_sample) 339 | labels.append(new_label) 340 | 341 | 342 | # Save sample meta-information 343 | samples_meta_inf = [] 344 | 345 | for slot_node_id in slot_node_ids: 346 | new_inf = SampleMetaInformation(filepath, slot_node_id) 347 | samples_meta_inf.append(new_inf) 348 | 349 | return samples, labels, samples_meta_inf 350 | 351 | 352 | 353 | 354 | def make_batch_samples(self, graph_samples, all_labels): 355 | 356 | max_nodes_in_batch = self.batch_size 357 | batch_samples, labels = [], [] 358 | current_batch = [] 359 | nodes_in_curr_batch = 0 360 | 361 | for sample_index, graph_sample in enumerate(graph_samples): 362 | 363 | num_nodes_in_sample = graph_sample[self.placeholders['node_label_indices']].shape[0] 364 | 365 | # Skip sample if it is too big 366 | if num_nodes_in_sample > max_nodes_in_batch: 367 | continue 368 | 369 | # Add to current batch if there is space 370 | if num_nodes_in_sample + nodes_in_curr_batch < max_nodes_in_batch: 371 | current_batch.append(graph_sample) 372 | nodes_in_curr_batch += num_nodes_in_sample 373 | 374 | # Otherwise start creating a new batch 375 | else: 376 | batch_samples.append(self.make_batch(current_batch)) 377 | current_batch = [graph_sample] 378 | nodes_in_curr_batch = num_nodes_in_sample 379 | 380 | labels.append(all_labels[sample_index]) 381 | 382 | 383 | if len(current_batch) > 0: 384 | batch_samples.append(self.make_batch(current_batch)) 385 | 386 | return batch_samples, labels 387 | 388 | 389 | 390 | # Merge set of given graph samples into a single batch 391 | def make_batch(self, graph_samples): 392 | 393 | node_offset = 0 394 | node_reps = [] 395 | slot_ids, slot_masks = [], [] 396 | num_incoming_edges_per_type, num_outgoing_edges_per_type = [], [] 397 | decoder_targets, decoder_inputs, decoder_targets_length, decoder_masks = [], [], [], [] 398 | adj_lists = [[] for _ in range(self.ggnn_params['n_edge_types'])] 399 | start_tokens = np.ones((len(graph_samples))) * self.sos_token_id 400 | 401 | for graph_sample in graph_samples: 402 | 403 | num_nodes_in_graph = graph_sample[self.placeholders['node_label_indices']].shape[0] 404 | 405 | node_reps.append(graph_sample[self.placeholders['node_label_indices']]) 406 | 407 | slot_ids.append(graph_sample[self.placeholders['slot_ids']] + graph_sample[self.placeholders['slot_ids_mask']] * node_offset) 408 | 409 | slot_masks.append(graph_sample[self.placeholders['slot_ids_mask']]) 410 | 411 | num_incoming_edges_per_type.append(graph_sample[self.placeholders['num_incoming_edges_per_type']]) 412 | 413 | num_outgoing_edges_per_type.append(graph_sample[self.placeholders['num_outgoing_edges_per_type']]) 414 | 415 | decoder_inputs.append(graph_sample[self.placeholders['decoder_inputs']]) 416 | 417 | decoder_targets.append(graph_sample[self.placeholders['decoder_targets']]) 418 | 419 | decoder_targets_length.append(graph_sample[self.placeholders['decoder_targets_length']]) 420 | 421 | decoder_masks.append(graph_sample[self.placeholders['target_mask']]) 422 | 423 | for i in range(self.ggnn_params['n_edge_types']): 424 | adj_lists[i].append(graph_sample[self.placeholders['adjacency_lists'][i]] + node_offset) 425 | 426 | node_offset += num_nodes_in_graph 427 | 428 | 429 | 430 | all_node_reps = np.vstack(node_reps) 431 | node_rep_mask = (all_node_reps != self.pad_token_id).astype(int) 432 | 433 | unique_label_subtokens, unique_label_indices, unique_label_inverse_indices = \ 434 | np.unique(all_node_reps, return_index=True, return_inverse=True, axis=0) 435 | 436 | batch_sample = { 437 | self.placeholders['unique_node_labels']: unique_label_subtokens, 438 | self.placeholders['unique_node_labels_mask']: node_rep_mask[unique_label_indices], 439 | self.placeholders['node_label_indices']: unique_label_inverse_indices, 440 | self.placeholders['slot_ids']: np.vstack(slot_ids), 441 | self.placeholders['slot_ids_mask']: np.vstack(slot_masks), 442 | self.placeholders['num_incoming_edges_per_type']: np.vstack(num_incoming_edges_per_type), 443 | self.placeholders['num_outgoing_edges_per_type']: np.vstack(num_outgoing_edges_per_type), 444 | self.placeholders['decoder_targets']: np.vstack(decoder_targets), 445 | self.placeholders['decoder_inputs']: np.hstack(decoder_inputs), 446 | self.placeholders['decoder_targets_length']: np.hstack(decoder_targets_length), 447 | self.placeholders['sos_tokens']: start_tokens, 448 | self.placeholders['target_mask']: np.vstack(decoder_masks), 449 | self.placeholders['num_samples_in_batch']: np.ones((1)) * len(decoder_targets) 450 | } 451 | 452 | for i in range(self.ggnn_params['n_edge_types']): 453 | if len(adj_lists[i]) > 0: 454 | adj_list = np.vstack(adj_lists[i]) 455 | else: 456 | adj_list = np.zeros((0, 2), dtype=np.int32) 457 | 458 | batch_sample[self.placeholders['adjacency_lists'][i]] = adj_list 459 | 460 | return batch_sample 461 | 462 | 463 | 464 | 465 | def get_samples(self, dir_path): 466 | 467 | graph_samples, labels, _ = self.get_samples_with_metainf(dir_path) 468 | 469 | return graph_samples, labels 470 | 471 | 472 | 473 | def get_samples_with_metainf(self, dir_path): 474 | 475 | graph_samples, labels, metainf = [], [], [] 476 | 477 | n_files = sum([1 for dirpath, dirs, files in os.walk(dir_path) for filename in files if filename.endswith('proto')]) 478 | n_processed = 0 479 | 480 | for dirpath, dirs, files in os.walk(dir_path): 481 | for filename in files: 482 | if filename.endswith('proto'): 483 | 484 | fname = os.path.join(dirpath, filename) 485 | 486 | new_samples, new_labels, new_inf = self.create_samples(fname) 487 | 488 | if len(new_samples) > 0: 489 | graph_samples += new_samples 490 | labels += new_labels 491 | metainf += new_inf 492 | 493 | n_processed += 1 494 | print("Processed ", n_processed/n_files * 100, "% of files...") 495 | 496 | 497 | zipped = list(zip(graph_samples, labels, metainf)) 498 | shuffle(zipped) 499 | graph_samples, labels, metainf = zip(*zipped) 500 | 501 | if self.enable_batching: 502 | graph_samples, labels = self.make_batch_samples(graph_samples, labels) 503 | 504 | return graph_samples, labels, metainf 505 | 506 | 507 | 508 | def train(self, train_path, val_path, n_epochs, checkpoint_path): 509 | 510 | train_samples, train_labels = self.get_samples(train_path) 511 | print("Extracted training samples... ", len(train_samples)) 512 | 513 | val_samples, val_labels, meta_inf = self.get_samples_with_metainf(val_path) 514 | print("Extracted validation samples... ", len(val_samples)) 515 | 516 | 517 | with self.graph.as_default(): 518 | 519 | for epoch in range(n_epochs): 520 | 521 | loss = 0 522 | 523 | for graph in train_samples: 524 | loss += self.sess.run([self.train_loss, self.train_step], feed_dict=graph)[0] 525 | 526 | print("Average Epoch Loss:", (loss/len(train_samples))) 527 | print("Epoch: ", epoch + 1, "/", n_epochs) 528 | print("---------------------------------------------") 529 | 530 | 531 | if (epoch+1) % 5 == 0: 532 | 533 | saver = tf.train.Saver() 534 | saver.save(self.sess, checkpoint_path) 535 | 536 | self.compute_metrics_from_graph_samples(val_samples, val_labels, meta_inf) 537 | 538 | saver = tf.train.Saver() 539 | saver.save(self.sess, checkpoint_path) 540 | 541 | 542 | 543 | 544 | 545 | 546 | def infer(self, corpus_path, checkpoint_path): 547 | 548 | test_samples, test_labels, meta_inf = self.get_samples_with_metainf(corpus_path) 549 | 550 | for i in range(len(test_labels)): 551 | meta_inf[i].true_label = test_labels[i] 552 | 553 | with self.graph.as_default(): 554 | 555 | saver = tf.train.Saver() 556 | saver.restore(self.sess, checkpoint_path) 557 | print("Model loaded successfully...") 558 | 559 | _, _, predicted_names = self.compute_metrics_from_graph_samples(test_samples, test_labels, meta_inf) 560 | 561 | return test_samples, test_labels, meta_inf, predicted_names 562 | 563 | 564 | 565 | def get_predictions(self, graph_samples): 566 | 567 | predicted_names = [] 568 | 569 | for graph in graph_samples: 570 | 571 | predictions = self.sess.run([self.predictions], feed_dict=graph)[0] 572 | 573 | for i in range(len(predictions)): 574 | 575 | predicted_name = [self.vocabulary.get_name_for_id(token_id) for token_id in predictions[i]] 576 | 577 | if self.vocabulary.get_pad() in predicted_name: 578 | pad_index = predicted_name.index(self.vocabulary.get_pad()) 579 | predicted_name = predicted_name[:pad_index] 580 | 581 | predicted_names.append(predicted_name) 582 | 583 | return predicted_names 584 | 585 | 586 | 587 | 588 | def compute_metrics_from_graph_samples(self, graph_samples, test_labels, sample_infs, print_labels=False): 589 | 590 | predicted_names = self.get_predictions(graph_samples) 591 | return self.compute_metrics(predicted_names, test_labels, sample_infs, print_labels) 592 | 593 | 594 | 595 | # Compute F1 and accuracy scores 596 | def compute_metrics(self, predicted_names, test_labels, sample_infs, print_labels=False): 597 | 598 | n_correct, n_nonzero, f1 = 0, 0, 0 599 | 600 | print("Predictions: ", len(predicted_names)) 601 | print("Test labels: ", len(test_labels)) 602 | 603 | for i in range(len(predicted_names)): 604 | 605 | if print_labels: 606 | 607 | print("Predicted: ", [sym.encode('utf-8') for sym in predicted_names[i]]) 608 | print("Actual: ", [sym.encode('utf-8') for sym in test_labels[i]]) 609 | print("") 610 | print("") 611 | 612 | 613 | f1 += compute_f1_score(predicted_names[i], test_labels[i]) 614 | 615 | if predicted_names[i] == test_labels[i]: 616 | n_correct += 1 617 | sample_infs[i].predicted_correctly = True 618 | 619 | else: 620 | sample_infs[i].predicted_correctly = False 621 | 622 | 623 | accuracy = n_correct / len(test_labels) * 100 624 | 625 | f1 = f1 * 100 / len(predicted_names) 626 | 627 | print("Absolute accuracy: ", accuracy) 628 | print("F1 score: ", f1) 629 | 630 | return accuracy, f1, predicted_names 631 | 632 | 633 | 634 | 635 | 636 | # Compute F1 and accuracy scores, as well as usage and type information 637 | # using the variables seen during training 638 | def metrics_on_seen_vars(self, train_path, test_path, checkpoint_path): 639 | 640 | train_samples, train_labels = self.get_samples(train_path) 641 | test_samples, test_labels, sample_infs, predicted_names = self.infer(test_path, checkpoint_path) 642 | 643 | seen_correct, seen_incorrect, unseen_correct, unseen_incorrect = 0, 0, 0, 0 644 | 645 | for i, sample_inf in enumerate(sample_infs): 646 | 647 | if test_labels[i] in train_labels: 648 | sample_inf.seen_in_training = True 649 | else: 650 | sample_inf.seen_in_training = False 651 | 652 | 653 | if test_labels[i] in train_labels and sample_inf.predicted_correctly: 654 | seen_correct += 1 655 | elif test_labels[i] in train_labels and not sample_inf.predicted_correctly: 656 | seen_incorrect += 1 657 | elif test_labels[i] not in train_labels and sample_inf.predicted_correctly: 658 | unseen_correct += 1 659 | else: 660 | unseen_incorrect += 1 661 | 662 | seen_predictions = [predicted_names[i] for i in range(len(predicted_names)) 663 | if sample_infs[i].seen_in_training ] 664 | 665 | seen_test_labels = [test_labels[i] for i in range(len(test_labels)) 666 | if sample_infs[i].seen_in_training ] 667 | 668 | 669 | seen_sample_infs = [sample_infs[i] for i in range(len(sample_infs)) 670 | if sample_infs[i].seen_in_training ] 671 | 672 | 673 | print("Metrics on seen variables: ") 674 | accuracy, f1, _ = self.compute_metrics(seen_predictions, seen_test_labels, seen_sample_infs) 675 | 676 | meta_corpus = CorpusMetaInformation(sample_infs) 677 | meta_corpus.process_sample_inf() 678 | 679 | 680 | 681 | 682 | 683 | 684 | 685 | 686 | 687 | 688 | 689 | 690 | -------------------------------------------------------------------------------- /saved_models/MethodNaming/Definition/checkpoint/train.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/MethodNaming/Definition/checkpoint/train.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/MethodNaming/Definition/checkpoint/train.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/MethodNaming/Definition/checkpoint/train.ckpt.index -------------------------------------------------------------------------------- /saved_models/MethodNaming/Definition/checkpoint/train.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/MethodNaming/Definition/checkpoint/train.ckpt.meta -------------------------------------------------------------------------------- /saved_models/MethodNaming/Definition/tokens.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/MethodNaming/Definition/tokens.txt -------------------------------------------------------------------------------- /saved_models/MethodNaming/Usage/checkpoint/train.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/MethodNaming/Usage/checkpoint/train.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/MethodNaming/Usage/checkpoint/train.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/MethodNaming/Usage/checkpoint/train.ckpt.index -------------------------------------------------------------------------------- /saved_models/MethodNaming/Usage/checkpoint/train.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/MethodNaming/Usage/checkpoint/train.ckpt.meta -------------------------------------------------------------------------------- /saved_models/MethodNaming/Usage/tokens.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/MethodNaming/Usage/tokens.txt -------------------------------------------------------------------------------- /saved_models/VarNaming/checkpoint/train.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/VarNaming/checkpoint/train.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /saved_models/VarNaming/checkpoint/train.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/VarNaming/checkpoint/train.ckpt.index -------------------------------------------------------------------------------- /saved_models/VarNaming/checkpoint/train.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/VarNaming/checkpoint/train.ckpt.meta -------------------------------------------------------------------------------- /saved_models/VarNaming/tokens.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/saved_models/VarNaming/tokens.txt -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils import vocabulary_extractor 2 | from model.model import Model 3 | import yaml 4 | import sys 5 | from utils.arg_parser import parse_input_args 6 | 7 | 8 | def train(task_id): 9 | 10 | with open("config.yml", 'r') as ymlfile: 11 | cfg = yaml.load(ymlfile) 12 | 13 | checkpoint_path = cfg['checkpoint_path'] 14 | train_path = cfg['train_path'] 15 | val_path = cfg['val_path'] 16 | token_path = cfg['token_path'] 17 | 18 | 19 | vocabulary = vocabulary_extractor.create_vocabulary_from_corpus(train_path, token_path) 20 | print("Constructed vocabulary...") 21 | 22 | m = Model(mode='train', task_id=task_id, vocabulary=vocabulary) 23 | n_train_epochs = 50 24 | 25 | m.train(train_path=train_path, val_path=val_path, n_epochs=n_train_epochs, checkpoint_path=checkpoint_path) 26 | print("Model trained successfully...") 27 | 28 | 29 | 30 | args = sys.argv[1:] 31 | task_id = parse_input_args(args) 32 | 33 | 34 | train(task_id) 35 | 36 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrykazhdan/Representing-Programs-with-Graphs/a19477bf650832f6d541ff8f8e4586316c97d68d/utils/__init__.py -------------------------------------------------------------------------------- /utils/arg_parser.py: -------------------------------------------------------------------------------- 1 | 2 | def parse_input_args(command_line_args): 3 | 4 | if len(command_line_args) > 1: 5 | raise ValueError("Too many input arguments provided") 6 | 7 | if len(command_line_args) == 0: 8 | return 0 9 | 10 | task = command_line_args[0] 11 | 12 | if task == 'mth_usage': 13 | return 1 14 | elif task == 'mth_def': 15 | return 2 16 | else: 17 | raise ValueError("Invalid argument entered. Expecting: \'mth_usage\' or \'mth_def\'...") 18 | 19 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from graph_pb2 import Graph 3 | from dpu_utils.codeutils import split_identifier_into_parts 4 | from graph_pb2 import FeatureNode 5 | from data_processing.graph_features import get_used_nodes_type 6 | from collections import defaultdict 7 | 8 | 9 | 10 | def compute_successors_and_predecessors(graph): 11 | 12 | successor_table = defaultdict(set) 13 | predecessor_table = defaultdict(set) 14 | 15 | for edge in graph.edge: 16 | successor_table[edge.sourceId].add(edge.destinationId) 17 | predecessor_table[edge.destinationId].add(edge.sourceId) 18 | 19 | return successor_table, predecessor_table 20 | 21 | 22 | def compute_node_table(graph): 23 | 24 | id_dict = {} 25 | 26 | for node in graph.node: 27 | id_dict[node.id] = node 28 | 29 | return id_dict 30 | 31 | 32 | 33 | 34 | 35 | def compute_f1_score(prediction, test_label): 36 | 37 | pred_copy = prediction.copy() 38 | tp = 0 39 | 40 | for subtoken in set(test_label): 41 | if subtoken in pred_copy: 42 | tp += 1 43 | pred_copy.remove(subtoken) 44 | 45 | 46 | if len(prediction) > 0: 47 | pr = tp / len(prediction) 48 | else: 49 | pr = 0 50 | 51 | if len(test_label) > 0: 52 | rec = tp / len(test_label) 53 | else: 54 | rec = 0 55 | 56 | 57 | if (pr + rec) > 0: 58 | f1 = 2 * pr * rec / (pr + rec) 59 | else: 60 | f1 = 0 61 | 62 | return f1 63 | 64 | 65 | 66 | 67 | 68 | # Compute corpus metrics in order make a more informed model hyperparameter selection 69 | def compute_corpus_stats(corpus_path): 70 | 71 | max_node_len, max_var_len, max_var_usage = 0, 0, 0 72 | 73 | for dirpath, dirs, files in os.walk(corpus_path): 74 | for filename in files: 75 | if filename.endswith('proto'): 76 | 77 | fname = os.path.join(dirpath, filename) 78 | 79 | with open(fname, "rb") as f: 80 | 81 | g = Graph() 82 | g.ParseFromString(f.read()) 83 | 84 | var_node_usages = {} 85 | identifier_node_ids = [] 86 | 87 | for node in g.node: 88 | 89 | if node.type not in get_used_nodes_type() \ 90 | and node.type != FeatureNode.SYMBOL_VAR: 91 | continue 92 | 93 | node_len = len(split_identifier_into_parts(node.contents)) 94 | 95 | if node_len > max_node_len: 96 | max_node_len = node_len 97 | 98 | if node.type == FeatureNode.SYMBOL_VAR: 99 | 100 | var_node_usages[node.id] = 0 101 | 102 | if node_len > max_var_len: 103 | max_var_len = node_len 104 | 105 | 106 | elif node.type == FeatureNode.IDENTIFIER_TOKEN: 107 | identifier_node_ids.append(node.id) 108 | 109 | 110 | for edge in g.edge: 111 | 112 | if edge.sourceId in var_node_usages and edge.destinationId in identifier_node_ids: 113 | var_node_usages[edge.sourceId] += 1 114 | 115 | 116 | if len(var_node_usages.values()) > 0: 117 | var_usage = max(var_node_usages.values()) 118 | else: 119 | var_usage = 0 120 | 121 | if var_usage > max_var_usage: max_var_usage = var_usage 122 | 123 | 124 | print("Longest node length: ", max_node_len) 125 | print("Longest variable length: ", max_var_len) 126 | print("Largest variable usage: ", max_var_usage) 127 | 128 | 129 | 130 | # Used for parsing type information from a type file 131 | def get_type_dists(types_fname): 132 | 133 | content_arr = [] 134 | 135 | with open(types_fname, "r") as f: 136 | 137 | for line in f: 138 | 139 | line_contents = str.split(line.strip()) 140 | 141 | if len(line_contents) == 5: 142 | 143 | line_contents = [line_contents[0], int(line_contents[1]), int(line_contents[3])] 144 | 145 | if line_contents[1] + line_contents[2] > 100: 146 | content_arr.append(line_contents) 147 | 148 | 149 | pred_acc = [[inf[0], 100 * inf[2] / (inf[2] + inf[1])] for inf in content_arr] 150 | pred_acc = sorted(pred_acc, key=lambda x: x[1], reverse=True) 151 | 152 | names = [inf[0] for inf in pred_acc] 153 | percentages = [inf[1] for inf in pred_acc] 154 | 155 | 156 | return names, percentages 157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /utils/vocabulary_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from graph_pb2 import Graph 3 | from dpu_utils.codeutils import split_identifier_into_parts 4 | from dpu_utils.mlutils import Vocabulary 5 | import pickle 6 | from data_processing.graph_features import get_used_nodes_type 7 | 8 | 9 | def create_vocabulary_from_corpus(corpus_path, output_token_path=None): 10 | 11 | all_sub_tokens = [] 12 | node_types = get_used_nodes_type() 13 | 14 | # Extract all subtokens from all nodes of the appropriate type using all graphs in the corpus 15 | for dirpath, dirs, files in os.walk(corpus_path): 16 | for filename in files: 17 | if filename.endswith('proto'): 18 | fname = os.path.join(dirpath, filename) 19 | 20 | with open(fname, "rb") as f: 21 | g = Graph() 22 | g.ParseFromString(f.read()) 23 | 24 | for n in g.node: 25 | if n.type in node_types: 26 | all_sub_tokens += split_identifier_into_parts(n.contents) 27 | 28 | all_sub_tokens = list(set(all_sub_tokens)) 29 | all_sub_tokens.append('') 30 | all_sub_tokens.append('sos_token') 31 | all_sub_tokens.sort() 32 | 33 | vocabulary = __create_voc_from_tokens(all_sub_tokens) 34 | 35 | # Save the vocabulary 36 | if output_token_path != None: 37 | with open(output_token_path, "wb") as fp: 38 | pickle.dump(vocabulary, fp) 39 | 40 | return vocabulary 41 | 42 | 43 | def load_vocabulary(token_path): 44 | 45 | if not os.path.isfile(token_path): 46 | raise ValueError("Error. File not found...") 47 | 48 | with open(token_path, "rb") as fp: 49 | vocabulary = pickle.load(fp) 50 | 51 | return vocabulary 52 | 53 | 54 | def __create_voc_from_tokens(all_sub_tokens): 55 | 56 | vocabulary = Vocabulary.create_vocabulary(all_sub_tokens, max_size=100000, count_threshold=1, 57 | add_unk=True, add_pad=True) 58 | 59 | return vocabulary 60 | --------------------------------------------------------------------------------