├── .gitignore
├── LICENSE
├── README.md
├── code2seq
├── LICENSE
├── README.md
├── common.py
├── preprocess.py
└── preprocess.sh
├── cpp_parser
├── __init__.py
├── ast_parser.py
├── ast_utils.py
├── context.py
├── path.py
└── sample.py
├── data
├── CMakeLists.txt
├── big_op.cc
├── func.cc
├── main.cc
├── method.cc
└── template.cc
├── docker
├── README.md
└── dockerfile
└── src
├── data_set_merge.py
├── merge.py
├── miner.py
├── parser_process.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Kirill
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # cppminer
2 | cppminer produces a [code2seq](https://github.com/tech-srl/code2seq) compatible datasets from C++ code bases.
3 |
4 | Experimental [C++](https://drive.google.com/file/d/15BDd6zHFkVJXl95FG4JnnSse48k1UR3E/view?usp=sharing) dataset mined from the Chromium project sources.
5 |
6 | This tool consists from three scripts which should be run consistently.
7 |
8 | # 1. Miner
9 | The `miner.py` is the main utility which traverse c++ sources, parse them and produce raw dataset files.
10 |
11 | It has following command line interface:
12 | ~~~
13 | usage: miner.py [-h] [-c contexts-number] [-l path-length] [-d ast-depth] [-p processes-number] [-e libclang-path] path out
14 |
15 | positional arguments:
16 | path the path sources directory
17 | out the output path
18 |
19 | optional arguments:
20 | -h, --help show this help message and exit
21 | -c contexts-number, --max_contexts_num contexts-number
22 | maximum number of contexts per sample
23 | -l path-length, --max_path_len path-length
24 | maximum path length (0 - no limit)
25 | -d ast-depth, --max_ast_depth ast-depth
26 | maximum depth of AST (0 - no limit)
27 | -p processes-number, --processes_num processes-number
28 | number of parallel processes
29 | -e libclang-path, --libclang libclang-path
30 | path to libclang.so file
31 | ~~~
32 |
33 | The input path is traversed recursively and all files with following extensions `c, cc, cpp` are parsed.
34 | It is recommended to use the [c++ compilation database](https://clang.llvm.org/docs/JSONCompilationDatabase.html) which provides all required compilation flags for project files.
35 |
36 | These files have following format:
37 |
38 | * Each row is an example.
39 | * Each example is a space-delimited list of fields, where:
40 |
41 | 1. The first field is the target label, internally delimited by the "|" character (for example: compare|ignore|case)
42 | 2. Each of the following field are contexts, where each context has three components separated by commas (","). None of these components can include spaces nor commas.
43 |
44 | Context's components are a token, a path, and another token.
45 |
46 | Each `token` component is a token in the code, split to subtokens using the "|" character.
47 |
48 | Each `path` is a path between two tokens, split to path nodes using the "|" character. Example for a context:
49 | ```
50 | my|key,StringExression|MethodCall|Name,get|value
51 | ```
52 | Here `my|key` and `get|value` are tokens, and `StringExression|MethodCall|Name` is the syntactic path that connects them.
53 |
54 | # 2. Merge
55 | The `merge.py` is the utility which concatenates all raw file, shuffles them and produce three files `dataset.train.c2s`, `dataset.test.c2s` and `dataset.val.c2s` into the given directory.
56 | Also it can clean source files after merging. The important settings is the `map_file_size` which defines the size of the database file used for merging,
57 | you should increase default value of 6Gb for large datasets.
58 |
59 | It has following command line interface:
60 |
61 | ~~~
62 | usage: merge.py [-h] [-c clear_resources_flag] [-m map_file_size] path
63 |
64 | merge resources generated by cppminer to a code2seq dataset
65 |
66 | positional arguments:
67 | path the dataset sources path
68 |
69 | optional arguments:
70 | -h, --help show this help message and exit
71 | -c clear_resources_flag, --clear_resources clear_resources_flag
72 | if True clear resource files
73 | -m map_file_size, --map_size map_file_size
74 | size of the DB file, default(6442450944 bytes)
75 | ~~~
76 |
77 | # 3. Code2vec preprocess
78 |
79 | The third utility is the `preprocess.sh` from the `code2seq` folder, this is modified script from the original project which generates dataset in format suitable for the `code2seq` model.
80 | in general it creates new files with truncated and padded number of paths for each example.
81 |
--------------------------------------------------------------------------------
/code2seq/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Technion
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/code2seq/README.md:
--------------------------------------------------------------------------------
1 | This folder contatins modifies scripts from code2seq repository.
2 |
3 | Use them to produce final datasets.
4 |
--------------------------------------------------------------------------------
/code2seq/common.py:
--------------------------------------------------------------------------------
1 | import re
2 | import subprocess
3 | import sys
4 |
5 |
6 | class Common:
7 | internal_delimiter = '|'
8 | SOS = ''
9 | EOS = ''
10 | PAD = ''
11 | UNK = ''
12 |
13 | @staticmethod
14 | def normalize_word(word):
15 | stripped = re.sub(r'[^a-zA-Z]', '', word)
16 | if len(stripped) == 0:
17 | return word.lower()
18 | else:
19 | return stripped.lower()
20 |
21 | @staticmethod
22 | def load_histogram(path, max_size=None):
23 | histogram = {}
24 | with open(path, 'r') as file:
25 | for line in file.readlines():
26 | parts = line.split(' ')
27 | if not len(parts) == 2:
28 | continue
29 | histogram[parts[0]] = int(parts[1])
30 | sorted_histogram = [(k, histogram[k]) for k in sorted(histogram, key=histogram.get, reverse=True)]
31 | return dict(sorted_histogram[:max_size])
32 |
33 | @staticmethod
34 | def load_vocab_from_dict(word_to_count, add_values=[], max_size=None):
35 | word_to_index, index_to_word = {}, {}
36 | current_index = 0
37 | for value in add_values:
38 | word_to_index[value] = current_index
39 | index_to_word[current_index] = value
40 | current_index += 1
41 | sorted_counts = [(k, word_to_count[k]) for k in sorted(word_to_count, key=word_to_count.get, reverse=True)]
42 | limited_sorted = dict(sorted_counts[:max_size])
43 | for word, count in limited_sorted.items():
44 | word_to_index[word] = current_index
45 | index_to_word[current_index] = word
46 | current_index += 1
47 | return word_to_index, index_to_word, current_index
48 |
49 | @staticmethod
50 | def binary_to_string(binary_string):
51 | return binary_string.decode("utf-8")
52 |
53 | @staticmethod
54 | def binary_to_string_list(binary_string_list):
55 | return [Common.binary_to_string(w) for w in binary_string_list]
56 |
57 | @staticmethod
58 | def binary_to_string_matrix(binary_string_matrix):
59 | return [Common.binary_to_string_list(l) for l in binary_string_matrix]
60 |
61 | @staticmethod
62 | def binary_to_string_3d(binary_string_tensor):
63 | return [Common.binary_to_string_matrix(l) for l in binary_string_tensor]
64 |
65 | @staticmethod
66 | def legal_method_names_checker(name):
67 | return not name in [Common.UNK, Common.PAD, Common.EOS]
68 |
69 | @staticmethod
70 | def filter_impossible_names(top_words):
71 | result = list(filter(Common.legal_method_names_checker, top_words))
72 | return result
73 |
74 | @staticmethod
75 | def unique(sequence):
76 | unique = []
77 | [unique.append(item) for item in sequence if item not in unique]
78 | return unique
79 |
80 | @staticmethod
81 | def parse_results(result, pc_info_dict, topk=5):
82 | prediction_results = {}
83 | results_counter = 0
84 | for single_method in result:
85 | original_name, top_suggestions, top_scores, attention_per_context = list(single_method)
86 | current_method_prediction_results = PredictionResults(original_name)
87 | if attention_per_context is not None:
88 | word_attention_pairs = [(word, attention) for word, attention in
89 | zip(top_suggestions, attention_per_context) if
90 | Common.legal_method_names_checker(word)]
91 | for predicted_word, attention_timestep in word_attention_pairs:
92 | current_timestep_paths = []
93 | for context, attention in [(key, attention_timestep[key]) for key in
94 | sorted(attention_timestep, key=attention_timestep.get, reverse=True)][
95 | :topk]:
96 | if context in pc_info_dict:
97 | pc_info = pc_info_dict[context]
98 | current_timestep_paths.append((attention.item(), pc_info))
99 |
100 | current_method_prediction_results.append_prediction(predicted_word, current_timestep_paths)
101 | else:
102 | for predicted_seq in top_suggestions:
103 | filtered_seq = [word for word in predicted_seq if Common.legal_method_names_checker(word)]
104 | current_method_prediction_results.append_prediction(filtered_seq, None)
105 |
106 | prediction_results[results_counter] = current_method_prediction_results
107 | results_counter += 1
108 | return prediction_results
109 |
110 | @staticmethod
111 | def compute_bleu(ref_file_name, predicted_file_name):
112 | with open(predicted_file_name) as predicted_file:
113 | pipe = subprocess.Popen(["perl", "scripts/multi-bleu.perl", ref_file_name], stdin=predicted_file,
114 | stdout=sys.stdout, stderr=sys.stderr)
115 |
116 |
117 | class PredictionResults:
118 | def __init__(self, original_name):
119 | self.original_name = original_name
120 | self.predictions = list()
121 |
122 | def append_prediction(self, name, current_timestep_paths):
123 | self.predictions.append(SingleTimeStepPrediction(name, current_timestep_paths))
124 |
125 | class SingleTimeStepPrediction:
126 | def __init__(self, prediction, attention_paths):
127 | self.prediction = prediction
128 | if attention_paths is not None:
129 | paths_with_scores = []
130 | for attention_score, pc_info in attention_paths:
131 | path_context_dict = {'score': attention_score,
132 | 'path': pc_info.longPath,
133 | 'token1': pc_info.token1,
134 | 'token2': pc_info.token2}
135 | paths_with_scores.append(path_context_dict)
136 | self.attention_paths = paths_with_scores
137 |
138 |
139 | class PathContextInformation:
140 | def __init__(self, context):
141 | self.token1 = context['name1']
142 | self.longPath = context['path']
143 | self.shortPath = context['shortPath']
144 | self.token2 = context['name2']
145 |
146 | def __str__(self):
147 | return '%s,%s,%s' % (self.token1, self.shortPath, self.token2)
148 |
--------------------------------------------------------------------------------
/code2seq/preprocess.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from argparse import ArgumentParser
3 |
4 | import numpy as np
5 |
6 | import common
7 |
8 | '''
9 | This script preprocesses the data from MethodPaths. It truncates methods with too many contexts,
10 | and pads methods with less paths with spaces.
11 | '''
12 |
13 |
14 | def save_dictionaries(dataset_name, subtoken_to_count, node_to_count, target_to_count, max_contexts, num_examples):
15 | save_dict_file_path = '{}.dict.c2s'.format(dataset_name)
16 | with open(save_dict_file_path, 'wb') as file:
17 | pickle.dump(subtoken_to_count, file)
18 | pickle.dump(node_to_count, file)
19 | pickle.dump(target_to_count, file)
20 | pickle.dump(max_contexts, file)
21 | pickle.dump(num_examples, file)
22 | print('Dictionaries saved to: {}'.format(save_dict_file_path))
23 |
24 |
25 | def process_file(file_path, data_file_role, dataset_name, max_contexts, max_data_contexts):
26 | sum_total = 0
27 | sum_sampled = 0
28 | total = 0
29 | max_unfiltered = 0
30 | max_contexts_to_sample = max_data_contexts if data_file_role == 'train' else max_contexts
31 | output_path = '{}.{}.c2s'.format(dataset_name, data_file_role)
32 | with open(output_path, 'w') as outfile:
33 | with open(file_path, 'r') as file:
34 | for line in file:
35 | parts = line.rstrip('\n').split(' ')
36 | target_name = parts[0]
37 | contexts = parts[1:]
38 |
39 | if len(contexts) > max_unfiltered:
40 | max_unfiltered = len(contexts)
41 |
42 | sum_total += len(contexts)
43 | if len(contexts) > max_contexts_to_sample:
44 | contexts = np.random.choice(contexts, max_contexts_to_sample, replace=False)
45 |
46 | sum_sampled += len(contexts)
47 |
48 | csv_padding = " " * (max_data_contexts - len(contexts))
49 | total += 1
50 | outfile.write(target_name + ' ' + " ".join(contexts) + csv_padding + '\n')
51 |
52 | print('File: ' + data_file_path)
53 | print('Average total contexts: ' + str(float(sum_total) / total))
54 | print('Average final (after sampling) contexts: ' + str(float(sum_sampled) / total))
55 | print('Total examples: ' + str(total))
56 | print('Max number of contexts per word: ' + str(max_unfiltered))
57 | return total
58 |
59 |
60 | def context_full_found(context_parts, word_to_count, path_to_count):
61 | return context_parts[0] in word_to_count \
62 | and context_parts[1] in path_to_count and context_parts[2] in word_to_count
63 |
64 |
65 | def context_partial_found(context_parts, word_to_count, path_to_count):
66 | return context_parts[0] in word_to_count \
67 | or context_parts[1] in path_to_count or context_parts[2] in word_to_count
68 |
69 |
70 | if __name__ == '__main__':
71 | parser = ArgumentParser()
72 | parser.add_argument("-trd", "--train_data", dest="train_data_path",
73 | help="path to training data file", required=True)
74 | parser.add_argument("-ted", "--test_data", dest="test_data_path",
75 | help="path to test data file", required=True)
76 | parser.add_argument("-vd", "--val_data", dest="val_data_path",
77 | help="path to validation data file", required=True)
78 | parser.add_argument("-mc", "--max_contexts", dest="max_contexts", default=200,
79 | help="number of max contexts to keep in test+validation", required=False)
80 | parser.add_argument("-mdc", "--max_data_contexts", dest="max_data_contexts", default=1000,
81 | help="number of max contexts to keep in the dataset", required=False)
82 | parser.add_argument("-svs", "--subtoken_vocab_size", dest="subtoken_vocab_size", default=186277,
83 | help="Max number of source subtokens to keep in the vocabulary", required=False)
84 | parser.add_argument("-tvs", "--target_vocab_size", dest="target_vocab_size", default=26347,
85 | help="Max number of target words to keep in the vocabulary", required=False)
86 | parser.add_argument("-sh", "--subtoken_histogram", dest="subtoken_histogram",
87 | help="subtoken histogram file", metavar="FILE", required=True)
88 | parser.add_argument("-nh", "--node_histogram", dest="node_histogram",
89 | help="node_histogram file", metavar="FILE", required=True)
90 | parser.add_argument("-th", "--target_histogram", dest="target_histogram",
91 | help="target histogram file", metavar="FILE", required=True)
92 | parser.add_argument("-o", "--output_name", dest="output_name",
93 | help="output name - the base name for the created dataset", required=True, default='data')
94 | args = parser.parse_args()
95 |
96 | train_data_path = args.train_data_path
97 | test_data_path = args.test_data_path
98 | val_data_path = args.val_data_path
99 | subtoken_histogram_path = args.subtoken_histogram
100 | node_histogram_path = args.node_histogram
101 |
102 | subtoken_to_count = common.Common.load_histogram(subtoken_histogram_path,
103 | max_size=int(args.subtoken_vocab_size))
104 | node_to_count = common.Common.load_histogram(node_histogram_path,
105 | max_size=None)
106 | target_to_count = common.Common.load_histogram(args.target_histogram,
107 | max_size=int(args.target_vocab_size))
108 | print('subtoken vocab size: ', len(subtoken_to_count))
109 | print('node vocab size: ', len(node_to_count))
110 | print('target vocab size: ', len(target_to_count))
111 |
112 | num_training_examples = 0
113 | for data_file_path, data_role in zip([test_data_path, val_data_path, train_data_path], ['test', 'val', 'train']):
114 | num_examples = process_file(file_path=data_file_path, data_file_role=data_role, dataset_name=args.output_name,
115 | max_contexts=int(args.max_contexts), max_data_contexts=int(args.max_data_contexts))
116 | if data_role == 'train':
117 | num_training_examples = num_examples
118 |
119 | save_dictionaries(dataset_name=args.output_name, subtoken_to_count=subtoken_to_count,
120 | node_to_count=node_to_count, target_to_count=target_to_count,
121 | max_contexts=int(args.max_data_contexts), num_examples=num_training_examples)
122 |
--------------------------------------------------------------------------------
/code2seq/preprocess.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | ###########################################################
3 | # Change the following values to preprocess a new dataset.
4 | # BASE_PATH - script parameter point to the directory where data files are placed
5 | # TRAIN_DATA_FILE, VAL_DATA_FILE and TEST_DATA_FILE should be paths to
6 | # files containing files obtained from merge.py script
7 | # DATASET_NAME is just a name for the currently extracted
8 | # dataset.
9 | # MAX_DATA_CONTEXTS is the number of contexts to keep in the dataset for each
10 | # method (by default 1000). At training time, these contexts
11 | # will be downsampled dynamically to MAX_CONTEXTS.
12 | # MAX_CONTEXTS - the number of actual contexts (by default 200)
13 | # that are taken into consideration (out of MAX_DATA_CONTEXTS)
14 | # every training iteration. To avoid randomness at test time,
15 | # for the test and validation sets only MAX_CONTEXTS contexts are kept
16 | # (while for training, MAX_DATA_CONTEXTS are kept and MAX_CONTEXTS are
17 | # selected dynamically during training).
18 | # SUBTOKEN_VOCAB_SIZE, TARGET_VOCAB_SIZE -
19 | # - the number of subtokens and target words to keep
20 | # in the vocabulary (the top occurring words and paths will be kept).
21 | # NUM_THREADS - the number of parallel threads to use. It is
22 | # recommended to use a multi-core machine for the preprocessing
23 | # step and set this value to the number of cores.
24 | # PYTHON - python3 interpreter alias.
25 | BASE_PATH=$1
26 | OUT_PATH=$2
27 | DATASET_NAME=dataset
28 | MAX_DATA_CONTEXTS=1000
29 | MAX_CONTEXTS=200
30 | SUBTOKEN_VOCAB_SIZE=186277
31 | TARGET_VOCAB_SIZE=26347
32 | NUM_THREADS=64
33 | PYTHON=python3
34 | ###########################################################
35 |
36 | TRAIN_DATA_FILE=${BASE_PATH}/${DATASET_NAME}.train.c2s
37 | VAL_DATA_FILE=${BASE_PATH}/${DATASET_NAME}.val.c2s
38 | TEST_DATA_FILE=${BASE_PATH}/${DATASET_NAME}.test.c2s
39 |
40 | mkdir -p ${OUT_PATH}/data
41 | mkdir -p ${OUT_PATH}/data/${DATASET_NAME}
42 |
43 | TARGET_HISTOGRAM_FILE=${OUT_PATH}/data/${DATASET_NAME}/${DATASET_NAME}.histo.tgt.c2s
44 | SOURCE_SUBTOKEN_HISTOGRAM=${OUT_PATH}/data/${DATASET_NAME}/${DATASET_NAME}.histo.ori.c2s
45 | NODE_HISTOGRAM_FILE=${OUT_PATH}/data/${DATASET_NAME}/${DATASET_NAME}.histo.node.c2s
46 |
47 | echo "Creating histograms from the training data"
48 | cat ${TRAIN_DATA_FILE} | cut -d' ' -f1 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${TARGET_HISTOGRAM_FILE}
49 | cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f1,3 | tr ',|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${SOURCE_SUBTOKEN_HISTOGRAM}
50 | cat ${TRAIN_DATA_FILE} | cut -d' ' -f2- | tr ' ' '\n' | cut -d',' -f2 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' > ${NODE_HISTOGRAM_FILE}
51 |
52 | ${PYTHON} preprocess.py --train_data ${TRAIN_DATA_FILE} --test_data ${TEST_DATA_FILE} --val_data ${VAL_DATA_FILE} \
53 | --max_contexts ${MAX_CONTEXTS} --max_data_contexts ${MAX_DATA_CONTEXTS} --subtoken_vocab_size ${SUBTOKEN_VOCAB_SIZE} \
54 | --target_vocab_size ${TARGET_VOCAB_SIZE} --subtoken_histogram ${SOURCE_SUBTOKEN_HISTOGRAM} \
55 | --node_histogram ${NODE_HISTOGRAM_FILE} --target_histogram ${TARGET_HISTOGRAM_FILE} --output_name ${OUT_PATH}/data/${DATASET_NAME}/${DATASET_NAME}
56 |
57 | # If all went well, the raw data files can be deleted, because preprocess.py creates new files
58 | # with truncated and padded number of paths for each example.
59 | rm ${TARGET_HISTOGRAM_FILE} ${SOURCE_SUBTOKEN_HISTOGRAM} ${NODE_HISTOGRAM_FILE}
60 |
61 |
--------------------------------------------------------------------------------
/cpp_parser/__init__.py:
--------------------------------------------------------------------------------
1 | from .context import Context
2 | from .path import Path
3 | from .sample import Sample
4 | from .ast_parser import AstParser
5 |
--------------------------------------------------------------------------------
/cpp_parser/ast_parser.py:
--------------------------------------------------------------------------------
1 | from clang.cindex import Index
2 | from .sample import Sample
3 | from .context import Context
4 | from .path import Path
5 | from .ast_utils import ast_to_graph, is_function, is_class, is_operator_token, is_namespace, make_ast_err_message
6 | from networkx.algorithms import shortest_path
7 | from networkx.drawing.nx_agraph import to_agraph
8 | from itertools import combinations
9 | import uuid
10 | import os
11 | import re
12 | import random
13 |
14 |
15 | def debug_save_graph(func_node, g):
16 | file_name = func_node.spelling + ".png"
17 | num = 0
18 | while os.path.exists(file_name):
19 | file_name = func_node.spelling + str(num) + ".png"
20 | num += 1
21 | a = to_agraph(g)
22 | a.draw(file_name, prog='dot')
23 | a.clear()
24 |
25 |
26 | def tokenize(name, max_subtokens_num):
27 | if is_operator_token(name):
28 | return [name]
29 | first_tokens = name.split('_')
30 | str_tokens = []
31 | for token in first_tokens:
32 | internal_tokens = re.findall('[a-z]+|[A-Z]+[a-z]*|[0-9.]+|[-*/&|%=()]+', token)
33 | str_tokens += [t for t in internal_tokens if len(t) > 0]
34 | assert len(str_tokens) > 0, "Can't tokenize expr: {0}".format(name)
35 | if max_subtokens_num != 0:
36 | str_tokens = str_tokens[:max_subtokens_num]
37 | return str_tokens
38 |
39 |
40 | class AstParser:
41 | def __init__(self, max_contexts_num, max_path_len, max_subtokens_num, max_ast_depth, out_path):
42 | self.validate = False
43 | self.save_buffer_size = 1000
44 | self.out_path = out_path
45 | self.max_subtokens_num = max_subtokens_num
46 | self.max_contexts_num = max_contexts_num
47 | self.max_path_len = max_path_len
48 | self.max_ast_depth = max_ast_depth
49 | self.index = Index.create()
50 | self.samples = set()
51 | self.header_only_functions = set()
52 |
53 | def __del__(self):
54 | self.save()
55 |
56 | def __parse_node(self, node):
57 | try:
58 | namespaces = [x for x in node.get_children() if is_namespace(x)]
59 | for n in namespaces:
60 | # ignore standard library functions
61 | if n.displayname != 'std' and not n.displayname.startswith('__'):
62 | self.__parse_node(n)
63 |
64 | functions = [x for x in node.get_children() if is_function(x)]
65 | for f in functions:
66 | self.__parse_function(f)
67 |
68 | classes = [x for x in node.get_children() if is_class(x)]
69 | for c in classes:
70 | methods = [x for x in c.get_children() if is_function(x)]
71 | for m in methods:
72 | self.__parse_function(m)
73 | except Exception as e:
74 | if 'Unknown template argument kind' not in str(e):
75 | msg = make_ast_err_message(str(e), node)
76 | raise Exception(msg)
77 |
78 | self.__dump_samples()
79 |
80 | def parse(self, compiler_args, file_path=None):
81 | ast = self.index.parse(file_path, compiler_args)
82 | self.__parse_node(ast.cursor)
83 |
84 | def __dump_samples(self):
85 | if len(self.samples) >= self.save_buffer_size:
86 | self.save()
87 |
88 | def save(self):
89 | if not self.out_path:
90 | return
91 | if not os.path.exists(self.out_path):
92 | os.makedirs(self.out_path)
93 | if len(self.samples) > 0:
94 | file_name = os.path.join(self.out_path, str(uuid.uuid4().hex) + ".c2s")
95 | # print(file_name)
96 | with open(file_name, "w") as file:
97 | for sample in self.samples:
98 | file.write(str(sample.source_mark) + str(sample) + "\n")
99 | self.samples.clear()
100 |
101 | def __parse_function(self, func_node):
102 | try:
103 | # ignore standard library functions
104 | if func_node.displayname.startswith('__'):
105 | return
106 |
107 | # detect header only function duplicates
108 | file_name = func_node.location.file.name
109 | source_mark = (file_name, func_node.extent.start.line)
110 | if file_name.endswith('.h') and func_node.is_definition:
111 | # print('Header only function: {0}'.format(func_node.displayname))
112 | if source_mark in self.header_only_functions:
113 | # print('Duplicate')
114 | return
115 | else:
116 | self.header_only_functions.add(source_mark)
117 |
118 | key = tokenize(func_node.spelling, self.max_subtokens_num)
119 | g = ast_to_graph(func_node, self.max_ast_depth)
120 |
121 | # debug_save_graph(func_node, g)
122 |
123 | terminal_nodes = [node for (node, degree) in g.degree() if degree == 1]
124 | random.shuffle(terminal_nodes)
125 |
126 | contexts = set()
127 | ends = combinations(terminal_nodes, 2)
128 |
129 | for start, end in ends:
130 | path = shortest_path(g, start, end)
131 | if path:
132 | if self.max_path_len != 0 and len(path) > self.max_path_len:
133 | continue # skip too long paths
134 | path = path[1:-1]
135 | start_node = g.nodes[start]['label']
136 | tokenize_start_node = not g.nodes[start]['is_reserved']
137 | end_node = g.nodes[end]['label']
138 | tokenize_end_node = not g.nodes[end]['is_reserved']
139 |
140 | path_tokens = []
141 | for path_item in path:
142 | path_node = g.nodes[path_item]['label']
143 | path_tokens.append(path_node)
144 |
145 | context = Context(
146 | tokenize(start_node, self.max_subtokens_num) if tokenize_start_node else [start_node],
147 | tokenize(end_node, self.max_subtokens_num) if tokenize_end_node else [end_node],
148 | Path(path_tokens, self.validate), self.validate)
149 | contexts.add(context)
150 | if len(contexts) > self.max_contexts_num:
151 | break
152 |
153 | if len(contexts) > 0:
154 | sample = Sample(key, contexts, source_mark, self.validate)
155 | self.samples.add(sample)
156 | except Exception as e:
157 | # skip unknown cursor exceptions
158 | if 'Unknown template argument kind' not in str(e):
159 | print('Failed to parse function : ')
160 | print('Filename : ' + func_node.location.file.name)
161 | print('Start {0}:{1}'.format(func_node.extent.start.line, func_node.extent.start.column))
162 | print('End {0}:{1}'.format(func_node.extent.end.line, func_node.extent.end.column))
163 | print(e)
164 |
--------------------------------------------------------------------------------
/cpp_parser/ast_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import networkx as nx
4 | import uuid
5 | from clang.cindex import CursorKind, TokenKind
6 |
7 |
8 | def make_ast_err_message(msg, ast_node):
9 | msg = 'Error: {0} in:\n'.format(msg)
10 | msg += 'Filename : {0}\n'.format(ast_node.location.file.name)
11 | msg += 'Start {0}:{1}'.format(ast_node.extent.start.line, ast_node.extent.start.column)
12 | msg += ' End {0}:{1}'.format(ast_node.extent.end.line, ast_node.extent.end.column)
13 | return msg
14 |
15 |
16 | def is_node_kind_safe(node, kinds):
17 | try:
18 | return node.kind in kinds
19 | except Exception as e:
20 | msg = make_ast_err_message(str(e), node)
21 | if 'Unknown template argument kind' not in str(e):
22 | print(msg)
23 | return CursorKind.NOT_IMPLEMENTED
24 |
25 |
26 | def is_namespace(node):
27 | if is_node_kind_safe(node, [CursorKind.NAMESPACE]):
28 | return True
29 | return False
30 |
31 |
32 | def is_function(node):
33 | if is_node_kind_safe(node, [CursorKind.FUNCTION_DECL,
34 | CursorKind.FUNCTION_TEMPLATE,
35 | CursorKind.CXX_METHOD,
36 | CursorKind.DESTRUCTOR,
37 | CursorKind.CONSTRUCTOR]):
38 | if node.is_definition():
39 | not_empty = False
40 | for _ in node.get_children():
41 | not_empty = True
42 | break
43 | return not_empty
44 | return False
45 |
46 |
47 | def is_class(node):
48 | return is_node_kind_safe(node, [CursorKind.CLASS_TEMPLATE,
49 | CursorKind.CLASS_TEMPLATE_PARTIAL_SPECIALIZATION,
50 | CursorKind.CLASS_DECL])
51 |
52 |
53 | def is_literal(node):
54 | return is_node_kind_safe(node, [CursorKind.INTEGER_LITERAL,
55 | CursorKind.FLOATING_LITERAL,
56 | CursorKind.IMAGINARY_LITERAL,
57 | CursorKind.STRING_LITERAL,
58 | CursorKind.CHARACTER_LITERAL])
59 |
60 |
61 | def is_template_parameter(node):
62 | return is_node_kind_safe(node, [CursorKind.TEMPLATE_TYPE_PARAMETER,
63 | CursorKind.TEMPLATE_TEMPLATE_PARAMETER])
64 |
65 |
66 | def is_reference(node):
67 | return is_node_kind_safe(node, [CursorKind.DECL_REF_EXPR, CursorKind.MEMBER_REF_EXPR])
68 |
69 |
70 | def is_operator(node):
71 | return is_node_kind_safe(node, [CursorKind.BINARY_OPERATOR,
72 | CursorKind.UNARY_OPERATOR,
73 | CursorKind.COMPOUND_ASSIGNMENT_OPERATOR])
74 |
75 |
76 | def is_call_expr(node):
77 | return is_node_kind_safe(node, [CursorKind.CALL_EXPR])
78 |
79 |
80 | binary_operators = ['+', '-', '*', '/', '%', '&', '|']
81 | unary_operators = ['++', '--']
82 | comparison_operators = ['==', '<=', '>=', '<', '>', '!=', '&&', '||']
83 | unary_assignment_operators = [op + '=' for op in binary_operators]
84 | assignment_operators = ['='] + unary_assignment_operators
85 |
86 |
87 | def is_operator_token(token):
88 | if token in binary_operators:
89 | return True
90 | if token in unary_operators:
91 | return True
92 | if token in comparison_operators:
93 | return True
94 | if token in unary_assignment_operators:
95 | return True
96 | if token in assignment_operators:
97 | return True
98 |
99 |
100 | def get_id():
101 | node_id = uuid.uuid1()
102 | return node_id.int
103 |
104 |
105 | def add_node(ast_node, graph):
106 | try:
107 | node_id = ast_node.hash
108 | kind = ast_node.kind.name
109 | # skip meaningless AST primitives
110 | if ast_node.kind == CursorKind.DECL_STMT or \
111 | ast_node.kind == CursorKind.UNEXPOSED_EXPR:
112 | return False
113 |
114 | if is_operator(ast_node):
115 | op_name = get_operator(ast_node)
116 | kind = kind.strip() + "_" + "_".join(op_name)
117 |
118 | graph.add_node(node_id, label=kind, is_reserved=True)
119 |
120 | # print("Cursor kind : {0}".format(kind))
121 | if ast_node.kind.is_declaration():
122 | add_declaration(node_id, ast_node, graph)
123 | elif is_literal(ast_node):
124 | add_literal(node_id, ast_node, graph)
125 | elif is_reference(ast_node):
126 | add_reference(node_id, ast_node, graph)
127 | elif is_call_expr(ast_node):
128 | add_call_expr(node_id, ast_node, graph)
129 |
130 | return True
131 | except Exception as e:
132 | if 'Unknown template argument kind' not in str(e):
133 | msg = make_ast_err_message(str(e), ast_node)
134 | raise Exception(msg)
135 |
136 |
137 | def add_child(graph, parent_id, name, is_reserved=True):
138 | child_id = get_id()
139 | assert len(name) > 0, "Missing node name"
140 | graph.add_node(child_id, label=name, is_reserved=is_reserved)
141 | graph.add_edge(parent_id, child_id)
142 |
143 |
144 | def add_intermediate_node(graph, parent_id, name):
145 | child_id = get_id()
146 | assert len(name) > 0, "Missing node name"
147 | graph.add_node(child_id, label=name, is_reserved=True)
148 | graph.add_edge(parent_id, child_id)
149 | return child_id
150 |
151 |
152 | def add_call_expr(parent_id, ast_node, graph):
153 | expr_type = ast_node.type.spelling
154 | expr_type_node_id = add_intermediate_node(graph, parent_id, "EXPR_TYPE")
155 | add_child(graph, expr_type_node_id, expr_type, is_reserved=False)
156 |
157 |
158 | def fix_cpp_operator_spelling(op_name):
159 | if op_name == '|':
160 | return 'OPERATOR_BINARY_OR'
161 | elif op_name == '||':
162 | return 'OPERATOR_LOGICAL_OR'
163 | elif op_name == '|=':
164 | return 'OPERATOR_ASSIGN_OR'
165 | elif op_name == ',':
166 | return 'OPERATOR_COMMA'
167 | else:
168 | return op_name
169 |
170 |
171 | def get_operator(ast_node):
172 | name_token = None
173 | for token in ast_node.get_tokens():
174 | if is_operator_token(token.spelling):
175 | name_token = token
176 | break
177 |
178 | if not name_token:
179 | filename = ast_node.location.file.name
180 | with open(filename, 'r') as fh:
181 | contents = fh.read()
182 | code_str = contents[ast_node.extent.start.offset: ast_node.extent.end.offset]
183 | name = []
184 | for ch in code_str:
185 | if ch in binary_operators:
186 | name.append(fix_cpp_operator_spelling(ch).strip())
187 | return name
188 | else:
189 | name = name_token.spelling
190 | name = fix_cpp_operator_spelling(name)
191 | # print("\tName : {0}".format(name))
192 | return [name.strip()]
193 |
194 |
195 | def add_reference(parent_id, ast_node, graph):
196 | is_reserved = True
197 | name = "REFERENCE"
198 | if ast_node.kind in [CursorKind.DECL_REF_EXPR, CursorKind.MEMBER_REF_EXPR]:
199 | for token in ast_node.get_tokens():
200 | if token.kind == TokenKind.IDENTIFIER:
201 | name = token.spelling
202 | is_reserved = False
203 | break
204 | else:
205 | name = ast_node.spelling
206 | is_reserved = False
207 |
208 | add_child(graph, parent_id, name, is_reserved)
209 | # print("\tName : {0}".format(name))
210 |
211 |
212 | def add_literal(parent_id, ast_node, graph):
213 | if ast_node.kind in [CursorKind.STRING_LITERAL,
214 | CursorKind.CHARACTER_LITERAL]:
215 | add_child(graph, parent_id, 'STRING_VALUE', is_reserved=True)
216 | else:
217 | token = next(ast_node.get_tokens(), None)
218 | if token:
219 | value = token.spelling
220 | add_child(graph, parent_id, value)
221 | # print("\tValue : {0}".format(value))
222 |
223 |
224 | def add_declaration(parent_id, ast_node, graph):
225 | is_func = False
226 | if is_function(ast_node):
227 | is_func = True
228 | return_type = ast_node.type.get_result().spelling
229 | if len(return_type) > 0:
230 | return_type_node_id = add_intermediate_node(graph, parent_id, "RETURN_TYPE")
231 | add_child(graph, return_type_node_id, return_type, is_reserved=False)
232 | else:
233 | declaration_type = ast_node.type.spelling
234 | if len(declaration_type) > 0:
235 | declaration_type_node_id = add_intermediate_node(graph, parent_id, "DECLARATION_TYPE")
236 | add_child(graph, declaration_type_node_id, declaration_type, is_reserved=False)
237 | # print("\tDecl type : {0}".format(declaration_type))
238 |
239 | if not is_template_parameter(ast_node):
240 | is_reserved = False
241 | # we skip function names to prevent over-fitting in the code2seq learning tasks
242 | if is_func:
243 | name = "FUNCTION_NAME"
244 | is_reserved = True
245 | else:
246 | name = ast_node.spelling
247 |
248 | # handle unnamed declarations
249 | if len(name) == 0:
250 | name = ast_node.kind.name + "_UNNAMED"
251 | is_reserved = True
252 |
253 | name_node_id = add_intermediate_node(graph, parent_id, "DECLARATION_NAME")
254 |
255 | add_child(graph, name_node_id, name, is_reserved=is_reserved)
256 | # print("\tName : {0}".format(name))
257 |
258 |
259 | def func_from_pointer(ast_node):
260 | children = list(ast_node.get_children())
261 | if children:
262 | return func_from_pointer(children[0])
263 | else:
264 | return ast_node
265 |
266 |
267 | def ast_to_graph(ast_start_node, max_depth):
268 | g = nx.Graph()
269 | stack = [(ast_start_node, 0)]
270 | parent_map = {ast_start_node.hash: None}
271 | while stack:
272 | ast_node, depth = stack.pop()
273 | node_id = ast_node.hash
274 | if not g.has_node(node_id):
275 | parent_id = parent_map[node_id]
276 | if add_node(ast_node, g):
277 | if parent_id is not None:
278 | g.add_edge(parent_id, node_id)
279 |
280 | if is_call_expr(ast_node):
281 | func_name = None
282 | if ast_node.referenced:
283 | func_name = ast_node.referenced.spelling
284 | else:
285 | # pointer to function
286 | func_node = func_from_pointer(list(ast_node.get_children())[0])
287 | func_name = func_node.spelling
288 |
289 | if not func_name:
290 | func_name = "FUNCTION_CALL"
291 | func_name = re.sub(r'\s+|,+', '', func_name)
292 |
293 | call_expr_id = add_intermediate_node(g, node_id, func_name)
294 | node_id = call_expr_id
295 | else:
296 | node_id = parent_id
297 |
298 | # Ignore too deep trees
299 | if max_depth == 0 or depth <= max_depth:
300 | if is_call_expr(ast_node):
301 | for arg_node in ast_node.get_arguments():
302 | stack.append((arg_node, depth + 1))
303 | parent_map[arg_node.hash] = node_id
304 | else:
305 | for child_node in ast_node.get_children():
306 | stack.append((child_node, depth + 1))
307 | parent_map[child_node.hash] = node_id
308 | return g
309 |
--------------------------------------------------------------------------------
/cpp_parser/context.py:
--------------------------------------------------------------------------------
1 | class Context:
2 | def __init__(self, start_token, end_token, path, validate=False):
3 | self.start_token = start_token
4 | self.end_token = end_token
5 | self.path = path
6 | if validate:
7 | self.__validate()
8 |
9 | def __validate(self):
10 | self.__validate_token(self.start_token)
11 | self.__validate_token(self.end_token)
12 |
13 | def __validate_token(self, token):
14 | assert len(token) > 0, "Invalid token format: {0}".format(token)
15 | for sub_token in token:
16 | assert len(sub_token) > 0, "Invalid sub-token: {0}".format(token)
17 | self.validate_sub_token(sub_token)
18 |
19 | @staticmethod
20 | def validate_sub_token(sub_token):
21 | assert (' ' not in sub_token and
22 | '|' not in sub_token and
23 | ',' not in sub_token), "Invalid sub-token format: {0}".format(sub_token)
24 |
--------------------------------------------------------------------------------
/cpp_parser/path.py:
--------------------------------------------------------------------------------
1 | from .context import Context
2 |
3 |
4 | class Path:
5 | def __init__(self, tokens, validate=False):
6 | self.tokens = tokens
7 | if validate:
8 | self.__validate()
9 |
10 | def __validate(self):
11 | for sub_token in self.tokens:
12 | assert len(sub_token) > 0, "Invalid sub-token in the path: {0}".format(self.tokens)
13 | Context.validate_sub_token(sub_token)
14 |
--------------------------------------------------------------------------------
/cpp_parser/sample.py:
--------------------------------------------------------------------------------
1 | from .context import Context
2 |
3 |
4 | def make_str_key(list_value):
5 | str_value = ""
6 | for item in list_value:
7 | str_value += str(item)
8 | str_value += "|"
9 | str_value = str_value[:-1]
10 | return str_value
11 |
12 |
13 | class Sample:
14 | def __init__(self, key, contexts, source_mark, validate=False):
15 | self.key = key
16 | self.contexts = contexts
17 | self.source_mark = source_mark
18 | if validate:
19 | self.__validate()
20 |
21 | def __validate(self):
22 | assert len(self.key) > 0, "Invalid target key format: {0}".format(self.key)
23 | for sub_token in self.key:
24 | assert len(sub_token) > 0, "Invalid sub-token in the target key: {0}".format(self.key)
25 | Context.validate_sub_token(sub_token)
26 |
27 | def __str__(self):
28 | str_value = make_str_key(self.key)
29 | for context in self.contexts:
30 | str_value += " " + make_str_key(context.start_token)
31 | str_value += "," + make_str_key(context.path.tokens)
32 | str_value += "," + make_str_key(context.end_token)
33 | return str_value
34 |
--------------------------------------------------------------------------------
/data/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.0)
2 | project(data)
3 |
4 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
5 | set(CMAKE_CXX_STANDARD 17)
6 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
7 | set(CMAKE_CXX_EXTENSIONS OFF)
8 |
9 | file(GLOB SRC_FILES
10 | "*.c"
11 | "*.cc"
12 | "*.cpp"
13 | )
14 |
15 | add_executable(${PROJECT_NAME} ${SRC_FILES})
16 |
--------------------------------------------------------------------------------
/data/big_op.cc:
--------------------------------------------------------------------------------
1 | #define SHA1_DIGEST_LENGTH 20
2 |
3 | int main() {
4 | char expected_sha1str[SHA1_DIGEST_LENGTH * 2 + 1];
5 | return expected_sha1str[6];
6 | }
7 |
--------------------------------------------------------------------------------
/data/func.cc:
--------------------------------------------------------------------------------
1 | int div(int x, int y) {
2 | float z = 45;
3 | if (y != 0) {
4 | return x /y;
5 | } else {
6 | return 0;
7 | }
8 | }
9 |
10 | int main() {
11 | auto e = 8;
12 | auto r = div(5, e);
13 | auto pr = &r;
14 | *pr = 45;
15 | int& ref = r;
16 | return r;
17 | }
18 |
--------------------------------------------------------------------------------
/data/main.cc:
--------------------------------------------------------------------------------
1 | template
2 | class Math {
3 | public:
4 | T Add(T a, T b) {
5 | return AddImpl(a,b);
6 | }
7 |
8 | T Sub(T a, T b) {
9 | return AddImpl(a,b);
10 | }
11 | private:
12 | T AddImpl(T a, T b) {
13 | return a + b;
14 | }
15 |
16 | T SubImpl(T a, T b) {
17 | return a - b;
18 | }
19 | };
20 |
21 | class Base {
22 | public:
23 | virtual ~Base() { --x; }
24 | virtual void Print() = 0;
25 | virtual void Draw();
26 | friend bool operator==(const Derived& a, const Derived& b);
27 |
28 | private:
29 | int x = 0;
30 | int y = 8;
31 | };
32 |
33 | void Base::Draw() {
34 | x = x + y;
35 | }
36 |
37 | class Derived : public Base {
38 | public:
39 | void Print() override {
40 | Draw();
41 | }
42 |
43 | void operator()() {
44 | Print();
45 | }
46 | };
47 |
48 | bool operator==(const Base& a, const Base& b) {
49 | return (a.x == b.x) && (a.y == b.y);
50 | }
51 |
52 |
53 | int func(float x, float y){
54 | Math m;
55 | return static_cast(m.Add(x, y));
56 | }
57 |
58 | int func_TestAll(float x, float y){
59 | Math m;
60 | return static_cast(m.Add(x, y));
61 | }
62 |
63 | template
64 | T subtract(T&& x, T&& y){
65 | return x - y;
66 | }
67 |
68 | int maxof(int n_args, ...) {
69 | ++n_args;
70 | }
71 |
72 | template
73 | T adder(T v) {
74 | return v;
75 | }
76 |
77 | template
78 | T adder(T first, Args... args) {
79 | return first + adder(args...);
80 | }
81 |
82 | int main(int argc, char* argv[]) {
83 | auto r = subtract(4, 5);
84 | auto s = func(3.4, 5.8);
85 | return r+s;
86 | }
87 |
--------------------------------------------------------------------------------
/data/method.cc:
--------------------------------------------------------------------------------
1 | class Base {
2 | public:
3 | virtual void Draw();
4 | private:
5 | int x = 0;
6 | int y = 8;
7 | };
8 |
9 | void Base::Draw() {
10 | x = x + y;
11 | }
12 |
--------------------------------------------------------------------------------
/data/template.cc:
--------------------------------------------------------------------------------
1 | template
2 | T adder(T v) {
3 | return v;
4 | }
5 |
6 | template
7 | T adder(T first, Args... args) {
8 | return first + adder(args...);
9 | }
10 |
11 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | # Configure Docker environment for code2seq project
2 | ``
3 | cd docker
4 | docker build -t code2seq:1.0 .
5 | docker run -it -v [host_dataset_path]:[container_dataset_path] code2seq:1.0 bash
6 | cd /code2seq
7 | ``
8 |
9 |
--------------------------------------------------------------------------------
/docker/dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:18.04
2 | MAINTAINER Kirill rotate@ukr.net
3 |
4 | ENV DEBIAN_FRONTEND=noninteractive
5 |
6 | RUN apt-get -y update
7 | RUN apt-get -y upgrade
8 |
9 | RUN apt-get install -y git
10 | RUN apt-get install -y python3
11 | RUN apt-get install -y python3-pip
12 |
13 | RUN pip3 install rouge
14 | RUN pip3 install requests
15 | RUN pip3 install tensorflow==1.12
16 |
17 | RUN git clone https://github.com/tech-srl/code2seq
18 |
19 |
20 |
--------------------------------------------------------------------------------
/src/data_set_merge.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from bitarray import bitarray
3 | from tqdm import tqdm
4 | import lmdb
5 | import os
6 | import random
7 |
8 |
9 | class DataSetMerge:
10 | def __init__(self, output_path, map_size):
11 | self.output_path = output_path
12 | self.train_set_file = os.path.join(self.output_path, "dataset.train.c2s")
13 | self.test_set_file = os.path.join(self.output_path, "dataset.test.c2s")
14 | self.validation_set_file = os.path.join(self.output_path, "dataset.val.c2s")
15 | self.samples_db = lmdb.open(os.path.join(self.output_path, 'samples.db'), writemap=True)
16 | self.samples_db.set_mapsize(map_size)
17 | self.total_num = 0
18 |
19 | def merge(self, clear_resources=True):
20 | functions = set()
21 | sample_id = 0
22 | self.total_num = 0
23 | with self.samples_db.begin(write=True) as txn:
24 | files_num = sum(1 for _ in Path(self.output_path).rglob('*.c2s'))
25 | with tqdm(total=files_num) as pbar:
26 | for file_path in Path(self.output_path).rglob('*.c2s'):
27 | with file_path.open() as file:
28 | # print('Loading file: ' + file_path.absolute().as_posix())
29 | for line in file.readlines():
30 | src_mark_str, _, sample_line = line.partition(')')
31 | src_mark = src_mark_str[2:]
32 | if src_mark not in functions:
33 | txn.put(str(sample_id).encode('ascii'), sample_line.encode('ascii'))
34 | sample_id += 1
35 | functions.add(src_mark)
36 | if clear_resources:
37 | os.remove(file_path.absolute().as_posix())
38 | pbar.update(1)
39 | self.total_num = sample_id - 1
40 |
41 | def dump_datasets(self, train_set_ratio=0.7):
42 | # split samples into test, validation and training parts
43 | all_samples_num = self.total_num + 1
44 | train_samples_num = int(all_samples_num * train_set_ratio)
45 | test_samples_num = (all_samples_num - train_samples_num) // 2
46 | processed = bitarray(all_samples_num)
47 | processed.setall(False)
48 | train_file = open(self.train_set_file, 'w')
49 | train_index = 0
50 | test_file = open(self.test_set_file, 'w')
51 | test_index = 0
52 | validation_file = open(self.validation_set_file, 'w')
53 | validation_index = 0
54 | try:
55 | with self.samples_db.begin(write=False) as txn:
56 | for _ in tqdm(range(self.total_num + 1)):
57 | index = random.randint(0, self.total_num)
58 | while processed[index]:
59 | index = random.randint(0, self.total_num)
60 | processed[index] = True
61 | sample = txn.get(str(index).encode('ascii'))
62 | if train_index < train_samples_num:
63 | train_file.write(sample.decode('ascii'))
64 | train_index += 1
65 | elif test_index < test_samples_num:
66 | test_file.write(sample.decode('ascii'))
67 | test_index += 1
68 | elif validation_index < test_samples_num:
69 | validation_file.write(sample.decode('ascii'))
70 | validation_index += 1
71 | finally:
72 | print("Closing files ...")
73 | train_file.close()
74 | test_file.close()
75 | validation_file.close()
76 |
--------------------------------------------------------------------------------
/src/merge.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from data_set_merge import DataSetMerge
4 | from pathlib import Path
5 |
6 |
7 | def main():
8 | args_parser = argparse.ArgumentParser(
9 | description='merge resources generated by cppminer to a code2seq dataset')
10 |
11 | args_parser.add_argument('DataPath',
12 | metavar='path',
13 | type=str,
14 | help='the dataset sources path')
15 |
16 | args_parser.add_argument('-c', '--clear_resources',
17 | metavar='clear_resources_flag',
18 | type=bool,
19 | help='if True clear resource files',
20 | default=False,
21 | required=False)
22 |
23 | args_parser.add_argument('-m', '--map_size',
24 | metavar='map_file_size',
25 | type=int,
26 | help='size of the DB file, default(6442450944 bytes)',
27 | default=100000000000,
28 | required=False)
29 |
30 | args = args_parser.parse_args()
31 |
32 | output_path = Path(args.DataPath).resolve().as_posix()
33 | print('Path: ' + output_path)
34 |
35 | map_size = args.map_size
36 | print('Map size: ' + str(map_size))
37 |
38 | print('Clear resources: ' + str(args.clear_resources))
39 |
40 | # shuffle and merge samples
41 | print("Merging samples ...")
42 | merge = DataSetMerge(output_path, map_size)
43 | merge.merge(args.clear_resources)
44 | print("Dumping datasets ...")
45 | merge.dump_datasets(0.7)
46 | print("Merging done")
47 |
48 |
49 | if __name__ == '__main__':
50 | main()
51 |
--------------------------------------------------------------------------------
/src/miner.py:
--------------------------------------------------------------------------------
1 | from clang.cindex import Config
2 | import argparse
3 | import time
4 | from pathlib import Path
5 | import multiprocessing
6 | import os
7 | from tqdm import tqdm
8 | from parser_process import ParserProcess
9 |
10 | file_types = ('*.c', '*.cc', '*.cpp', '*.cxx', '*.c++')
11 |
12 |
13 | def files(input_path):
14 | if os.path.isfile(input_path):
15 | yield input_path
16 | for file_type in file_types:
17 | for file_path in Path(input_path).rglob(file_type):
18 | yield file_path.as_posix()
19 |
20 |
21 | def main():
22 | args_parser = argparse.ArgumentParser(
23 | description='cppminer generates a code2seq dataset from C++ sources')
24 |
25 | args_parser.add_argument('Path',
26 | metavar='path',
27 | type=str,
28 | help='the path sources directory')
29 |
30 | args_parser.add_argument('OutPath',
31 | metavar='out',
32 | type=str,
33 | help='the output path')
34 |
35 | args_parser.add_argument('-c', '--max_contexts_num',
36 | metavar='contexts-number',
37 | type=int,
38 | help='maximum number of contexts per sample',
39 | default=100,
40 | required=False)
41 |
42 | args_parser.add_argument('-l', '--max_path_len',
43 | metavar='path-length',
44 | type=int,
45 | help='maximum path length (0 - no limit)',
46 | default=0,
47 | required=False)
48 |
49 | args_parser.add_argument('-s', '--max_subtokens_num',
50 | metavar='subtokens-num',
51 | type=int,
52 | help='maximum number of sub-tokens in a token (0 - no limit)',
53 | default=5,
54 | required=False)
55 |
56 | args_parser.add_argument('-d', '--max_ast_depth',
57 | metavar='ast-depth',
58 | type=int,
59 | help='maximum depth of AST (0 - no limit)',
60 | default=0,
61 | required=False)
62 |
63 | args_parser.add_argument('-p', '--processes_num',
64 | metavar='processes-number',
65 | type=int,
66 | help='number of parallel processes',
67 | default=4,
68 | required=False)
69 |
70 | args_parser.add_argument('-e', '--libclang',
71 | metavar='libclang-path',
72 | type=str,
73 | help='path to libclang.so file',
74 | required=False)
75 |
76 | args = args_parser.parse_args()
77 |
78 | if args.libclang:
79 | # File path example '/usr/lib/llvm-6.0/lib/libclang.so'
80 | Config.set_library_file(args.libclang)
81 |
82 | parallel_processes_num = args.processes_num
83 | print('Parallel processes num: ' + str(parallel_processes_num))
84 |
85 | max_contexts_num = args.max_contexts_num
86 | print('Max contexts num: ' + str(max_contexts_num))
87 |
88 | max_path_len = args.max_path_len
89 | print('Max path length: ' + str(max_path_len))
90 |
91 | max_subtokens_num = args.max_subtokens_num
92 | print('Max sub-tokens num: ' + str(max_subtokens_num))
93 |
94 | max_ast_depth = args.max_ast_depth
95 | print('Max AST depth: ' + str(max_ast_depth))
96 |
97 | input_path = Path(args.Path).resolve().as_posix()
98 | print('Input path: ' + input_path)
99 |
100 | output_path = Path(args.OutPath).resolve().as_posix()
101 | print('Output path: ' + output_path)
102 |
103 | print("Parsing files ...")
104 | tasks = multiprocessing.JoinableQueue()
105 | if parallel_processes_num == 1:
106 | parser = ParserProcess(tasks, max_contexts_num, max_path_len, max_subtokens_num, max_ast_depth, input_path,
107 | output_path)
108 | for file_path in files(input_path):
109 | print("Parsing : " + file_path)
110 | tasks.put(file_path)
111 | parser.parse_file()
112 | parser.save()
113 | tasks.join()
114 | else:
115 | processes = [ParserProcess(tasks, max_contexts_num, max_path_len, max_subtokens_num, max_ast_depth, input_path,
116 | output_path)
117 | for _ in range(parallel_processes_num)]
118 | for p in processes:
119 | p.start()
120 |
121 | for file_path in files(input_path):
122 | tasks.put(file_path)
123 |
124 | # add terminating tasks
125 | for i in range(parallel_processes_num):
126 | tasks.put(None)
127 |
128 | # Wait for all of the tasks to finish
129 | tasks_left = tasks.qsize()
130 | with tqdm(total=tasks_left) as pbar:
131 | while tasks_left > 0:
132 | time.sleep(1)
133 | tasks_num = tasks.qsize()
134 | pbar.update(tasks_left - tasks_num)
135 | tasks_left = tasks_num
136 |
137 | tasks.join()
138 | for p in processes:
139 | p.join()
140 | print("Parsing done")
141 |
142 |
143 | if __name__ == '__main__':
144 | main()
145 |
--------------------------------------------------------------------------------
/src/parser_process.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os
3 | from clang.cindex import CompilationDatabase, CompilationDatabaseError
4 | from cpp_parser import AstParser
5 |
6 |
7 | def is_object_file(file_path):
8 | file_name = os.path.basename(file_path)
9 | if '.o.' in file_name:
10 | return True
11 | elif '.o' in file_name[-2:]:
12 | return True
13 | else:
14 | return False
15 |
16 |
17 | class ParserProcess(multiprocessing.Process):
18 | def __init__(self, task_queue, max_contexts_num, max_path_len, max_subtokens_num, max_ast_depth, input_path, output_path):
19 | multiprocessing.Process.__init__(self)
20 | self.task_queue = task_queue
21 | self.parser = AstParser(max_contexts_num, max_path_len, max_subtokens_num, max_ast_depth, output_path)
22 | try:
23 | self.compdb = CompilationDatabase.fromDirectory(input_path)
24 | except CompilationDatabaseError:
25 | self.compdb = None
26 |
27 | def run(self):
28 | default_compile_args = []
29 |
30 | while self.parse_file(default_compile_args):
31 | pass
32 |
33 | self.save()
34 | return
35 |
36 | def save(self):
37 | self.parser.save()
38 |
39 | def parse_file(self, default_compile_args=[]):
40 | file_path = self.task_queue.get()
41 | if file_path is None:
42 | self.task_queue.task_done()
43 | return False
44 | # print('Parsing : {0} [{1}]'.format(file_path, os.getpid()))
45 | if not self.compdb:
46 | # print('Compilation database was not found in the input directory, using default args list')
47 | self.parser.parse([file_path] + default_compile_args)
48 | else:
49 | commands = self.compdb.getCompileCommands(file_path)
50 | if commands and len(commands) > 0:
51 | command = commands[0]
52 | cwd = os.getcwd()
53 | os.chdir(command.directory)
54 | args = []
55 | cmd_args = command.arguments
56 | next(cmd_args) # drop compiler executable path
57 | for arg in cmd_args:
58 | if arg == '-Xclang':
59 | next(cmd_args) # skip clang specific arguments
60 | elif arg == '-c':
61 | continue # drop input filename option
62 | elif arg == '-o':
63 | continue # drop output filename option
64 | elif os.path.isfile(arg) or os.path.isdir(arg):
65 | continue # drop file names
66 | elif is_object_file(arg):
67 | continue # drop file names
68 | else:
69 | args.append(arg)
70 | self.parser.parse(args, command.filename)
71 | os.chdir(cwd)
72 | self.task_queue.task_done()
73 | return True
74 |
--------------------------------------------------------------------------------
/src/requirements.txt:
--------------------------------------------------------------------------------
1 | bitarray
2 | lmdb
3 | clang
4 | decorator
5 | networkx
6 | pygraphviz
7 | tqdm
8 |
--------------------------------------------------------------------------------