├── .gitignore ├── LICENSE ├── README.md ├── code2seq ├── LICENSE ├── README.md ├── common.py ├── preprocess.py └── preprocess.sh ├── cpp_parser ├── __init__.py ├── ast_parser.py ├── ast_utils.py ├── context.py ├── path.py └── sample.py ├── data ├── CMakeLists.txt ├── big_op.cc ├── func.cc ├── main.cc ├── method.cc └── template.cc ├── docker ├── README.md └── dockerfile └── src ├── data_set_merge.py ├── merge.py ├── miner.py ├── parser_process.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kirill 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cppminer 2 | cppminer produces a [code2seq](https://github.com/tech-srl/code2seq) compatible datasets from C++ code bases. 3 | 4 | Experimental [C++](https://drive.google.com/file/d/15BDd6zHFkVJXl95FG4JnnSse48k1UR3E/view?usp=sharing) dataset mined from the Chromium project sources. 5 | 6 | This tool consists from three scripts which should be run consistently. 7 | 8 | # 1. Miner 9 | The `miner.py` is the main utility which traverse c++ sources, parse them and produce raw dataset files. 10 | 11 | It has following command line interface: 12 | ~~~ 13 | usage: miner.py [-h] [-c contexts-number] [-l path-length] [-d ast-depth] [-p processes-number] [-e libclang-path] path out 14 | 15 | positional arguments: 16 | path the path sources directory 17 | out the output path 18 | 19 | optional arguments: 20 | -h, --help show this help message and exit 21 | -c contexts-number, --max_contexts_num contexts-number 22 | maximum number of contexts per sample 23 | -l path-length, --max_path_len path-length 24 | maximum path length (0 - no limit) 25 | -d ast-depth, --max_ast_depth ast-depth 26 | maximum depth of AST (0 - no limit) 27 | -p processes-number, --processes_num processes-number 28 | number of parallel processes 29 | -e libclang-path, --libclang libclang-path 30 | path to libclang.so file 31 | ~~~ 32 | 33 | The input path is traversed recursively and all files with following extensions `c, cc, cpp` are parsed. 34 | It is recommended to use the [c++ compilation database](https://clang.llvm.org/docs/JSONCompilationDatabase.html) which provides all required compilation flags for project files. 35 | 36 | These files have following format: 37 | 38 | * Each row is an example. 39 | * Each example is a space-delimited list of fields, where: 40 | 41 | 1. The first field is the target label, internally delimited by the "|" character (for example: compare|ignore|case) 42 | 2. Each of the following field are contexts, where each context has three components separated by commas (","). None of these components can include spaces nor commas. 43 | 44 | Context's components are a token, a path, and another token. 45 | 46 | Each `token` component is a token in the code, split to subtokens using the "|" character. 47 | 48 | Each `path` is a path between two tokens, split to path nodes using the "|" character. Example for a context: 49 | ``` 50 | my|key,StringExression|MethodCall|Name,get|value 51 | ``` 52 | Here `my|key` and `get|value` are tokens, and `StringExression|MethodCall|Name` is the syntactic path that connects them. 53 | 54 | # 2. Merge 55 | The `merge.py` is the utility which concatenates all raw file, shuffles them and produce three files `dataset.train.c2s`, `dataset.test.c2s` and `dataset.val.c2s` into the given directory. 56 | Also it can clean source files after merging. The important settings is the `map_file_size` which defines the size of the database file used for merging, 57 | you should increase default value of 6Gb for large datasets. 58 | 59 | It has following command line interface: 60 | 61 | ~~~ 62 | usage: merge.py [-h] [-c clear_resources_flag] [-m map_file_size] path 63 | 64 | merge resources generated by cppminer to a code2seq dataset 65 | 66 | positional arguments: 67 | path the dataset sources path 68 | 69 | optional arguments: 70 | -h, --help show this help message and exit 71 | -c clear_resources_flag, --clear_resources clear_resources_flag 72 | if True clear resource files 73 | -m map_file_size, --map_size map_file_size 74 | size of the DB file, default(6442450944 bytes) 75 | ~~~ 76 | 77 | # 3. Code2vec preprocess 78 | 79 | The third utility is the `preprocess.sh` from the `code2seq` folder, this is modified script from the original project which generates dataset in format suitable for the `code2seq` model. 80 | in general it creates new files with truncated and padded number of paths for each example. 81 | -------------------------------------------------------------------------------- /code2seq/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Technion 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code2seq/README.md: -------------------------------------------------------------------------------- 1 | This folder contatins modifies scripts from code2seq repository. 2 | 3 | Use them to produce final datasets. 4 | -------------------------------------------------------------------------------- /code2seq/common.py: -------------------------------------------------------------------------------- 1 | import re 2 | import subprocess 3 | import sys 4 | 5 | 6 | class Common: 7 | internal_delimiter = '|' 8 | SOS = '' 9 | EOS = '' 10 | PAD = '' 11 | UNK = '' 12 | 13 | @staticmethod 14 | def normalize_word(word): 15 | stripped = re.sub(r'[^a-zA-Z]', '', word) 16 | if len(stripped) == 0: 17 | return word.lower() 18 | else: 19 | return stripped.lower() 20 | 21 | @staticmethod 22 | def load_histogram(path, max_size=None): 23 | histogram = {} 24 | with open(path, 'r') as file: 25 | for line in file.readlines(): 26 | parts = line.split(' ') 27 | if not len(parts) == 2: 28 | continue 29 | histogram[parts[0]] = int(parts[1]) 30 | sorted_histogram = [(k, histogram[k]) for k in sorted(histogram, key=histogram.get, reverse=True)] 31 | return dict(sorted_histogram[:max_size]) 32 | 33 | @staticmethod 34 | def load_vocab_from_dict(word_to_count, add_values=[], max_size=None): 35 | word_to_index, index_to_word = {}, {} 36 | current_index = 0 37 | for value in add_values: 38 | word_to_index[value] = current_index 39 | index_to_word[current_index] = value 40 | current_index += 1 41 | sorted_counts = [(k, word_to_count[k]) for k in sorted(word_to_count, key=word_to_count.get, reverse=True)] 42 | limited_sorted = dict(sorted_counts[:max_size]) 43 | for word, count in limited_sorted.items(): 44 | word_to_index[word] = current_index 45 | index_to_word[current_index] = word 46 | current_index += 1 47 | return word_to_index, index_to_word, current_index 48 | 49 | @staticmethod 50 | def binary_to_string(binary_string): 51 | return binary_string.decode("utf-8") 52 | 53 | @staticmethod 54 | def binary_to_string_list(binary_string_list): 55 | return [Common.binary_to_string(w) for w in binary_string_list] 56 | 57 | @staticmethod 58 | def binary_to_string_matrix(binary_string_matrix): 59 | return [Common.binary_to_string_list(l) for l in binary_string_matrix] 60 | 61 | @staticmethod 62 | def binary_to_string_3d(binary_string_tensor): 63 | return [Common.binary_to_string_matrix(l) for l in binary_string_tensor] 64 | 65 | @staticmethod 66 | def legal_method_names_checker(name): 67 | return not name in [Common.UNK, Common.PAD, Common.EOS] 68 | 69 | @staticmethod 70 | def filter_impossible_names(top_words): 71 | result = list(filter(Common.legal_method_names_checker, top_words)) 72 | return result 73 | 74 | @staticmethod 75 | def unique(sequence): 76 | unique = [] 77 | [unique.append(item) for item in sequence if item not in unique] 78 | return unique 79 | 80 | @staticmethod 81 | def parse_results(result, pc_info_dict, topk=5): 82 | prediction_results = {} 83 | results_counter = 0 84 | for single_method in result: 85 | original_name, top_suggestions, top_scores, attention_per_context = list(single_method) 86 | current_method_prediction_results = PredictionResults(original_name) 87 | if attention_per_context is not None: 88 | word_attention_pairs = [(word, attention) for word, attention in 89 | zip(top_suggestions, attention_per_context) if 90 | Common.legal_method_names_checker(word)] 91 | for predicted_word, attention_timestep in word_attention_pairs: 92 | current_timestep_paths = [] 93 | for context, attention in [(key, attention_timestep[key]) for key in 94 | sorted(attention_timestep, key=attention_timestep.get, reverse=True)][ 95 | :topk]: 96 | if context in pc_info_dict: 97 | pc_info = pc_info_dict[context] 98 | current_timestep_paths.append((attention.item(), pc_info)) 99 | 100 | current_method_prediction_results.append_prediction(predicted_word, current_timestep_paths) 101 | else: 102 | for predicted_seq in top_suggestions: 103 | filtered_seq = [word for word in predicted_seq if Common.legal_method_names_checker(word)] 104 | current_method_prediction_results.append_prediction(filtered_seq, None) 105 | 106 | prediction_results[results_counter] = current_method_prediction_results 107 | results_counter += 1 108 | return prediction_results 109 | 110 | @staticmethod 111 | def compute_bleu(ref_file_name, predicted_file_name): 112 | with open(predicted_file_name) as predicted_file: 113 | pipe = subprocess.Popen(["perl", "scripts/multi-bleu.perl", ref_file_name], stdin=predicted_file, 114 | stdout=sys.stdout, stderr=sys.stderr) 115 | 116 | 117 | class PredictionResults: 118 | def __init__(self, original_name): 119 | self.original_name = original_name 120 | self.predictions = list() 121 | 122 | def append_prediction(self, name, current_timestep_paths): 123 | self.predictions.append(SingleTimeStepPrediction(name, current_timestep_paths)) 124 | 125 | class SingleTimeStepPrediction: 126 | def __init__(self, prediction, attention_paths): 127 | self.prediction = prediction 128 | if attention_paths is not None: 129 | paths_with_scores = [] 130 | for attention_score, pc_info in attention_paths: 131 | path_context_dict = {'score': attention_score, 132 | 'path': pc_info.longPath, 133 | 'token1': pc_info.token1, 134 | 'token2': pc_info.token2} 135 | paths_with_scores.append(path_context_dict) 136 | self.attention_paths = paths_with_scores 137 | 138 | 139 | class PathContextInformation: 140 | def __init__(self, context): 141 | self.token1 = context['name1'] 142 | self.longPath = context['path'] 143 | self.shortPath = context['shortPath'] 144 | self.token2 = context['name2'] 145 | 146 | def __str__(self): 147 | return '%s,%s,%s' % (self.token1, self.shortPath, self.token2) 148 | -------------------------------------------------------------------------------- /code2seq/preprocess.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from argparse import ArgumentParser 3 | 4 | import numpy as np 5 | 6 | import common 7 | 8 | ''' 9 | This script preprocesses the data from MethodPaths. It truncates methods with too many contexts, 10 | and pads methods with less paths with spaces. 11 | ''' 12 | 13 | 14 | def save_dictionaries(dataset_name, subtoken_to_count, node_to_count, target_to_count, max_contexts, num_examples): 15 | save_dict_file_path = '{}.dict.c2s'.format(dataset_name) 16 | with open(save_dict_file_path, 'wb') as file: 17 | pickle.dump(subtoken_to_count, file) 18 | pickle.dump(node_to_count, file) 19 | pickle.dump(target_to_count, file) 20 | pickle.dump(max_contexts, file) 21 | pickle.dump(num_examples, file) 22 | print('Dictionaries saved to: {}'.format(save_dict_file_path)) 23 | 24 | 25 | def process_file(file_path, data_file_role, dataset_name, max_contexts, max_data_contexts): 26 | sum_total = 0 27 | sum_sampled = 0 28 | total = 0 29 | max_unfiltered = 0 30 | max_contexts_to_sample = max_data_contexts if data_file_role == 'train' else max_contexts 31 | output_path = '{}.{}.c2s'.format(dataset_name, data_file_role) 32 | with open(output_path, 'w') as outfile: 33 | with open(file_path, 'r') as file: 34 | for line in file: 35 | parts = line.rstrip('\n').split(' ') 36 | target_name = parts[0] 37 | contexts = parts[1:] 38 | 39 | if len(contexts) > max_unfiltered: 40 | max_unfiltered = len(contexts) 41 | 42 | sum_total += len(contexts) 43 | if len(contexts) > max_contexts_to_sample: 44 | contexts = np.random.choice(contexts, max_contexts_to_sample, replace=False) 45 | 46 | sum_sampled += len(contexts) 47 | 48 | csv_padding = " " * (max_data_contexts - len(contexts)) 49 | total += 1 50 | outfile.write(target_name + ' ' + " ".join(contexts) + csv_padding + '\n') 51 | 52 | print('File: ' + data_file_path) 53 | print('Average total contexts: ' + str(float(sum_total) / total)) 54 | print('Average final (after sampling) contexts: ' + str(float(sum_sampled) / total)) 55 | print('Total examples: ' + str(total)) 56 | print('Max number of contexts per word: ' + str(max_unfiltered)) 57 | return total 58 | 59 | 60 | def context_full_found(context_parts, word_to_count, path_to_count): 61 | return context_parts[0] in word_to_count \ 62 | and context_parts[1] in path_to_count and context_parts[2] in word_to_count 63 | 64 | 65 | def context_partial_found(context_parts, word_to_count, path_to_count): 66 | return context_parts[0] in word_to_count \ 67 | or context_parts[1] in path_to_count or context_parts[2] in word_to_count 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = ArgumentParser() 72 | parser.add_argument("-trd", "--train_data", dest="train_data_path", 73 | help="path to training data file", required=True) 74 | parser.add_argument("-ted", "--test_data", dest="test_data_path", 75 | help="path to test data file", required=True) 76 | parser.add_argument("-vd", "--val_data", dest="val_data_path", 77 | help="path to validation data file", required=True) 78 | parser.add_argument("-mc", "--max_contexts", dest="max_contexts", default=200, 79 | help="number of max contexts to keep in test+validation", required=False) 80 | parser.add_argument("-mdc", "--max_data_contexts", dest="max_data_contexts", default=1000, 81 | help="number of max contexts to keep in the dataset", required=False) 82 | parser.add_argument("-svs", "--subtoken_vocab_size", dest="subtoken_vocab_size", default=186277, 83 | help="Max number of source subtokens to keep in the vocabulary", required=False) 84 | parser.add_argument("-tvs", "--target_vocab_size", dest="target_vocab_size", default=26347, 85 | help="Max number of target words to keep in the vocabulary", required=False) 86 | parser.add_argument("-sh", "--subtoken_histogram", dest="subtoken_histogram", 87 | help="subtoken histogram file", metavar="FILE", required=True) 88 | parser.add_argument("-nh", "--node_histogram", dest="node_histogram", 89 | help="node_histogram file", metavar="FILE", required=True) 90 | parser.add_argument("-th", "--target_histogram", dest="target_histogram", 91 | help="target histogram file", metavar="FILE", required=True) 92 | parser.add_argument("-o", "--output_name", dest="output_name", 93 | help="output name - the base name for the created dataset", required=True, default='data') 94 | args = parser.parse_args() 95 | 96 | train_data_path = args.train_data_path 97 | test_data_path = args.test_data_path 98 | val_data_path = args.val_data_path 99 | subtoken_histogram_path = args.subtoken_histogram 100 | node_histogram_path = args.node_histogram 101 | 102 | subtoken_to_count = common.Common.load_histogram(subtoken_histogram_path, 103 | max_size=int(args.subtoken_vocab_size)) 104 | node_to_count = common.Common.load_histogram(node_histogram_path, 105 | max_size=None) 106 | target_to_count = common.Common.load_histogram(args.target_histogram, 107 | max_size=int(args.target_vocab_size)) 108 | print('subtoken vocab size: ', len(subtoken_to_count)) 109 | print('node vocab size: ', len(node_to_count)) 110 | print('target vocab size: ', len(target_to_count)) 111 | 112 | num_training_examples = 0 113 | for data_file_path, data_role in zip([test_data_path, val_data_path, train_data_path], ['test', 'val', 'train']): 114 | num_examples = process_file(file_path=data_file_path, data_file_role=data_role, dataset_name=args.output_name, 115 | max_contexts=int(args.max_contexts), max_data_contexts=int(args.max_data_contexts)) 116 | if data_role == 'train': 117 | num_training_examples = num_examples 118 | 119 | save_dictionaries(dataset_name=args.output_name, subtoken_to_count=subtoken_to_count, 120 | node_to_count=node_to_count, target_to_count=target_to_count, 121 | max_contexts=int(args.max_data_contexts), num_examples=num_training_examples) 122 | -------------------------------------------------------------------------------- /code2seq/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ########################################################### 3 | # Change the following values to preprocess a new dataset. 4 | # BASE_PATH - script parameter point to the directory where data files are placed 5 | # TRAIN_DATA_FILE, VAL_DATA_FILE and TEST_DATA_FILE should be paths to 6 | # files containing files obtained from merge.py script 7 | # DATASET_NAME is just a name for the currently extracted 8 | # dataset. 9 | # MAX_DATA_CONTEXTS is the number of contexts to keep in the dataset for each 10 | # method (by default 1000). At training time, these contexts 11 | # will be downsampled dynamically to MAX_CONTEXTS. 12 | # MAX_CONTEXTS - the number of actual contexts (by default 200) 13 | # that are taken into consideration (out of MAX_DATA_CONTEXTS) 14 | # every training iteration. To avoid randomness at test time, 15 | # for the test and validation sets only MAX_CONTEXTS contexts are kept 16 | # (while for training, MAX_DATA_CONTEXTS are kept and MAX_CONTEXTS are 17 | # selected dynamically during training). 18 | # SUBTOKEN_VOCAB_SIZE, TARGET_VOCAB_SIZE - 19 | # - the number of subtokens and target words to keep 20 | # in the vocabulary (the top occurring words and paths will be kept). 21 | # NUM_THREADS - the number of parallel threads to use. It is 22 | # recommended to use a multi-core machine for the preprocessing 23 | # step and set this value to the number of cores. 24 | # PYTHON - python3 interpreter alias. 25 | BASE_PATH=$1 26 | OUT_PATH=$2 27 | DATASET_NAME=dataset 28 | MAX_DATA_CONTEXTS=1000 29 | MAX_CONTEXTS=200 30 | SUBTOKEN_VOCAB_SIZE=186277 31 | TARGET_VOCAB_SIZE=26347 32 | NUM_THREADS=64 33 | PYTHON=python3 34 | ########################################################### 35 | 36 | TRAIN_DATA_FILE=${BASE_PATH}/${DATASET_NAME}.train.c2s 37 | VAL_DATA_FILE=${BASE_PATH}/${DATASET_NAME}.val.c2s 38 | TEST_DATA_FILE=${BASE_PATH}/${DATASET_NAME}.test.c2s 39 | 40 | mkdir -p ${OUT_PATH}/data 41 | mkdir -p ${OUT_PATH}/data/${DATASET_NAME} 42 | 43 | TARGET_HISTOGRAM_FILE=${OUT_PATH}/data/${DATASET_NAME}/${DATASET_NAME}.histo.tgt.c2s 44 | SOURCE_SUBTOKEN_HISTOGRAM=${OUT_PATH}/data/${DATASET_NAME}/${DATASET_NAME}.histo.ori.c2s 45 | NODE_HISTOGRAM_FILE=${OUT_PATH}/data/${DATASET_NAME}/${DATASET_NAME}.histo.node.c2s 46 | 47 | echo "Creating histograms from the training data" 48 | cat ${TRAIN_DATA_FILE} | cut -d' ' -f1 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${TARGET_HISTOGRAM_FILE} 49 | cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f1,3 | tr ',|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${SOURCE_SUBTOKEN_HISTOGRAM} 50 | cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f2 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${NODE_HISTOGRAM_FILE} 51 | 52 | ${PYTHON} preprocess.py --train_data ${TRAIN_DATA_FILE} --test_data ${TEST_DATA_FILE} --val_data ${VAL_DATA_FILE} \ 53 | --max_contexts ${MAX_CONTEXTS} --max_data_contexts ${MAX_DATA_CONTEXTS} --subtoken_vocab_size ${SUBTOKEN_VOCAB_SIZE} \ 54 | --target_vocab_size ${TARGET_VOCAB_SIZE} --subtoken_histogram ${SOURCE_SUBTOKEN_HISTOGRAM} \ 55 | --node_histogram ${NODE_HISTOGRAM_FILE} --target_histogram ${TARGET_HISTOGRAM_FILE} --output_name ${OUT_PATH}/data/${DATASET_NAME}/${DATASET_NAME} 56 | 57 | # If all went well, the raw data files can be deleted, because preprocess.py creates new files 58 | # with truncated and padded number of paths for each example. 59 | rm ${TARGET_HISTOGRAM_FILE} ${SOURCE_SUBTOKEN_HISTOGRAM} ${NODE_HISTOGRAM_FILE} 60 | 61 | -------------------------------------------------------------------------------- /cpp_parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .context import Context 2 | from .path import Path 3 | from .sample import Sample 4 | from .ast_parser import AstParser 5 | -------------------------------------------------------------------------------- /cpp_parser/ast_parser.py: -------------------------------------------------------------------------------- 1 | from clang.cindex import Index 2 | from .sample import Sample 3 | from .context import Context 4 | from .path import Path 5 | from .ast_utils import ast_to_graph, is_function, is_class, is_operator_token, is_namespace, make_ast_err_message 6 | from networkx.algorithms import shortest_path 7 | from networkx.drawing.nx_agraph import to_agraph 8 | from itertools import combinations 9 | import uuid 10 | import os 11 | import re 12 | import random 13 | 14 | 15 | def debug_save_graph(func_node, g): 16 | file_name = func_node.spelling + ".png" 17 | num = 0 18 | while os.path.exists(file_name): 19 | file_name = func_node.spelling + str(num) + ".png" 20 | num += 1 21 | a = to_agraph(g) 22 | a.draw(file_name, prog='dot') 23 | a.clear() 24 | 25 | 26 | def tokenize(name, max_subtokens_num): 27 | if is_operator_token(name): 28 | return [name] 29 | first_tokens = name.split('_') 30 | str_tokens = [] 31 | for token in first_tokens: 32 | internal_tokens = re.findall('[a-z]+|[A-Z]+[a-z]*|[0-9.]+|[-*/&|%=()]+', token) 33 | str_tokens += [t for t in internal_tokens if len(t) > 0] 34 | assert len(str_tokens) > 0, "Can't tokenize expr: {0}".format(name) 35 | if max_subtokens_num != 0: 36 | str_tokens = str_tokens[:max_subtokens_num] 37 | return str_tokens 38 | 39 | 40 | class AstParser: 41 | def __init__(self, max_contexts_num, max_path_len, max_subtokens_num, max_ast_depth, out_path): 42 | self.validate = False 43 | self.save_buffer_size = 1000 44 | self.out_path = out_path 45 | self.max_subtokens_num = max_subtokens_num 46 | self.max_contexts_num = max_contexts_num 47 | self.max_path_len = max_path_len 48 | self.max_ast_depth = max_ast_depth 49 | self.index = Index.create() 50 | self.samples = set() 51 | self.header_only_functions = set() 52 | 53 | def __del__(self): 54 | self.save() 55 | 56 | def __parse_node(self, node): 57 | try: 58 | namespaces = [x for x in node.get_children() if is_namespace(x)] 59 | for n in namespaces: 60 | # ignore standard library functions 61 | if n.displayname != 'std' and not n.displayname.startswith('__'): 62 | self.__parse_node(n) 63 | 64 | functions = [x for x in node.get_children() if is_function(x)] 65 | for f in functions: 66 | self.__parse_function(f) 67 | 68 | classes = [x for x in node.get_children() if is_class(x)] 69 | for c in classes: 70 | methods = [x for x in c.get_children() if is_function(x)] 71 | for m in methods: 72 | self.__parse_function(m) 73 | except Exception as e: 74 | if 'Unknown template argument kind' not in str(e): 75 | msg = make_ast_err_message(str(e), node) 76 | raise Exception(msg) 77 | 78 | self.__dump_samples() 79 | 80 | def parse(self, compiler_args, file_path=None): 81 | ast = self.index.parse(file_path, compiler_args) 82 | self.__parse_node(ast.cursor) 83 | 84 | def __dump_samples(self): 85 | if len(self.samples) >= self.save_buffer_size: 86 | self.save() 87 | 88 | def save(self): 89 | if not self.out_path: 90 | return 91 | if not os.path.exists(self.out_path): 92 | os.makedirs(self.out_path) 93 | if len(self.samples) > 0: 94 | file_name = os.path.join(self.out_path, str(uuid.uuid4().hex) + ".c2s") 95 | # print(file_name) 96 | with open(file_name, "w") as file: 97 | for sample in self.samples: 98 | file.write(str(sample.source_mark) + str(sample) + "\n") 99 | self.samples.clear() 100 | 101 | def __parse_function(self, func_node): 102 | try: 103 | # ignore standard library functions 104 | if func_node.displayname.startswith('__'): 105 | return 106 | 107 | # detect header only function duplicates 108 | file_name = func_node.location.file.name 109 | source_mark = (file_name, func_node.extent.start.line) 110 | if file_name.endswith('.h') and func_node.is_definition: 111 | # print('Header only function: {0}'.format(func_node.displayname)) 112 | if source_mark in self.header_only_functions: 113 | # print('Duplicate') 114 | return 115 | else: 116 | self.header_only_functions.add(source_mark) 117 | 118 | key = tokenize(func_node.spelling, self.max_subtokens_num) 119 | g = ast_to_graph(func_node, self.max_ast_depth) 120 | 121 | # debug_save_graph(func_node, g) 122 | 123 | terminal_nodes = [node for (node, degree) in g.degree() if degree == 1] 124 | random.shuffle(terminal_nodes) 125 | 126 | contexts = set() 127 | ends = combinations(terminal_nodes, 2) 128 | 129 | for start, end in ends: 130 | path = shortest_path(g, start, end) 131 | if path: 132 | if self.max_path_len != 0 and len(path) > self.max_path_len: 133 | continue # skip too long paths 134 | path = path[1:-1] 135 | start_node = g.nodes[start]['label'] 136 | tokenize_start_node = not g.nodes[start]['is_reserved'] 137 | end_node = g.nodes[end]['label'] 138 | tokenize_end_node = not g.nodes[end]['is_reserved'] 139 | 140 | path_tokens = [] 141 | for path_item in path: 142 | path_node = g.nodes[path_item]['label'] 143 | path_tokens.append(path_node) 144 | 145 | context = Context( 146 | tokenize(start_node, self.max_subtokens_num) if tokenize_start_node else [start_node], 147 | tokenize(end_node, self.max_subtokens_num) if tokenize_end_node else [end_node], 148 | Path(path_tokens, self.validate), self.validate) 149 | contexts.add(context) 150 | if len(contexts) > self.max_contexts_num: 151 | break 152 | 153 | if len(contexts) > 0: 154 | sample = Sample(key, contexts, source_mark, self.validate) 155 | self.samples.add(sample) 156 | except Exception as e: 157 | # skip unknown cursor exceptions 158 | if 'Unknown template argument kind' not in str(e): 159 | print('Failed to parse function : ') 160 | print('Filename : ' + func_node.location.file.name) 161 | print('Start {0}:{1}'.format(func_node.extent.start.line, func_node.extent.start.column)) 162 | print('End {0}:{1}'.format(func_node.extent.end.line, func_node.extent.end.column)) 163 | print(e) 164 | -------------------------------------------------------------------------------- /cpp_parser/ast_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import networkx as nx 4 | import uuid 5 | from clang.cindex import CursorKind, TokenKind 6 | 7 | 8 | def make_ast_err_message(msg, ast_node): 9 | msg = 'Error: {0} in:\n'.format(msg) 10 | msg += 'Filename : {0}\n'.format(ast_node.location.file.name) 11 | msg += 'Start {0}:{1}'.format(ast_node.extent.start.line, ast_node.extent.start.column) 12 | msg += ' End {0}:{1}'.format(ast_node.extent.end.line, ast_node.extent.end.column) 13 | return msg 14 | 15 | 16 | def is_node_kind_safe(node, kinds): 17 | try: 18 | return node.kind in kinds 19 | except Exception as e: 20 | msg = make_ast_err_message(str(e), node) 21 | if 'Unknown template argument kind' not in str(e): 22 | print(msg) 23 | return CursorKind.NOT_IMPLEMENTED 24 | 25 | 26 | def is_namespace(node): 27 | if is_node_kind_safe(node, [CursorKind.NAMESPACE]): 28 | return True 29 | return False 30 | 31 | 32 | def is_function(node): 33 | if is_node_kind_safe(node, [CursorKind.FUNCTION_DECL, 34 | CursorKind.FUNCTION_TEMPLATE, 35 | CursorKind.CXX_METHOD, 36 | CursorKind.DESTRUCTOR, 37 | CursorKind.CONSTRUCTOR]): 38 | if node.is_definition(): 39 | not_empty = False 40 | for _ in node.get_children(): 41 | not_empty = True 42 | break 43 | return not_empty 44 | return False 45 | 46 | 47 | def is_class(node): 48 | return is_node_kind_safe(node, [CursorKind.CLASS_TEMPLATE, 49 | CursorKind.CLASS_TEMPLATE_PARTIAL_SPECIALIZATION, 50 | CursorKind.CLASS_DECL]) 51 | 52 | 53 | def is_literal(node): 54 | return is_node_kind_safe(node, [CursorKind.INTEGER_LITERAL, 55 | CursorKind.FLOATING_LITERAL, 56 | CursorKind.IMAGINARY_LITERAL, 57 | CursorKind.STRING_LITERAL, 58 | CursorKind.CHARACTER_LITERAL]) 59 | 60 | 61 | def is_template_parameter(node): 62 | return is_node_kind_safe(node, [CursorKind.TEMPLATE_TYPE_PARAMETER, 63 | CursorKind.TEMPLATE_TEMPLATE_PARAMETER]) 64 | 65 | 66 | def is_reference(node): 67 | return is_node_kind_safe(node, [CursorKind.DECL_REF_EXPR, CursorKind.MEMBER_REF_EXPR]) 68 | 69 | 70 | def is_operator(node): 71 | return is_node_kind_safe(node, [CursorKind.BINARY_OPERATOR, 72 | CursorKind.UNARY_OPERATOR, 73 | CursorKind.COMPOUND_ASSIGNMENT_OPERATOR]) 74 | 75 | 76 | def is_call_expr(node): 77 | return is_node_kind_safe(node, [CursorKind.CALL_EXPR]) 78 | 79 | 80 | binary_operators = ['+', '-', '*', '/', '%', '&', '|'] 81 | unary_operators = ['++', '--'] 82 | comparison_operators = ['==', '<=', '>=', '<', '>', '!=', '&&', '||'] 83 | unary_assignment_operators = [op + '=' for op in binary_operators] 84 | assignment_operators = ['='] + unary_assignment_operators 85 | 86 | 87 | def is_operator_token(token): 88 | if token in binary_operators: 89 | return True 90 | if token in unary_operators: 91 | return True 92 | if token in comparison_operators: 93 | return True 94 | if token in unary_assignment_operators: 95 | return True 96 | if token in assignment_operators: 97 | return True 98 | 99 | 100 | def get_id(): 101 | node_id = uuid.uuid1() 102 | return node_id.int 103 | 104 | 105 | def add_node(ast_node, graph): 106 | try: 107 | node_id = ast_node.hash 108 | kind = ast_node.kind.name 109 | # skip meaningless AST primitives 110 | if ast_node.kind == CursorKind.DECL_STMT or \ 111 | ast_node.kind == CursorKind.UNEXPOSED_EXPR: 112 | return False 113 | 114 | if is_operator(ast_node): 115 | op_name = get_operator(ast_node) 116 | kind = kind.strip() + "_" + "_".join(op_name) 117 | 118 | graph.add_node(node_id, label=kind, is_reserved=True) 119 | 120 | # print("Cursor kind : {0}".format(kind)) 121 | if ast_node.kind.is_declaration(): 122 | add_declaration(node_id, ast_node, graph) 123 | elif is_literal(ast_node): 124 | add_literal(node_id, ast_node, graph) 125 | elif is_reference(ast_node): 126 | add_reference(node_id, ast_node, graph) 127 | elif is_call_expr(ast_node): 128 | add_call_expr(node_id, ast_node, graph) 129 | 130 | return True 131 | except Exception as e: 132 | if 'Unknown template argument kind' not in str(e): 133 | msg = make_ast_err_message(str(e), ast_node) 134 | raise Exception(msg) 135 | 136 | 137 | def add_child(graph, parent_id, name, is_reserved=True): 138 | child_id = get_id() 139 | assert len(name) > 0, "Missing node name" 140 | graph.add_node(child_id, label=name, is_reserved=is_reserved) 141 | graph.add_edge(parent_id, child_id) 142 | 143 | 144 | def add_intermediate_node(graph, parent_id, name): 145 | child_id = get_id() 146 | assert len(name) > 0, "Missing node name" 147 | graph.add_node(child_id, label=name, is_reserved=True) 148 | graph.add_edge(parent_id, child_id) 149 | return child_id 150 | 151 | 152 | def add_call_expr(parent_id, ast_node, graph): 153 | expr_type = ast_node.type.spelling 154 | expr_type_node_id = add_intermediate_node(graph, parent_id, "EXPR_TYPE") 155 | add_child(graph, expr_type_node_id, expr_type, is_reserved=False) 156 | 157 | 158 | def fix_cpp_operator_spelling(op_name): 159 | if op_name == '|': 160 | return 'OPERATOR_BINARY_OR' 161 | elif op_name == '||': 162 | return 'OPERATOR_LOGICAL_OR' 163 | elif op_name == '|=': 164 | return 'OPERATOR_ASSIGN_OR' 165 | elif op_name == ',': 166 | return 'OPERATOR_COMMA' 167 | else: 168 | return op_name 169 | 170 | 171 | def get_operator(ast_node): 172 | name_token = None 173 | for token in ast_node.get_tokens(): 174 | if is_operator_token(token.spelling): 175 | name_token = token 176 | break 177 | 178 | if not name_token: 179 | filename = ast_node.location.file.name 180 | with open(filename, 'r') as fh: 181 | contents = fh.read() 182 | code_str = contents[ast_node.extent.start.offset: ast_node.extent.end.offset] 183 | name = [] 184 | for ch in code_str: 185 | if ch in binary_operators: 186 | name.append(fix_cpp_operator_spelling(ch).strip()) 187 | return name 188 | else: 189 | name = name_token.spelling 190 | name = fix_cpp_operator_spelling(name) 191 | # print("\tName : {0}".format(name)) 192 | return [name.strip()] 193 | 194 | 195 | def add_reference(parent_id, ast_node, graph): 196 | is_reserved = True 197 | name = "REFERENCE" 198 | if ast_node.kind in [CursorKind.DECL_REF_EXPR, CursorKind.MEMBER_REF_EXPR]: 199 | for token in ast_node.get_tokens(): 200 | if token.kind == TokenKind.IDENTIFIER: 201 | name = token.spelling 202 | is_reserved = False 203 | break 204 | else: 205 | name = ast_node.spelling 206 | is_reserved = False 207 | 208 | add_child(graph, parent_id, name, is_reserved) 209 | # print("\tName : {0}".format(name)) 210 | 211 | 212 | def add_literal(parent_id, ast_node, graph): 213 | if ast_node.kind in [CursorKind.STRING_LITERAL, 214 | CursorKind.CHARACTER_LITERAL]: 215 | add_child(graph, parent_id, 'STRING_VALUE', is_reserved=True) 216 | else: 217 | token = next(ast_node.get_tokens(), None) 218 | if token: 219 | value = token.spelling 220 | add_child(graph, parent_id, value) 221 | # print("\tValue : {0}".format(value)) 222 | 223 | 224 | def add_declaration(parent_id, ast_node, graph): 225 | is_func = False 226 | if is_function(ast_node): 227 | is_func = True 228 | return_type = ast_node.type.get_result().spelling 229 | if len(return_type) > 0: 230 | return_type_node_id = add_intermediate_node(graph, parent_id, "RETURN_TYPE") 231 | add_child(graph, return_type_node_id, return_type, is_reserved=False) 232 | else: 233 | declaration_type = ast_node.type.spelling 234 | if len(declaration_type) > 0: 235 | declaration_type_node_id = add_intermediate_node(graph, parent_id, "DECLARATION_TYPE") 236 | add_child(graph, declaration_type_node_id, declaration_type, is_reserved=False) 237 | # print("\tDecl type : {0}".format(declaration_type)) 238 | 239 | if not is_template_parameter(ast_node): 240 | is_reserved = False 241 | # we skip function names to prevent over-fitting in the code2seq learning tasks 242 | if is_func: 243 | name = "FUNCTION_NAME" 244 | is_reserved = True 245 | else: 246 | name = ast_node.spelling 247 | 248 | # handle unnamed declarations 249 | if len(name) == 0: 250 | name = ast_node.kind.name + "_UNNAMED" 251 | is_reserved = True 252 | 253 | name_node_id = add_intermediate_node(graph, parent_id, "DECLARATION_NAME") 254 | 255 | add_child(graph, name_node_id, name, is_reserved=is_reserved) 256 | # print("\tName : {0}".format(name)) 257 | 258 | 259 | def func_from_pointer(ast_node): 260 | children = list(ast_node.get_children()) 261 | if children: 262 | return func_from_pointer(children[0]) 263 | else: 264 | return ast_node 265 | 266 | 267 | def ast_to_graph(ast_start_node, max_depth): 268 | g = nx.Graph() 269 | stack = [(ast_start_node, 0)] 270 | parent_map = {ast_start_node.hash: None} 271 | while stack: 272 | ast_node, depth = stack.pop() 273 | node_id = ast_node.hash 274 | if not g.has_node(node_id): 275 | parent_id = parent_map[node_id] 276 | if add_node(ast_node, g): 277 | if parent_id is not None: 278 | g.add_edge(parent_id, node_id) 279 | 280 | if is_call_expr(ast_node): 281 | func_name = None 282 | if ast_node.referenced: 283 | func_name = ast_node.referenced.spelling 284 | else: 285 | # pointer to function 286 | func_node = func_from_pointer(list(ast_node.get_children())[0]) 287 | func_name = func_node.spelling 288 | 289 | if not func_name: 290 | func_name = "FUNCTION_CALL" 291 | func_name = re.sub(r'\s+|,+', '', func_name) 292 | 293 | call_expr_id = add_intermediate_node(g, node_id, func_name) 294 | node_id = call_expr_id 295 | else: 296 | node_id = parent_id 297 | 298 | # Ignore too deep trees 299 | if max_depth == 0 or depth <= max_depth: 300 | if is_call_expr(ast_node): 301 | for arg_node in ast_node.get_arguments(): 302 | stack.append((arg_node, depth + 1)) 303 | parent_map[arg_node.hash] = node_id 304 | else: 305 | for child_node in ast_node.get_children(): 306 | stack.append((child_node, depth + 1)) 307 | parent_map[child_node.hash] = node_id 308 | return g 309 | -------------------------------------------------------------------------------- /cpp_parser/context.py: -------------------------------------------------------------------------------- 1 | class Context: 2 | def __init__(self, start_token, end_token, path, validate=False): 3 | self.start_token = start_token 4 | self.end_token = end_token 5 | self.path = path 6 | if validate: 7 | self.__validate() 8 | 9 | def __validate(self): 10 | self.__validate_token(self.start_token) 11 | self.__validate_token(self.end_token) 12 | 13 | def __validate_token(self, token): 14 | assert len(token) > 0, "Invalid token format: {0}".format(token) 15 | for sub_token in token: 16 | assert len(sub_token) > 0, "Invalid sub-token: {0}".format(token) 17 | self.validate_sub_token(sub_token) 18 | 19 | @staticmethod 20 | def validate_sub_token(sub_token): 21 | assert (' ' not in sub_token and 22 | '|' not in sub_token and 23 | ',' not in sub_token), "Invalid sub-token format: {0}".format(sub_token) 24 | -------------------------------------------------------------------------------- /cpp_parser/path.py: -------------------------------------------------------------------------------- 1 | from .context import Context 2 | 3 | 4 | class Path: 5 | def __init__(self, tokens, validate=False): 6 | self.tokens = tokens 7 | if validate: 8 | self.__validate() 9 | 10 | def __validate(self): 11 | for sub_token in self.tokens: 12 | assert len(sub_token) > 0, "Invalid sub-token in the path: {0}".format(self.tokens) 13 | Context.validate_sub_token(sub_token) 14 | -------------------------------------------------------------------------------- /cpp_parser/sample.py: -------------------------------------------------------------------------------- 1 | from .context import Context 2 | 3 | 4 | def make_str_key(list_value): 5 | str_value = "" 6 | for item in list_value: 7 | str_value += str(item) 8 | str_value += "|" 9 | str_value = str_value[:-1] 10 | return str_value 11 | 12 | 13 | class Sample: 14 | def __init__(self, key, contexts, source_mark, validate=False): 15 | self.key = key 16 | self.contexts = contexts 17 | self.source_mark = source_mark 18 | if validate: 19 | self.__validate() 20 | 21 | def __validate(self): 22 | assert len(self.key) > 0, "Invalid target key format: {0}".format(self.key) 23 | for sub_token in self.key: 24 | assert len(sub_token) > 0, "Invalid sub-token in the target key: {0}".format(self.key) 25 | Context.validate_sub_token(sub_token) 26 | 27 | def __str__(self): 28 | str_value = make_str_key(self.key) 29 | for context in self.contexts: 30 | str_value += " " + make_str_key(context.start_token) 31 | str_value += "," + make_str_key(context.path.tokens) 32 | str_value += "," + make_str_key(context.end_token) 33 | return str_value 34 | -------------------------------------------------------------------------------- /data/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | project(data) 3 | 4 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | set(CMAKE_CXX_EXTENSIONS OFF) 8 | 9 | file(GLOB SRC_FILES 10 | "*.c" 11 | "*.cc" 12 | "*.cpp" 13 | ) 14 | 15 | add_executable(${PROJECT_NAME} ${SRC_FILES}) 16 | -------------------------------------------------------------------------------- /data/big_op.cc: -------------------------------------------------------------------------------- 1 | #define SHA1_DIGEST_LENGTH 20 2 | 3 | int main() { 4 | char expected_sha1str[SHA1_DIGEST_LENGTH * 2 + 1]; 5 | return expected_sha1str[6]; 6 | } 7 | -------------------------------------------------------------------------------- /data/func.cc: -------------------------------------------------------------------------------- 1 | int div(int x, int y) { 2 | float z = 45; 3 | if (y != 0) { 4 | return x /y; 5 | } else { 6 | return 0; 7 | } 8 | } 9 | 10 | int main() { 11 | auto e = 8; 12 | auto r = div(5, e); 13 | auto pr = &r; 14 | *pr = 45; 15 | int& ref = r; 16 | return r; 17 | } 18 | -------------------------------------------------------------------------------- /data/main.cc: -------------------------------------------------------------------------------- 1 | template 2 | class Math { 3 | public: 4 | T Add(T a, T b) { 5 | return AddImpl(a,b); 6 | } 7 | 8 | T Sub(T a, T b) { 9 | return AddImpl(a,b); 10 | } 11 | private: 12 | T AddImpl(T a, T b) { 13 | return a + b; 14 | } 15 | 16 | T SubImpl(T a, T b) { 17 | return a - b; 18 | } 19 | }; 20 | 21 | class Base { 22 | public: 23 | virtual ~Base() { --x; } 24 | virtual void Print() = 0; 25 | virtual void Draw(); 26 | friend bool operator==(const Derived& a, const Derived& b); 27 | 28 | private: 29 | int x = 0; 30 | int y = 8; 31 | }; 32 | 33 | void Base::Draw() { 34 | x = x + y; 35 | } 36 | 37 | class Derived : public Base { 38 | public: 39 | void Print() override { 40 | Draw(); 41 | } 42 | 43 | void operator()() { 44 | Print(); 45 | } 46 | }; 47 | 48 | bool operator==(const Base& a, const Base& b) { 49 | return (a.x == b.x) && (a.y == b.y); 50 | } 51 | 52 | 53 | int func(float x, float y){ 54 | Math m; 55 | return static_cast(m.Add(x, y)); 56 | } 57 | 58 | int func_TestAll(float x, float y){ 59 | Math m; 60 | return static_cast(m.Add(x, y)); 61 | } 62 | 63 | template 64 | T subtract(T&& x, T&& y){ 65 | return x - y; 66 | } 67 | 68 | int maxof(int n_args, ...) { 69 | ++n_args; 70 | } 71 | 72 | template 73 | T adder(T v) { 74 | return v; 75 | } 76 | 77 | template 78 | T adder(T first, Args... args) { 79 | return first + adder(args...); 80 | } 81 | 82 | int main(int argc, char* argv[]) { 83 | auto r = subtract(4, 5); 84 | auto s = func(3.4, 5.8); 85 | return r+s; 86 | } 87 | -------------------------------------------------------------------------------- /data/method.cc: -------------------------------------------------------------------------------- 1 | class Base { 2 | public: 3 | virtual void Draw(); 4 | private: 5 | int x = 0; 6 | int y = 8; 7 | }; 8 | 9 | void Base::Draw() { 10 | x = x + y; 11 | } 12 | -------------------------------------------------------------------------------- /data/template.cc: -------------------------------------------------------------------------------- 1 | template 2 | T adder(T v) { 3 | return v; 4 | } 5 | 6 | template 7 | T adder(T first, Args... args) { 8 | return first + adder(args...); 9 | } 10 | 11 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Configure Docker environment for code2seq project 2 | `` 3 | cd docker 4 | docker build -t code2seq:1.0 . 5 | docker run -it -v [host_dataset_path]:[container_dataset_path] code2seq:1.0 bash 6 | cd /code2seq 7 | `` 8 | 9 | -------------------------------------------------------------------------------- /docker/dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | MAINTAINER Kirill rotate@ukr.net 3 | 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | 6 | RUN apt-get -y update 7 | RUN apt-get -y upgrade 8 | 9 | RUN apt-get install -y git 10 | RUN apt-get install -y python3 11 | RUN apt-get install -y python3-pip 12 | 13 | RUN pip3 install rouge 14 | RUN pip3 install requests 15 | RUN pip3 install tensorflow==1.12 16 | 17 | RUN git clone https://github.com/tech-srl/code2seq 18 | 19 | 20 | -------------------------------------------------------------------------------- /src/data_set_merge.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from bitarray import bitarray 3 | from tqdm import tqdm 4 | import lmdb 5 | import os 6 | import random 7 | 8 | 9 | class DataSetMerge: 10 | def __init__(self, output_path, map_size): 11 | self.output_path = output_path 12 | self.train_set_file = os.path.join(self.output_path, "dataset.train.c2s") 13 | self.test_set_file = os.path.join(self.output_path, "dataset.test.c2s") 14 | self.validation_set_file = os.path.join(self.output_path, "dataset.val.c2s") 15 | self.samples_db = lmdb.open(os.path.join(self.output_path, 'samples.db'), writemap=True) 16 | self.samples_db.set_mapsize(map_size) 17 | self.total_num = 0 18 | 19 | def merge(self, clear_resources=True): 20 | functions = set() 21 | sample_id = 0 22 | self.total_num = 0 23 | with self.samples_db.begin(write=True) as txn: 24 | files_num = sum(1 for _ in Path(self.output_path).rglob('*.c2s')) 25 | with tqdm(total=files_num) as pbar: 26 | for file_path in Path(self.output_path).rglob('*.c2s'): 27 | with file_path.open() as file: 28 | # print('Loading file: ' + file_path.absolute().as_posix()) 29 | for line in file.readlines(): 30 | src_mark_str, _, sample_line = line.partition(')') 31 | src_mark = src_mark_str[2:] 32 | if src_mark not in functions: 33 | txn.put(str(sample_id).encode('ascii'), sample_line.encode('ascii')) 34 | sample_id += 1 35 | functions.add(src_mark) 36 | if clear_resources: 37 | os.remove(file_path.absolute().as_posix()) 38 | pbar.update(1) 39 | self.total_num = sample_id - 1 40 | 41 | def dump_datasets(self, train_set_ratio=0.7): 42 | # split samples into test, validation and training parts 43 | all_samples_num = self.total_num + 1 44 | train_samples_num = int(all_samples_num * train_set_ratio) 45 | test_samples_num = (all_samples_num - train_samples_num) // 2 46 | processed = bitarray(all_samples_num) 47 | processed.setall(False) 48 | train_file = open(self.train_set_file, 'w') 49 | train_index = 0 50 | test_file = open(self.test_set_file, 'w') 51 | test_index = 0 52 | validation_file = open(self.validation_set_file, 'w') 53 | validation_index = 0 54 | try: 55 | with self.samples_db.begin(write=False) as txn: 56 | for _ in tqdm(range(self.total_num + 1)): 57 | index = random.randint(0, self.total_num) 58 | while processed[index]: 59 | index = random.randint(0, self.total_num) 60 | processed[index] = True 61 | sample = txn.get(str(index).encode('ascii')) 62 | if train_index < train_samples_num: 63 | train_file.write(sample.decode('ascii')) 64 | train_index += 1 65 | elif test_index < test_samples_num: 66 | test_file.write(sample.decode('ascii')) 67 | test_index += 1 68 | elif validation_index < test_samples_num: 69 | validation_file.write(sample.decode('ascii')) 70 | validation_index += 1 71 | finally: 72 | print("Closing files ...") 73 | train_file.close() 74 | test_file.close() 75 | validation_file.close() 76 | -------------------------------------------------------------------------------- /src/merge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from data_set_merge import DataSetMerge 4 | from pathlib import Path 5 | 6 | 7 | def main(): 8 | args_parser = argparse.ArgumentParser( 9 | description='merge resources generated by cppminer to a code2seq dataset') 10 | 11 | args_parser.add_argument('DataPath', 12 | metavar='path', 13 | type=str, 14 | help='the dataset sources path') 15 | 16 | args_parser.add_argument('-c', '--clear_resources', 17 | metavar='clear_resources_flag', 18 | type=bool, 19 | help='if True clear resource files', 20 | default=False, 21 | required=False) 22 | 23 | args_parser.add_argument('-m', '--map_size', 24 | metavar='map_file_size', 25 | type=int, 26 | help='size of the DB file, default(6442450944 bytes)', 27 | default=100000000000, 28 | required=False) 29 | 30 | args = args_parser.parse_args() 31 | 32 | output_path = Path(args.DataPath).resolve().as_posix() 33 | print('Path: ' + output_path) 34 | 35 | map_size = args.map_size 36 | print('Map size: ' + str(map_size)) 37 | 38 | print('Clear resources: ' + str(args.clear_resources)) 39 | 40 | # shuffle and merge samples 41 | print("Merging samples ...") 42 | merge = DataSetMerge(output_path, map_size) 43 | merge.merge(args.clear_resources) 44 | print("Dumping datasets ...") 45 | merge.dump_datasets(0.7) 46 | print("Merging done") 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /src/miner.py: -------------------------------------------------------------------------------- 1 | from clang.cindex import Config 2 | import argparse 3 | import time 4 | from pathlib import Path 5 | import multiprocessing 6 | import os 7 | from tqdm import tqdm 8 | from parser_process import ParserProcess 9 | 10 | file_types = ('*.c', '*.cc', '*.cpp', '*.cxx', '*.c++') 11 | 12 | 13 | def files(input_path): 14 | if os.path.isfile(input_path): 15 | yield input_path 16 | for file_type in file_types: 17 | for file_path in Path(input_path).rglob(file_type): 18 | yield file_path.as_posix() 19 | 20 | 21 | def main(): 22 | args_parser = argparse.ArgumentParser( 23 | description='cppminer generates a code2seq dataset from C++ sources') 24 | 25 | args_parser.add_argument('Path', 26 | metavar='path', 27 | type=str, 28 | help='the path sources directory') 29 | 30 | args_parser.add_argument('OutPath', 31 | metavar='out', 32 | type=str, 33 | help='the output path') 34 | 35 | args_parser.add_argument('-c', '--max_contexts_num', 36 | metavar='contexts-number', 37 | type=int, 38 | help='maximum number of contexts per sample', 39 | default=100, 40 | required=False) 41 | 42 | args_parser.add_argument('-l', '--max_path_len', 43 | metavar='path-length', 44 | type=int, 45 | help='maximum path length (0 - no limit)', 46 | default=0, 47 | required=False) 48 | 49 | args_parser.add_argument('-s', '--max_subtokens_num', 50 | metavar='subtokens-num', 51 | type=int, 52 | help='maximum number of sub-tokens in a token (0 - no limit)', 53 | default=5, 54 | required=False) 55 | 56 | args_parser.add_argument('-d', '--max_ast_depth', 57 | metavar='ast-depth', 58 | type=int, 59 | help='maximum depth of AST (0 - no limit)', 60 | default=0, 61 | required=False) 62 | 63 | args_parser.add_argument('-p', '--processes_num', 64 | metavar='processes-number', 65 | type=int, 66 | help='number of parallel processes', 67 | default=4, 68 | required=False) 69 | 70 | args_parser.add_argument('-e', '--libclang', 71 | metavar='libclang-path', 72 | type=str, 73 | help='path to libclang.so file', 74 | required=False) 75 | 76 | args = args_parser.parse_args() 77 | 78 | if args.libclang: 79 | # File path example '/usr/lib/llvm-6.0/lib/libclang.so' 80 | Config.set_library_file(args.libclang) 81 | 82 | parallel_processes_num = args.processes_num 83 | print('Parallel processes num: ' + str(parallel_processes_num)) 84 | 85 | max_contexts_num = args.max_contexts_num 86 | print('Max contexts num: ' + str(max_contexts_num)) 87 | 88 | max_path_len = args.max_path_len 89 | print('Max path length: ' + str(max_path_len)) 90 | 91 | max_subtokens_num = args.max_subtokens_num 92 | print('Max sub-tokens num: ' + str(max_subtokens_num)) 93 | 94 | max_ast_depth = args.max_ast_depth 95 | print('Max AST depth: ' + str(max_ast_depth)) 96 | 97 | input_path = Path(args.Path).resolve().as_posix() 98 | print('Input path: ' + input_path) 99 | 100 | output_path = Path(args.OutPath).resolve().as_posix() 101 | print('Output path: ' + output_path) 102 | 103 | print("Parsing files ...") 104 | tasks = multiprocessing.JoinableQueue() 105 | if parallel_processes_num == 1: 106 | parser = ParserProcess(tasks, max_contexts_num, max_path_len, max_subtokens_num, max_ast_depth, input_path, 107 | output_path) 108 | for file_path in files(input_path): 109 | print("Parsing : " + file_path) 110 | tasks.put(file_path) 111 | parser.parse_file() 112 | parser.save() 113 | tasks.join() 114 | else: 115 | processes = [ParserProcess(tasks, max_contexts_num, max_path_len, max_subtokens_num, max_ast_depth, input_path, 116 | output_path) 117 | for _ in range(parallel_processes_num)] 118 | for p in processes: 119 | p.start() 120 | 121 | for file_path in files(input_path): 122 | tasks.put(file_path) 123 | 124 | # add terminating tasks 125 | for i in range(parallel_processes_num): 126 | tasks.put(None) 127 | 128 | # Wait for all of the tasks to finish 129 | tasks_left = tasks.qsize() 130 | with tqdm(total=tasks_left) as pbar: 131 | while tasks_left > 0: 132 | time.sleep(1) 133 | tasks_num = tasks.qsize() 134 | pbar.update(tasks_left - tasks_num) 135 | tasks_left = tasks_num 136 | 137 | tasks.join() 138 | for p in processes: 139 | p.join() 140 | print("Parsing done") 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /src/parser_process.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | from clang.cindex import CompilationDatabase, CompilationDatabaseError 4 | from cpp_parser import AstParser 5 | 6 | 7 | def is_object_file(file_path): 8 | file_name = os.path.basename(file_path) 9 | if '.o.' in file_name: 10 | return True 11 | elif '.o' in file_name[-2:]: 12 | return True 13 | else: 14 | return False 15 | 16 | 17 | class ParserProcess(multiprocessing.Process): 18 | def __init__(self, task_queue, max_contexts_num, max_path_len, max_subtokens_num, max_ast_depth, input_path, output_path): 19 | multiprocessing.Process.__init__(self) 20 | self.task_queue = task_queue 21 | self.parser = AstParser(max_contexts_num, max_path_len, max_subtokens_num, max_ast_depth, output_path) 22 | try: 23 | self.compdb = CompilationDatabase.fromDirectory(input_path) 24 | except CompilationDatabaseError: 25 | self.compdb = None 26 | 27 | def run(self): 28 | default_compile_args = [] 29 | 30 | while self.parse_file(default_compile_args): 31 | pass 32 | 33 | self.save() 34 | return 35 | 36 | def save(self): 37 | self.parser.save() 38 | 39 | def parse_file(self, default_compile_args=[]): 40 | file_path = self.task_queue.get() 41 | if file_path is None: 42 | self.task_queue.task_done() 43 | return False 44 | # print('Parsing : {0} [{1}]'.format(file_path, os.getpid())) 45 | if not self.compdb: 46 | # print('Compilation database was not found in the input directory, using default args list') 47 | self.parser.parse([file_path] + default_compile_args) 48 | else: 49 | commands = self.compdb.getCompileCommands(file_path) 50 | if commands and len(commands) > 0: 51 | command = commands[0] 52 | cwd = os.getcwd() 53 | os.chdir(command.directory) 54 | args = [] 55 | cmd_args = command.arguments 56 | next(cmd_args) # drop compiler executable path 57 | for arg in cmd_args: 58 | if arg == '-Xclang': 59 | next(cmd_args) # skip clang specific arguments 60 | elif arg == '-c': 61 | continue # drop input filename option 62 | elif arg == '-o': 63 | continue # drop output filename option 64 | elif os.path.isfile(arg) or os.path.isdir(arg): 65 | continue # drop file names 66 | elif is_object_file(arg): 67 | continue # drop file names 68 | else: 69 | args.append(arg) 70 | self.parser.parse(args, command.filename) 71 | os.chdir(cwd) 72 | self.task_queue.task_done() 73 | return True 74 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | bitarray 2 | lmdb 3 | clang 4 | decorator 5 | networkx 6 | pygraphviz 7 | tqdm 8 | --------------------------------------------------------------------------------