├── .clang-format ├── .gitignore ├── .travis.yml ├── CMakeLists.txt ├── LICENSE ├── README.md ├── compy ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dataset.py │ ├── dataset_test.py │ ├── devmap.py │ └── devmap_test.py ├── models │ ├── __init__.py │ ├── graphs │ │ ├── __init__.py │ │ ├── pytorch_dgl_model.py │ │ ├── pytorch_dgl_model_test.py │ │ ├── pytorch_geom_model.py │ │ ├── pytorch_geom_model_test.py │ │ ├── tf │ │ │ ├── __init__.py │ │ │ ├── cell │ │ │ │ ├── __init__.py │ │ │ │ ├── prediction_cell.py │ │ │ │ └── prediction_cell_test.py │ │ │ ├── layer │ │ │ │ ├── __init__.py │ │ │ │ ├── embedding_layer.py │ │ │ │ ├── gnn_model_layer.py │ │ │ │ ├── propagation_model_layer.py │ │ │ │ └── propagation_model_layer_test.py │ │ │ ├── test_utils.py │ │ │ └── utils.py │ │ ├── tf_graph_model.py │ │ └── tf_graph_model_test.py │ ├── model.py │ └── seqs │ │ ├── __init__.py │ │ ├── tf_seq_model.py │ │ └── tf_seq_model_test.py └── representations │ ├── __init__.py │ ├── ast_graphs.py │ ├── ast_graphs_test.py │ ├── common.py │ ├── common_test.py │ ├── extractors │ ├── CMakeLists.txt │ ├── __init__.py │ ├── clang_ast │ │ ├── CMakeLists.txt │ │ ├── clang_extractor.cc │ │ ├── clang_extractor.h │ │ ├── clang_extractor_test.cc │ │ ├── clang_graph_frontendaction.cc │ │ ├── clang_graph_frontendaction.h │ │ ├── clang_seq_frontendaction.cc │ │ └── clang_seq_frontendaction.h │ ├── common │ │ ├── clang_driver.cc │ │ ├── clang_driver.h │ │ ├── clang_driver_test.cc │ │ ├── common_test.h │ │ └── visitor.h │ ├── extractors.cc │ ├── extractors_test.py │ └── llvm_ir │ │ ├── CMakeLists.txt │ │ ├── llvm_extractor.cc │ │ ├── llvm_extractor.h │ │ ├── llvm_extractor_test.cc │ │ ├── llvm_graph_funcinfo.cc │ │ ├── llvm_graph_funcinfo.h │ │ ├── llvm_graph_pass.cc │ │ ├── llvm_graph_pass.h │ │ ├── llvm_pass_test.cc │ │ ├── llvm_seq_pass.cc │ │ └── llvm_seq_pass.h │ ├── llvm_graphs.py │ ├── llvm_graphs_test.py │ ├── llvm_seq.py │ ├── llvm_seq_test.py │ ├── syntax_seq.py │ └── syntax_seq_test.py ├── docs └── img │ ├── flow-overview.png │ └── representation-examples.png ├── examples └── devmap_exploration.py ├── install_deps.sh ├── setup.py └── tests ├── __init__.py └── test_runner.py /.clang-format: -------------------------------------------------------------------------------- 1 | # Run manually to reformat a file: 2 | # clang-format -i --style=file 3 | BasedOnStyle: Google 4 | IndentWidth: 2 5 | TabWidth: 2 6 | IncludeBlocks: Regroup 7 | IncludeCategories: 8 | - Regex: '^["<](llvm|llvm-c|clang|clang-c)/' 9 | Priority: 2 10 | - Regex: '^<' 11 | Priority: 1 12 | - Regex: '.*' 13 | Priority: 3 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Jetbrains IDEs 2 | .idea 3 | 4 | # VIM 5 | *.swp 6 | 7 | # C++ build artifacts 8 | build 9 | cmake-build-debug 10 | bin 11 | *.so 12 | 13 | # Python build artifacts 14 | __pycache__ 15 | *.pyc 16 | 17 | # Python package artifacts 18 | dist 19 | *.egg-info 20 | .eggs 21 | 22 | # Pytest runtime artifacts 23 | .coverage 24 | coverage.xml 25 | htmlcov 26 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | jobs: 4 | include: 5 | - python: 3.7 6 | dist: xenial 7 | sudo: true 8 | 9 | - python: 3.7 10 | dist: bionic 11 | sudo: true 12 | 13 | install: 14 | - ./install_deps.sh cpu 15 | 16 | script: 17 | - python setup.py test 18 | 19 | after_success: 20 | - bash <(curl -s https://codecov.io/bash) 21 | 22 | notifications: 23 | email: false 24 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12) 2 | project(gnns4code_native) 3 | 4 | set(CMAKE_CXX_STANDARD 14) 5 | 6 | find_package(LLVM 10) 7 | 8 | # prefer Clang matching LLVM 9 | find_package(Clang HINTS "${LLVM_CMAKE_DIR}/../clang") 10 | 11 | include(FetchContent) 12 | 13 | FetchContent_Declare( 14 | googletest 15 | GIT_REPOSITORY https://github.com/google/googletest.git 16 | GIT_TAG release-1.10.0 17 | ) 18 | FetchContent_GetProperties(googletest) 19 | if(NOT googletest_POPULATED) 20 | FetchContent_Populate(googletest) 21 | add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR}) 22 | endif() 23 | 24 | FetchContent_Declare( 25 | pybind11 26 | GIT_REPOSITORY https://github.com/pybind/pybind11 27 | GIT_TAG v2.5.0 28 | ) 29 | FetchContent_GetProperties(pybind11) 30 | if(NOT pybind11_POPULATED) 31 | FetchContent_Populate(pybind11) 32 | add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR}) 33 | endif() 34 | 35 | add_subdirectory(compy/representations/extractors) 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | # ComPy-Learn 6 | [![Build Status](https://app.travis-ci.com/tud-ccc/compy-learn.svg?branch=master)](https://app.travis-ci.com/tud-ccc/compy-learn) 7 | [![codecov](https://codecov.io/gh/tud-ccc/compy-learn/branch/master/graph/badge.svg)](https://codecov.io/gh/tud-ccc/compy-learn) 8 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 9 | 10 | ComPy-Learn is a framework for defining and exploring program representations for machine learning on source code (ML4CODE) tasks. 11 | While the special focus is on compiler optimization tasks, ComPy-Learn can also be used in other domains like software engineering, or systems security. 12 | 13 | ## Project goals 14 | * **Exploration of best-performing code representation and model:** Depending on the task, different representations and models have shown to be differently suitable. Finding the best-performing one is not obvious and currently requires empirical evaluation. ComPy-Learn provides a common framework for that - evaluating different representations on a given task to find the best-performing one. 15 | * **Design and discovery of new representations:** Custom, task-specific representations of code can improve a models performance. However, extracting representations of program code is a tedious endeavor and requires low-level development with compiler tools. We aim to take away this burden by enabling to define program representations with a simple, high-level programming interface. This allows easier design and faster iterations. 16 | * **Common tools, evaluation pipeline and datasets:** Several promising representations and models to learn embeddings from those representations have been proposed in recent time. However, they use unique tools and pipelines for evaluations, making further comparisons to those methods time-consuming and difficult. ComPy-Learn provides a common framework for representations, models, and datasets and allows for evaluation of their combinations. Implementing a novel representation and model in this framework enables researches to do an effort-less and complete evaluation on the one hand, on the other hand contributes another widely applicable method to the community. 17 | 18 | ## Design 19 | ComPy-Learn's main components are shown in the pipeline below: 20 |

21 | 22 |

23 | 24 | * `compy.representation` allows the user to define custom representations (such as the ones from published work) of source code based on available semantic compiler-internal information, currently from the Clang/LLVM framework. Both, linear and graph representations of code are supported. 25 | * `compy.model` contains ML-models (in fact, it provides connectors to well-established model libraries) that embed the representations into vectors and finally output a prediction. 26 | * `compy.dataset` contains datasets of source code for evaluation, along with helper functions that allow integration of new datasets. 27 | 28 | 29 | ## Supported representations 30 | Currently, the following representations and models from published work are implemented in this framework: 31 | * [Cummins, Chris, et al. "End-to-end deep learning of optimization heuristics."](https://ieeexplore.ieee.org/document/8091247) 2017 26th International Conference on Parallel Architectures and Compilation Techniques (PACT). IEEE, 2017. 32 | * [Barchi, Francesco, et al. "Code Mapping in Heterogeneous Platforms Using Deep Learning and LLVM-IR."](https://dl.acm.org/doi/10.1145/3316781.3317789) 2019 56th ACM/IEEE Design Automation Conference (DAC). IEEE, 2019. 33 | * [Brauckmann, Alexander, et al. "Compiler-based graph representations for deep learning models of code."](https://dl.acm.org/doi/abs/10.1145/3377555.3377894) Proceedings of the 29th International Conference on Compiler Construction. ACM, 2020. 34 | * [Cummins, Chris, et al. "ProGraML: Graph-based Deep Learning for Program Optimization and Analysis."](https://arxiv.org/abs/2003.10536) arXiv preprint arXiv:2003.10536 (2020). 35 | 36 | ## Installation 37 | 38 | We supply an installation script that automates the build, test, and installation process. The script currently supports the platforms listed below. Because the process builds ComPy-Learn from its sources, other platforms can be used with a bit of manual installation effort. 39 | 40 | Platform | Build status 41 | --- | --- 42 | Ubuntu 16.04 | [![Build Status](https://app.travis-ci.com/tud-ccc/compy-learn.svg?branch=master)](https://app.travis-ci.com/tud-ccc/compy-learn) 43 | Ubuntu 18.04 | [![Build Status](https://app.travis-ci.com/tud-ccc/compy-learn.svg?branch=master)](https://app.travis-ci.com/tud-ccc/compy-learn) 44 | Ubuntu 20.04 | [![Build Status](https://app.travis-ci.com/tud-ccc/compy-learn.svg?branch=master)](https://app.travis-ci.com/tud-ccc/compy-learn) 45 | 46 | To get started on one of the supported platforms, we suggest to first create a virtual environment, then run: 47 | ``` 48 | ./install_deps.sh ${CUDA} 49 | ``` 50 | whereas `${CUDA}` needs to be `cpu`, `cu92`, `cu100` or `cu102`, depending on your machine's capabilities. 51 | 52 | After successful installation, ComPy-Learn should be compiled and tested. To do so, please run: 53 | ``` 54 | python setup.py test 55 | ``` 56 | 57 | Finally, install ComPy-Learn in order to use it in your project: 58 | ``` 59 | python setup.py install 60 | ``` 61 | 62 | An example exploration is located in `examples/devmap_exploration.py`. 63 | 64 | 65 | ## Publications 66 | * [Brauckmann, Alexander, et al. "ComPy-Learn: A Toolbox for Exploring Machine Learning Representations for Compilers."](https://ieeexplore.ieee.org/document/9232946) 2020 Forum for Specification and Design Languages (FDL). IEEE, 2020. 67 | -------------------------------------------------------------------------------- /compy/__init__.py: -------------------------------------------------------------------------------- 1 | import compy.datasets 2 | import compy.models 3 | import compy.representations 4 | -------------------------------------------------------------------------------- /compy/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | from .devmap import OpenCLDevmapDataset 3 | -------------------------------------------------------------------------------- /compy/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import urllib.request 4 | import zipfile 5 | 6 | from appdirs import user_data_dir 7 | from git import Repo 8 | 9 | 10 | class Dataset(object): 11 | def __init__(self): 12 | self.name = self.__class__.__name__ 13 | 14 | app_dir = user_data_dir(appname="compy-Learn", version="1.0") 15 | self.dataset_dir = os.path.join(app_dir, self.name) 16 | os.makedirs(self.dataset_dir, exist_ok=True) 17 | 18 | self.content_dir = os.path.join(self.dataset_dir, "content") 19 | 20 | def preprocess(self, builder, visitor): 21 | raise NotImplementedError 22 | 23 | def download_http_and_extract(self, url): 24 | archive_file = os.path.join(self.dataset_dir, "content.zip") 25 | 26 | if not (os.path.isfile(archive_file) or os.path.isdir(self.content_dir)): 27 | urllib.request.urlretrieve(url, archive_file) 28 | 29 | os.makedirs(self.content_dir, exist_ok=True) 30 | with zipfile.ZipFile(archive_file, "r") as f: 31 | f.extractall(self.content_dir) 32 | 33 | return archive_file, self.content_dir 34 | 35 | def clone_git(self, uri): 36 | if not os.path.isdir(self.content_dir): 37 | Repo.clone_from(uri, self.content_dir) 38 | 39 | return self.content_dir 40 | -------------------------------------------------------------------------------- /compy/datasets/dataset_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import shutil 4 | 5 | from compy.datasets import dataset 6 | 7 | 8 | @pytest.fixture 9 | def dataset_fixture(): 10 | class TestDataset(dataset.Dataset): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | ds = TestDataset() 15 | 16 | yield ds 17 | 18 | shutil.rmtree(ds.dataset_dir) 19 | 20 | 21 | def test_dataset_has_correct_name(dataset_fixture): 22 | assert dataset_fixture.name == "TestDataset" 23 | 24 | 25 | def test_app_dir_is_initialized(dataset_fixture): 26 | assert type(dataset_fixture.dataset_dir) is str 27 | 28 | 29 | def test_app_dir_exists(dataset_fixture): 30 | assert os.path.isdir(dataset_fixture.dataset_dir) 31 | 32 | 33 | def test_download_http_and_extract(dataset_fixture): 34 | zip_file, content_dir = dataset_fixture.download_http_and_extract( 35 | "http://wwwpub.zih.tu-dresden.de/~s9602232/test.zip" 36 | ) 37 | 38 | assert os.path.isfile(zip_file) 39 | assert os.path.isdir(content_dir) 40 | 41 | 42 | def test_clone_git(dataset_fixture): 43 | content_dir = dataset_fixture.clone_git( 44 | "https://github.com/alexanderb14/build-webrtc-builds.git" 45 | ) 46 | 47 | assert os.path.isdir(content_dir) 48 | -------------------------------------------------------------------------------- /compy/datasets/devmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from tqdm import tqdm 4 | 5 | from compy.datasets import dataset 6 | 7 | 8 | class OpenCLDevmapDataset(dataset.Dataset): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | uri = "http://wwwpub.zih.tu-dresden.de/~s9602232/devmap.zip" 13 | self.download_http_and_extract(uri) 14 | 15 | self.additional_include_dirs = [ 16 | os.path.join(self.content_dir, "support/libclc") 17 | ] 18 | 19 | def preprocess(self, builder, visitor, benchmark_suites=None): 20 | suite_specifics = { 21 | "amd-app-sdk-3.0": {"subdir": "samples/opencl/cl/1.x"}, 22 | "npb-3.3": {"subdir": ""}, 23 | "nvidia-4.2": {"subdir": "OpenCL/src", "benchmark_name_prefix": "ocl"}, 24 | "parboil-0.2": {"subdir": "benchmarks"}, 25 | "polybench-gpu-1.0": { 26 | "subdir": "OpenCL", 27 | "remappings": { 28 | "2DConvolution": "2DCONV", 29 | "2mm": "2MM", 30 | "3DConvolution": "3DCONV", 31 | "3mm": "3MM", 32 | "atax": "ATAX", 33 | "bicg": "BICG", 34 | "correlation": "CORR", 35 | "covariance": "COVAR", 36 | "gemm": "GEMM", 37 | "gesummv": "GESUMMV", 38 | "gramschmidt": "GRAMSCHM", 39 | "mvt": "MVT", 40 | "syr2k": "SYR2K", 41 | "syrk": "SYRK", 42 | }, 43 | }, 44 | "rodinia-3.1": {"subdir": "opencl",}, 45 | "shoc-1.1.5": {"subdir": "src/opencl/level1"}, 46 | } 47 | if benchmark_suites is None: 48 | benchmark_suites = suite_specifics.keys() 49 | 50 | opencl_header = str.encode( 51 | '#include "' + self.content_dir + '/support/opencl-shim.h"\n' 52 | ) 53 | basedir = os.path.join(self.content_dir, "src") 54 | 55 | # Load cgo17 dataset. 56 | df = pd.read_csv(os.path.join(self.content_dir, "data", "cgo17-amd.csv")) 57 | 58 | # Remove preprocessing data from it. 59 | for column in [ 60 | "Unnamed: 0", 61 | "comp", 62 | "rational", 63 | "mem", 64 | "localmem", 65 | "coalesced", 66 | "atomic", 67 | "seq", 68 | "src", 69 | ]: 70 | del df[column] 71 | 72 | # Split benchmark name for better identification. 73 | for idx, row in df.iterrows(): 74 | b = row["benchmark"] 75 | 76 | df.loc[idx, "function_name"] = b.split("-")[-1] 77 | df.loc[idx, "benchmark_name"] = b.split("-")[-2] 78 | df.loc[idx, "suite_name"] = "-".join(b.split("-")[0:-2]) 79 | 80 | del df["benchmark"] 81 | df = df[df["suite_name"].isin(benchmark_suites)] 82 | 83 | # Build a map of files to process: {file_name: further infos e.g. aux inputs, mappings} 84 | to_process = {} 85 | for idx, row in df.iterrows(): 86 | suite_data = suite_specifics[row["suite_name"]] 87 | 88 | # Build benchmark name 89 | benchmark_name = row["benchmark_name"] 90 | if "remappings" in suite_data: 91 | benchmark_name = suite_data["remappings"][row["benchmark_name"]] 92 | 93 | if "benchmark_name_prefix" in suite_data: 94 | benchmark_name = suite_data["benchmark_name_prefix"] + benchmark_name 95 | 96 | if row["suite_name"] == "shoc-1.1.5": 97 | benchmark_name = benchmark_name.lower() 98 | 99 | # Build subdir 100 | subdir = suite_data["subdir"] 101 | if row["suite_name"] == "shoc-1.1.5": 102 | if benchmark_name == "s3d": 103 | subdir = os.path.join(subdir, "..", "level2") 104 | 105 | # Search for CL file 106 | # 1. Build search path 107 | bench_dir = os.path.join(basedir, row["suite_name"], subdir, benchmark_name) 108 | if row["suite_name"] == "parboil-0.2": 109 | bench_dir = os.path.join(bench_dir, "src", "opencl_base") 110 | assert os.path.isdir(bench_dir) 111 | 112 | # 2. Search. 113 | cls = [] 114 | for folder, subfolders, files in os.walk(bench_dir): 115 | for file in files: 116 | if file.endswith(".cl"): 117 | cls.append(os.path.join(os.path.abspath(folder), file)) 118 | assert len(cls) >= 1 119 | if len(cls) > 1: 120 | if "filename_matcher" in suite_data: 121 | for cls_it in cls: 122 | if suite_data["filename_matcher"] in os.path.basename(cls_it): 123 | cls = [cls_it] 124 | break 125 | 126 | for bench_file in cls: 127 | assert os.path.isfile(bench_file) 128 | 129 | df.loc[idx, "bench_file"] = bench_file 130 | 131 | with open(bench_file, "rb") as file: 132 | source_code = file.read() 133 | 134 | # Additional data 135 | # - Include dirs 136 | additional_include_dir = bench_dir 137 | 138 | # Add to to_process. 139 | file_data = ( 140 | bench_file, 141 | source_code, 142 | additional_include_dir, 143 | row["suite_name"], 144 | row["benchmark_name"], 145 | ) 146 | if file_data not in to_process: 147 | to_process[file_data] = [] 148 | 149 | function_data = ( 150 | row["function_name"], 151 | row["transfer"], 152 | row["wgsize"], 153 | row["oracle"], 154 | ) 155 | to_process[file_data].append(function_data) 156 | 157 | # Process the map of files 158 | processed = {} 159 | for file_data in tqdm(to_process, desc="Source Code -> IR+"): 160 | ( 161 | bench_file, 162 | source_code, 163 | additional_include_dir, 164 | suite_name, 165 | benchmark_name, 166 | ) = file_data 167 | 168 | extractionInfo = builder.string_to_info( 169 | opencl_header + source_code, additional_include_dir 170 | ) 171 | 172 | for functionInfo in extractionInfo.functionInfos: 173 | processed[ 174 | (suite_name, benchmark_name, functionInfo.name) 175 | ] = functionInfo 176 | 177 | # Map to dataset and extract representations 178 | samples = {} 179 | for file_data, function_datas in tqdm( 180 | to_process.items(), desc="IR+ -> ML Representation" 181 | ): 182 | ( 183 | bench_file, 184 | source_code, 185 | additional_include_dir, 186 | suite_name, 187 | benchmark_name, 188 | ) = file_data 189 | 190 | for function_data in function_datas: 191 | function_name, transfer, wgsize, label = function_data 192 | 193 | item_info = (suite_name, benchmark_name, function_name) 194 | samples[ 195 | item_info + (transfer, wgsize, label) 196 | ] = builder.info_to_representation(processed[item_info], visitor) 197 | 198 | print("Size of dataset:", len(samples)) 199 | print("Number of unique tokens:", builder.num_tokens()) 200 | builder.print_tokens() 201 | 202 | return { 203 | "samples": [ 204 | { 205 | "info": info, 206 | "x": {"code_rep": sample, "aux_in": [info[3], info[4]]}, 207 | "y": 0 if info[5] == "CPU" else 1, 208 | } 209 | for info, sample in samples.items() 210 | ], 211 | "num_types": builder.num_tokens(), 212 | } 213 | -------------------------------------------------------------------------------- /compy/datasets/devmap_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from compy.datasets import OpenCLDevmapDataset 4 | from compy.representations import RepresentationBuilder 5 | 6 | 7 | class objectview(object): 8 | def __init__(self, d): 9 | self.__dict__ = d 10 | 11 | 12 | class TestBuilder(RepresentationBuilder): 13 | def string_to_info(self, src, additional_include_dir): 14 | functionInfo = objectview({"name": "xyz"}) 15 | return objectview({"functionInfos": [functionInfo]}) 16 | 17 | def info_to_representation(self, info, visitor): 18 | return "Repr" 19 | 20 | 21 | @pytest.fixture 22 | def devmap_fixture(): 23 | ds = OpenCLDevmapDataset() 24 | yield ds 25 | 26 | 27 | def d_test_preprocess(devmap_fixture): 28 | builder = TestBuilder() 29 | devmap_fixture.preprocess(builder, None) 30 | -------------------------------------------------------------------------------- /compy/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .graphs import * 2 | from .seqs import * 3 | -------------------------------------------------------------------------------- /compy/models/graphs/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch_geom_model import GnnPytorchGeomModel 2 | from .pytorch_dgl_model import GnnPytorchDGLModel 3 | from .tf_graph_model import GnnTfModel 4 | -------------------------------------------------------------------------------- /compy/models/graphs/pytorch_dgl_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from torch import nn 5 | from torch.optim import Adam 6 | 7 | from compy.models.model import Model 8 | 9 | 10 | class Net(nn.Module): 11 | def __init__(self, config): 12 | from dgl.nn.pytorch import ( 13 | GatedGraphConv, 14 | ) # Prevents DGL from clashing with TensorFlow backend 15 | from dgl.nn.pytorch import GlobalAttentionPooling 16 | 17 | super(Net, self).__init__() 18 | 19 | annotation_size = config["hidden_size_orig"] 20 | self.annotation_size = annotation_size 21 | hidden_size = config["gnn_h_size"] 22 | n_steps = config["num_timesteps"] 23 | n_etypes = config["num_edge_types"] 24 | num_cls = 2 25 | 26 | self.reduce_layer = nn.Linear(annotation_size, hidden_size) 27 | self.ggnn = GatedGraphConv( 28 | in_feats=hidden_size, 29 | out_feats=hidden_size, 30 | n_steps=n_steps, 31 | n_etypes=n_etypes, 32 | ) 33 | 34 | pooling_gate_nn = nn.Linear(hidden_size * 2, 1) 35 | self.pooling = GlobalAttentionPooling(pooling_gate_nn) 36 | self.output_layer = nn.Linear(hidden_size * 2, num_cls) 37 | 38 | self.loss_fn = nn.CrossEntropyLoss() 39 | 40 | def forward(self, graph, labels=None): 41 | etypes = graph.edata.pop("type") 42 | annotation = graph.ndata.pop("annotation").float() 43 | assert annotation.size()[-1] == self.annotation_size 44 | 45 | annotation = self.reduce_layer(annotation) 46 | 47 | out = self.ggnn(graph, annotation, etypes) 48 | out = torch.cat([out, annotation], -1) 49 | out = self.pooling(graph, out) 50 | 51 | logits = self.output_layer(out) 52 | preds = torch.argmax(logits, -1) 53 | 54 | if labels is not None: 55 | loss = self.loss_fn(logits, labels) 56 | return loss, preds 57 | return preds 58 | 59 | 60 | class GnnPytorchDGLModel(Model): 61 | def __init__(self, config=None, num_types=None): 62 | if not config: 63 | config = { 64 | "num_timesteps": 4, 65 | "hidden_size_orig": num_types, 66 | "gnn_h_size": 32, 67 | "gnn_m_size": 2, 68 | "num_edge_types": 4, 69 | "learning_rate": 0.001, 70 | "batch_size": 64, 71 | "num_epochs": 1000, 72 | } 73 | super().__init__(config) 74 | 75 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 76 | 77 | self.model = Net(config) 78 | self.model = self.model.to(self.device) 79 | 80 | def __process_data(self, data): 81 | return [ 82 | { 83 | "nodes": data["x"]["code_rep"].get_node_list(), 84 | "edges": data["x"]["code_rep"].get_edge_list(), 85 | "aux_in": data["x"]["aux_in"], 86 | "label": data["y"], 87 | } 88 | for data in data 89 | ] 90 | 91 | def __build_dgl_graph_and_labels(self, batch_graphs): 92 | import dgl # Import DGL locally to prevent clashing with TensorFlow import 93 | 94 | dgl_graphs = [] 95 | labels = [] 96 | for batch_graph in batch_graphs: 97 | # Graph 98 | g = dgl.DGLGraph() 99 | 100 | # - nodes 101 | g.add_nodes(len(batch_graph["nodes"])) 102 | g.ndata["annotation"] = torch.zeros( 103 | [len(batch_graph["nodes"]), self.config["hidden_size_orig"]], 104 | dtype=torch.long, 105 | ) 106 | 107 | # -edges 108 | edge_types = [] 109 | for edge in batch_graph["edges"]: 110 | g.add_edge(edge[0], edge[2]) 111 | edge_types.append(edge[1]) 112 | g.edata["type"] = torch.tensor(edge_types, dtype=torch.long) 113 | 114 | dgl_graphs.append(g) 115 | 116 | # Label 117 | labels.append(batch_graph["label"]) 118 | 119 | # Put small graphs into a large graph with disconnected subgraphs 120 | dgl_graph = dgl.batch(dgl_graphs) 121 | 122 | labels = torch.tensor(labels, dtype=torch.long) 123 | 124 | dgl_graph = dgl_graph.to(self.device) 125 | labels = labels.to(self.device) 126 | 127 | return dgl_graph, labels 128 | 129 | def _train_init(self, data_train, data_valid): 130 | self.opt = Adam(self.model.parameters(), lr=self.config["learning_rate"]) 131 | 132 | return self.__process_data(data_train), self.__process_data(data_valid) 133 | 134 | def _train_with_batch(self, batch): 135 | g, labels = self.__build_dgl_graph_and_labels(batch) 136 | 137 | self.model.train() 138 | self.opt.zero_grad() 139 | 140 | loss, pred = self.model(g, labels) 141 | loss.backward() 142 | self.opt.step() 143 | 144 | train_accuracy = ( 145 | np.equal( 146 | labels.cpu().data.numpy().tolist(), pred.cpu().data.numpy().tolist() 147 | ) 148 | .astype(np.float) 149 | .tolist() 150 | ) 151 | train_accuracy = sum(train_accuracy) / len(train_accuracy) 152 | train_loss = loss / len(batch) 153 | 154 | return train_loss, train_accuracy 155 | 156 | def _test_init(self): 157 | self.model.eval() 158 | 159 | def _predict_with_batch(self, batch): 160 | g, labels = self.__build_dgl_graph_and_labels(batch) 161 | 162 | with torch.no_grad(): 163 | loss, pred = self.model(g, labels) 164 | 165 | valid_accuracy = ( 166 | np.equal( 167 | labels.cpu().data.numpy().tolist(), pred.cpu().data.numpy().tolist() 168 | ) 169 | .astype(np.float) 170 | .tolist() 171 | ) 172 | valid_accuracy = sum(valid_accuracy) / len(valid_accuracy) 173 | 174 | return valid_accuracy, pred 175 | -------------------------------------------------------------------------------- /compy/models/graphs/pytorch_dgl_model_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | from compy.models.graphs.pytorch_dgl_model import GnnPytorchDGLModel 4 | from compy.representations.common import Graph 5 | 6 | 7 | def test_model(): 8 | dummy_graph = nx.MultiDiGraph() 9 | dummy_graph.add_node("n1", attr="a") 10 | dummy_graph.add_node("n2", attr="b") 11 | dummy_graph.add_node("n3", attr="c") 12 | dummy_graph.add_edge("n1", "n2", attr="dummy") 13 | dummy_graph.add_edge("n2", "n3", attr="dummy") 14 | 15 | config = { 16 | "num_timesteps": 2, 17 | "hidden_size_orig": len(dummy_graph), 18 | "gnn_h_size": 4, 19 | "gnn_m_size": 2, 20 | "num_edge_types": 1, 21 | "learning_rate": 0.001, 22 | "batch_size": 4, 23 | "num_epochs": 1, 24 | } 25 | model = GnnPytorchDGLModel(config=config) 26 | 27 | data = [ 28 | { 29 | "x": { 30 | "code_rep": Graph(dummy_graph, ["a", "b", "c"], ["dummy"]), 31 | "aux_in": [0, 0], 32 | }, 33 | "y": 0, 34 | } 35 | ] 36 | model.train(data, data) 37 | -------------------------------------------------------------------------------- /compy/models/graphs/pytorch_geom_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from torch import nn 6 | from torch_geometric.nn import GatedGraphConv 7 | from torch_geometric.nn import GlobalAttention 8 | from torch_geometric.data import Data 9 | from torch_geometric.data import DataLoader 10 | 11 | from compy.models.model import Model 12 | 13 | 14 | class Net(torch.nn.Module): 15 | def __init__(self, config): 16 | super(Net, self).__init__() 17 | 18 | annotation_size = config["hidden_size_orig"] 19 | hidden_size = config["gnn_h_size"] 20 | n_steps = config["num_timesteps"] 21 | num_cls = 2 22 | 23 | self.reduce = nn.Linear(annotation_size, hidden_size) 24 | self.conv = GatedGraphConv(hidden_size, n_steps) 25 | self.agg = GlobalAttention(nn.Linear(hidden_size, 1), nn.Linear(hidden_size, 2)) 26 | self.lin = nn.Linear(hidden_size, num_cls) 27 | 28 | def forward( 29 | self, graph, 30 | ): 31 | x, edge_index, batch = graph.x, graph.edge_index, graph.batch 32 | 33 | x = self.reduce(x) 34 | 35 | x = self.conv(x, edge_index) 36 | x = self.agg(x, batch) 37 | 38 | x = F.log_softmax(x, dim=1) 39 | 40 | return x 41 | 42 | 43 | class GnnPytorchGeomModel(Model): 44 | def __init__(self, config=None, num_types=None): 45 | if not config: 46 | config = { 47 | "num_timesteps": 4, 48 | "hidden_size_orig": num_types, 49 | "gnn_h_size": 32, 50 | "gnn_m_size": 2, 51 | "learning_rate": 0.001, 52 | "batch_size": 64, 53 | "num_epochs": 1000, 54 | } 55 | super().__init__(config) 56 | 57 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 58 | 59 | self.model = Net(config) 60 | self.model = self.model.to(self.device) 61 | 62 | def __process_data(self, data): 63 | return [ 64 | { 65 | "nodes": data["x"]["code_rep"].get_node_list(), 66 | "edges": data["x"]["code_rep"].get_edge_list(), 67 | "aux_in": data["x"]["aux_in"], 68 | "label": data["y"], 69 | } 70 | for data in data 71 | ] 72 | 73 | def __build_pg_graphs(self, batch_graphs): 74 | pg_graphs = [] 75 | 76 | for batch_graph in batch_graphs: 77 | # Graph 78 | # - nodes 79 | one_hot = np.zeros( 80 | (len(batch_graph["nodes"]), self.config["hidden_size_orig"]) 81 | ) 82 | one_hot[np.arange(len(batch_graph["nodes"])), batch_graph["nodes"]] = 1 83 | x = torch.tensor(one_hot, dtype=torch.float) 84 | 85 | # -edges 86 | edge_index, edge_features = [], [] 87 | for edge in batch_graph["edges"]: 88 | edge_index.append([edge[0], edge[2]]) 89 | edge_features.append([edge[1]]) 90 | edge_index = torch.tensor(edge_index, dtype=torch.long) 91 | edge_features = torch.tensor(edge_features, dtype=torch.long) 92 | 93 | graph = Data( 94 | x=x, 95 | edge_index=edge_index.t().contiguous(), 96 | edge_features=edge_features, 97 | y=batch_graph["label"], 98 | ) 99 | pg_graphs.append(graph) 100 | 101 | return pg_graphs 102 | 103 | def _train_init(self, data_train, data_valid): 104 | self.opt = torch.optim.Adam( 105 | self.model.parameters(), lr=self.config["learning_rate"] 106 | ) 107 | 108 | return self.__process_data(data_train), self.__process_data(data_valid) 109 | 110 | def _train_with_batch(self, batch): 111 | loss_sum = 0 112 | correct_sum = 0 113 | 114 | graphs = self.__build_pg_graphs(batch) 115 | loader = DataLoader(graphs, batch_size=999999) 116 | for data in loader: 117 | data = data.to(self.device) 118 | 119 | self.model.train() 120 | self.opt.zero_grad() 121 | 122 | pred = self.model(data) 123 | loss = F.nll_loss(pred, data.y) 124 | loss.backward() 125 | self.opt.step() 126 | 127 | loss_sum += loss 128 | correct_sum += pred.max(dim=1)[1].eq(data.y.view(-1)).sum().item() 129 | 130 | train_accuracy = correct_sum / len(loader.dataset) 131 | train_loss = loss_sum / len(loader.dataset) 132 | 133 | return train_loss, train_accuracy 134 | 135 | def _test_init(self): 136 | self.model.eval() 137 | 138 | def _predict_with_batch(self, batch): 139 | correct = 0 140 | 141 | graphs = self.__build_pg_graphs(batch) 142 | loader = DataLoader(graphs, batch_size=999999) 143 | for data in loader: 144 | data = data.to(self.device) 145 | 146 | with torch.no_grad(): 147 | pred = self.model(data) 148 | 149 | correct += pred.max(dim=1)[1].eq(data.y.view(-1)).sum().item() 150 | valid_accuracy = correct / len(loader.dataset) 151 | 152 | return valid_accuracy, pred 153 | -------------------------------------------------------------------------------- /compy/models/graphs/pytorch_geom_model_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | from compy.models.graphs.pytorch_geom_model import GnnPytorchGeomModel 4 | from compy.representations.common import Graph 5 | 6 | 7 | def test_model(): 8 | dummy_graph = nx.MultiDiGraph() 9 | dummy_graph.add_node("n1", attr="a") 10 | dummy_graph.add_node("n2", attr="b") 11 | dummy_graph.add_node("n3", attr="c") 12 | dummy_graph.add_edge("n1", "n2", attr="dummy") 13 | dummy_graph.add_edge("n2", "n3", attr="dummy") 14 | 15 | config = { 16 | "num_timesteps": 2, 17 | "hidden_size_orig": len(dummy_graph), 18 | "gnn_h_size": 4, 19 | "gnn_m_size": 2, 20 | "num_edge_types": 1, 21 | "learning_rate": 0.001, 22 | "batch_size": 4, 23 | "num_epochs": 1, 24 | } 25 | model = GnnPytorchGeomModel(config=config) 26 | 27 | data = [ 28 | { 29 | "x": { 30 | "code_rep": Graph(dummy_graph, ["a", "b", "c"], ["dummy"]), 31 | "aux_in": [0, 0], 32 | }, 33 | "y": 0, 34 | } 35 | ] 36 | model.train(data, data) 37 | -------------------------------------------------------------------------------- /compy/models/graphs/tf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tud-ccc/compy-learn/24a394bd1ddc109a37b348c3de440389fa9a1f23/compy/models/graphs/tf/__init__.py -------------------------------------------------------------------------------- /compy/models/graphs/tf/cell/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tud-ccc/compy-learn/24a394bd1ddc109a37b348c3de440389fa9a1f23/compy/models/graphs/tf/cell/__init__.py -------------------------------------------------------------------------------- /compy/models/graphs/tf/cell/prediction_cell.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from compy.models.graphs.tf import utils 4 | 5 | 6 | class PredictionCellState(object): 7 | """Holds the state / weights of a Prediction cell.""" 8 | 9 | def __init__(self, config): 10 | self.config = config 11 | 12 | h_size = self.config["gnn_h_size"] 13 | h_size_orig = self.config["hidden_size_orig"] 14 | m_size = self.config["gnn_m_size"] 15 | 16 | self.weights = {} 17 | 18 | self.weights["mlp_f_m"] = utils.MLP( 19 | h_size, 20 | h_size * m_size, 21 | self.config["prediction_cell"]["mlp_f_m_dims"], 22 | self.config["prediction_cell"]["mlp_f_m_activation"], 23 | "mlp_regression_transform", 24 | ) 25 | self.weights["mlp_g_m"] = utils.MLP( 26 | h_size + h_size_orig, 27 | h_size * m_size, 28 | self.config["prediction_cell"]["mlp_g_m_dims"], 29 | self.config["prediction_cell"]["mlp_g_m_activation"], 30 | "mlp_regression_gate", 31 | ) 32 | 33 | self.weights["mlp_reduce"] = utils.MLP( 34 | h_size * m_size, 35 | h_size * m_size, 36 | self.config["prediction_cell"]["mlp_reduce_dims"], 37 | self.config["prediction_cell"]["mlp_reduce_activation"], 38 | "mlp_reduce", 39 | ) 40 | 41 | offset = 0 42 | if config["with_aux_in"] == 1: 43 | offset = 2 44 | self.weights["mlp_reduce_after_aux_in_1"] = utils.MLP( 45 | h_size * m_size + offset, 46 | self.config["prediction_cell"]["mlp_reduce_after_aux_in_1_out_dim"], 47 | self.config["prediction_cell"]["mlp_reduce_after_aux_in_1_dims"], 48 | self.config["prediction_cell"]["mlp_reduce_after_aux_in_1_activation"], 49 | "mlp_reduce_after_aux_in_1", 50 | ) 51 | 52 | self.weights["mlp_reduce_after_aux_in_2"] = utils.MLP( 53 | self.config["prediction_cell"]["mlp_reduce_after_aux_in_1_out_dim"], 54 | self.config["prediction_cell"]["mlp_reduce_after_aux_in_2_out_dim"], 55 | self.config["prediction_cell"]["mlp_reduce_after_aux_in_2_dims"], 56 | self.config["prediction_cell"]["mlp_reduce_after_aux_in_2_activation"], 57 | "mlp_reduce_after_aux_in_2", 58 | ) 59 | 60 | self.weights["graph_model_out"] = utils.MLP( 61 | self.config["prediction_cell"]["mlp_reduce_out_dim"], 62 | self.config["prediction_cell"]["mlp_reduce_after_aux_in_2_out_dim"], 63 | [], 64 | "sigmoid", 65 | "graph_model_out", 66 | ) 67 | 68 | 69 | class PredictionCell(object): 70 | """Implementation of the Prediction cell.""" 71 | 72 | def __init__(self, config, enable_training, state, with_aux_in): 73 | self.config = config 74 | self.enable_training = enable_training 75 | self.state = state 76 | self.with_aux_in = with_aux_in 77 | 78 | self.ops = {} 79 | self.placeholders = {} 80 | 81 | def compute_predictions(self, embeddings: tf.Tensor) -> tf.Tensor: 82 | """ 83 | Make prediction based on node embeddings. 84 | 85 | Args: 86 | embeddings: Tensor of shape [b*v, h]. 87 | 88 | Returns: 89 | Tensor of predictions. 90 | """ 91 | # Placeholders 92 | # ######################################### 93 | # # Initial embeddings 94 | # self.placeholders['initial_embeddings'] = tf.compat.v1.placeholder(tf.float32, [None, self.config['gnn_h_size']], name='initial_embeddings') 95 | # initial_embeddings = self.placeholders['initial_embeddings'] 96 | 97 | # Is training (for batch norm) 98 | self.placeholders["is_training"] = tf.compat.v1.placeholder( 99 | tf.bool, None, name="is_training" 100 | ) 101 | is_training = self.placeholders["is_training"] 102 | 103 | # Embeddings to graph mappings 104 | self.placeholders["embeddings_to_graph_mappings"] = tf.compat.v1.placeholder( 105 | tf.int32, [1, None], name="embeddings_to_graph_mappings" 106 | ) 107 | embeddings_to_graph_mappings = self.placeholders["embeddings_to_graph_mappings"] 108 | num_graphs = tf.reduce_max(embeddings_to_graph_mappings) + 1 # Scalar 109 | 110 | # Input 111 | if self.with_aux_in: 112 | self.placeholders["aux_in"] = tf.compat.v1.placeholder( 113 | tf.float32, [None, 2], name="aux_in" 114 | ) 115 | aux_in = tf.cast(self.placeholders["aux_in"], dtype=tf.float32) 116 | 117 | # Graph Model 118 | # ######################################### 119 | gate_input = tf.concat( 120 | [embeddings, tf.expand_dims(self.initial_embeddings, 0)], axis=-1 121 | ) # [b*v, 2h + h_init] 122 | h_v_G = self.state.weights["mlp_f_m"](embeddings) # [b*v, 2h] 123 | g_v_G = self.state.weights["mlp_g_m"](gate_input) # [b*v, 2h] 124 | g_v_G = tf.nn.sigmoid(g_v_G) # [b*v, 2h] 125 | 126 | h_G = h_v_G * g_v_G # [b*v, 2h] 127 | 128 | # Sum up all nodes per graph 129 | h_G = tf.compat.v1.unsorted_segment_sum( 130 | data=h_G, segment_ids=embeddings_to_graph_mappings, num_segments=num_graphs 131 | ) # [b, 2h] 132 | h_G = self.state.weights["mlp_reduce"](h_G) # [b, 2] 133 | 134 | graphmodel_logits = self.state.weights["graph_model_out"](h_G) 135 | 136 | # Prediction Model 137 | # ######################################### 138 | if self.with_aux_in: 139 | h_G = tf.concat([h_G, aux_in], axis=-1) # [b, 2h + 2] 140 | 141 | h_G = tf.compat.v1.layers.batch_normalization(h_G, training=is_training) 142 | 143 | output = self.state.weights["mlp_reduce_after_aux_in_1"](h_G) # [b, 32] 144 | output = self.state.weights["mlp_reduce_after_aux_in_2"](output) # [b, 2] 145 | 146 | predictionmodel_logits = tf.nn.softmax(output) 147 | 148 | # Training 149 | if self.enable_training: 150 | # Input 151 | self.placeholders["labels"] = tf.compat.v1.placeholder( 152 | tf.int32, 153 | [None, self.config["prediction_cell"]["output_dim"]], 154 | name="labels", 155 | ) 156 | labels = tf.cast(self.placeholders["labels"], dtype=tf.float32) 157 | 158 | # Graph model 159 | graphmodel_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( 160 | labels=labels, logits=graphmodel_logits 161 | ) # [b, 2] 162 | 163 | # Prediction model 164 | predictionmodel_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( 165 | labels=labels, logits=predictionmodel_logits 166 | ) # [b, 2] 167 | 168 | # Loss 169 | loss = tf.reduce_sum(graphmodel_loss + 0.2 * predictionmodel_loss) # [b] 170 | self.ops["loss"] = loss 171 | 172 | self.ops["output"] = predictionmodel_logits 173 | -------------------------------------------------------------------------------- /compy/models/graphs/tf/cell/prediction_cell_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | tf.compat.v1.disable_eager_execution() 5 | 6 | from compy.models.graphs.tf.cell.prediction_cell import ( 7 | PredictionCell, 8 | PredictionCellState, 9 | ) 10 | from compy.models.graphs.tf.layer.gnn_model_layer import ( 11 | GGNNModelLayer, 12 | GGNNModelLayerState, 13 | ) 14 | from compy.models.graphs.tf import test_utils 15 | 16 | 17 | CONFIG = { 18 | "hidden_size_orig": 4, 19 | "num_timesteps": 4, 20 | "graph_rnn_cell": "gru", 21 | "gnn_h_size": 4, 22 | "gnn_m_size": 2, 23 | "num_node_types": 2, 24 | "num_edge_types": 2, 25 | "use_edge_bias": 0, 26 | "use_edge_msg_avg_aggregation": 0, 27 | "with_aux_in": 0, 28 | "prediction_cell": { 29 | "mlp_f_m_dims": [64, 64], 30 | "mlp_f_m_activation": "relu", 31 | "mlp_g_m_dims": [64, 64], 32 | "mlp_g_m_activation": "relu", 33 | "mlp_reduce_dims": [64, 64], 34 | "mlp_reduce_activation": "relu", 35 | "mlp_reduce_after_aux_in_1_dims": [], 36 | "mlp_reduce_after_aux_in_1_activation": "relu", 37 | "mlp_reduce_after_aux_in_1_out_dim": 32, 38 | "mlp_reduce_after_aux_in_2_dims": [], 39 | "mlp_reduce_after_aux_in_2_activation": "sigmoid", 40 | "mlp_reduce_after_aux_in_2_out_dim": 2, 41 | "output_dim": 2, 42 | "mlp_reduce_out_dim": 8, 43 | }, 44 | } 45 | 46 | 47 | # Helper functions 48 | def setup_deepgmg_cell_prediction_and_fetch_op(num_graphs: int): 49 | # Get data 50 | test_data = test_utils.get_test_data(CONFIG, num_graphs) 51 | 52 | embeddings_in = tf.compat.v1.placeholder( 53 | tf.float32, [None, test_data["h_dim"]], name="embeddings_in" 54 | ) 55 | 56 | # Create state 57 | ggnn_layer_state = GGNNModelLayerState(CONFIG) 58 | prediction_cell_state = PredictionCellState(CONFIG) 59 | 60 | # Create layer and propagate 61 | ggnn_layer = GGNNModelLayer(CONFIG, ggnn_layer_state) 62 | embeddings = ggnn_layer.compute_embeddings(embeddings_in) 63 | 64 | # Create cell and predict 65 | prediction_cell = PredictionCell( 66 | CONFIG, False, prediction_cell_state, CONFIG["with_aux_in"] 67 | ) 68 | prediction_cell.initial_embeddings = tf.compat.v1.placeholder( 69 | tf.float32, [None, CONFIG["hidden_size_orig"]], name="embeddings_in" 70 | ) 71 | prediction_cell.compute_predictions(embeddings) 72 | 73 | with tf.compat.v1.Session() as session: 74 | session.run(tf.compat.v1.global_variables_initializer()) 75 | 76 | fetch_list = [prediction_cell.ops["output"]] 77 | feed_dict = { 78 | ggnn_layer.placeholders["adjacency_lists"][0]: test_data["adjacency_lists"][ 79 | 0 80 | ], 81 | ggnn_layer.placeholders["adjacency_lists"][1]: test_data["adjacency_lists"][ 82 | 1 83 | ], 84 | prediction_cell.initial_embeddings: test_data["node_features"], 85 | embeddings_in: test_data["node_features"], 86 | prediction_cell.placeholders["embeddings_to_graph_mappings"]: test_data[ 87 | "embeddings_to_graph_mappings" 88 | ], 89 | prediction_cell.placeholders["is_training"]: False, 90 | } 91 | 92 | result = session.run(fetch_list, feed_dict=feed_dict) 93 | 94 | return result 95 | 96 | 97 | def setup_deepgmg_cell_training_and_fetch_op(num_graphs: int): 98 | # Get data 99 | test_data = test_utils.get_test_data(CONFIG, num_graphs) 100 | 101 | embeddings_in = tf.compat.v1.placeholder( 102 | tf.float32, [None, test_data["h_dim"]], name="embeddings_in" 103 | ) 104 | 105 | # Create state 106 | ggnn_layer_state = GGNNModelLayerState(CONFIG) 107 | prediction_cell_state = PredictionCellState(CONFIG) 108 | 109 | # Create layer and propagate 110 | ggnn_layer = GGNNModelLayer(CONFIG, ggnn_layer_state) 111 | embeddings = ggnn_layer.compute_embeddings(embeddings_in) 112 | 113 | # Create cell and predict 114 | prediction_cell = PredictionCell( 115 | CONFIG, True, prediction_cell_state, CONFIG["with_aux_in"] 116 | ) 117 | prediction_cell.initial_embeddings = tf.compat.v1.placeholder( 118 | tf.float32, [None, CONFIG["hidden_size_orig"]], name="embeddings_in" 119 | ) 120 | prediction_cell.compute_predictions(embeddings) 121 | 122 | with tf.compat.v1.Session() as session: 123 | session.run(tf.compat.v1.global_variables_initializer()) 124 | 125 | fetch_list = [prediction_cell.ops["loss"]] 126 | feed_dict = { 127 | ggnn_layer.placeholders["adjacency_lists"][0]: test_data["adjacency_lists"][ 128 | 0 129 | ], 130 | ggnn_layer.placeholders["adjacency_lists"][1]: test_data["adjacency_lists"][ 131 | 1 132 | ], 133 | prediction_cell.initial_embeddings: test_data["node_features"], 134 | embeddings_in: test_data["node_features"], 135 | prediction_cell.placeholders["embeddings_to_graph_mappings"]: test_data[ 136 | "embeddings_to_graph_mappings" 137 | ], 138 | prediction_cell.placeholders["labels"]: test_data["labels"], 139 | prediction_cell.placeholders["is_training"]: True, 140 | } 141 | 142 | result = session.run(fetch_list, feed_dict=feed_dict) 143 | 144 | return result 145 | 146 | 147 | # Prediction Tests 148 | def test_prediction_cell_1_graph(): 149 | result = setup_deepgmg_cell_prediction_and_fetch_op(1) 150 | 151 | assert isinstance(result[0], np.ndarray) 152 | 153 | 154 | def test_prediction_cell_2_graphs(): 155 | result = setup_deepgmg_cell_prediction_and_fetch_op(2) 156 | 157 | assert isinstance(result[0], np.ndarray) 158 | assert result[0].shape[0] == 2 159 | 160 | 161 | # Training Tests 162 | # 1 Graph 163 | def test_training_cell_1_graph(): 164 | result = setup_deepgmg_cell_training_and_fetch_op(1) 165 | 166 | 167 | def test_training_cell_2_graphs(): 168 | result = setup_deepgmg_cell_training_and_fetch_op(2) 169 | -------------------------------------------------------------------------------- /compy/models/graphs/tf/layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tud-ccc/compy-learn/24a394bd1ddc109a37b348c3de440389fa9a1f23/compy/models/graphs/tf/layer/__init__.py -------------------------------------------------------------------------------- /compy/models/graphs/tf/layer/embedding_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from compy.models.graphs.tf import utils 4 | from compy.models.graphs.tf.layer.propagation_model_layer import PropagationModelLayer 5 | 6 | 7 | class EmbeddingLayerState(object): 8 | """Holds the state / weights of a Embedding Layer.""" 9 | 10 | def __init__(self, config): 11 | self.config = config 12 | 13 | hidden_size_orig = self.config["hidden_size_orig"] 14 | h_size = self.config["gnn_h_size"] 15 | 16 | self.weights = {} 17 | 18 | self.weights["mapping"] = utils.MLP( 19 | hidden_size_orig, 20 | h_size, 21 | self.config["embedding_layer"]["mapping_dims"], 22 | "relu", 23 | "mapping", 24 | ) 25 | 26 | 27 | class EmbeddingLayer(PropagationModelLayer): 28 | """Implementation of the Embedding Layer.""" 29 | 30 | def __init__(self, config, state): 31 | super().__init__() 32 | 33 | self.config = config 34 | self.state = state 35 | 36 | def compute_embeddings(self, embeddings: tf.Tensor) -> tf.Tensor: 37 | """ 38 | Uses the model layer to process embeddings to new embeddings. All embeddings are in one dimension. 39 | Propagation is made in one pass with many disconnected graphs. 40 | 41 | Args: 42 | embeddings: Tensor of shape [v, h]. 43 | 44 | Returns: 45 | Tensor of shape [v, h]. 46 | """ 47 | embeddings_new = self.state.weights["mapping"](embeddings) # [v, h] 48 | 49 | return embeddings_new 50 | -------------------------------------------------------------------------------- /compy/models/graphs/tf/layer/gnn_model_layer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from compy.models.graphs.tf.layer.propagation_model_layer import PropagationModelLayer 5 | 6 | 7 | def glorot_init(shape): 8 | initialization_range = np.sqrt(6.0 / (shape[-2] + shape[-1])) 9 | return np.random.uniform( 10 | low=-initialization_range, high=initialization_range, size=shape 11 | ).astype(np.float32) 12 | 13 | 14 | class GGNNModelLayerState(object): 15 | """Holds the state / weights of a GGNN Layer.""" 16 | 17 | def __init__(self, config): 18 | self.config = config 19 | 20 | h_dim = self.config["gnn_h_size"] 21 | num_edge_types = self.config["num_edge_types"] 22 | 23 | self.weights = {} 24 | 25 | edge_weights = tf.Variable( 26 | glorot_init([num_edge_types * h_dim, h_dim]), name="edge_weights" 27 | ) 28 | self.weights["edge_weights"] = tf.reshape( 29 | edge_weights, [num_edge_types, h_dim, h_dim] 30 | ) 31 | 32 | if self.config["use_edge_bias"] == 1: 33 | self.weights["edge_biases"] = tf.Variable( 34 | np.zeros([num_edge_types, h_dim], dtype=np.float32), 35 | name="gnn_edge_biases", 36 | ) 37 | 38 | cell_type = config["graph_rnn_cell"] 39 | activation_fun = tf.nn.tanh 40 | if cell_type == "gru": 41 | cell = tf.compat.v1.keras.layers.GRUCell(h_dim, activation=activation_fun) 42 | elif cell_type == "cudnncompatiblegrucell": 43 | import tensorflow.contrib.cudnn_rnn as cudnn_rnn 44 | 45 | cell = cudnn_rnn.CudnnCompatibleGRUCell(h_dim) 46 | elif cell_type == "rnn": 47 | cell = tf.nn.rnn_cell.BasicRNNCell(h_dim, activation=activation_fun) 48 | else: 49 | raise Exception("Unknown RNN cell type '%s'." % cell_type) 50 | self.weights["rnn_cells"] = cell 51 | 52 | 53 | class GGNNModelLayer(PropagationModelLayer): 54 | """Implementation of the model described in 55 | Li, Yujia, et al. "Gated graph sequence neural networks." 56 | 57 | Sparse implementation from: https://github.com/microsoft/gated-graph-neural-network-samples""" 58 | 59 | def __init__(self, config, state): 60 | super().__init__() 61 | 62 | self.config = config 63 | self.state = state 64 | 65 | num_edge_types = self.config["num_edge_types"] 66 | 67 | # Placeholders 68 | h_dim = self.config["gnn_h_size"] 69 | self.placeholders["adjacency_lists"] = [ 70 | tf.compat.v1.placeholder(tf.int32, [None, 2], name="adjacency_e%s" % e) 71 | for e in range(num_edge_types) 72 | ] 73 | 74 | if self.config["use_edge_bias"] == 1: 75 | self.placeholders["num_incoming_edges_per_type"] = tf.compat.v1.placeholder( 76 | tf.float32, [None, num_edge_types], name="num_incoming_edges_per_type" 77 | ) 78 | 79 | def compute_embeddings(self, embeddings: tf.Tensor) -> tf.Tensor: 80 | """ 81 | Uses the model layer to process embeddings to new embeddings. All embeddings are in one dimension. 82 | Propagation is made in one pass with many disconnected graphs. 83 | 84 | Args: 85 | embeddings: Tensor of shape [v, h]. 86 | 87 | Returns: 88 | Tensor of shape [v, h]. 89 | """ 90 | num_nodes = tf.shape(embeddings, out_type=tf.int32)[0] 91 | 92 | # Get all edge targets (aggregate of typed edges) 93 | edge_targets = [] # list of tensors of message targets of shape [e] 94 | for edge_type_idx, adjacency_list_for_edge_type in enumerate( 95 | self.placeholders["adjacency_lists"] 96 | ): 97 | edge_targets_for_one_type = adjacency_list_for_edge_type[:, 1] 98 | edge_targets.append(edge_targets_for_one_type) 99 | edge_targets = tf.concat(edge_targets, axis=0) # [M] 100 | 101 | # Propagate 102 | embeddings = [embeddings] 103 | for step in range(self.config["num_timesteps"]): 104 | messages = [] # list of tensors of messages of shape [e, h] 105 | message_source_states = ( 106 | [] 107 | ) # list of tensors of edge source states of shape [e, h] 108 | 109 | # Collect incoming messages per edge type 110 | for edge_type_idx, adjacency_list_for_edge_type in enumerate( 111 | self.placeholders["adjacency_lists"] 112 | ): 113 | edge_sources = adjacency_list_for_edge_type[:, 0] 114 | edge_source_states = tf.nn.embedding_lookup( 115 | params=embeddings, ids=edge_sources 116 | ) # [e, h] 117 | all_messages_for_edge_type = tf.matmul( 118 | edge_source_states, 119 | self.state.weights["edge_weights"][edge_type_idx], 120 | ) # Shape [e, h] 121 | messages.append(all_messages_for_edge_type) 122 | message_source_states.append(edge_source_states) 123 | 124 | messages = tf.concat(messages, axis=0) # [M, h] 125 | 126 | messages = tf.math.unsorted_segment_sum( 127 | data=messages, segment_ids=edge_targets, num_segments=num_nodes 128 | ) # [v, h] 129 | 130 | if self.config["use_edge_bias"] == 1: 131 | embeddings += tf.matmul( 132 | self.placeholders["num_incoming_edges_per_type"], 133 | self.state.weights["edge_biases"], 134 | ) 135 | 136 | # pass updated vertex features into RNN cell 137 | embeddings = self.state.weights["rnn_cells"](messages, embeddings)[ 138 | 1 139 | ] # [v, h] 140 | 141 | return embeddings 142 | -------------------------------------------------------------------------------- /compy/models/graphs/tf/layer/propagation_model_layer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class PropagationModelLayer(object): 5 | def __init__(self): 6 | self.placeholders = {} 7 | 8 | def compute_embeddings(self, embeddings: tf.Tensor) -> tf.Tensor: 9 | """Uses the model layer to process embeddings to new embeddings. All embeddings are in one dimension. 10 | Propagation is made in one pass with many disconnected graphs. 11 | 12 | Args: 13 | embeddings: Tensor of shape [V, D]. 14 | 15 | Returns: 16 | Tensor of shape [V, D]. 17 | """ 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /compy/models/graphs/tf/layer/propagation_model_layer_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | tf.compat.v1.disable_eager_execution() 5 | 6 | from compy.models.graphs.tf.layer.gnn_model_layer import ( 7 | GGNNModelLayer, 8 | GGNNModelLayerState, 9 | ) 10 | from compy.models.graphs.tf import test_utils 11 | from compy.models.graphs.tf import utils 12 | 13 | 14 | CONFIG = { 15 | "num_timesteps": 4, 16 | "graph_rnn_cell": "gru", 17 | "gnn_h_size": 4, 18 | "num_edge_types": 2, 19 | "use_edge_bias": 0, 20 | } 21 | 22 | 23 | def run_propagation_1_layer(graph_layer): 24 | # Get data 25 | test_data = test_utils.get_test_data(CONFIG, 1) 26 | 27 | # Process embeddings 28 | embeddings_in = tf.compat.v1.placeholder( 29 | tf.float32, [None, test_data["h_dim"]], name="embeddings_in" 30 | ) 31 | embeddings_out = graph_layer.compute_embeddings(embeddings_in) 32 | 33 | with tf.compat.v1.Session() as session: 34 | session.run(tf.compat.v1.global_variables_initializer()) 35 | 36 | fetch_list = [embeddings_out] 37 | feed_dict = { 38 | graph_layer.placeholders["adjacency_lists"][0]: test_data[ 39 | "adjacency_lists" 40 | ][0], 41 | graph_layer.placeholders["adjacency_lists"][1]: test_data[ 42 | "adjacency_lists" 43 | ][1], 44 | embeddings_in: test_data["node_features"], 45 | } 46 | 47 | result = session.run(fetch_list, feed_dict=feed_dict) 48 | 49 | # Check if shape is [v_dim, h_dim] 50 | assert len(result[0][0]) == test_data["v_dim"] 51 | assert len(result[0][0][0]) == test_data["h_dim"] 52 | 53 | 54 | def run_propagation_model_1_layer_sparse(graph_layer): 55 | # Get data 56 | test_data = test_utils.get_test_data(CONFIG, 1) 57 | 58 | # Process embeddings 59 | embeddings_in = tf.compat.v1.placeholder( 60 | tf.float32, [None, test_data["h_dim"]], name="embeddings_in" 61 | ) 62 | embeddings_out = graph_layer.compute_embeddings(embeddings_in) 63 | 64 | test_data["adjacency_lists"] = {0: [[0, 1]], 1: [[0, 1]]} 65 | 66 | with tf.compat.v1.Session() as session: 67 | session.run(tf.compat.v1.global_variables_initializer()) 68 | 69 | fetch_list = [embeddings_out] 70 | feed_dict = { 71 | graph_layer.placeholders["adjacency_lists"][0]: test_data[ 72 | "adjacency_lists" 73 | ][0], 74 | graph_layer.placeholders["adjacency_lists"][1]: test_data[ 75 | "adjacency_lists" 76 | ][1], 77 | embeddings_in: test_data["node_features"], 78 | } 79 | 80 | result = session.run(fetch_list, feed_dict=feed_dict) 81 | 82 | assert result 83 | 84 | 85 | # GNN Tests 86 | def test_ggnn_propagation_model_1_layer(): 87 | state = GGNNModelLayerState(CONFIG) 88 | graph_layer = GGNNModelLayer(CONFIG, state) 89 | 90 | run_propagation_1_layer(graph_layer) 91 | 92 | 93 | def test_ggnn_propagation_model_1_layer_sparse(): 94 | state = GGNNModelLayerState(CONFIG) 95 | graph_layer = GGNNModelLayer(CONFIG, state) 96 | 97 | run_propagation_model_1_layer_sparse(graph_layer) 98 | -------------------------------------------------------------------------------- /compy/models/graphs/tf/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from compy.models.graphs.tf import utils 4 | 5 | 6 | ONE_GRAPH = [(0, 0, 1), (1, 0, 2), (2, 1, 1)] 7 | TWO_GRAPHS = [(0, 0, 1), (1, 0, 2), (2, 1, 1), (3, 0, 4), (4, 1, 5), (5, 0, 4)] 8 | 9 | 10 | # Helper functions 11 | def get_test_data( 12 | config: dict, num_graphs: int 13 | ) -> (np.ndarray, dict, list, list, int, int): 14 | h_dim = config["gnn_h_size"] 15 | v_dim = len(ONE_GRAPH) 16 | 17 | if num_graphs == 1: 18 | node_features = np.ones((v_dim, h_dim)) 19 | adjacency_lists = utils.graph_to_adjacency_lists(ONE_GRAPH, False)[0] 20 | embeddings_to_graph_mappings = [[0, 0, 0]] 21 | embeddings_last_added_node_idxs = [2] 22 | last_added_node_types = [1] 23 | labels = np.array([[1, 0]]) 24 | a1_labels = np.array([1]) 25 | a2_labels = np.array([1]) 26 | a3_labels = np.array([[0, 1], [0, 0], [0, 0]]) 27 | 28 | elif num_graphs == 2: 29 | node_features = np.ones((v_dim * 2, h_dim)) 30 | adjacency_lists = utils.graph_to_adjacency_lists(TWO_GRAPHS, False)[0] 31 | embeddings_to_graph_mappings = [[0, 0, 0, 1, 1, 1]] 32 | embeddings_last_added_node_idxs = [2, 5] 33 | last_added_node_types = [1, 2] 34 | labels = np.array([[1, 0], [1, 0]]) 35 | a1_labels = np.array([1]) 36 | a2_labels = np.array([1]) 37 | a3_labels = np.array([[0, 0], [1, 0], [0, 0], [0, 0], [0, 0], [1, 0]]) 38 | 39 | return { 40 | "node_features": node_features, 41 | "adjacency_lists": adjacency_lists, 42 | "embeddings_to_graph_mappings": embeddings_to_graph_mappings, 43 | "embeddings_last_added_node_idxs": embeddings_last_added_node_idxs, 44 | "last_added_node_types": last_added_node_types, 45 | "labels": labels, 46 | "a1_labels": a1_labels, 47 | "a2_labels": a2_labels, 48 | "a3_labels": a3_labels, 49 | "v_dim": v_dim, 50 | "h_dim": h_dim, 51 | } 52 | -------------------------------------------------------------------------------- /compy/models/graphs/tf/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from collections import defaultdict 4 | from typing import Dict 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | 10 | # Constants 11 | LABEL_OFFSET = 20 12 | I_OFFSET = 40 13 | 14 | # Enums 15 | ####### 16 | class AE: 17 | ( 18 | GRAPH_IDX, 19 | STEP_IDX, 20 | ACTION, 21 | LAST_ADDED_NODE_ID, 22 | LAST_ADDED_NODE_TYPE, 23 | ACTIONS, 24 | GRAPH, 25 | NODE_STATES, 26 | ADJ_LIST, 27 | ACTION_CURRENT_IDX, 28 | ACTION_CURRENT, 29 | SKIP_NEXT, 30 | SUBGRAPH_START, 31 | NUM_NODES, 32 | PROBABILITY, 33 | NUMS_INCOMING_EDGES_BY_TYPE, 34 | KERNEL_NAME, 35 | ) = range(0, 17) 36 | 37 | 38 | # Labels 39 | class L: 40 | LABEL_0, LABEL_1 = range(LABEL_OFFSET, LABEL_OFFSET + 2) 41 | 42 | 43 | # Type 44 | class T: 45 | NODES, EDGES, NODE_VALUES = range(30, 33) 46 | 47 | 48 | # Inputs 49 | class I: 50 | AUX_IN_0 = range(LABEL_OFFSET, I_OFFSET + 1) 51 | 52 | 53 | # Functions 54 | ########### 55 | def glorot_init(shape): 56 | initialization_range = np.sqrt(6.0 / (shape[-2] + shape[-1])) 57 | return np.random.uniform( 58 | low=-initialization_range, high=initialization_range, size=shape 59 | ).astype(np.float32) 60 | 61 | 62 | def graph_to_adjacency_lists( 63 | graph, tie_fwd_bkwd, edge_type_filter=[] 64 | ) -> (Dict[int, np.ndarray], Dict[int, Dict[int, int]]): 65 | adj_lists = defaultdict(list) 66 | num_incoming_edges_dicts_per_type = defaultdict(lambda: defaultdict(lambda: 0)) 67 | for src, e, dest in graph: 68 | fwd_edge_type = e 69 | if fwd_edge_type not in edge_type_filter and len(edge_type_filter) > 0: 70 | continue 71 | 72 | adj_lists[fwd_edge_type].append((src, dest)) 73 | num_incoming_edges_dicts_per_type[fwd_edge_type][dest] += 1 74 | 75 | if tie_fwd_bkwd: 76 | adj_lists[fwd_edge_type].append((dest, src)) 77 | num_incoming_edges_dicts_per_type[fwd_edge_type][src] += 1 78 | 79 | final_adj_lists = { 80 | e: np.array(sorted(lm), dtype=np.int32) for e, lm in adj_lists.items() 81 | } 82 | 83 | return final_adj_lists, num_incoming_edges_dicts_per_type 84 | 85 | 86 | def get_one_hot(targets, nb_classes): 87 | res = np.eye(nb_classes)[np.array(targets).reshape(-1)] 88 | return res.reshape(list(targets.shape) + [nb_classes]) 89 | 90 | 91 | # Classes 92 | ######### 93 | class MLP(object): 94 | def __init__(self, in_size, out_size, hid_sizes, activation, func_name): 95 | self.in_size = in_size 96 | self.out_size = out_size 97 | self.hid_sizes = hid_sizes 98 | self.activation = activation 99 | self.func_name = func_name 100 | self.params = self.make_network_params() 101 | 102 | def make_network_params(self) -> dict: 103 | dims = [self.in_size] + self.hid_sizes + [self.out_size] 104 | weight_sizes = list(zip(dims[:-1], dims[1:])) 105 | weights = [ 106 | tf.Variable(self.init_weights(s), name="%s_W_layer%i" % (self.func_name, i)) 107 | for (i, s) in enumerate(weight_sizes) 108 | ] 109 | biases = [ 110 | tf.Variable( 111 | np.zeros(s[-1]).astype(np.float32), 112 | name="%s_b_layer%i" % (self.func_name, i), 113 | ) 114 | for (i, s) in enumerate(weight_sizes) 115 | ] 116 | 117 | network_params = { 118 | "weights": weights, 119 | "biases": biases, 120 | } 121 | 122 | return network_params 123 | 124 | def init_weights(self, shape: tuple): 125 | return np.sqrt(6.0 / (shape[-2] + shape[-1])) * ( 126 | 2 * np.random.rand(*shape).astype(np.float32) - 1 127 | ) 128 | 129 | def __call__(self, inputs): 130 | acts = inputs 131 | for W, b in zip(self.params["weights"], self.params["biases"]): 132 | hid = tf.matmul(acts, W) + b 133 | if self.activation == "relu": 134 | acts = tf.nn.relu(hid) 135 | elif self.activation == "sigmoid": 136 | acts = tf.nn.sigmoid(hid) 137 | elif self.activation == "linear": 138 | acts = hid 139 | else: 140 | raise Exception("Unknown activation function: %s" % self.activation) 141 | last_hidden = hid 142 | return last_hidden 143 | -------------------------------------------------------------------------------- /compy/models/graphs/tf_graph_model_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import networkx as nx 3 | import tensorflow.python.util.deprecation as deprecation 4 | 5 | from compy.models.graphs.tf_graph_model import GnnTfModel 6 | from compy.models.graphs.tf_graph_model import GnnTfModelState 7 | from compy.representations.common import Graph 8 | 9 | # Disable TensorFlow messages and deprecation warnings 10 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 11 | deprecation._PRINT_DEPRECATION_WARNINGS = False 12 | 13 | 14 | CONFIG = { 15 | "graph_rnn_cell": "gru", 16 | "num_timesteps": 2, 17 | "gnn_h_size": 4, 18 | "gnn_m_size": 2, 19 | "num_edge_types": 1, 20 | "prediction_cell": { 21 | "mlp_f_m_dims": [], 22 | "mlp_f_m_activation": "relu", 23 | "mlp_g_m_dims": [], 24 | "mlp_g_m_activation": "relu", 25 | "mlp_reduce_dims": [], 26 | "mlp_reduce_activation": "relu", 27 | "mlp_reduce_after_aux_in_1_dims": [], 28 | "mlp_reduce_after_aux_in_1_activation": "relu", 29 | "mlp_reduce_after_aux_in_1_out_dim": 2, 30 | "mlp_reduce_after_aux_in_2_dims": [], 31 | "mlp_reduce_after_aux_in_2_activation": "sigmoid", 32 | "mlp_reduce_after_aux_in_2_out_dim": 2, 33 | "mlp_reduce_out_dim": 8, 34 | "output_dim": 2, 35 | }, 36 | "embedding_layer": {"mapping_dims": []}, 37 | "learning_rate": 0.0005, 38 | "clamp_gradient_norm": 1.0, 39 | "L2_loss_factor": 0, 40 | "batch_size": 4, 41 | "num_epochs": 1, 42 | "tie_fwd_bkwd": 1, 43 | "use_edge_bias": 0, 44 | "save_best_model_interval": 1, 45 | "with_aux_in": 0, 46 | "with_gradient_monitoring": 1, 47 | "seed": 0, 48 | } 49 | 50 | 51 | def test_train_model(): 52 | dummy_graph = nx.MultiDiGraph() 53 | dummy_graph.add_node("n1", attr="a") 54 | dummy_graph.add_node("n2", attr="b") 55 | dummy_graph.add_node("n3", attr="c") 56 | 57 | config = CONFIG 58 | config["hidden_size_orig"] = len(dummy_graph) 59 | 60 | state = GnnTfModelState(config) 61 | model = GnnTfModel(config, state) 62 | 63 | data = [ 64 | { 65 | "x": { 66 | "code_rep": Graph(dummy_graph, ["a", "b", "c"], ["dummy"]), 67 | "aux_in": [0, 0], 68 | }, 69 | "y": 0, 70 | } 71 | ] 72 | model.train(data, data) 73 | 74 | state.backup_best_weights() 75 | state.restore_best_weights() 76 | 77 | num_params = state.count_number_trainable_params() 78 | assert num_params 79 | -------------------------------------------------------------------------------- /compy/models/model.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | import time 3 | 4 | import numpy as np 5 | 6 | 7 | class Model(object): 8 | def __init__(self, config): 9 | pp = pprint.PrettyPrinter(indent=2) 10 | pp.pprint(config) 11 | 12 | self.config = config 13 | 14 | def train(self, data_train, data_valid): 15 | train_summary = [] 16 | data_train, data_valid = self._train_init(data_train, data_valid) 17 | 18 | print() 19 | for epoch in range(self.config["num_epochs"]): 20 | batch_size = self.config["batch_size"] 21 | np.random.shuffle(data_train) 22 | batches = [ 23 | data_train[i * batch_size : (i + 1) * batch_size] 24 | for i in range((len(data_train) + batch_size - 1) // batch_size) 25 | ] 26 | 27 | # Train 28 | start_time = time.time() 29 | for batch in batches: 30 | train_loss, train_accuracy = self._train_with_batch(batch) 31 | end_time = time.time() 32 | 33 | # Valid 34 | self._test_init() 35 | 36 | batch_size = self.config["batch_size"] 37 | np.random.shuffle(data_valid) 38 | batches = [ 39 | data_valid[i * batch_size : (i + 1) * batch_size] 40 | for i in range((len(data_valid) + batch_size - 1) // batch_size) 41 | ] 42 | 43 | valid_count = 0 44 | for batch in batches: 45 | batch_accuracy, _ = self._predict_with_batch(batch) 46 | valid_count += batch_accuracy * len(batch) 47 | valid_accuracy = valid_count / len(data_valid) 48 | 49 | # Logging 50 | instances_per_sec = len(data_train) / (end_time - start_time) 51 | print( 52 | "epoch: %i, train_loss: %.8f, train_accuracy: %.4f, valid_accuracy:" 53 | " %.4f, train instances/sec: %.2f" 54 | % (epoch, train_loss, train_accuracy, valid_accuracy, instances_per_sec) 55 | ) 56 | 57 | train_summary.append({"train_accuracy": train_accuracy}) 58 | train_summary.append({"valid_accuracy": valid_accuracy}) 59 | 60 | return train_summary 61 | 62 | def predict(self, data): 63 | _, pred = self._predict_with_batch(data) 64 | 65 | return pred 66 | 67 | def _train_init(self, data_train, data_valid): 68 | return data_train, data_valid 69 | 70 | def _test_init(self): 71 | pass 72 | 73 | def _train_with_batch(self, batch): 74 | raise NotImplementedError 75 | 76 | def _predict_with_batch(self, batch): 77 | raise NotImplementedError 78 | -------------------------------------------------------------------------------- /compy/models/seqs/__init__.py: -------------------------------------------------------------------------------- 1 | from .tf_seq_model import RnnTfModel 2 | -------------------------------------------------------------------------------- /compy/models/seqs/tf_seq_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | tf.compat.v1.disable_eager_execution() 5 | 6 | from compy.models.model import Model 7 | 8 | 9 | class SummaryCallback(tf.keras.callbacks.Callback): 10 | def __init__(self, summary): 11 | self.__summary = summary 12 | 13 | def on_epoch_end(self, epoch, logs=None): 14 | self.__summary["accuracy"] = logs["dense_2_accuracy"] 15 | self.__summary["loss"] = logs["loss"] 16 | 17 | 18 | class RnnTfModel(Model): 19 | def __init__(self, config=None, num_types=None): 20 | if not config: 21 | config = { 22 | "learning_rate": 0.001, 23 | "batch_size": 64, 24 | "num_epochs": 1000, 25 | } 26 | super().__init__(config) 27 | 28 | self.__num_types = num_types 29 | 30 | np.random.seed(0) 31 | 32 | # Language model. Takes as inputs source code sequences 33 | code_in = tf.keras.layers.Input(shape=(1024,), dtype="int32", name="code_in") 34 | x = tf.keras.layers.Embedding( 35 | input_dim=num_types + 1, input_length=1024, output_dim=64, name="embedding" 36 | )(code_in) 37 | x = tf.keras.layers.LSTM( 38 | 64, implementation=1, return_sequences=True, name="lstm_1" 39 | )(x) 40 | x = tf.keras.layers.LSTM(64, implementation=1, name="lstm_2")(x) 41 | langmodel_out = tf.keras.layers.Dense(2, activation="sigmoid")(x) 42 | 43 | # Auxiliary inputs. wgsize and dsize 44 | auxiliary_inputs = tf.keras.layers.Input(shape=(2,)) 45 | 46 | # Heuristic model. Takes as inputs the language model, outputs 1-hot encoded device mapping 47 | x = tf.keras.layers.Concatenate()([auxiliary_inputs, x]) 48 | x = tf.keras.layers.BatchNormalization()(x) 49 | x = tf.keras.layers.Dense(32, activation="relu")(x) 50 | out = tf.keras.layers.Dense(2, activation="sigmoid")(x) 51 | 52 | self.model = tf.keras.models.Model( 53 | inputs=[auxiliary_inputs, code_in], outputs=[out, langmodel_out] 54 | ) 55 | self.model.compile( 56 | optimizer="adam", 57 | metrics=["accuracy"], 58 | loss=["categorical_crossentropy", "categorical_crossentropy"], 59 | loss_weights=[1.0, 0.2], 60 | ) 61 | 62 | def __process_data(self, data): 63 | processed = {"sequences": [], "aux_in": [], "label": []} 64 | for item in data: 65 | processed["sequences"].append(item["x"]["code_rep"].get_token_list()) 66 | processed["aux_in"].append(item["x"]["aux_in"]) 67 | processed["label"].append(item["y"]) 68 | 69 | return processed 70 | 71 | def __process(self, data): 72 | # Pad sequences 73 | encoded = np.array( 74 | tf.keras.preprocessing.sequence.pad_sequences( 75 | data["sequences"], maxlen=1024, value=self.__num_types 76 | ) 77 | ) 78 | seqs = np.vstack([np.expand_dims(x, axis=0) for x in encoded]) 79 | 80 | aux_in = data["aux_in"] 81 | 82 | # Encode labels one-hot 83 | ys = tf.keras.utils.to_categorical(data["label"], num_classes=2) 84 | 85 | return seqs, aux_in, ys 86 | 87 | def _train_with_batch(self, batch): 88 | seqs, aux_in, ys = self.__process(self.__process_data(batch)) 89 | 90 | summary = {} 91 | callback = SummaryCallback(summary) 92 | 93 | self.model.fit( 94 | x=[np.array(aux_in), np.array(seqs)], 95 | y=[np.array(ys), np.array(ys)], 96 | epochs=1, 97 | batch_size=self.config["batch_size"], 98 | verbose=False, 99 | shuffle=True, 100 | callbacks=[callback], 101 | ) 102 | 103 | return summary["loss"], summary["accuracy"] 104 | 105 | def _predict_with_batch(self, batch): 106 | seqs, aux_in, ys = self.__process(self.__process_data(batch)) 107 | 108 | pred = self.model.predict( 109 | x=[np.array(aux_in), np.array(seqs)], batch_size=999999, verbose=False 110 | )[0] 111 | 112 | valid_accuracy = np.sum(np.argmax(pred, axis=1) == np.argmax(ys, axis=1)) / len( 113 | pred 114 | ) 115 | 116 | return valid_accuracy, pred 117 | -------------------------------------------------------------------------------- /compy/models/seqs/tf_seq_model_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow.python.util.deprecation as deprecation 3 | 4 | from compy.models.seqs.tf_seq_model import RnnTfModel 5 | from compy.representations.common import Sequence 6 | 7 | # Disable TensorFlow messages and deprecation warnings 8 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 9 | deprecation._PRINT_DEPRECATION_WARNINGS = False 10 | 11 | 12 | def test_model(): 13 | config = { 14 | "batch_size": 4, 15 | "num_epochs": 1, 16 | } 17 | model = RnnTfModel(config, num_types=3) 18 | 19 | data = [ 20 | { 21 | "x": { 22 | "code_rep": Sequence( 23 | ["a", "b", "c", "a", "b", "b", "a"], ["a", "b", "c"] 24 | ), 25 | "aux_in": [0, 0], 26 | }, 27 | "y": 0, 28 | } 29 | ] 30 | model.train(data, data) 31 | -------------------------------------------------------------------------------- /compy/representations/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import RepresentationBuilder, Sequence, Graph 2 | from .extractors import * 3 | from .ast_graphs import ASTVisitor, ASTDataVisitor, ASTDataCFGVisitor, ASTGraphBuilder 4 | from .llvm_graphs import ( 5 | LLVMCDFGVisitor, 6 | LLVMCDFGCallVisitor, 7 | LLVMCDFGPlusVisitor, 8 | LLVMProGraMLVisitor, 9 | LLVMGraphBuilder, 10 | ) 11 | from .syntax_seq import ( 12 | SyntaxSeqVisitor, 13 | SyntaxTokenkindVisitor, 14 | SyntaxTokenkindVariableVisitor, 15 | SyntaxSeqBuilder, 16 | ) 17 | from .llvm_seq import LLVMSeqVisitor, LLVMSeqBuilder 18 | -------------------------------------------------------------------------------- /compy/representations/ast_graphs.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | from compy.representations.extractors import clang_driver_scoped_options 4 | from compy.representations.extractors.extractors import Visitor 5 | from compy.representations.extractors.extractors import ClangDriver 6 | from compy.representations.extractors.extractors import ClangExtractor 7 | from compy.representations.extractors.extractors import clang 8 | from compy.representations import common 9 | 10 | 11 | def filter_type(type): 12 | if "[" in type or "]" in type: 13 | return "arrayType" 14 | elif "(" in type or ")" in type: 15 | return "fnType" 16 | elif "int" in type: 17 | return "intType" 18 | elif "float" in type: 19 | return "floatType" 20 | else: 21 | return "type" 22 | 23 | 24 | def add_ast_edges(g: nx.MultiDiGraph, node): 25 | """Add edges with attr `ast` that represent the AST parent-child relationship""" 26 | 27 | if isinstance(node, clang.graph.FunctionInfo): 28 | g.add_node(node, attr="function") 29 | for arg in node.args: 30 | g.add_node(arg, attr=("argument", filter_type(arg.type))) 31 | g.add_edge(node, arg, attr="ast") 32 | 33 | g.add_node(node.entryStmt, attr=(node.entryStmt.name)) 34 | g.add_edge(node, node.entryStmt, attr="ast") 35 | 36 | if isinstance(node, clang.graph.StmtInfo): 37 | for ast_rel in node.ast_relations: 38 | g.add_node(ast_rel, attr=(ast_rel.name)) 39 | g.add_edge(node, ast_rel, attr="ast") 40 | 41 | 42 | def add_ref_edges(g: nx.MultiDiGraph, node): 43 | """Add edges with attr `data` for data references of the given node""" 44 | 45 | if isinstance(node, clang.graph.StmtInfo): 46 | for ref_rel in node.ref_relations: 47 | g.add_node(ref_rel, attr=(filter_type(ref_rel.type))) 48 | g.add_edge(node, ref_rel, attr="data") 49 | 50 | 51 | def add_cfg_edges(g: nx.MultiDiGraph, node): 52 | """Add edges with attr `cfg` or `in` for control flow for the given node""" 53 | 54 | if isinstance(node, clang.graph.FunctionInfo): 55 | for cfg_b in node.cfgBlocks: 56 | g.add_node(cfg_b, attr="cfg") 57 | for succ in cfg_b.successors: 58 | g.add_edge(cfg_b, succ, attr="cfg") 59 | g.add_node(succ, attr="cfg") 60 | for stmt in cfg_b.statements: 61 | g.add_edge(stmt, cfg_b, attr="in") 62 | g.add_node(stmt, attr=(stmt.name)) 63 | 64 | 65 | def add_token_ast_edges(g: nx.MultiDiGraph, node): 66 | """Add edges with attr `token` connecting tokens to the closest AST node covering them""" 67 | if hasattr(node, 'tokens'): 68 | for token in node.tokens: 69 | g.add_node(token, attr=token.name, seq_order=token.index) 70 | g.add_edge(node, token, attr="token") 71 | 72 | 73 | class ASTVisitor(Visitor): 74 | def __init__(self): 75 | Visitor.__init__(self) 76 | self.edge_types = ["ast"] 77 | self.G = nx.MultiDiGraph() 78 | 79 | def visit(self, v): 80 | add_ast_edges(self.G, v) 81 | 82 | 83 | class ASTDataVisitor(Visitor): 84 | def __init__(self): 85 | Visitor.__init__(self) 86 | self.edge_types = ["ast", "data"] 87 | self.G = nx.MultiDiGraph() 88 | 89 | def visit(self, v): 90 | add_ast_edges(self.G, v) 91 | add_ref_edges(self.G, v) 92 | 93 | 94 | class ASTDataCFGVisitor(Visitor): 95 | def __init__(self): 96 | Visitor.__init__(self) 97 | self.edge_types = ["ast", "cfg", "in", "data"] 98 | self.G = nx.MultiDiGraph() 99 | 100 | def visit(self, v): 101 | add_ast_edges(self.G, v) 102 | add_ref_edges(self.G, v) 103 | add_cfg_edges(self.G, v) 104 | 105 | 106 | class ASTDataCFGTokenVisitor(Visitor): 107 | def __init__(self): 108 | Visitor.__init__(self) 109 | self.edge_types = ["ast", "cfg", "in", "data", "token"] 110 | self.G = nx.MultiDiGraph() 111 | 112 | def visit(self, v): 113 | add_ast_edges(self.G, v) 114 | add_ref_edges(self.G, v) 115 | add_cfg_edges(self.G, v) 116 | add_token_ast_edges(self.G, v) 117 | 118 | 119 | class ASTGraphBuilder(common.RepresentationBuilder): 120 | def __init__(self, clang_driver=None): 121 | common.RepresentationBuilder.__init__(self) 122 | 123 | if clang_driver: 124 | self.__clang_driver = clang_driver 125 | else: 126 | self.__clang_driver = ClangDriver( 127 | ClangDriver.ProgrammingLanguage.C, 128 | ClangDriver.OptimizationLevel.O3, 129 | [], 130 | ["-Wall"], 131 | ) 132 | self.__extractor = ClangExtractor(self.__clang_driver) 133 | 134 | self.__graphs = [] 135 | 136 | def string_to_info(self, src, additional_include_dir=None, filename=None): 137 | with clang_driver_scoped_options(self.__clang_driver, additional_include_dir=additional_include_dir, filename=filename): 138 | return self.__extractor.GraphFromString(src) 139 | 140 | def info_to_representation(self, info, visitor=ASTDataVisitor): 141 | vis = visitor() 142 | info.accept(vis) 143 | 144 | for (n, data) in vis.G.nodes(data=True): 145 | attr = data["attr"] 146 | if attr not in self._tokens: 147 | self._tokens[attr] = 1 148 | self._tokens[attr] += 1 149 | 150 | return common.Graph(vis.G, self.get_tokens(), vis.edge_types) 151 | -------------------------------------------------------------------------------- /compy/representations/ast_graphs_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import networkx as nx 4 | import pytest 5 | 6 | from compy.representations.extractors.extractors import Visitor 7 | from compy.representations.extractors.extractors import clang 8 | from compy.representations.ast_graphs import ASTGraphBuilder 9 | from compy.representations.ast_graphs import ASTVisitor 10 | from compy.representations.ast_graphs import ASTDataVisitor 11 | from compy.representations.ast_graphs import ASTDataCFGVisitor 12 | from compy.representations.ast_graphs import ASTDataCFGTokenVisitor 13 | 14 | 15 | program_1fn_2 = """ 16 | int bar(int a) { 17 | if (a > 10) 18 | return a; 19 | return -1; 20 | } 21 | """ 22 | 23 | program_fib = """ 24 | int fib(int x) { 25 | switch(x) { 26 | case 0: 27 | return 0; 28 | case 1: 29 | return 1; 30 | default: 31 | return fib(x-1) + fib(x-2); 32 | } 33 | } 34 | """ 35 | 36 | 37 | # Construction 38 | def test_construct_with_custom_visitor(): 39 | class CustomVisitor(Visitor): 40 | def __init__(self): 41 | Visitor.__init__(self) 42 | self.edge_types = [] 43 | self.G = nx.DiGraph() 44 | 45 | def visit(self, v): 46 | if not isinstance(v, clang.graph.ExtractionInfo): 47 | self.G.add_node(v, attr=type(v)) 48 | 49 | builder = ASTGraphBuilder() 50 | info = builder.string_to_info(program_1fn_2) 51 | ast = builder.info_to_representation(info, CustomVisitor) 52 | 53 | assert len(ast.G) == 35 54 | 55 | 56 | # Attributes 57 | def test_get_node_list(): 58 | builder = ASTGraphBuilder() 59 | info = builder.string_to_info(program_1fn_2) 60 | ast = builder.info_to_representation(info, ASTDataVisitor) 61 | nodes = ast.get_node_list() 62 | 63 | assert len(nodes) == 14 64 | 65 | 66 | def test_get_edge_list(): 67 | builder = ASTGraphBuilder() 68 | info = builder.string_to_info(program_1fn_2) 69 | ast = builder.info_to_representation(info, ASTDataVisitor) 70 | edges = ast.get_edge_list() 71 | 72 | assert len(edges) > 0 73 | 74 | assert type(edges[0][0]) is int 75 | assert type(edges[0][1]) is int 76 | assert type(edges[0][2]) is int 77 | 78 | 79 | # Plot 80 | def test_plot(tmpdir): 81 | for visitor in [ASTDataVisitor]: 82 | builder = ASTGraphBuilder() 83 | info = builder.string_to_info(program_fib) 84 | graph = builder.info_to_representation(info, ASTDataVisitor) 85 | 86 | outfile = os.path.join(tmpdir, str(visitor.__name__) + ".png") 87 | graph.draw(path=outfile, with_legend=True) 88 | 89 | assert os.path.isfile(outfile) 90 | 91 | # os.system('xdg-open ' + str(tmpdir)) 92 | 93 | 94 | # All visitors 95 | def test_all_visitors(): 96 | for visitor in [ASTVisitor, ASTDataVisitor, ASTDataCFGVisitor]: 97 | builder = ASTGraphBuilder() 98 | info = builder.string_to_info(program_1fn_2) 99 | ast = builder.info_to_representation(info, visitor) 100 | 101 | assert ast 102 | 103 | 104 | def test_token_visitor(): 105 | builder = ASTGraphBuilder() 106 | info = builder.string_to_info(program_1fn_2) 107 | ast = builder.info_to_representation(info, ASTDataCFGTokenVisitor) 108 | 109 | assert ast 110 | token_data = sorted([data for t, data in ast.G.nodes(data=True) if 'seq_order' in data], key=lambda x: x['seq_order']) 111 | tokens = [t['attr'] for t in token_data] 112 | assert tokens == [ 113 | 'int', 'bar', '(', 'int', 'a', ')', 114 | '{', 115 | 'if', '(', 'a', '>', '10', ')', 'return', 'a', ';', 116 | 'return', '-', '1', ';', 117 | '}' 118 | ] 119 | 120 | leaves = ast.get_leaf_node_list() 121 | labels = ast.get_node_str_list() 122 | assert [labels[n] for n in leaves] == tokens 123 | -------------------------------------------------------------------------------- /compy/representations/common.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import networkx as nx 4 | import pygraphviz as pgv 5 | 6 | 7 | class RepresentationBuilder(object): 8 | def __init__(self): 9 | self._tokens = collections.OrderedDict() 10 | 11 | def num_tokens(self): 12 | return len(self._tokens) 13 | 14 | def get_tokens(self): 15 | return list(self._tokens.keys()) 16 | 17 | def print_tokens(self): 18 | print("-" * 50) 19 | print("{:<8} {:<25} {:<10}".format("NodeID", "Label", "Number")) 20 | t_view = [(v, k) for k, v in self._tokens.items()] 21 | t_view = sorted(t_view, key=lambda x: x[0], reverse=True) 22 | for v, k in t_view: 23 | idx = list(self._tokens.keys()).index(k) 24 | print("{:<8} {:<25} {:<10}".format(str(idx), str(k), str(v))) 25 | print("-" * 50) 26 | 27 | 28 | class Sequence(object): 29 | def __init__(self, S, token_types): 30 | self.S = S 31 | self.__token_types = token_types 32 | 33 | def get_token_list(self): 34 | node_ints = [self.__token_types.index(token_str) for token_str in self.S] 35 | 36 | return node_ints 37 | 38 | def size(self): 39 | return len(self.S) 40 | 41 | def draw(self, width=8, limit=30, path=None): 42 | # Create dot graph. 43 | graphviz_graph = pgv.AGraph( 44 | directed=True, 45 | splines=False, 46 | rankdir="LR", 47 | nodesep=0.001, 48 | ranksep=0.4, 49 | outputorder="edgesfirst", 50 | fillcolor="white", 51 | ) 52 | 53 | remaining_tokens = None 54 | for i, token in enumerate(self.S): 55 | if i == limit: 56 | remaining_tokens = 5 57 | 58 | if remaining_tokens is not None: 59 | if remaining_tokens > 0: 60 | token = "..." 61 | remaining_tokens -= 1 62 | else: 63 | break 64 | 65 | if i % width == 0: 66 | subgraph = graphviz_graph.subgraph( 67 | name="cluster_%i" % i, label="", color="white" 68 | ) 69 | 70 | graphviz_graph.add_node(i, label=token, shape="box") 71 | if i > 0: 72 | graphviz_graph.add_edge( 73 | i - width, i, color="white", constraint=False 74 | ) 75 | else: 76 | subgraph.add_node(i, label=token, shape="box") 77 | if i > 0: 78 | if i % width == 0: 79 | graphviz_graph.add_edge(i - 1, i, constraint=False, color="gray") 80 | else: 81 | graphviz_graph.add_edge(i - 1, i) 82 | 83 | graphviz_graph.layout("dot") 84 | 85 | return graphviz_graph.draw(path) 86 | 87 | 88 | class Graph(object): 89 | def __init__(self, graph, node_types, edge_types): 90 | self.G = graph 91 | self.__node_types = node_types 92 | self.__node_types_dict = {n: i for i, n in enumerate(node_types)} 93 | self.__edge_types = edge_types 94 | 95 | def _get_node_attr_dict(self): 96 | return collections.OrderedDict(self.G.nodes(data="attr", default="N/A")) 97 | 98 | def get_node_str_list(self): 99 | node_strs = list(self._get_node_attr_dict().values()) 100 | 101 | return node_strs 102 | 103 | def get_node_list(self): 104 | node_strs = list(self._get_node_attr_dict().values()) 105 | node_ints = [self.__node_types_dict[node_str] for node_str in node_strs] 106 | 107 | return node_ints 108 | 109 | def get_edge_list(self): 110 | nodes_keys = {n: i for i, n in enumerate(self._get_node_attr_dict().keys())} 111 | 112 | edges = [] 113 | for node1, node2, data in self.G.edges(data=True): 114 | edges.append( 115 | ( 116 | nodes_keys[node1], 117 | self.__edge_types.index(data["attr"]), 118 | nodes_keys[node2], 119 | ) 120 | ) 121 | 122 | return edges 123 | 124 | def get_leaf_node_list(self): 125 | """Return an ordered list of node indices for leaves of the graph. 126 | 127 | Only useful for graphs that are built based on a sequence (like ASTs on tokens) 128 | """ 129 | nodes_keys = list(self._get_node_attr_dict().keys()) 130 | 131 | data = { n: order for n, order in self.G.nodes(data='seq_order') if order is not None } 132 | return [nodes_keys.index(n) for n, _ in sorted(data.items(), key=lambda x: x[1])] 133 | 134 | def map_to_leaves(self, relations=None): 135 | """Map inner nodes of the graph to leaf nodes. 136 | 137 | Leaves are all nodes which have a `seq_order` attribute. 138 | Each node in the graph which is not a leaf is mapped to the descendant leaf with lowest seq_order. 139 | All edges are moved to the mapped leaf nodes. 140 | Nodes which have no leaf descendants are not transformed. 141 | 142 | The `relations` parameter specifies which edges indicate a parent-child relationship. 143 | The value of this parameter must be a dict with two keys, `child` and `parent`. 144 | For both keys, the value must be a collection of edge types of that kind. 145 | 146 | Returns the new graph. 147 | """ 148 | if relations is None: 149 | relations = { 150 | 'child': {'ast', 'token'}, 151 | 'parent': {'in', 'data'}, 152 | } 153 | relations.setdefault('parent', set()) 154 | relations.setdefault('child', set()) 155 | 156 | result = nx.MultiDiGraph() 157 | 158 | # Map nodes to leaf node 159 | leaf_for_node = {} 160 | 161 | # Walk leaves in sequential order, so sequentially first leaves get priority 162 | leaves = { n: data for n, data in self.G.nodes(data=True) if 'seq_order' in data } 163 | for leaf, data in sorted(leaves.items(), key=lambda x: x[1]['seq_order']): 164 | result.add_node(leaf, **data) 165 | # Walk the tree upwards, and assign any unassigned nodes to this leaf 166 | todo = { leaf } 167 | while todo: 168 | n = todo.pop() 169 | if n in leaf_for_node and leaf_for_node[n][1] <= data['seq_order']: 170 | # this node has already been assigned to an earlier leaf, so no need to traverse further 171 | continue 172 | leaf_for_node[n] = (leaf, data['seq_order']) 173 | todo.update(set(target for _, target, attr in self.G.out_edges(n, data='attr') if attr in relations['parent'])) 174 | todo.update(set(source for source, _, attr in self.G.in_edges(n, data='attr') if attr in relations['child'])) 175 | # ensure leaves are never mapped to other leaves 176 | # this makes sure that the function is idempotent 177 | leaf_for_node[leaf] = (leaf, data['seq_order']) 178 | 179 | # Translate edges 180 | for source, target, data in self.G.edges(data=True): 181 | source = leaf_for_node.get(source, (source, 0))[0] 182 | target = leaf_for_node.get(target, (target, 0))[0] 183 | 184 | if source in leaf_for_node: 185 | source = leaf_for_node[source][0] 186 | else: 187 | result.add_node(source, **self.G.nodes(data=True)[source]) 188 | 189 | if target in leaf_for_node: 190 | target = leaf_for_node[target][0] 191 | else: 192 | result.add_node(target, **self.G.nodes(data=True)[target]) 193 | 194 | result.add_edge(source, target, **data) 195 | 196 | return Graph(result, list(self.__node_types), list(self.__edge_types)) 197 | 198 | def size(self): 199 | return len(self.G) 200 | 201 | def draw(self, path=None, with_legend=False, align_tokens=True): 202 | # Copy graph object because attr modifications for a cleaner view are needed. 203 | G = self.G 204 | 205 | # Add node labels. 206 | for (n, data) in G.nodes(data=True): 207 | if "attr" in data: 208 | if type(data["attr"]) is tuple: 209 | label = "\n".join(data["attr"]) 210 | else: 211 | label = data["attr"] 212 | 213 | G.nodes[n]["label"] = label 214 | 215 | # Add edge colors. 216 | edge_colors_by_types = { 217 | "ast": "black", 218 | "cfg": "green", 219 | "data": "blue", 220 | "mem": "pink", 221 | "call": "yellow", 222 | } 223 | edge_colors_available = ["orange", "pink", "cyan", "crimson", "darkgreen", "darkblue", "darkcyan"] 224 | for etype in self.__edge_types: 225 | if etype in edge_colors_by_types: continue 226 | edge_colors_by_types[etype] = edge_colors_available.pop(0) 227 | 228 | for u, v, key, data in G.edges(keys=True, data=True): 229 | edge_type = data["attr"] 230 | if edge_type not in edge_colors_by_types: 231 | edge_colors_by_types[edge_type] = edge_colors_available.pop(0) 232 | 233 | G[u][v][key]["color"] = edge_colors_by_types[edge_type] 234 | 235 | # G[u][v][key]['weight'] = 10 if edge_type == 'cfg' else 0 236 | 237 | # Create dot graph. 238 | graphviz_graph = nx.drawing.nx_agraph.to_agraph(G) 239 | 240 | # Add Legend. 241 | if with_legend: 242 | edge_types_used = set() 243 | for (u, v, key, data) in G.edges(keys=True, data=True): 244 | edge_type = data["attr"] 245 | edge_types_used.add(edge_type) 246 | 247 | subgraph = graphviz_graph.subgraph(name="cluster", label="Edges") 248 | for edge_type, color in edge_colors_by_types.items(): 249 | if edge_type in edge_types_used: 250 | subgraph.add_node(edge_type, color="invis", fontcolor=color) 251 | 252 | # Put all tokens on single level ("rank") and enforce order 253 | if align_tokens: 254 | tokens = graphviz_graph.subgraph(rank="sink", rankdir="LR") 255 | leaves = { n: data for n, data in self.G.nodes(data=True) if 'seq_order' in data } 256 | leaf_nodes = list(sorted(leaves.items(), key=lambda x: x[1]['seq_order'])) 257 | for a, b in zip(leaf_nodes, leaf_nodes[1:]) : 258 | tokens.add_edge(a[0], b[0], color="invis") 259 | 260 | graphviz_graph.layout("dot") 261 | return graphviz_graph.draw(path) 262 | -------------------------------------------------------------------------------- /compy/representations/common_test.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | import compy.representations.common as common 4 | 5 | 6 | def sample_graph(): 7 | G = nx.MultiDiGraph() 8 | G.add_node("root1", attr="root") 9 | for n in ["n1", "n2", "n3", "n4", "n5"]: 10 | G.add_node(n, attr=n) 11 | for n in ["n1", "n2", "n3"]: 12 | G.add_edge("root1", n, attr="child") 13 | G.add_edge("n4", "n3", attr="parent") 14 | G.add_edge("n4", "n5", attr="child") 15 | for l in range(7): 16 | G.add_node("l" + str(l+1), attr="leaf" + str(l+1), seq_order=l) 17 | G.add_edge("n1", "l1", attr="token") 18 | G.add_edge("n1", "l2", attr="token") 19 | G.add_edge("n2", "l3", attr="token") 20 | G.add_edge("root1", "l4", attr="token") 21 | G.add_edge("n4", "l5", attr="token") 22 | G.add_edge("n4", "l6", attr="token") 23 | G.add_edge("n4", "l7", attr="token") 24 | G.add_edge("root1", "n3", attr="rel2") 25 | G.add_edge("l1", "l2", attr="token_rel") 26 | G.add_edge("n2", "l2", attr="node_token_rel") 27 | return common.Graph(G, 28 | list(sorted(set(attr for _, attr in G.nodes(data="attr")))), 29 | list(sorted(set(attr for _, _, attr in G.edges(data="attr"))))) 30 | 31 | 32 | def test_map_to_leaves(): 33 | graph = sample_graph() 34 | relations = {'child': {'token', 'child'}, 'parent': {'parent'}} 35 | leaves_only = graph.map_to_leaves(relations) 36 | 37 | # n5 is kept because it has no child nodes 38 | assert sorted(leaves_only.get_node_str_list()) == ['leaf1', 'leaf2', 'leaf3', 'leaf4', 'leaf5', 'leaf6', 'leaf7', 'n5'] 39 | 40 | # map to leaves should be idempotent 41 | assert sorted(graph.map_to_leaves(relations).G.edges(data='attr')) == sorted( 42 | leaves_only.map_to_leaves(relations).G.edges(data='attr')) 43 | 44 | edges = sorted(leaves_only.G.edges(data='attr')) 45 | expected_edges = [ 46 | ('l1', 'l1', 'token'), 47 | ('l1', 'l2', 'token'), 48 | ('l3', 'l3', 'token'), 49 | ('l1', 'l4', 'token'), 50 | ('l5', 'l5', 'token'), 51 | ('l5', 'l6', 'token'), 52 | ('l5', 'l7', 'token'), 53 | ('l1', 'l1', 'child'), 54 | ('l1', 'l3', 'child'), 55 | ('l1', 'l5', 'child'), 56 | ('l5', 'n5', 'child'), 57 | ('l1', 'l2', 'token_rel'), 58 | ('l3', 'l2', 'node_token_rel'), 59 | ('l1', 'l5', 'rel2'), 60 | ('l5', 'l5', 'parent'), 61 | ] 62 | assert edges == sorted(expected_edges) 63 | 64 | without_parent = graph.map_to_leaves({'child': ['token', 'child']}) 65 | assert sorted(without_parent.get_node_str_list()) == ['leaf1', 'leaf2', 'leaf3', 'leaf4', 'leaf5', 'leaf6', 'leaf7', 'n3', 'n5'] 66 | 67 | def test_map_to_leaves_cycle(): 68 | cycle = nx.MultiDiGraph() 69 | for node in ["0", "1", "2", "4"]: 70 | cycle.add_node(node, attr=node) 71 | cycle.add_edge("0", "1", attr='flow') 72 | cycle.add_edge("1", "2", attr='flow') 73 | cycle.add_edge("2", "0", attr='flow') 74 | cycle.add_edge("4", "0", attr='flow') 75 | cycle.add_node('leaf', attr='leaf', seq_order=0) 76 | cycle.add_edge("0", 'leaf', attr='flow') 77 | 78 | graph = common.Graph(cycle, [], ['leaf', "0", "1", "2", "4"]) 79 | mapped = graph.map_to_leaves({'child': 'flow'}) 80 | assert sorted(mapped.G.edges(data=False)) == [("leaf", "leaf")] * 5 81 | -------------------------------------------------------------------------------- /compy/representations/extractors/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Common compile options 2 | set(compile_options_common 3 | -fno-rtti -fPIC 4 | -Wall -Wextra 5 | # LLVM warnings 6 | -Wno-unused-parameter -Wno-tautological-overlap-compare -Wno-deprecated-copy -Wno-class-memaccess -Wno-maybe-uninitialized 7 | # Clang warnings 8 | -Wno-comment -Wno-strict-aliasing) 9 | 10 | # Common library 11 | add_library(extractors_common 12 | common/clang_driver.cc) 13 | llvm_map_components_to_libnames(REQ_LLVM_LIBRARIES ${LLVM_TARGETS_TO_BUILD} 14 | asmparser 15 | core 16 | linker 17 | bitreader 18 | irreader 19 | ipo 20 | scalaropts 21 | analysis 22 | support 23 | frontendopenmp 24 | option 25 | passes 26 | objcarcopts 27 | coroutines 28 | lto 29 | coverage 30 | aarch64codegen 31 | ) 32 | target_include_directories(extractors_common PUBLIC 33 | . 34 | ${LLVM_INCLUDE_DIRS} 35 | ${CLANG_INCLUDE_DIRS} 36 | ) 37 | 38 | # if tools are linked against llvm shared object, we need to do the same 39 | # otherwise, we end up with two versions (shared and static) of llvm libs 40 | if(LLVM_LINK_LLVM_DYLIB) 41 | set(REQ_LLVM_LIBRARIES LLVM) 42 | endif() 43 | 44 | # if clang is built as a shared lib, use that, otherwise link to the static components 45 | if(DEFINED CLANG_LINK_CLANG_DYLIB AND CLANG_LINK_CLANG_DYLIB) 46 | set(REQ_CLANG_LIBRARIES clang-cpp) 47 | else() 48 | set(REQ_CLANG_LIBRARIES clangBasic clangFrontendTool) 49 | endif() 50 | 51 | target_link_libraries(extractors_common 52 | -Wl,--start-group 53 | ${REQ_LLVM_LIBRARIES} 54 | ${REQ_CLANG_LIBRARIES} 55 | -Wl,--end-group 56 | ) 57 | target_compile_options(extractors_common PRIVATE 58 | ${compile_options_common} 59 | ) 60 | 61 | # Common tests 62 | add_executable(extractors_common_tests 63 | common/clang_driver_test.cc 64 | ) 65 | target_link_libraries(extractors_common_tests 66 | extractors_common 67 | 68 | gmock 69 | gtest 70 | gtest_main 71 | ) 72 | target_compile_options(extractors_common_tests PRIVATE 73 | -fno-rtti -fPIC 74 | ) 75 | 76 | # LLVM pybind11 module 77 | pybind11_add_module(extractors 78 | extractors.cc 79 | ) 80 | target_link_libraries(extractors PRIVATE 81 | clang_extractor 82 | llvm_extractor 83 | ) 84 | target_compile_options(extractors PRIVATE 85 | -Wno-unused-value 86 | ) 87 | 88 | add_subdirectory(clang_ast) 89 | add_subdirectory(llvm_ir) -------------------------------------------------------------------------------- /compy/representations/extractors/__init__.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import functools 3 | import itertools 4 | import shutil 5 | import subprocess 6 | import warnings 7 | from typing import Optional 8 | 9 | from .extractors import ClangDriver, LLVM_VERSION 10 | del extractors # HACK: don't override extractors 11 | 12 | 13 | @functools.lru_cache() 14 | def clang_binary_path(): 15 | """Find the clang compiler binary, trying to match the version that the native extension was compiled with. 16 | 17 | Prints a warning if the binary that was found is a different version. 18 | Raises RuntimeException if no clang compiler is found at all. 19 | """ 20 | best_score = None 21 | best_version = None 22 | best_path = None 23 | 24 | components = LLVM_VERSION.split('.') 25 | for suffix_len in reversed(range(len(components) + 1)): 26 | suffix = '.'.join(components[:suffix_len]) 27 | path = shutil.which(f"clang-{suffix}" if suffix else "clang") 28 | if path is None: 29 | continue 30 | 31 | llvm_config_path = subprocess.run( 32 | [path, "-print-prog-name=llvm-config"], 33 | check=True, stdout=subprocess.PIPE 34 | ).stdout.decode().strip() 35 | 36 | this_version = subprocess.run( 37 | [llvm_config_path, "--version"], 38 | check=True, stdout=subprocess.PIPE 39 | ).stdout.decode().strip() 40 | 41 | if this_version == LLVM_VERSION: 42 | return path 43 | 44 | score = sum(1 for _ in itertools.takewhile(lambda x: x[0] == x[1], zip(this_version, LLVM_VERSION))) 45 | if best_score is None or score > best_score: 46 | best_score = score 47 | best_version = this_version 48 | best_path = path 49 | 50 | if not best_path: 51 | raise RuntimeError("cannot find clang compiler binary in PATH") 52 | 53 | warnings.warn(f"found clang compiler at {best_path} for LLVM {best_version}, but native extension was compiled " 54 | f"against LLVM {LLVM_VERSION}") 55 | return best_path 56 | 57 | 58 | @contextlib.contextmanager 59 | def clang_driver_scoped_options(clang_driver, additional_include_dir: Optional[str] = None, filename: Optional[str] = None): 60 | """A context manager to set and restore options for the clang driver in a local scope. 61 | 62 | >>> with clang_driver_scoped_options(clang_driver, filename="foo"): 63 | ... # clang_driver's file name is set to foo in this scope 64 | ... pass 65 | ... # filename is restored to previous value after the with block 66 | """ 67 | prev_filename = None 68 | if filename is not None: 69 | prev_filename = clang_driver.getFileName() 70 | clang_driver.setFileName(filename) 71 | 72 | if additional_include_dir: 73 | clang_driver.addIncludeDir( 74 | additional_include_dir, ClangDriver.IncludeDirType.User 75 | ) 76 | 77 | try: 78 | yield clang_driver 79 | finally: 80 | if prev_filename is not None: 81 | clang_driver.setFileName(prev_filename) 82 | if additional_include_dir: 83 | clang_driver.removeIncludeDir( 84 | additional_include_dir, ClangDriver.IncludeDirType.User 85 | ) 86 | 87 | -------------------------------------------------------------------------------- /compy/representations/extractors/clang_ast/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Extractor library 2 | add_library(clang_extractor 3 | clang_extractor.cc 4 | clang_graph_frontendaction.cc 5 | clang_seq_frontendaction.cc 6 | ) 7 | target_link_libraries(clang_extractor 8 | extractors_common 9 | ) 10 | target_compile_options(clang_extractor PRIVATE 11 | ${compile_options_common} 12 | ) 13 | 14 | # Extractor tests 15 | add_executable(clang_extractor_tests 16 | clang_extractor_test.cc 17 | ) 18 | target_link_libraries(clang_extractor_tests 19 | clang_extractor 20 | 21 | gmock 22 | gtest 23 | gtest_main 24 | ) 25 | target_compile_options(clang_extractor_tests PRIVATE 26 | -fno-rtti -fPIC 27 | ) -------------------------------------------------------------------------------- /compy/representations/extractors/clang_ast/clang_extractor.cc: -------------------------------------------------------------------------------- 1 | #include "clang_extractor.h" 2 | 3 | #include 4 | 5 | #include "clang/Config/config.h" 6 | #include "clang/Frontend/CompilerInvocation.h" 7 | #include "clang/Lex/PreprocessorOptions.h" 8 | #include "llvm/LinkAllPasses.h" 9 | #include "llvm/Support/Compiler.h" 10 | 11 | #include "clang_graph_frontendaction.h" 12 | #include "clang_seq_frontendaction.h" 13 | 14 | using namespace ::clang; 15 | using namespace ::llvm; 16 | 17 | namespace compy { 18 | namespace clang { 19 | 20 | ClangExtractor::ClangExtractor(ClangDriverPtr clangDriver) 21 | : clangDriver_(clangDriver) {} 22 | 23 | graph::ExtractionInfoPtr ClangExtractor::GraphFromString(std::string src) { 24 | auto fa = std::make_unique(); 25 | 26 | std::vector<::clang::FrontendAction *> frontendActions; 27 | std::vector<::llvm::Pass *> passes; 28 | 29 | frontendActions.push_back(fa.get()); 30 | 31 | clangDriver_->Invoke(src, frontendActions, passes); 32 | 33 | return fa->extractionInfo; 34 | } 35 | 36 | seq::ExtractionInfoPtr ClangExtractor::SeqFromString(std::string src) { 37 | auto fa = std::make_unique(); 38 | 39 | std::vector<::clang::FrontendAction *> frontendActions; 40 | std::vector<::llvm::Pass *> passes; 41 | 42 | frontendActions.push_back(fa.get()); 43 | 44 | clangDriver_->Invoke(src, frontendActions, passes); 45 | 46 | return fa->extractionInfo; 47 | } 48 | 49 | } // namespace clang 50 | } // namespace compy 51 | -------------------------------------------------------------------------------- /compy/representations/extractors/clang_ast/clang_extractor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "common/clang_driver.h" 9 | #include "common/visitor.h" 10 | 11 | namespace compy { 12 | namespace clang { 13 | 14 | namespace seq { 15 | struct FunctionInfo; 16 | using FunctionInfoPtr = std::shared_ptr; 17 | 18 | struct ExtractionInfo; 19 | using ExtractionInfoPtr = std::shared_ptr; 20 | 21 | struct TokenInfo; 22 | using TokenInfoPtr = std::shared_ptr; 23 | 24 | struct TokenInfo : IVisitee { 25 | std::string name; 26 | std::string kind; 27 | 28 | void accept(IVisitor* v) override { v->visit(this); } 29 | }; 30 | 31 | struct FunctionInfo : IVisitee { 32 | std::string name; 33 | std::vector tokenInfos; 34 | 35 | void accept(IVisitor* v) override { 36 | v->visit(this); 37 | for (const auto& it : tokenInfos) it->accept(v); 38 | } 39 | }; 40 | 41 | struct ExtractionInfo : IVisitee { 42 | std::vector functionInfos; 43 | 44 | void accept(IVisitor* v) override { 45 | v->visit(this); 46 | for (const auto& it : functionInfos) it->accept(v); 47 | } 48 | }; 49 | } // namespace seq 50 | 51 | namespace graph { 52 | struct OperandInfo; 53 | using OperandInfoPtr = std::shared_ptr; 54 | 55 | struct DeclInfo; 56 | using DeclInfoPtr = std::shared_ptr; 57 | 58 | struct StmtInfo; 59 | using StmtInfoPtr = std::shared_ptr; 60 | 61 | struct CFGBlockInfo; 62 | using CFGBlockInfoPtr = std::shared_ptr; 63 | 64 | struct FunctionInfo; 65 | using FunctionInfoPtr = std::shared_ptr; 66 | 67 | struct ExtractionInfo; 68 | using ExtractionInfoPtr = std::shared_ptr; 69 | 70 | struct TokenInfo : IVisitee { 71 | std::uint64_t index; 72 | std::string name; 73 | std::string kind; 74 | ::clang::SourceLocation location; 75 | 76 | void accept(IVisitor* v) override { v->visit(this); } 77 | }; 78 | 79 | struct OperandInfo : IVisitee { 80 | virtual ~OperandInfo() = default; 81 | }; 82 | 83 | struct DeclInfo : OperandInfo { 84 | std::string name; 85 | std::string type; 86 | std::string kind; 87 | std::vector tokens; 88 | TokenInfo nameToken; 89 | 90 | void accept(IVisitor* v) override { 91 | v->visit(this); 92 | for (auto& it : tokens) it.accept(v); 93 | } 94 | }; 95 | 96 | struct StmtInfo : OperandInfo { 97 | std::string name; 98 | std::vector tokens; 99 | std::string operation; 100 | std::vector ast_relations; 101 | std::vector ref_relations; 102 | 103 | void accept(IVisitor* v) override { 104 | v->visit(this); 105 | for (auto& it : tokens) it.accept(v); 106 | for (const auto& it : ast_relations) it->accept(v); 107 | } 108 | }; 109 | 110 | struct CFGBlockInfo { 111 | std::string name; 112 | std::vector statements; 113 | std::vector successors; 114 | }; 115 | 116 | struct FunctionInfo : IVisitee { 117 | std::string name; 118 | std::string type; 119 | std::vector tokens; 120 | std::vector args; 121 | std::vector cfgBlocks; 122 | StmtInfoPtr entryStmt; 123 | 124 | void accept(IVisitor* v) override { 125 | v->visit(this); 126 | for (auto& it : tokens) it.accept(v); 127 | for (const auto& it : args) it->accept(v); 128 | entryStmt->accept(v); 129 | } 130 | }; 131 | 132 | struct ExtractionInfo : IVisitee { 133 | std::vector functionInfos; 134 | 135 | void accept(IVisitor* v) override { 136 | v->visit(this); 137 | for (const auto& it : functionInfos) it->accept(v); 138 | } 139 | }; 140 | } // namespace graph 141 | 142 | class ClangExtractor { 143 | public: 144 | ClangExtractor(ClangDriverPtr clangDriver); 145 | 146 | graph::ExtractionInfoPtr GraphFromString(std::string src); 147 | seq::ExtractionInfoPtr SeqFromString(std::string src); 148 | 149 | private: 150 | ClangDriverPtr clangDriver_; 151 | }; 152 | 153 | } // namespace clang 154 | } // namespace compy 155 | -------------------------------------------------------------------------------- /compy/representations/extractors/clang_ast/clang_extractor_test.cc: -------------------------------------------------------------------------------- 1 | #include "clang_extractor.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "common/clang_driver.h" 8 | #include "common/common_test.h" 9 | #include "gtest/gtest.h" 10 | 11 | using namespace testing; 12 | using namespace compy; 13 | using namespace compy::clang; 14 | 15 | using CE = ClangExtractor; 16 | using CD = ClangDriver; 17 | 18 | class ClangExtractorFixture : public testing::Test { 19 | protected: 20 | void Init(CD::ProgrammingLanguage programmingLanguage) { 21 | // Init extractor 22 | std::vector> includeDirs = { 23 | std::make_tuple("/usr/include", CD::IncludeDirType::SYSTEM), 24 | std::make_tuple("/usr/include/x86_64-linux-gnu", 25 | CD::IncludeDirType::SYSTEM), 26 | std::make_tuple("/usr/lib/llvm-10/lib/clang/10.0.0/include", 27 | CD::IncludeDirType::SYSTEM), 28 | std::make_tuple("/usr/lib/llvm-10/lib/clang/10.0.1/include", 29 | CD::IncludeDirType::SYSTEM)}; 30 | std::vector compilerFlags = {"-Werror"}; 31 | 32 | driver_.reset(new ClangDriver(programmingLanguage, 33 | CD::OptimizationLevel::O0, includeDirs, 34 | compilerFlags)); 35 | extractor_.reset(new CE(driver_)); 36 | } 37 | 38 | std::shared_ptr driver_; 39 | std::shared_ptr extractor_; 40 | }; 41 | 42 | class ClangExtractorCFixture : public ClangExtractorFixture { 43 | protected: 44 | void SetUp() override { Init(CD::ProgrammingLanguage::C); } 45 | }; 46 | 47 | class ClangExtractorCPlusPlusFixture : public ClangExtractorFixture { 48 | protected: 49 | void SetUp() override { Init(CD::ProgrammingLanguage::CPLUSPLUS); } 50 | }; 51 | 52 | // TEST_F(ClangExtractorCFixture, ExtractGraphFromFunction5) { 53 | // graph::ExtractionInfoPtr info = extractor_->GraphFromString(kProgram5); 54 | // 55 | // ASSERT_EQ(info->functionInfos.size(), 2UL); 56 | //} 57 | // 58 | // TEST_F(ClangExtractorCFixture, ExtractSeqFromFunction5) { 59 | // seq::ExtractionInfoPtr info = extractor_->SeqFromString(kProgram5); 60 | //} 61 | // 62 | // TEST(O, WithOpenCL) { 63 | // std::shared_ptr driver_; 64 | // std::shared_ptr extractor_; 65 | // 66 | // // Init extractor 67 | // std::vector> includeDirs = { 68 | // std::make_tuple("/usr/include", CD::IncludeDirType::SYSTEM), 69 | // std::make_tuple("/usr/include/x86_64-linux-gnu", 70 | // CD::IncludeDirType::SYSTEM), 71 | // std::make_tuple("/devel/git_3rd/llvm-project/build_release/lib/clang/" 72 | // "7.1.0/include/", 73 | // CD::IncludeDirType::SYSTEM)}; 74 | // std::vector compilerFlags = {"-xcl"}; 75 | // 76 | // driver_.reset(new ClangDriver(CD::ProgrammingLanguage::OPENCL, 77 | // CD::OptimizationLevel::O0, includeDirs, 78 | // compilerFlags)); 79 | // extractor_.reset(new CE(driver_)); 80 | // 81 | // std::ifstream 82 | // t("/devel/git/research/code_graphs/eval/datasets/devmap/data/rodinia-3.1/opencl/leukocyte/OpenCL/track_ellipse_kernel_opt.cl"); 83 | // std::string str((std::istreambuf_iterator(t)), 84 | // std::istreambuf_iterator()); 85 | // str = "#include \"/devel/git/gnns4code/c/3rd_party/opencl-shim.h\"\n" + str; 86 | // 87 | // seq::ExtractionInfoPtr info = extractor_->SeqFromString(str); 88 | //} -------------------------------------------------------------------------------- /compy/representations/extractors/clang_ast/clang_graph_frontendaction.cc: -------------------------------------------------------------------------------- 1 | #include "clang_graph_frontendaction.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "clang/AST/ASTConsumer.h" 8 | #include "clang/AST/Decl.h" 9 | #include "clang/Analysis/CFG.h" 10 | #include "clang/Frontend/ASTConsumers.h" 11 | #include "clang/Frontend/CompilerInstance.h" 12 | #include "clang/Frontend/MultiplexConsumer.h" 13 | #include "clang/StaticAnalyzer/Core/Checker.h" 14 | #include "llvm/Support/raw_ostream.h" 15 | 16 | using namespace ::clang; 17 | using namespace ::llvm; 18 | 19 | namespace compy { 20 | namespace clang { 21 | namespace graph { 22 | 23 | bool ExtractorASTVisitor::VisitStmt(Stmt *s) { 24 | // Collect child stmts 25 | std::vector ast_relations; 26 | for (auto it : s->children()) { 27 | if (it) { 28 | StmtInfoPtr childInfo = getInfo(*it); 29 | ast_relations.push_back(childInfo); 30 | } 31 | } 32 | 33 | if (auto *ds = dyn_cast(s)) { 34 | for (const Decl *decl : ds->decls()) { 35 | ast_relations.push_back(getInfo(*decl, false)); 36 | } 37 | } 38 | 39 | StmtInfoPtr info = getInfo(*s); 40 | info->ast_relations.insert(info->ast_relations.end(), ast_relations.begin(), 41 | ast_relations.end()); 42 | 43 | return RecursiveASTVisitor::VisitStmt(s); 44 | } 45 | 46 | bool ExtractorASTVisitor::VisitFunctionDecl(FunctionDecl *f) { 47 | // Only proceed on function definitions, not declarations. Otherwise, all 48 | // function declarations in headers are traversed also. 49 | if (!f->hasBody() || !f->getDeclName().isIdentifier()) { 50 | // throw away the tokens 51 | tokenQueue_.popTokensForRange(f->getSourceRange()); 52 | return true; 53 | } 54 | 55 | // ::llvm::errs() << f->getNameAsString() << "\n"; 56 | 57 | FunctionInfoPtr functionInfo = getInfo(*f); 58 | extractionInfo_->functionInfos.push_back(functionInfo); 59 | 60 | // Add entry stmt. 61 | functionInfo->entryStmt = getInfo(*f->getBody()); 62 | 63 | // Add args. 64 | for (auto it : f->parameters()) { 65 | functionInfo->args.push_back(getInfo(*it, true)); 66 | } 67 | 68 | // Dump CFG. 69 | std::unique_ptr cfg = 70 | CFG::buildCFG(f, f->getBody(), &context_, CFG::BuildOptions()); 71 | // cfg->dump(LangOptions(), true); 72 | 73 | // Create CFG Blocks. 74 | for (CFG::iterator it = cfg->begin(), Eb = cfg->end(); it != Eb; ++it) { 75 | CFGBlock *B = *it; 76 | functionInfo->cfgBlocks.push_back(getInfo(*B)); 77 | } 78 | 79 | return RecursiveASTVisitor::VisitFunctionDecl(f); 80 | } 81 | 82 | CFGBlockInfoPtr ExtractorASTVisitor::getInfo(const ::clang::CFGBlock &block) { 83 | auto it = cfgBlockInfos_.find(&block); 84 | if (it != cfgBlockInfos_.end()) return it->second; 85 | 86 | CFGBlockInfoPtr info(new CFGBlockInfo); 87 | cfgBlockInfos_[&block] = info; 88 | 89 | // Collect name. 90 | info->name = "cfg_" + std::to_string(block.getBlockID()); 91 | 92 | // Collect statements. 93 | for (CFGBlock::const_iterator it = block.begin(), Es = block.end(); it != Es; 94 | ++it) { 95 | if (Optional CS = it->getAs()) { 96 | const Stmt *S = CS->getStmt(); 97 | info->statements.push_back(getInfo(*S)); 98 | } 99 | } 100 | if (block.getTerminatorStmt()) { 101 | const Stmt *S = block.getTerminatorStmt(); 102 | info->statements.push_back(getInfo(*S)); 103 | } 104 | 105 | // Collect successors. 106 | for (CFGBlock::const_succ_iterator it = block.succ_begin(), 107 | Es = block.succ_end(); 108 | it != Es; ++it) { 109 | CFGBlock *B = *it; 110 | if (B) info->successors.push_back(getInfo(*B)); 111 | } 112 | 113 | return info; 114 | } 115 | 116 | FunctionInfoPtr ExtractorASTVisitor::getInfo(const FunctionDecl &func) { 117 | FunctionInfoPtr info(new FunctionInfo()); 118 | 119 | // Collect name. 120 | info->name = func.getNameAsString(); 121 | 122 | // Collect type. 123 | info->type = func.getType().getAsString(); 124 | 125 | // Collect tokens 126 | info->tokens = tokenQueue_.popTokensForRange(func.getSourceRange()); 127 | 128 | return info; 129 | } 130 | 131 | StmtInfoPtr ExtractorASTVisitor::getInfo(const Stmt &stmt) { 132 | auto it = stmtInfos_.find(&stmt); 133 | if (it != stmtInfos_.end()) return it->second; 134 | 135 | StmtInfoPtr info(new StmtInfo); 136 | stmtInfos_[&stmt] = info; 137 | 138 | // Collect name. 139 | info->name = stmt.getStmtClassName(); 140 | 141 | // Collect referencing targets. 142 | if (const DeclRefExpr *de = dyn_cast(&stmt)) { 143 | info->ref_relations.push_back(getInfo(*de->getDecl(), false)); 144 | } 145 | 146 | // Collect tokens 147 | info->tokens = tokenQueue_.popTokensForRange(stmt.getSourceRange()); 148 | 149 | return info; 150 | } 151 | 152 | DeclInfoPtr ExtractorASTVisitor::getInfo(const Decl &decl, bool consumeTokens) { 153 | auto it = declInfos_.find(&decl); 154 | if (it != declInfos_.end()) { 155 | if (consumeTokens) { 156 | auto tokens = tokenQueue_.popTokensForRange(decl.getSourceRange()); 157 | it->second->tokens.insert(it->second->tokens.end(), tokens.begin(), 158 | tokens.end()); 159 | } 160 | 161 | return it->second; 162 | } 163 | 164 | DeclInfoPtr info(new DeclInfo); 165 | declInfos_[&decl] = info; 166 | 167 | info->kind = decl.getDeclKindName(); 168 | 169 | // Collect name. 170 | if (const ValueDecl *vd = dyn_cast(&decl)) { 171 | info->name = vd->getQualifiedNameAsString(); 172 | 173 | if (const auto nameTokenPtr = tokenQueue_.getTokenAt(vd->getLocation())) { 174 | info->nameToken = *nameTokenPtr; 175 | } 176 | } 177 | 178 | // Collect type. 179 | if (const ValueDecl *vd = dyn_cast(&decl)) { 180 | info->type = vd->getType().getAsString(); 181 | } 182 | 183 | // Collect tokens 184 | if (consumeTokens) { 185 | info->tokens = tokenQueue_.popTokensForRange(decl.getSourceRange()); 186 | } 187 | 188 | return info; 189 | } 190 | 191 | ExtractorASTConsumer::ExtractorASTConsumer(CompilerInstance &CI, 192 | ExtractionInfoPtr extractionInfo) 193 | : visitor_(CI.getASTContext(), std::move(extractionInfo), tokenQueue_), 194 | tokenQueue_(CI.getPreprocessor()) {} 195 | 196 | bool ExtractorASTConsumer::HandleTopLevelDecl(DeclGroupRef DR) { 197 | for (auto it = DR.begin(), e = DR.end(); it != e; ++it) { 198 | visitor_.TraverseDecl(*it); 199 | } 200 | 201 | return true; 202 | } 203 | 204 | std::unique_ptr ExtractorFrontendAction::CreateASTConsumer( 205 | CompilerInstance &CI, StringRef file) { 206 | extractionInfo.reset(new ExtractionInfo()); 207 | // CI.getASTContext().getLangOpts().OpenCL 208 | return std::make_unique(CI, extractionInfo); 209 | } 210 | 211 | std::vector TokenQueue::popTokensForRange( 212 | ::clang::SourceRange range) { 213 | std::vector result; 214 | auto startPos = token_index_[range.getBegin().getRawEncoding()]; 215 | auto endPos = token_index_[range.getEnd().getRawEncoding()]; 216 | for (std::size_t i = startPos; i <= endPos; ++i) { 217 | if (token_consumed_[i]) continue; 218 | 219 | result.push_back(tokens_[i]); 220 | token_consumed_[i] = true; 221 | } 222 | 223 | return result; 224 | } 225 | 226 | void TokenQueue::addToken(::clang::Token token) { 227 | TokenInfo info; 228 | info.index = nextIndex++; 229 | info.kind = token.getName(); 230 | info.name = pp_.getSpelling(token, nullptr); 231 | info.location = token.getLocation(); 232 | tokens_.push_back(info); 233 | token_consumed_.push_back(false); 234 | token_index_[info.location.getRawEncoding()] = tokens_.size() - 1; 235 | } 236 | 237 | TokenInfo *TokenQueue::getTokenAt(SourceLocation loc) { 238 | auto pos = token_index_.find(loc.getRawEncoding()); 239 | if (pos == token_index_.end()) return nullptr; 240 | 241 | return &tokens_[pos->second]; 242 | } 243 | 244 | } // namespace graph 245 | } // namespace clang 246 | } // namespace compy 247 | -------------------------------------------------------------------------------- /compy/representations/extractors/clang_ast/clang_graph_frontendaction.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "clang/AST/AST.h" 8 | #include "clang/AST/ExternalASTSource.h" 9 | #include "clang/AST/RecursiveASTVisitor.h" 10 | #include "clang/Frontend/FrontendActions.h" 11 | #include "llvm/ADT/StringRef.h" 12 | #include 13 | 14 | #include "clang_extractor.h" 15 | 16 | namespace compy { 17 | namespace clang { 18 | namespace graph { 19 | 20 | /** 21 | * Keeps a queue of tokens that are not assigned to any AST-Graph node yet 22 | */ 23 | class TokenQueue { 24 | public: 25 | TokenQueue(::clang::Preprocessor &pp) : pp_(pp) { 26 | pp_.setTokenWatcher([this](auto token) { this->addToken(token); }); 27 | } 28 | 29 | TokenQueue(TokenQueue const &) = delete; 30 | TokenQueue &operator=(TokenQueue const &) = delete; 31 | 32 | ~TokenQueue() { pp_.setTokenWatcher(nullptr); } 33 | 34 | std::vector popTokensForRange(::clang::SourceRange range); 35 | TokenInfo *getTokenAt(::clang::SourceLocation loc); 36 | 37 | private: 38 | void addToken(::clang::Token token); 39 | 40 | ::clang::Preprocessor &pp_; 41 | std::vector tokens_; 42 | std::vector token_consumed_; 43 | llvm::DenseMap token_index_; 44 | 45 | std::uint64_t nextIndex = 0; 46 | }; 47 | 48 | class ExtractorASTVisitor 49 | : public ::clang::RecursiveASTVisitor { 50 | public: 51 | ExtractorASTVisitor(::clang::ASTContext &context, 52 | ExtractionInfoPtr extractionInfo, TokenQueue &tokenQueue) 53 | : context_(context), 54 | extractionInfo_(extractionInfo), 55 | tokenQueue_(tokenQueue) {} 56 | 57 | bool VisitStmt(::clang::Stmt *s); 58 | bool VisitFunctionDecl(::clang::FunctionDecl *f); 59 | 60 | // postorder traversal is necessary so that tokens get assigned to 61 | // nodes closer to the leaves first 62 | bool shouldTraversePostOrder() const { return true; } 63 | 64 | private: 65 | FunctionInfoPtr getInfo(const ::clang::FunctionDecl &func); 66 | CFGBlockInfoPtr getInfo(const ::clang::CFGBlock &block); 67 | StmtInfoPtr getInfo(const ::clang::Stmt &stmt); 68 | DeclInfoPtr getInfo(const ::clang::Decl &decl, bool consumeTokens); 69 | 70 | private: 71 | ::clang::ASTContext &context_; 72 | ExtractionInfoPtr extractionInfo_; 73 | TokenQueue &tokenQueue_; 74 | 75 | std::unordered_map stmtInfos_; 76 | std::unordered_map cfgBlockInfos_; 77 | std::unordered_map declInfos_; 78 | }; 79 | 80 | class ExtractorASTConsumer : public ::clang::ASTConsumer { 81 | public: 82 | ExtractorASTConsumer(::clang::CompilerInstance &CI, 83 | ExtractionInfoPtr extractionInfo); 84 | 85 | bool HandleTopLevelDecl(::clang::DeclGroupRef DR) override; 86 | 87 | private: 88 | ExtractorASTVisitor visitor_; 89 | TokenQueue tokenQueue_; 90 | }; 91 | 92 | class ExtractorFrontendAction : public ::clang::ASTFrontendAction { 93 | public: 94 | std::unique_ptr<::clang::ASTConsumer> CreateASTConsumer( 95 | ::clang::CompilerInstance &CI, ::llvm::StringRef file) override; 96 | 97 | ExtractionInfoPtr extractionInfo; 98 | }; 99 | 100 | } // namespace graph 101 | } // namespace clang 102 | } // namespace compy 103 | -------------------------------------------------------------------------------- /compy/representations/extractors/clang_ast/clang_seq_frontendaction.cc: -------------------------------------------------------------------------------- 1 | #include "clang_seq_frontendaction.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "clang/AST/AST.h" 8 | #include "clang/AST/RecursiveASTVisitor.h" 9 | #include "clang/Frontend/CompilerInstance.h" 10 | #include "clang/Lex/Lexer.h" 11 | #include "clang/Rewrite/Core/Rewriter.h" 12 | #include "clang/Rewrite/Core/TokenRewriter.h" 13 | #include "llvm/ADT/StringRef.h" 14 | 15 | using namespace ::clang; 16 | using namespace ::llvm; 17 | 18 | namespace compy { 19 | namespace clang { 20 | namespace seq { 21 | 22 | void ExtractorASTVisitor::init() { 23 | mappedNames_.clear(); 24 | num_functions_ = 0; 25 | num_variables_ = 0; 26 | } 27 | 28 | void ExtractorASTVisitor::setState(STATE state) { state_ = state; } 29 | 30 | bool ExtractorASTVisitor::VisitFunctionDecl(FunctionDecl *f) { 31 | // Only proceed on function definitions, not declarations. Otherwise, all 32 | // function declarations in headers are traversed also. 33 | if (!f->hasBody() || !f->getDeclName().isIdentifier()) { 34 | return true; 35 | } 36 | 37 | if (state_ == STATE::Map) { 38 | mapName(*f); 39 | } 40 | 41 | else if (state_ == STATE::Capture) { 42 | FunctionInfoPtr functionInfo = getInfo(*f); 43 | extractionInfo_->functionInfos.push_back(functionInfo); 44 | 45 | // Get string of function. 46 | SourceRange sourceRange = f->getSourceRange(); 47 | 48 | CharSourceRange charSourceRange = ::clang::Lexer::getAsCharRange( 49 | sourceRange, context_.getSourceManager(), context_.getLangOpts()); 50 | 51 | std::string str = ::clang::Lexer::getSourceText(charSourceRange, 52 | context_.getSourceManager(), 53 | context_.getLangOpts()) 54 | .str(); 55 | 56 | // Create file id in source manager. So it will follow the same lex flow as 57 | // regular files. 58 | FileID fid = context_.getSourceManager().createFileID( 59 | MemoryBuffer::getMemBuffer(str)); 60 | 61 | // Create lexer. 62 | ::clang::Lexer lex(context_.getSourceManager().getLocForStartOfFile(fid), 63 | context_.getLangOpts(), str.data(), str.data(), 64 | str.data() + str.size()); 65 | 66 | // Lex function. 67 | Token tok; 68 | lex.LexFromRawLexer(tok); 69 | while (true) { 70 | // Add to tokens. 71 | TokenInfoPtr tokenInfo(new TokenInfo()); 72 | functionInfo->tokenInfos.push_back(tokenInfo); 73 | 74 | // Get string token. 75 | // - Get token from lexer. 76 | std::string strToken = ::clang::Lexer::getSpelling( 77 | tok, context_.getSourceManager(), context_.getLangOpts(), nullptr); 78 | 79 | // - Check if there exists a mapping for this token. 80 | auto it = mappedNames_.find(strToken); 81 | if (it != mappedNames_.end()) strToken = it->second; 82 | 83 | tokenInfo->name = strToken; 84 | 85 | // Get token kind. 86 | tokenInfo->kind = tok.getName(); 87 | 88 | // Check if done and get next token if not. 89 | if (tok.getLocation() == sourceRange.getEnd() || tok.is(tok::eof)) { 90 | break; 91 | } 92 | lex.LexFromRawLexer(tok); 93 | } 94 | } 95 | 96 | return RecursiveASTVisitor::VisitFunctionDecl(f); 97 | } 98 | 99 | bool ExtractorASTVisitor::VisitVarDecl(VarDecl *decl) { 100 | if (state_ == STATE::Map) { 101 | mapName(*decl); 102 | } 103 | 104 | return RecursiveASTVisitor::VisitVarDecl(decl); 105 | } 106 | 107 | FunctionInfoPtr ExtractorASTVisitor::getInfo(const FunctionDecl &func) { 108 | FunctionInfoPtr info(new FunctionInfo()); 109 | 110 | // Collect name. 111 | info->name = func.getNameAsString(); 112 | 113 | return info; 114 | } 115 | 116 | std::string ExtractorASTVisitor::mapName(const NamedDecl &decl) { 117 | std::string name = decl.getNameAsString(); 118 | 119 | auto it = mappedNames_.find(name); 120 | if (it != mappedNames_.end()) return it->second; 121 | 122 | std::string mappedName; 123 | if (isa(decl)) { 124 | mappedName = "fn_" + std::to_string(num_functions_); 125 | num_functions_++; 126 | } else if (isa(decl)) { 127 | mappedName = "var_" + std::to_string(num_variables_); 128 | num_variables_++; 129 | } 130 | mappedNames_[name] = mappedName; 131 | 132 | return mappedName; 133 | } 134 | 135 | ExtractorASTConsumer::ExtractorASTConsumer(ASTContext &context, 136 | ExtractionInfoPtr extractionInfo) 137 | : visitor_(context, extractionInfo) {} 138 | 139 | bool ExtractorASTConsumer::HandleTopLevelDecl(DeclGroupRef DR) { 140 | for (auto it = DR.begin(), e = DR.end(); it != e; ++it) { 141 | visitor_.setState(ExtractorASTVisitor::STATE::Map); 142 | visitor_.TraverseDecl(*it); 143 | 144 | visitor_.setState(ExtractorASTVisitor::STATE::Capture); 145 | visitor_.TraverseDecl(*it); 146 | } 147 | 148 | return true; 149 | } 150 | 151 | std::unique_ptr ExtractorFrontendAction::CreateASTConsumer( 152 | CompilerInstance &CI, StringRef file) { 153 | extractionInfo.reset(new ExtractionInfo()); 154 | 155 | return std::make_unique(CI.getASTContext(), 156 | extractionInfo); 157 | } 158 | 159 | } // namespace seq 160 | } // namespace clang 161 | } // namespace compy 162 | -------------------------------------------------------------------------------- /compy/representations/extractors/clang_ast/clang_seq_frontendaction.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "clang/AST/AST.h" 8 | #include "clang/AST/ExternalASTSource.h" 9 | #include "clang/AST/RecursiveASTVisitor.h" 10 | #include "clang/Frontend/FrontendActions.h" 11 | #include "llvm/ADT/StringRef.h" 12 | 13 | #include "clang_extractor.h" 14 | 15 | namespace compy { 16 | namespace clang { 17 | namespace seq { 18 | 19 | class ExtractorASTVisitor 20 | : public ::clang::RecursiveASTVisitor { 21 | public: 22 | enum STATE { 23 | Map = 0, 24 | Capture = 1, 25 | }; 26 | 27 | public: 28 | ExtractorASTVisitor(::clang::ASTContext &context, 29 | ExtractionInfoPtr extractionInfo) 30 | : state_(STATE::Map), context_(context), extractionInfo_(extractionInfo) { 31 | init(); 32 | } 33 | 34 | void init(); 35 | void setState(STATE state); 36 | 37 | bool VisitFunctionDecl(::clang::FunctionDecl *f); 38 | bool VisitVarDecl(::clang::VarDecl *decl); 39 | 40 | private: 41 | FunctionInfoPtr getInfo(const ::clang::FunctionDecl &func); 42 | 43 | std::string mapName(const ::clang::NamedDecl &decl); 44 | 45 | private: 46 | STATE state_; 47 | ::clang::ASTContext &context_; 48 | ExtractionInfoPtr extractionInfo_; 49 | 50 | std::unordered_map mappedNames_; 51 | unsigned int num_functions_; 52 | unsigned int num_variables_; 53 | }; 54 | 55 | class ExtractorASTConsumer : public ::clang::ASTConsumer { 56 | public: 57 | ExtractorASTConsumer(::clang::ASTContext &context, 58 | ExtractionInfoPtr extractionInfo); 59 | 60 | bool HandleTopLevelDecl(::clang::DeclGroupRef DR) override; 61 | 62 | private: 63 | ExtractorASTVisitor visitor_; 64 | }; 65 | 66 | class ExtractorFrontendAction : public ::clang::ASTFrontendAction { 67 | public: 68 | std::unique_ptr<::clang::ASTConsumer> CreateASTConsumer( 69 | ::clang::CompilerInstance &CI, ::llvm::StringRef file) override; 70 | 71 | ExtractionInfoPtr extractionInfo; 72 | }; 73 | 74 | } // namespace seq 75 | } // namespace clang 76 | } // namespace compy 77 | -------------------------------------------------------------------------------- /compy/representations/extractors/common/clang_driver.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "clang/Frontend/FrontendAction.h" 9 | #include "llvm/IR/LegacyPassManager.h" 10 | #include "llvm/IR/Module.h" 11 | #include "llvm/Pass.h" 12 | 13 | namespace compy { 14 | 15 | class ClangDriver { 16 | public: 17 | enum ProgrammingLanguage { 18 | C = 0, 19 | CPLUSPLUS = 1, 20 | OPENCL = 3, 21 | LLVM = 4, 22 | }; 23 | 24 | enum OptimizationLevel { O0 = 0, O1 = 1, O2 = 2, O3 = 3 }; 25 | 26 | enum IncludeDirType { 27 | SYSTEM = 0, 28 | USER = 1, 29 | }; 30 | 31 | public: 32 | ClangDriver(ProgrammingLanguage programmingLanguage, 33 | OptimizationLevel optimizationLevel, 34 | std::vector> includeDirs, 35 | std::vector compilerFlags); 36 | 37 | void addIncludeDir(std::string includeDir, IncludeDirType includeDirType); 38 | void removeIncludeDir(std::string includeDir, IncludeDirType includeDirType); 39 | void setOptimizationLevel(OptimizationLevel optimizationLevel); 40 | void setFileName(std::string fileName); 41 | std::string getFileName() const; 42 | void setCompilerBinary(std::string path); 43 | std::string getCompilerBinary() const; 44 | 45 | void Invoke(std::string src, 46 | std::vector<::clang::FrontendAction *> frontendActions, 47 | std::vector<::llvm::Pass *> passes); 48 | 49 | private: 50 | void InvokeClangAndLLVM(std::string& src, 51 | std::vector<::clang::FrontendAction *>& frontendActions, 52 | std::vector<::llvm::Pass *>& passes); 53 | void InvokeLLVM(std::string& src, 54 | std::vector<::llvm::Pass *>& passes); 55 | void runLLVMPasses(std::unique_ptr<::llvm::Module> Module, 56 | std::vector<::llvm::Pass *>& passes); 57 | 58 | private: 59 | std::shared_ptr<::llvm::legacy::PassManager> pm_; 60 | 61 | ProgrammingLanguage programmingLanguage_; 62 | OptimizationLevel optimizationLevel_; 63 | std::vector> includeDirs_; 64 | std::vector compilerFlags_; 65 | std::string fileName_; 66 | std::string compilerBinary_; 67 | }; 68 | using ClangDriverPtr = std::shared_ptr; 69 | 70 | } // namespace compy 71 | -------------------------------------------------------------------------------- /compy/representations/extractors/common/clang_driver_test.cc: -------------------------------------------------------------------------------- 1 | #include "clang_driver.h" 2 | 3 | #include 4 | 5 | #include "common/common_test.h" 6 | #include "gmock/gmock.h" 7 | #include "gtest/gtest.h" 8 | 9 | using namespace testing; 10 | using namespace compy; 11 | 12 | class MockPass : public llvm::ModulePass { 13 | public: 14 | char ID = 0; 15 | MockPass() : llvm::ModulePass(ID) {} 16 | 17 | MOCK_METHOD1(runOnModule, bool(llvm::Module &M)); 18 | MOCK_CONST_METHOD1(getAnalysisUsage, void(llvm::AnalysisUsage &au)); 19 | }; 20 | 21 | class ClangDriverFixture : public testing::Test { 22 | protected: 23 | void SetUp() override { 24 | // Init extractor 25 | std::vector> 26 | includeDirs = { 27 | std::make_tuple("/usr/include", 28 | ClangDriver::IncludeDirType::SYSTEM), 29 | std::make_tuple("/usr/include/x86_64-linux-gnu", 30 | ClangDriver::IncludeDirType::SYSTEM), 31 | std::make_tuple( 32 | "/devel/git_3rd/llvm-project/build_release/lib/clang/" 33 | "7.1.0/include/", 34 | ClangDriver::IncludeDirType::SYSTEM)}; 35 | std::vector compilerFlags = {"-Werror"}; 36 | 37 | clang_.reset(new ClangDriver(ClangDriver::ProgrammingLanguage::C, 38 | ClangDriver::OptimizationLevel::O0, 39 | includeDirs, compilerFlags)); 40 | } 41 | 42 | std::shared_ptr clang_; 43 | }; 44 | 45 | // Tests 46 | TEST_F(ClangDriverFixture, CompileWithPassFunction1) { 47 | NiceMock *pass = new NiceMock(); 48 | EXPECT_CALL(*pass, runOnModule(_)).Times(AtLeast(1)); 49 | 50 | std::vector<::clang::FrontendAction *> frontendActions; 51 | std::vector<::llvm::Pass *> passes; 52 | passes.push_back(pass); 53 | 54 | clang_->Invoke(kProgram1, frontendActions, passes); 55 | } 56 | -------------------------------------------------------------------------------- /compy/representations/extractors/common/common_test.h: -------------------------------------------------------------------------------- 1 | // C samples 2 | constexpr char kProgram1[] = 3 | "int foo() {\n" 4 | " return 1;\n" 5 | "}"; 6 | constexpr char kProgram2[] = 7 | "int max(int a, int b) {\n" 8 | " if(a > b) {\n" 9 | " return a;\n" 10 | " } else {\n" 11 | " return b;\n" 12 | " }\n" 13 | "}"; 14 | constexpr char kProgram3[] = 15 | "#include \n" 16 | "\n" 17 | "void foo() {\n" 18 | " printf(\"Hello\");\n" 19 | "}"; 20 | constexpr char kProgram4[] = 21 | "#include \"tempHdr.h\"\n" 22 | "\n" 23 | "void foo() {\n" 24 | " barbara(1.2, 3.4);\n" 25 | "}"; 26 | constexpr char kProgram5[] = 27 | "int max(int a, int b) {\n" 28 | " if (a > b) {\n" 29 | " return a;\n" 30 | " } else {\n" 31 | " return b;\n" 32 | " }\n" 33 | "}\n" 34 | "int foo(int x) {\n" 35 | " return max(1, x);\n" 36 | "}"; 37 | 38 | // LLVM samples 39 | constexpr char kLLVM1[] = 40 | "define dso_local void @A(i32*) #0 {\n" 41 | " %2 = alloca i32*, align 8\n" 42 | " %3 = alloca i32, align 4\n" 43 | " store i32* %0, i32** %2, align 8\n" 44 | " store i32 2, i32* %3, align 4\n" 45 | " %4 = load i32, i32* %3, align 4\n" 46 | " %5 = load i32*, i32** %2, align 8\n" 47 | " %6 = getelementptr inbounds i32, i32* %5, i64 0\n" 48 | " store i32 %4, i32* %6, align 4\n" 49 | " ret void\n" 50 | "}\n"; 51 | constexpr char kLLVM2[] = 52 | "define dso_local void @A(i32*) #0 {\n" 53 | " %2 = alloca i32*, align 8\n" 54 | " %3 = alloca i32, align 4\n" 55 | " store i32* %0, i32** %2, align 8\n" 56 | "}\n"; -------------------------------------------------------------------------------- /compy/representations/extractors/common/visitor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace compy { 4 | 5 | struct IVisitee; 6 | struct IVisitor { 7 | virtual void visit(IVisitee* v) = 0; 8 | }; 9 | 10 | struct IVisitee { 11 | virtual void accept(IVisitor* v) = 0; 12 | }; 13 | 14 | } // namespace compy -------------------------------------------------------------------------------- /compy/representations/extractors/llvm_ir/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Extractor library 2 | add_library(llvm_extractor 3 | llvm_extractor.cc 4 | llvm_graph_pass.cc 5 | llvm_graph_funcinfo.cc 6 | llvm_seq_pass.cc 7 | ) 8 | target_link_libraries(llvm_extractor 9 | extractors_common 10 | ) 11 | target_compile_options(llvm_extractor PRIVATE 12 | ${compile_options_common} 13 | ) 14 | 15 | # Extractor tests 16 | add_executable(llvm_extractor_tests 17 | llvm_pass_test.cc 18 | llvm_extractor_test.cc 19 | ) 20 | target_link_libraries(llvm_extractor_tests 21 | llvm_extractor 22 | 23 | gmock 24 | gtest 25 | gtest_main 26 | ) 27 | target_compile_options(llvm_extractor_tests PRIVATE 28 | -fno-rtti -fPIC 29 | ) 30 | target_compile_definitions(llvm_extractor_tests PRIVATE 31 | CLANG_INSTALL_PREFIX=${CLANG_INSTALL_PREFIX} 32 | ) 33 | -------------------------------------------------------------------------------- /compy/representations/extractors/llvm_ir/llvm_extractor.cc: -------------------------------------------------------------------------------- 1 | #include "llvm_extractor.h" 2 | 3 | #include 4 | 5 | #include "clang/Config/config.h" 6 | #include "clang/Frontend/CompilerInvocation.h" 7 | #include "clang/Lex/PreprocessorOptions.h" 8 | #include "llvm/LinkAllPasses.h" 9 | #include "llvm/Support/Compiler.h" 10 | 11 | #include "llvm_graph_pass.h" 12 | #include "llvm_seq_pass.h" 13 | 14 | using namespace ::clang; 15 | using namespace ::llvm; 16 | 17 | namespace compy { 18 | namespace llvm { 19 | 20 | LLVMIRExtractor::LLVMIRExtractor(ClangDriverPtr clangDriver) 21 | : clangDriver_(clangDriver) {} 22 | 23 | graph::ExtractionInfoPtr LLVMIRExtractor::GraphFromString(std::string src) { 24 | std::vector<::clang::FrontendAction *> frontendActions; 25 | std::vector<::llvm::Pass *> passes; 26 | 27 | passes.push_back(createStripSymbolsPass()); 28 | 29 | graph::ExtractorPass *extractorPass = new graph::ExtractorPass(); 30 | passes.push_back(extractorPass); 31 | 32 | clangDriver_->Invoke(src, frontendActions, passes); 33 | 34 | return extractorPass->extractionInfo; 35 | } 36 | 37 | seq::ExtractionInfoPtr LLVMIRExtractor::SeqFromString(std::string src) { 38 | std::vector<::clang::FrontendAction *> frontendActions; 39 | std::vector<::llvm::Pass *> passes; 40 | 41 | passes.push_back(createStripSymbolsPass()); 42 | seq::ExtractorPass *pass = new seq::ExtractorPass(); 43 | passes.push_back(pass); 44 | 45 | clangDriver_->Invoke(src, frontendActions, passes); 46 | 47 | return pass->extractionInfo; 48 | } 49 | 50 | } // namespace llvm 51 | } // namespace compy 52 | -------------------------------------------------------------------------------- /compy/representations/extractors/llvm_ir/llvm_extractor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "common/clang_driver.h" 8 | #include "common/visitor.h" 9 | 10 | namespace compy { 11 | namespace llvm { 12 | 13 | namespace seq { 14 | struct InstructionInfo; 15 | using InstructionInfoPtr = std::shared_ptr; 16 | 17 | struct BasicBlockInfo; 18 | using BasicBlockInfoPtr = std::shared_ptr; 19 | 20 | struct FunctionInfo; 21 | using FunctionInfoPtr = std::shared_ptr; 22 | 23 | struct ExtractionInfo; 24 | using ExtractionInfoPtr = std::shared_ptr; 25 | 26 | struct InstructionInfo : IVisitee { 27 | std::vector tokens; 28 | 29 | void accept(IVisitor* v) override { v->visit(this); } 30 | }; 31 | 32 | struct BasicBlockInfo : IVisitee { 33 | std::string name; 34 | std::vector instructions; 35 | 36 | void accept(IVisitor* v) override { 37 | v->visit(this); 38 | for (const auto& it : instructions) it->accept(v); 39 | } 40 | }; 41 | 42 | struct FunctionInfo : IVisitee { 43 | std::string name; 44 | std::vector signature; 45 | std::vector basicBlocks; 46 | std::string str; 47 | 48 | void accept(IVisitor* v) override { 49 | v->visit(this); 50 | for (const auto& it : basicBlocks) it->accept(v); 51 | } 52 | }; 53 | 54 | struct ExtractionInfo : IVisitee { 55 | std::vector functionInfos; 56 | 57 | void accept(IVisitor* v) override { 58 | v->visit(this); 59 | for (const auto& it : functionInfos) it->accept(v); 60 | } 61 | }; 62 | } // namespace seq 63 | 64 | namespace graph { 65 | struct OperandInfo; 66 | using OperandInfoPtr = std::shared_ptr; 67 | 68 | struct ArgInfo; 69 | using ArgInfoPtr = std::shared_ptr; 70 | 71 | struct ConstantInfo; 72 | using ConstantInfoPtr = std::shared_ptr; 73 | 74 | struct InstructionInfo; 75 | using InstructionInfoPtr = std::shared_ptr; 76 | 77 | struct BasicBlockInfo; 78 | using BasicBlockInfoPtr = std::shared_ptr; 79 | 80 | struct MemoryAccessInfo; 81 | using MemoryAccessInfoPtr = std::shared_ptr; 82 | 83 | struct FunctionInfo; 84 | using FunctionInfoPtr = std::shared_ptr; 85 | 86 | struct CallGraphInfo; 87 | using CallGraphInfoPtr = std::shared_ptr; 88 | 89 | struct ExtractionInfo; 90 | using ExtractionInfoPtr = std::shared_ptr; 91 | 92 | struct OperandInfo : IVisitee { 93 | virtual ~OperandInfo() = default; 94 | }; 95 | 96 | struct ArgInfo : OperandInfo { 97 | std::string name; 98 | std::string type; 99 | 100 | void accept(IVisitor* v) override { v->visit(this); } 101 | }; 102 | 103 | struct ConstantInfo : OperandInfo { 104 | std::string type; 105 | std::string value; 106 | 107 | void accept(IVisitor* v) override { v->visit(this); } 108 | }; 109 | 110 | struct InstructionInfo : OperandInfo { 111 | std::string type; 112 | std::string opcode; 113 | std::string callTarget; 114 | bool isLoadOrStore; 115 | std::vector operands; 116 | FunctionInfoPtr function; 117 | 118 | void accept(IVisitor* v) override { v->visit(this); } 119 | }; 120 | 121 | struct BasicBlockInfo : IVisitee { 122 | std::string name; 123 | std::vector instructions; 124 | std::vector successors; 125 | 126 | void accept(IVisitor* v) override { 127 | v->visit(this); 128 | for (const auto& it : instructions) it->accept(v); 129 | } 130 | }; 131 | 132 | struct MemoryAccessInfo { 133 | std::string type; 134 | InstructionInfoPtr inst; 135 | BasicBlockInfoPtr block; 136 | std::vector dependencies; 137 | }; 138 | 139 | struct FunctionInfo : IVisitee { 140 | std::string name; 141 | std::string type; 142 | InstructionInfoPtr entryInstruction; 143 | std::vector exitInstructions; 144 | std::vector args; 145 | std::vector basicBlocks; 146 | std::vector memoryAccesses; 147 | 148 | void accept(IVisitor* v) override { 149 | v->visit(this); 150 | for (const auto& it : basicBlocks) it->accept(v); 151 | } 152 | }; 153 | 154 | struct CallGraphInfo { 155 | std::vector calls; 156 | }; 157 | 158 | struct ExtractionInfo : IVisitee { 159 | std::vector functionInfos; 160 | CallGraphInfoPtr callGraphInfo; 161 | 162 | void accept(IVisitor* v) override { 163 | v->visit(this); 164 | for (const auto& it : functionInfos) it->accept(v); 165 | } 166 | }; 167 | } // namespace graph 168 | 169 | class LLVMIRExtractor { 170 | public: 171 | LLVMIRExtractor(ClangDriverPtr clangDriver); 172 | 173 | graph::ExtractionInfoPtr GraphFromString(std::string src); 174 | seq::ExtractionInfoPtr SeqFromString(std::string src); 175 | 176 | private: 177 | ClangDriverPtr clangDriver_; 178 | }; 179 | 180 | } // namespace llvm 181 | } // namespace compy 182 | -------------------------------------------------------------------------------- /compy/representations/extractors/llvm_ir/llvm_extractor_test.cc: -------------------------------------------------------------------------------- 1 | #include "llvm_extractor.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "common/common_test.h" 7 | #include "gtest/gtest.h" 8 | 9 | #define TO_STRING(prefix) #prefix 10 | #define COMPILER_BINARY(prefix) TO_STRING(prefix) "/bin/clang" 11 | 12 | using namespace ::llvm; 13 | using namespace compy; 14 | using namespace compy::llvm; 15 | 16 | using LE = LLVMIRExtractor; 17 | using CD = ClangDriver; 18 | 19 | constexpr char kProgram4ForwardDecl[] = "int barbara(float x, float y);"; 20 | 21 | void createFileWithContents(std::string filename, std::string filecontent) { 22 | std::ofstream tempHeaderFile(filename.c_str()); 23 | tempHeaderFile << filecontent << std::endl; 24 | tempHeaderFile.close(); 25 | } 26 | void removeFile(std::string filename) { std::remove(filename.c_str()); } 27 | 28 | class LLVMExtractorFixture : public testing::Test { 29 | protected: 30 | void Init(CD::ProgrammingLanguage programmingLanguage) { 31 | // Init extractor 32 | std::vector> includeDirs = {}; 33 | std::vector compilerFlags = {"-Werror"}; 34 | 35 | driver_.reset(new ClangDriver(programmingLanguage, 36 | CD::OptimizationLevel::O0, includeDirs, 37 | compilerFlags)); 38 | driver_->setCompilerBinary(COMPILER_BINARY(CLANG_INSTALL_PREFIX)); 39 | extractor_.reset(new LE(driver_)); 40 | } 41 | 42 | std::shared_ptr driver_; 43 | std::shared_ptr extractor_; 44 | }; 45 | 46 | class LLVMExtractorCFixture : public LLVMExtractorFixture { 47 | protected: 48 | void SetUp() override { Init(CD::ProgrammingLanguage::C); } 49 | }; 50 | 51 | class LLVMExtractorCPlusPlusFixture : public LLVMExtractorFixture { 52 | protected: 53 | void SetUp() override { Init(CD::ProgrammingLanguage::CPLUSPLUS); } 54 | }; 55 | 56 | class LLVMExtractorLLVMFixture : public LLVMExtractorFixture { 57 | protected: 58 | void SetUp() override { Init(CD::ProgrammingLanguage::LLVM); } 59 | }; 60 | 61 | // C tests 62 | TEST_F(LLVMExtractorCFixture, ExtractFromFunction1) { 63 | graph::ExtractionInfoPtr info = extractor_->GraphFromString(kProgram1); 64 | 65 | ASSERT_EQ(info->functionInfos.size(), 1UL); 66 | ASSERT_EQ(info->functionInfos[0]->name, "foo"); 67 | ASSERT_EQ(info->functionInfos[0]->args.size(), 0UL); 68 | } 69 | 70 | TEST_F(LLVMExtractorCFixture, ExtractFromFunction2) { 71 | graph::ExtractionInfoPtr info = extractor_->GraphFromString(kProgram2); 72 | 73 | ASSERT_EQ(info->functionInfos.size(), 1UL); 74 | ASSERT_EQ(info->functionInfos[0]->name, "max"); 75 | ASSERT_EQ(info->functionInfos[0]->args.size(), 2UL); 76 | } 77 | 78 | TEST_F(LLVMExtractorCFixture, ExtractFromFunction5) { 79 | graph::ExtractionInfoPtr info = extractor_->GraphFromString(kProgram5); 80 | 81 | ASSERT_EQ(info->functionInfos.size(), 2UL); 82 | ASSERT_EQ(info->functionInfos[0]->name, "max"); 83 | ASSERT_EQ(info->functionInfos[0]->args.size(), 2UL); 84 | } 85 | 86 | TEST_F(LLVMExtractorCFixture, ExtractFromFunctionWithSystemInclude) { 87 | graph::ExtractionInfoPtr info = extractor_->GraphFromString(kProgram3); 88 | 89 | ASSERT_EQ(info->functionInfos.size(), 1UL); 90 | ASSERT_EQ(info->functionInfos[0]->name, "foo"); 91 | ASSERT_EQ(info->functionInfos[0]->args.size(), 0UL); 92 | } 93 | 94 | TEST_F(LLVMExtractorCFixture, ExtractFromFunctionWithUserInclude) { 95 | std::string headerFilename = "/tmp/tempHdr.h"; 96 | createFileWithContents(headerFilename, kProgram4ForwardDecl); 97 | 98 | driver_->addIncludeDir("/tmp", CD::IncludeDirType::SYSTEM); 99 | graph::ExtractionInfoPtr info = extractor_->GraphFromString(kProgram4); 100 | 101 | removeFile(headerFilename); 102 | 103 | ASSERT_EQ(info->functionInfos.size(), 1UL); 104 | ASSERT_EQ(info->functionInfos[0]->args.size(), 0UL); 105 | } 106 | 107 | TEST_F(LLVMExtractorCFixture, ExtractFromNoFunction) { 108 | graph::ExtractionInfoPtr info = extractor_->GraphFromString(""); 109 | 110 | ASSERT_EQ(info->functionInfos.size(), 0UL); 111 | } 112 | 113 | TEST_F(LLVMExtractorCFixture, ExtractFromBadFunction) { 114 | EXPECT_THROW( 115 | { 116 | try { 117 | graph::ExtractionInfoPtr info = extractor_->GraphFromString("foobar"); 118 | } catch (std::runtime_error const& err) { 119 | EXPECT_EQ(err.what(), std::string("Failed compiling to LLVM module")); 120 | throw; 121 | } 122 | }, 123 | std::runtime_error); 124 | } 125 | 126 | TEST_F(LLVMExtractorCFixture, ExtractWithDifferentOptimizationlevels) { 127 | driver_->setOptimizationLevel(CD::OptimizationLevel::O0); 128 | graph::ExtractionInfoPtr infoO0 = extractor_->GraphFromString(kProgram2); 129 | 130 | driver_->setOptimizationLevel(CD::OptimizationLevel::O1); 131 | graph::ExtractionInfoPtr infoO1 = extractor_->GraphFromString(kProgram2); 132 | 133 | ASSERT_TRUE(infoO0->functionInfos[0]->basicBlocks.size() > 134 | infoO1->functionInfos[0]->basicBlocks.size()); 135 | } 136 | 137 | // C++ tests 138 | TEST_F(LLVMExtractorCPlusPlusFixture, ExtractFromFunction1) { 139 | graph::ExtractionInfoPtr info = extractor_->GraphFromString(kProgram1); 140 | 141 | ASSERT_EQ(info->functionInfos.size(), 1UL); 142 | ASSERT_EQ(info->functionInfos[0]->name, "_Z3foov"); 143 | ASSERT_EQ(info->functionInfos[0]->args.size(), 0UL); 144 | } 145 | 146 | TEST_F(LLVMExtractorCPlusPlusFixture, ExtractFromFunction2) { 147 | graph::ExtractionInfoPtr info = extractor_->GraphFromString(kProgram2); 148 | 149 | ASSERT_EQ(info->functionInfos.size(), 1UL); 150 | ASSERT_EQ(info->functionInfos[0]->name, "_Z3maxii"); 151 | ASSERT_EQ(info->functionInfos[0]->args.size(), 2UL); 152 | } 153 | 154 | TEST_F(LLVMExtractorCPlusPlusFixture, ExtractFromFunctionWithSystemInclude) { 155 | graph::ExtractionInfoPtr info = extractor_->GraphFromString(kProgram3); 156 | 157 | ASSERT_EQ(info->functionInfos.size(), 1UL); 158 | ASSERT_EQ(info->functionInfos[0]->name, "_Z3foov"); 159 | ASSERT_EQ(info->functionInfos[0]->args.size(), 0UL); 160 | } 161 | 162 | TEST_F(LLVMExtractorCPlusPlusFixture, ExtractFromFunctionWithUserInclude) { 163 | std::string headerFilename = "/tmp/tempHdr.h"; 164 | createFileWithContents(headerFilename, kProgram4ForwardDecl); 165 | 166 | driver_->addIncludeDir("/tmp", CD::IncludeDirType::SYSTEM); 167 | graph::ExtractionInfoPtr info = extractor_->GraphFromString(kProgram4); 168 | 169 | removeFile(headerFilename); 170 | 171 | ASSERT_EQ(info->functionInfos.size(), 1UL); 172 | ASSERT_EQ(info->functionInfos[0]->args.size(), 0UL); 173 | } 174 | 175 | // LLVM tests 176 | TEST_F(LLVMExtractorLLVMFixture, ExtractFromFunction1) { 177 | graph::ExtractionInfoPtr info = extractor_->GraphFromString(kLLVM1); 178 | 179 | ASSERT_EQ(info->functionInfos.size(), 1UL); 180 | ASSERT_EQ(info->functionInfos[0]->name, "A"); 181 | ASSERT_EQ(info->functionInfos[0]->args.size(), 1UL); 182 | } 183 | -------------------------------------------------------------------------------- /compy/representations/extractors/llvm_ir/llvm_graph_funcinfo.cc: -------------------------------------------------------------------------------- 1 | #include "llvm_graph_funcinfo.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "llvm/IR/Instructions.h" 7 | 8 | using namespace ::llvm; 9 | 10 | namespace compy { 11 | namespace llvm { 12 | namespace graph { 13 | 14 | std::string llvmTypeToString(Type *type) { 15 | std::string typeName; 16 | raw_string_ostream rso(typeName); 17 | type->print(rso); 18 | return rso.str(); 19 | } 20 | 21 | /** 22 | * Get a unique Name for an LLVM value. 23 | * 24 | * This function should always be used instead of the values getName() 25 | * function. If the object has no name yet, a new unique name is generated 26 | * based on the default name. 27 | */ 28 | std::string FunctionInfoPass::getUniqueName(const Value &v) { 29 | if (v.hasName()) return v.getName().str(); 30 | 31 | auto iter = valueNames.find(&v); 32 | if (iter != valueNames.end()) return iter->second; 33 | 34 | std::stringstream ss; 35 | if (isa(v)) 36 | ss << "val"; 37 | else if (isa(v)) 38 | ss << "bb"; 39 | else if (isa(v)) 40 | ss << "func"; 41 | else 42 | ss << "v"; 43 | 44 | ss << valueNames.size(); 45 | 46 | valueNames[&v] = ss.str(); 47 | return ss.str(); 48 | } 49 | 50 | ArgInfoPtr FunctionInfoPass::getInfo(const Argument &arg) { 51 | auto it = argInfos.find(&arg); 52 | if (it != argInfos.end()) return it->second; 53 | 54 | ArgInfoPtr info(new ArgInfo()); 55 | argInfos[&arg] = info; 56 | 57 | info->name = getUniqueName(arg); 58 | 59 | // collect the type 60 | info->type = llvmTypeToString(arg.getType()); 61 | 62 | return info; 63 | } 64 | 65 | ConstantInfoPtr FunctionInfoPass::getInfo(const ::llvm::Constant &con) { 66 | auto it = constantInfos.find(&con); 67 | if (it != constantInfos.end()) return it->second; 68 | 69 | ConstantInfoPtr info(new ConstantInfo()); 70 | constantInfos[&con] = info; 71 | 72 | // collect the type 73 | info->type = llvmTypeToString(con.getType()); 74 | 75 | return info; 76 | } 77 | 78 | InstructionInfoPtr FunctionInfoPass::getInfo(const Instruction &inst) { 79 | auto it = instructionInfos.find(&inst); 80 | if (it != instructionInfos.end()) return it->second; 81 | 82 | InstructionInfoPtr info(new InstructionInfo()); 83 | instructionInfos[&inst] = info; 84 | 85 | // collect opcode 86 | info->opcode = inst.getOpcodeName(); 87 | 88 | if (inst.getOpcodeName() == std::string("ret")) { 89 | info_->exitInstructions.push_back(info); 90 | } 91 | 92 | // collect type 93 | std::string typeName; 94 | raw_string_ostream rso(typeName); 95 | inst.getType()->print(rso); 96 | info->type = rso.str(); 97 | 98 | // collect data dependencies 99 | for (auto &use : inst.operands()) { 100 | if (isa(use.get())) { 101 | auto &opInst = *cast(use.get()); 102 | info->operands.push_back(getInfo(opInst)); 103 | } 104 | 105 | if (isa(use.get())) { 106 | auto &opInst = *cast(use.get()); 107 | info->operands.push_back(getInfo(opInst)); 108 | } 109 | 110 | if (isa(use.get())) { 111 | auto &opInst = *cast(use.get()); 112 | info->operands.push_back(getInfo(opInst)); 113 | } 114 | } 115 | 116 | // collect called function (if this instruction is a call) 117 | if (isa(inst)) { 118 | auto &call = cast(inst); 119 | Function *calledFunction = call.getCalledFunction(); 120 | if (calledFunction != nullptr) { 121 | info->callTarget = getUniqueName(*calledFunction); 122 | } 123 | } 124 | 125 | // load or store? 126 | info->isLoadOrStore = false; 127 | if (isa(inst)) info->isLoadOrStore = true; 128 | if (isa(inst)) info->isLoadOrStore = true; 129 | 130 | // collect function this instruction belongs to 131 | info->function = info_; 132 | 133 | return info; 134 | } 135 | 136 | BasicBlockInfoPtr FunctionInfoPass::getInfo(const BasicBlock &bb) { 137 | auto it = basicBlockInfos.find(&bb); 138 | if (it != basicBlockInfos.end()) return it->second; 139 | 140 | BasicBlockInfoPtr info(new BasicBlockInfo()); 141 | basicBlockInfos[&bb] = info; 142 | 143 | info->name = getUniqueName(bb); 144 | 145 | // collect all successors 146 | auto term = bb.getTerminator(); 147 | for (size_t i = 0; i < term->getNumSuccessors(); i++) { 148 | BasicBlock *succ = term->getSuccessor(i); 149 | info->successors.push_back(getInfo(*succ)); 150 | } 151 | 152 | return info; 153 | } 154 | 155 | MemoryAccessInfoPtr FunctionInfoPass::getInfo(MemoryAccess &acc) { 156 | auto it = memoryAccessInfos.find(&acc); 157 | if (it != memoryAccessInfos.end()) return it->second; 158 | 159 | MemoryAccessInfoPtr info(new MemoryAccessInfo()); 160 | memoryAccessInfos[&acc] = info; 161 | 162 | info->block = getInfo(*acc.getBlock()); 163 | 164 | if (isa(acc)) { 165 | if (isa(acc)) 166 | info->type = "use"; 167 | else 168 | info->type = "def"; 169 | 170 | auto inst = cast(acc).getMemoryInst(); 171 | if (inst != nullptr) { 172 | info->inst = getInfo(*inst); 173 | } else { 174 | info->inst = NULL; 175 | assert(info->type == "def"); 176 | info->type = "live on entry"; 177 | } 178 | 179 | auto dep = cast(acc).getDefiningAccess(); 180 | if (dep != nullptr) { 181 | info->dependencies.push_back(getInfo(*dep)); 182 | } 183 | } else { 184 | info->type = "phi"; 185 | info->inst = NULL; 186 | auto &phi = cast(acc); 187 | for (unsigned i = 0; i < phi.getNumIncomingValues(); i++) { 188 | auto dep = phi.getIncomingValue(i); 189 | info->dependencies.push_back(getInfo(*dep)); 190 | } 191 | } 192 | 193 | return info; 194 | } 195 | 196 | bool FunctionInfoPass::runOnFunction(::llvm::Function &func) { 197 | // wipe all data from the previous run 198 | valueNames.clear(); 199 | argInfos.clear(); 200 | basicBlockInfos.clear(); 201 | instructionInfos.clear(); 202 | memoryAccessInfos.clear(); 203 | valueNames.clear(); 204 | 205 | // create a new info object and invalidate the old one 206 | info_ = FunctionInfoPtr(new FunctionInfo()); 207 | 208 | info_->name = getUniqueName(func); 209 | info_->entryInstruction = 210 | getInfo(*func.getEntryBlock().getInstList().begin()); 211 | 212 | std::string rtypeName; 213 | raw_string_ostream rso(rtypeName); 214 | func.getReturnType()->print(rso); 215 | info_->type = rso.str(); 216 | 217 | // collect all basic blocks and their instructions 218 | for (auto &bb : func.getBasicBlockList()) { 219 | BasicBlockInfoPtr bbInfo = getInfo(bb); 220 | for (auto &inst : bb) { 221 | bbInfo->instructions.push_back(getInfo(inst)); 222 | } 223 | info_->basicBlocks.push_back(bbInfo); 224 | } 225 | 226 | // collect all arguments 227 | for (auto &arg : func.args()) { 228 | info_->args.push_back(getInfo(arg)); 229 | } 230 | 231 | // dump app memory accesses 232 | auto &mssaPass = getAnalysis(); 233 | auto &mssa = mssaPass.getMSSA(); 234 | for (auto &bb : func.getBasicBlockList()) { 235 | // live on entry 236 | auto entry = mssa.getLiveOnEntryDef(); 237 | info_->memoryAccesses.push_back(getInfo(*entry)); 238 | 239 | // memory phis 240 | auto phi = mssa.getMemoryAccess(&bb); 241 | if (phi != nullptr) { 242 | info_->memoryAccesses.push_back(getInfo(*phi)); 243 | } 244 | 245 | // memory use or defs 246 | for (auto &inst : bb) { 247 | auto access = mssa.getMemoryAccess(&inst); 248 | if (access != nullptr) { 249 | info_->memoryAccesses.push_back(getInfo(*access)); 250 | } 251 | } 252 | } 253 | 254 | // indicate that nothing was changed 255 | return false; 256 | } 257 | 258 | void FunctionInfoPass::getAnalysisUsage(AnalysisUsage &au) const { 259 | au.addRequired(); 260 | au.setPreservesAll(); 261 | } 262 | 263 | char FunctionInfoPass::ID = 0; 264 | 265 | static RegisterPass X("funcinfo", "Function Info Extractor", 266 | true /* Only looks at CFG */, 267 | true /* Analysis Pass */); 268 | 269 | } // namespace graph 270 | } // namespace llvm 271 | } // namespace compy 272 | -------------------------------------------------------------------------------- /compy/representations/extractors/llvm_ir/llvm_graph_funcinfo.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "llvm/Analysis/MemorySSA.h" 9 | #include "llvm/IR/Function.h" 10 | #include "llvm/Pass.h" 11 | #include "llvm/Support/raw_ostream.h" 12 | 13 | #include "llvm_extractor.h" 14 | 15 | namespace compy { 16 | namespace llvm { 17 | namespace graph { 18 | 19 | class FunctionInfoPass : public ::llvm::FunctionPass { 20 | private: 21 | FunctionInfoPtr info_; 22 | 23 | public: 24 | static char ID; 25 | 26 | FunctionInfoPass() : ::llvm::FunctionPass(ID), info_(nullptr) {} 27 | 28 | bool runOnFunction(::llvm::Function &func) override; 29 | void getAnalysisUsage(::llvm::AnalysisUsage &au) const override; 30 | 31 | const FunctionInfoPtr &getInfo() const { return info_; } 32 | FunctionInfoPtr &getInfo() { return info_; } 33 | 34 | private: 35 | std::string getUniqueName(const ::llvm::Value &v); 36 | ArgInfoPtr getInfo(const ::llvm::Argument &arg); 37 | ConstantInfoPtr getInfo(const ::llvm::Constant &con); 38 | BasicBlockInfoPtr getInfo(const ::llvm::BasicBlock &bb); 39 | InstructionInfoPtr getInfo(const ::llvm::Instruction &inst); 40 | MemoryAccessInfoPtr getInfo(::llvm::MemoryAccess &acc); 41 | 42 | private: 43 | std::unordered_map argInfos; 44 | std::unordered_map constantInfos; 45 | std::unordered_map 46 | basicBlockInfos; 47 | std::unordered_map 48 | instructionInfos; 49 | std::unordered_map 50 | memoryAccessInfos; 51 | std::unordered_map valueNames; 52 | }; 53 | 54 | } // namespace graph 55 | } // namespace llvm 56 | } // namespace compy 57 | -------------------------------------------------------------------------------- /compy/representations/extractors/llvm_ir/llvm_graph_pass.cc: -------------------------------------------------------------------------------- 1 | #include "llvm_graph_pass.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "llvm/Analysis/CallGraph.h" 7 | 8 | #include "llvm_graph_funcinfo.h" 9 | 10 | using namespace ::llvm; 11 | 12 | namespace compy { 13 | namespace llvm { 14 | namespace graph { 15 | 16 | bool ExtractorPass::runOnModule(::llvm::Module &module) { 17 | ExtractionInfoPtr info(new ExtractionInfo()); 18 | 19 | // Collect and dump all the function information 20 | for (auto &func : module.functions()) { 21 | // Skip functions without definition (fwd declarations) 22 | if (func.isDeclaration()) { 23 | continue; 24 | } 25 | 26 | auto &pass = getAnalysis(func); 27 | auto functionInfo = std::move(pass.getInfo()); 28 | info->functionInfos.push_back(std::move(functionInfo)); 29 | } 30 | 31 | // Dump the call graph 32 | info->callGraphInfo.reset(new CallGraphInfo()); 33 | 34 | const auto &callGraph = getAnalysis().getCallGraph(); 35 | for (auto &kv : callGraph) { 36 | auto *func = kv.first; 37 | auto &node = kv.second; 38 | 39 | // Skip the null entry 40 | if (func == nullptr) continue; 41 | 42 | // -1, because the null entry references everything 43 | for (auto &kv : *node) { 44 | // Skip for functions without definition (fwd declarations) 45 | if (kv.second->getFunction()) { 46 | info->callGraphInfo->calls.push_back( 47 | kv.second->getFunction()->getName().str()); 48 | } 49 | } 50 | } 51 | 52 | this->extractionInfo = info; 53 | 54 | // Returning false indicates that we didn't change anything 55 | return false; 56 | } 57 | 58 | void ExtractorPass::getAnalysisUsage(AnalysisUsage &au) const { 59 | au.addRequired(); 60 | au.addRequired(); 61 | au.setPreservesAll(); 62 | } 63 | 64 | char ExtractorPass::ID = 0; 65 | static ::llvm::RegisterPass X("graphExtractor", "GraphExtractor", 66 | true /* Only looks at CFG */, 67 | true /* Analysis Pass */); 68 | 69 | } // namespace graph 70 | } // namespace llvm 71 | } // namespace compy 72 | -------------------------------------------------------------------------------- /compy/representations/extractors/llvm_ir/llvm_graph_pass.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "llvm/IR/Module.h" 8 | #include "llvm/Pass.h" 9 | 10 | #include "llvm_extractor.h" 11 | 12 | namespace compy { 13 | namespace llvm { 14 | namespace graph { 15 | 16 | struct ExtractionInfo; 17 | using ExtractionInfoPtr = std::shared_ptr; 18 | 19 | class ExtractorPass : public ::llvm::ModulePass { 20 | public: 21 | static char ID; 22 | ExtractorPass() : ::llvm::ModulePass(ID) {} 23 | 24 | bool runOnModule(::llvm::Module &M) override; 25 | void getAnalysisUsage(::llvm::AnalysisUsage &au) const override; 26 | 27 | ExtractionInfoPtr extractionInfo; 28 | }; 29 | 30 | } // namespace graph 31 | } // namespace llvm 32 | } // namespace compy 33 | -------------------------------------------------------------------------------- /compy/representations/extractors/llvm_ir/llvm_pass_test.cc: -------------------------------------------------------------------------------- 1 | #include "llvm/IR/LegacyPassManager.h" 2 | #include "llvm/IR/Module.h" 3 | #include "llvm/IRReader/IRReader.h" 4 | #include "llvm/InitializePasses.h" 5 | #include "llvm/Support/PrettyStackTrace.h" 6 | #include "llvm/Support/Signals.h" 7 | #include "llvm/Support/SourceMgr.h" 8 | 9 | #include "common/common_test.h" 10 | #include "gtest/gtest.h" 11 | #include "llvm_graph_pass.h" 12 | #include "llvm_seq_pass.h" 13 | 14 | using namespace llvm; 15 | using namespace compy; 16 | using namespace compy::llvm; 17 | 18 | using std::string; 19 | 20 | class LLVMGraphPassFixture : public testing::Test { 21 | protected: 22 | void SetUp() override { 23 | // Register other llvm passes 24 | PassRegistry& reg = *PassRegistry::getPassRegistry(); 25 | initializeCallGraphWrapperPassPass(reg); 26 | initializeMemorySSAWrapperPassPass(reg); 27 | 28 | // Setup the pass manager, add pass 29 | _pm = new legacy::PassManager(); 30 | _ep = new graph::ExtractorPass(); 31 | _pm->add(_ep); 32 | } 33 | 34 | void TearDown() override { 35 | free(_pm); 36 | free(_ep); 37 | } 38 | 39 | graph::ExtractionInfoPtr Extract(std::string ir) { 40 | // Construct an IR file from the filename passed on the command line. 41 | SMDiagnostic err; 42 | LLVMContext context; 43 | MemoryBufferRef mb = MemoryBuffer::getMemBuffer(ir)->getMemBufferRef(); 44 | std::unique_ptr module = parseIR(mb, err, context); 45 | if (!module.get()) { 46 | throw std::runtime_error("Failed compiling to LLVM module"); 47 | } 48 | 49 | // Run pass 50 | _pm->run(*module); 51 | 52 | // Return extraction info 53 | return _ep->extractionInfo; 54 | } 55 | 56 | legacy::PassManager* _pm; 57 | graph::ExtractorPass* _ep; 58 | }; 59 | 60 | class LLVMSeqPassFixture : public testing::Test { 61 | protected: 62 | void SetUp() override { 63 | // Register other llvm passes 64 | PassRegistry& reg = *PassRegistry::getPassRegistry(); 65 | initializeCallGraphWrapperPassPass(reg); 66 | initializeMemorySSAWrapperPassPass(reg); 67 | 68 | // Setup the pass manager, add pass 69 | _pm = new legacy::PassManager(); 70 | _ep = new seq::ExtractorPass(); 71 | _pm->add(_ep); 72 | } 73 | 74 | void TearDown() override { 75 | free(_pm); 76 | free(_ep); 77 | } 78 | 79 | seq::ExtractionInfoPtr Extract(std::string ir) { 80 | // Construct an IR file from the filename passed on the command line. 81 | SMDiagnostic err; 82 | LLVMContext context; 83 | MemoryBufferRef mb = MemoryBuffer::getMemBuffer(ir)->getMemBufferRef(); 84 | std::unique_ptr module = parseIR(mb, err, context); 85 | if (!module.get()) { 86 | throw std::runtime_error("Failed compiling to LLVM module"); 87 | } 88 | 89 | // Run pass 90 | _pm->run(*module); 91 | 92 | // Return extraction info 93 | return _ep->extractionInfo; 94 | } 95 | 96 | legacy::PassManager* _pm; 97 | seq::ExtractorPass* _ep; 98 | }; 99 | 100 | TEST_F(LLVMGraphPassFixture, RunPassAndRetrieveSuccess) { 101 | graph::ExtractionInfoPtr info = Extract(kLLVM1); 102 | 103 | ASSERT_EQ(info->functionInfos.size(), 1UL); 104 | } 105 | 106 | TEST_F(LLVMSeqPassFixture, RunPassAndRetrieveSuccess2) { 107 | seq::ExtractionInfoPtr info = Extract(kLLVM1); 108 | 109 | ASSERT_EQ(info->functionInfos.size(), 1UL); 110 | 111 | std::vector signature = info->functionInfos[0]->signature; 112 | 113 | seq::BasicBlockInfoPtr basicBlock = info->functionInfos[0]->basicBlocks[0]; 114 | ASSERT_GT(basicBlock->instructions.size(), 1UL); 115 | 116 | seq::InstructionInfoPtr instructionInfoPtr = basicBlock->instructions[0]; 117 | ASSERT_GT(instructionInfoPtr->tokens.size(), 1UL); 118 | } 119 | 120 | TEST_F(LLVMGraphPassFixture, RunPassAndRetrieveFail) { 121 | EXPECT_THROW( 122 | { 123 | try { 124 | graph::ExtractionInfoPtr info = Extract(kLLVM2); 125 | } catch (std::runtime_error const& err) { 126 | EXPECT_EQ(err.what(), std::string("Failed compiling to LLVM module")); 127 | throw; 128 | } 129 | }, 130 | std::runtime_error); 131 | } 132 | 133 | TEST_F(LLVMGraphPassFixture, RunPassAndRetrieveZero) { 134 | graph::ExtractionInfoPtr info = Extract(""); 135 | 136 | ASSERT_EQ(info->functionInfos.size(), 0UL); 137 | } 138 | -------------------------------------------------------------------------------- /compy/representations/extractors/llvm_ir/llvm_seq_pass.cc: -------------------------------------------------------------------------------- 1 | #include "llvm_seq_pass.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "llvm/Analysis/CallGraph.h" 7 | #include "llvm/IR/AssemblyAnnotationWriter.h" 8 | #include "llvm/Transforms/IPO.h" 9 | 10 | using namespace ::llvm; 11 | 12 | namespace compy { 13 | namespace llvm { 14 | namespace seq { 15 | 16 | class InfoBuilder { 17 | public: 18 | InfoBuilder() { 19 | functionInfo_.reset(new FunctionInfo); 20 | largestSlotIdSoFar_ = 0; 21 | } 22 | 23 | void AddToken(std::string token) { tokenBuffer_.push_back(token); } 24 | 25 | void onBasicBlockStart() { 26 | // Before first BB was the function signature. 27 | if (functionInfo_->basicBlocks.empty()) { 28 | stripComments(tokenBuffer_); 29 | stripNewlines(tokenBuffer_); 30 | stripEntryInstruction(tokenBuffer_); 31 | 32 | functionInfo_->signature = tokenBuffer_; 33 | tokenBuffer_.clear(); 34 | } 35 | 36 | BasicBlockInfoPtr basicBlockInfo(new BasicBlockInfo); 37 | basicBlockInfo->name = std::to_string(++largestSlotIdSoFar_); 38 | 39 | functionInfo_->basicBlocks.push_back(basicBlockInfo); 40 | } 41 | 42 | void onInstructionStart() { tokenBuffer_.clear(); } 43 | 44 | void onInstructionEnd() { 45 | stripDoubleWhitespaces(tokenBuffer_); 46 | 47 | InstructionInfoPtr instructionInfo(new InstructionInfo); 48 | instructionInfo->tokens = tokenBuffer_; 49 | 50 | BasicBlockInfoPtr basicBlockInfo = functionInfo_->basicBlocks.back(); 51 | basicBlockInfo->instructions.push_back(instructionInfo); 52 | 53 | // Track largest slot id. 54 | for (std::size_t i = 0; i != instructionInfo->tokens.size() - 2; ++i) { 55 | if (instructionInfo->tokens[i] == "%" && 56 | instructionInfo->tokens[i + 2] == " = ") { 57 | int slotId = std::stoi(instructionInfo->tokens[i + 1]); 58 | largestSlotIdSoFar_ = std::max(largestSlotIdSoFar_, slotId); 59 | } 60 | } 61 | } 62 | 63 | FunctionInfoPtr getInfo() { return functionInfo_; } 64 | 65 | FunctionInfoPtr functionInfo_; 66 | 67 | private: 68 | void stripComments(std::vector &tokens) { 69 | std::vector::iterator semicolonEle; 70 | 71 | bool semicolonFound = false; 72 | for (auto it = tokens.begin(); it != tokens.end(); it++) { 73 | auto token = *it; 74 | 75 | if (token.find(";") != std::string::npos) { 76 | semicolonEle = it; 77 | semicolonFound = true; 78 | } else if (semicolonFound && token.find("\n") != std::string::npos) { 79 | tokens.erase(semicolonEle, it); 80 | semicolonFound = false; 81 | } 82 | } 83 | } 84 | 85 | void stripEntryInstruction(std::vector &tokens) { 86 | auto itEntry = std::find(tokens.begin(), tokens.end(), "entry"); 87 | tokens.erase(itEntry - 1, tokens.end()); 88 | } 89 | 90 | void stripNewlines(std::vector &tokens) { 91 | tokens.erase(std::remove(tokens.begin(), tokens.end(), "\n"), tokens.end()); 92 | } 93 | 94 | void stripDoubleWhitespaces(std::vector &tokens) { 95 | tokens.erase(std::remove(tokens.begin(), tokens.end(), " "), tokens.end()); 96 | } 97 | 98 | private: 99 | int largestSlotIdSoFar_; 100 | std::vector tokenBuffer_; 101 | }; 102 | using InfoBuilderPtr = std::shared_ptr; 103 | 104 | class token_ostream : public ::llvm::raw_ostream { 105 | void write_impl(const char *Ptr, size_t Size) override { 106 | std::string str(Ptr, Size); 107 | infoBuilder_->AddToken(str); 108 | 109 | OS.append(Ptr, Size); 110 | } 111 | 112 | uint64_t current_pos() const override { return OS.size(); } 113 | 114 | public: 115 | explicit token_ostream(InfoBuilderPtr infoBuilder) 116 | : infoBuilder_(infoBuilder) { 117 | // Set unbufferd, so we get token-by-token 118 | SetUnbuffered(); 119 | } 120 | ~token_ostream() override { flush(); } 121 | 122 | std::string &str() { 123 | flush(); 124 | return OS; 125 | } 126 | 127 | std::string &getStr() { return OS; } 128 | 129 | private: 130 | InfoBuilderPtr infoBuilder_; 131 | std::string OS; 132 | }; 133 | 134 | class TokenAnnotator : public ::llvm::AssemblyAnnotationWriter { 135 | public: 136 | TokenAnnotator(InfoBuilderPtr infoBuilder) : infoBuilder_(infoBuilder) {} 137 | 138 | virtual void emitBasicBlockStartAnnot(const BasicBlock *bb, 139 | formatted_raw_ostream &) { 140 | infoBuilder_->onBasicBlockStart(); 141 | } 142 | virtual void emitInstructionAnnot(const Instruction *, 143 | formatted_raw_ostream &) { 144 | infoBuilder_->onInstructionStart(); 145 | } 146 | virtual void printInfoComment(const Value &, formatted_raw_ostream &) { 147 | infoBuilder_->onInstructionEnd(); 148 | } 149 | 150 | private: 151 | InfoBuilderPtr infoBuilder_; 152 | }; 153 | 154 | bool ExtractorPass::runOnModule(::llvm::Module &module) { 155 | ExtractionInfoPtr info(new ExtractionInfo); 156 | 157 | for (const auto &F : module.functions()) { 158 | // InfoBuilder holds the state of the tokenization. It is built using a 159 | // custom stream that captures token by token. An Annotator object with hook 160 | // functions is regularly called by the LLVM stack, structuring the token 161 | // stream into the entities. 162 | InfoBuilderPtr infoBuilder(new InfoBuilder); 163 | 164 | token_ostream TokenStream(infoBuilder); 165 | std::unique_ptr tokenAnnotator( 166 | new TokenAnnotator(infoBuilder)); 167 | 168 | F.print(TokenStream, tokenAnnotator.get()); 169 | 170 | FunctionInfoPtr functionInfo = infoBuilder->getInfo(); 171 | functionInfo->name = F.getName().str(); 172 | functionInfo->str = TokenStream.getStr(); 173 | info->functionInfos.push_back(functionInfo); 174 | } 175 | 176 | this->extractionInfo = info; 177 | 178 | return false; 179 | } 180 | 181 | void ExtractorPass::getAnalysisUsage(AnalysisUsage &au) const { 182 | au.addRequired(); 183 | 184 | au.setPreservesAll(); 185 | } 186 | 187 | char ExtractorPass::ID = 0; 188 | static ::llvm::RegisterPass X("seqExtractor", "SeqExtractor", 189 | true /* Only looks at CFG */, 190 | true /* Analysis Pass */); 191 | 192 | } // namespace seq 193 | } // namespace llvm 194 | } // namespace compy 195 | -------------------------------------------------------------------------------- /compy/representations/extractors/llvm_ir/llvm_seq_pass.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "llvm/IR/Module.h" 8 | #include "llvm/Pass.h" 9 | 10 | #include "llvm_extractor.h" 11 | 12 | namespace compy { 13 | namespace llvm { 14 | namespace seq { 15 | 16 | class ExtractorPass : public ::llvm::ModulePass { 17 | public: 18 | static char ID; 19 | ExtractorPass() : ::llvm::ModulePass(ID) {} 20 | 21 | bool runOnModule(::llvm::Module &M) override; 22 | void getAnalysisUsage(::llvm::AnalysisUsage &au) const override; 23 | 24 | ExtractionInfoPtr extractionInfo; 25 | }; 26 | 27 | } // namespace seq 28 | } // namespace llvm 29 | } // namespace compy 30 | -------------------------------------------------------------------------------- /compy/representations/llvm_graphs.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | from compy.representations.extractors import clang_driver_scoped_options 4 | from compy.representations.extractors.extractors import Visitor 5 | from compy.representations.extractors.extractors import ClangDriver 6 | from compy.representations.extractors.extractors import LLVMIRExtractor 7 | from compy.representations.extractors.extractors import llvm 8 | from compy.representations import common 9 | 10 | 11 | class LLVMCDFGVisitor(Visitor): 12 | def __init__(self): 13 | Visitor.__init__(self) 14 | self.edge_types = ["cfg", "data", "mem"] 15 | self.G = nx.MultiDiGraph() 16 | 17 | def visit(self, v): 18 | if isinstance(v, llvm.graph.FunctionInfo): 19 | # Function arg nodes. 20 | for arg in v.args: 21 | self.G.add_node(arg, attr=(arg.type)) 22 | 23 | # Memory accesses edges. 24 | for memacc in v.memoryAccesses: 25 | if memacc.inst: 26 | for dep in memacc.dependencies: 27 | if dep.inst: 28 | self.G.add_edge(dep.inst, memacc.inst, attr="mem") 29 | 30 | if isinstance(v, llvm.graph.BasicBlockInfo): 31 | # CFG edges: Inner-BB. 32 | instr_prev = v.instructions[0] 33 | for instr in v.instructions[1:]: 34 | self.G.add_edge(instr_prev, instr, attr="cfg") 35 | instr_prev = instr 36 | 37 | # CFG edges: Inter-BB 38 | for succ in v.successors: 39 | self.G.add_edge(v.instructions[-1], succ.instructions[0], attr="cfg") 40 | 41 | if isinstance(v, llvm.graph.InstructionInfo): 42 | # Instruction nodes. 43 | self.G.add_node(v, attr=(v.opcode)) 44 | 45 | # Operands. 46 | for operand in v.operands: 47 | if isinstance(operand, llvm.graph.ArgInfo) or isinstance( 48 | operand, llvm.graph.InstructionInfo 49 | ): 50 | self.G.add_edge(operand, v, attr="data") 51 | 52 | 53 | class LLVMCDFGCallVisitor(Visitor): 54 | def __init__(self): 55 | Visitor.__init__(self) 56 | self.edge_types = ["cfg", "data", "mem", "call"] 57 | self.G = nx.MultiDiGraph() 58 | self.functions = {} 59 | 60 | def visit(self, v): 61 | if isinstance(v, llvm.graph.FunctionInfo): 62 | self.functions[v.name] = v 63 | 64 | # Function root node. 65 | self.G.add_node(v, attr="function") 66 | self.G.add_edge(v, v.entryInstruction, attr="call") 67 | 68 | # Function arg nodes. 69 | for arg in v.args: 70 | self.G.add_node(arg, attr=(arg.type)) 71 | 72 | # Memory accesses edges. 73 | for memacc in v.memoryAccesses: 74 | if memacc.inst: 75 | for dep in memacc.dependencies: 76 | if dep.inst: 77 | self.G.add_edge(dep.inst, memacc.inst, attr="mem") 78 | 79 | if isinstance(v, llvm.graph.BasicBlockInfo): 80 | # CFG edges: Inner-BB. 81 | instr_prev = v.instructions[0] 82 | for instr in v.instructions[1:]: 83 | self.G.add_edge(instr_prev, instr, attr="cfg") 84 | instr_prev = instr 85 | 86 | # CFG edges: Inter-BB 87 | for succ in v.successors: 88 | self.G.add_edge(v.instructions[-1], succ.instructions[0], attr="cfg") 89 | 90 | if isinstance(v, llvm.graph.InstructionInfo): 91 | # Instruction nodes. 92 | self.G.add_node(v, attr=(v.opcode)) 93 | 94 | # Call edges. 95 | if v.opcode == "ret": 96 | self.G.add_edge(v, v.function, attr="call") 97 | if v.opcode == "call": 98 | called_function = ( 99 | self.functions[v.callTarget] 100 | if v.callTarget in self.functions 101 | else None 102 | ) 103 | if called_function: 104 | self.G.add_edge(v, called_function.entryInstruction, attr="call") 105 | for exit in called_function.exitInstructions: 106 | self.G.add_edge(exit, v, attr="call") 107 | 108 | # Operands. 109 | for operand in v.operands: 110 | if isinstance(operand, llvm.graph.ArgInfo) or isinstance( 111 | operand, llvm.graph.InstructionInfo 112 | ): 113 | self.G.add_edge(operand, v, attr="data") 114 | 115 | 116 | class LLVMCDFGPlusVisitor(Visitor): 117 | def __init__(self): 118 | Visitor.__init__(self) 119 | self.edge_types = ["cfg", "data", "mem", "call", "bb"] 120 | self.G = nx.MultiDiGraph() 121 | 122 | def visit(self, v): 123 | if isinstance(v, llvm.graph.FunctionInfo): 124 | # Function root node. 125 | self.G.add_node(v, attr="function") 126 | self.G.add_edge(v, v.entryInstruction, attr="cfg") 127 | 128 | # Function arg nodes. 129 | for arg in v.args: 130 | self.G.add_node(arg, attr=(arg.type)) 131 | self.G.add_edge(v, arg, attr="data") 132 | 133 | # Memory accesses 134 | for memacc in v.memoryAccesses: 135 | if memacc.inst: 136 | for dep in memacc.dependencies: 137 | if dep.inst: 138 | self.G.add_edge(dep.inst, memacc.inst, attr="mem") 139 | 140 | if isinstance(v, llvm.graph.BasicBlockInfo): 141 | # BB nodes 142 | self.G.add_node(v, attr="bb") 143 | for instr in v.instructions: 144 | self.G.add_edge(instr, v, attr="bb") 145 | for succ in v.successors: 146 | self.G.add_edge(v, succ, attr="bb") 147 | 148 | # CFG edges: Inner-BB. 149 | instr_prev = v.instructions[0] 150 | for instr in v.instructions[1:]: 151 | self.G.add_edge(instr_prev, instr, attr="cfg") 152 | instr_prev = instr 153 | 154 | # CFG edges: Inter-BB 155 | for succ in v.successors: 156 | self.G.add_edge(v.instructions[-1], succ.instructions[0], attr="cfg") 157 | 158 | if isinstance(v, llvm.graph.InstructionInfo): 159 | # Instruction nodes. 160 | self.G.add_node(v, attr=(v.opcode)) 161 | 162 | # Operands. 163 | for operand in v.operands: 164 | if isinstance(operand, llvm.graph.ArgInfo) or isinstance( 165 | operand, llvm.graph.InstructionInfo 166 | ): 167 | self.G.add_edge(operand, v, attr="data") 168 | 169 | 170 | class LLVMProGraMLVisitor(Visitor): 171 | def __init__(self): 172 | Visitor.__init__(self) 173 | self.edge_types = ["cfg", "data", "call"] 174 | self.G = nx.MultiDiGraph() 175 | self.functions = {} 176 | 177 | def visit(self, v): 178 | if isinstance(v, llvm.graph.FunctionInfo): 179 | self.functions[v.name] = v 180 | 181 | # Function node. 182 | self.G.add_node(v, attr="function") 183 | self.G.add_edge(v, v.entryInstruction, attr="call") 184 | 185 | # Function arg nodes. 186 | for arg in v.args: 187 | self.G.add_node(arg, attr=(arg.type)) 188 | 189 | if isinstance(v, llvm.graph.BasicBlockInfo): 190 | # CFG edges: Inner-BB. 191 | instr_prev = v.instructions[0] 192 | for instr in v.instructions[1:]: 193 | self.G.add_edge(instr_prev, instr, attr="cfg") 194 | instr_prev = instr 195 | 196 | # CFG edges: Inter-BB 197 | for succ in v.successors: 198 | self.G.add_edge(v.instructions[-1], succ.instructions[0], attr="cfg") 199 | 200 | if isinstance(v, llvm.graph.InstructionInfo): 201 | # Instruction nodes. 202 | self.G.add_node(v, attr=(v.opcode)) 203 | 204 | # Call edges. 205 | if v.opcode == "ret": 206 | self.G.add_edge(v, v.function, attr="call") 207 | if v.opcode == "call": 208 | called_function = ( 209 | self.functions[v.callTarget] 210 | if v.callTarget in self.functions 211 | else None 212 | ) 213 | if called_function: 214 | self.G.add_edge(v, called_function.entryInstruction, attr="call") 215 | for exit in called_function.exitInstructions: 216 | self.G.add_edge(exit, v, attr="call") 217 | 218 | # Operands. 219 | for operand in v.operands: 220 | if isinstance(operand, llvm.graph.ArgInfo) or isinstance( 221 | operand, llvm.graph.ConstantInfo 222 | ): 223 | self.G.add_node(operand, attr=(operand.type)) 224 | self.G.add_edge(operand, v, attr="data") 225 | elif isinstance(operand, llvm.graph.InstructionInfo): 226 | self.G.add_node((v, operand), attr=(operand.type)) 227 | self.G.add_edge(operand, (v, operand), attr="data") 228 | self.G.add_edge((v, operand), v, attr="data") 229 | 230 | 231 | class LLVMGraphBuilder(common.RepresentationBuilder): 232 | def __init__(self, clang_driver=None): 233 | common.RepresentationBuilder.__init__(self) 234 | 235 | if clang_driver: 236 | self.__clang_driver = clang_driver 237 | else: 238 | self.__clang_driver = ClangDriver( 239 | ClangDriver.ProgrammingLanguage.C, 240 | ClangDriver.OptimizationLevel.O3, 241 | [], 242 | ["-Wall"], 243 | ) 244 | self.__extractor = LLVMIRExtractor(self.__clang_driver) 245 | 246 | def string_to_info(self, src, additional_include_dir=None, filename=None): 247 | with clang_driver_scoped_options(self.__clang_driver, additional_include_dir=additional_include_dir, filename=filename): 248 | return self.__extractor.GraphFromString(src) 249 | 250 | def info_to_representation(self, info, visitor=LLVMCDFGVisitor): 251 | vis = visitor() 252 | info.accept(vis) 253 | 254 | for (n, data) in vis.G.nodes(data=True): 255 | attr = data["attr"] 256 | if attr not in self._tokens: 257 | self._tokens[attr] = 1 258 | self._tokens[attr] += 1 259 | 260 | return common.Graph(vis.G, self.get_tokens(), vis.edge_types) 261 | -------------------------------------------------------------------------------- /compy/representations/llvm_graphs_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | import networkx as nx 5 | 6 | from compy.representations.extractors.extractors import Visitor 7 | from compy.representations.extractors.extractors import llvm 8 | from compy.representations.llvm_graphs import LLVMGraphBuilder 9 | from compy.representations.llvm_graphs import LLVMCDFGVisitor 10 | from compy.representations.llvm_graphs import LLVMCDFGCallVisitor 11 | from compy.representations.llvm_graphs import LLVMCDFGPlusVisitor 12 | from compy.representations.llvm_graphs import LLVMProGraMLVisitor 13 | 14 | 15 | program_1fn_2 = """ 16 | int bar(int a) { 17 | if (a > 10) 18 | return a; 19 | return -1; 20 | } 21 | """ 22 | 23 | program_fib = """ 24 | int fib(int x) { 25 | switch(x) { 26 | case 0: 27 | return 0; 28 | case 1: 29 | return 1; 30 | default: 31 | return fib(x-1) + fib(x-2); 32 | } 33 | } 34 | """ 35 | 36 | 37 | def get_nodes_with_attr(graph, attr): 38 | return [x for x, y in graph.G.nodes(data=True) if y["attr"] == attr] 39 | 40 | 41 | def get_function_nodes_by_name(graph): 42 | ret = {} 43 | for function_node in get_nodes_with_attr(graph, "function"): 44 | ret[function_node.name] = function_node 45 | 46 | return ret 47 | 48 | 49 | def get_first_instruction(graph): 50 | for node in graph.G.nodes(): 51 | if isinstance(node, llvm.graph.InstructionInfo): 52 | return node 53 | 54 | 55 | def get_all_instructions(graph): 56 | ret = [] 57 | for node in graph.G.nodes(): 58 | if isinstance(node, llvm.graph.InstructionInfo): 59 | ret.append(node) 60 | return ret 61 | 62 | 63 | def explore_cfg_with_dfs(graph, start_node): 64 | to_explore = [start_node] 65 | explored = [] 66 | while len(to_explore): 67 | node = to_explore.pop(0) 68 | explored.append(node) 69 | 70 | for u, v, data in graph.G.out_edges(node, data=True): 71 | if data["attr"] == "cfg" and data not in explored: 72 | to_explore.append(v) 73 | 74 | return explored 75 | 76 | 77 | # General tests: Construction 78 | def test_construct_with_custom_visitor(): 79 | class CustomVisitor(Visitor): 80 | def __init__(self): 81 | Visitor.__init__(self) 82 | self.edge_types = [] 83 | self.G = nx.DiGraph() 84 | 85 | def visit(self, v): 86 | if not isinstance(v, llvm.graph.ExtractionInfo): 87 | self.G.add_node(v, attr=type(v)) 88 | 89 | builder = LLVMGraphBuilder() 90 | info = builder.string_to_info(program_1fn_2) 91 | graph = builder.info_to_representation(info, CustomVisitor) 92 | 93 | assert len(graph.G) > 0 94 | 95 | 96 | # General tests: Attributes 97 | def test_get_node_list(): 98 | builder = LLVMGraphBuilder() 99 | info = builder.string_to_info(program_1fn_2) 100 | graph = builder.info_to_representation(info, LLVMCDFGVisitor) 101 | nodes = graph.get_node_list() 102 | 103 | assert len(nodes) > 0 104 | 105 | 106 | def test_get_edge_list(): 107 | builder = LLVMGraphBuilder() 108 | info = builder.string_to_info(program_1fn_2) 109 | graph = builder.info_to_representation(info, LLVMCDFGVisitor) 110 | edges = graph.get_edge_list() 111 | 112 | assert len(edges) > 0 113 | 114 | assert type(edges[0][0]) is int 115 | assert type(edges[0][1]) is int 116 | assert type(edges[0][2]) is int 117 | 118 | 119 | # General tests: Plot 120 | def test_plot(tmpdir): 121 | for visitor in [LLVMCDFGVisitor, LLVMCDFGPlusVisitor, LLVMProGraMLVisitor]: 122 | builder = LLVMGraphBuilder() 123 | info = builder.string_to_info(program_fib) 124 | graph = builder.info_to_representation(info, visitor) 125 | 126 | outfile = os.path.join(tmpdir, str(visitor.__name__) + ".png") 127 | graph.draw(path=outfile, with_legend=True) 128 | 129 | assert os.path.isfile(outfile) 130 | 131 | # os.system('xdg-open ' + str(tmpdir)) 132 | 133 | 134 | # All visitors 135 | def test_all_visitors(): 136 | for visitor in [ 137 | LLVMCDFGVisitor, 138 | LLVMCDFGCallVisitor, 139 | LLVMCDFGPlusVisitor, 140 | LLVMProGraMLVisitor, 141 | ]: 142 | builder = LLVMGraphBuilder() 143 | info = builder.string_to_info(program_1fn_2) 144 | ast = builder.info_to_representation(info, visitor) 145 | 146 | assert ast 147 | 148 | 149 | # CDFG 150 | # ############################ 151 | @pytest.fixture 152 | def llvm_cdfg_graph(): 153 | builder = LLVMGraphBuilder() 154 | info = builder.string_to_info(program_fib) 155 | 156 | return builder.info_to_representation(info, LLVMCDFGVisitor) 157 | 158 | 159 | # CFG edges 160 | def d_test_cdfg_cfg_edges_reach_all_nodes(llvm_cdfg_graph): 161 | first_instruction = get_first_instruction(llvm_cdfg_graph) 162 | explored = explore_cfg_with_dfs(llvm_cdfg_graph, first_instruction) 163 | 164 | assert set(explored) == set(get_all_instructions(llvm_cdfg_graph)) 165 | 166 | 167 | # ProGraML 168 | # ############################ 169 | @pytest.fixture 170 | def llvm_programl_graph(): 171 | builder = LLVMGraphBuilder() 172 | info = builder.string_to_info(program_fib) 173 | 174 | return builder.info_to_representation(info, LLVMProGraMLVisitor) 175 | 176 | 177 | # General 178 | def test_programl_has_root_node(llvm_programl_graph): 179 | assert llvm_programl_graph.get_node_str_list().count("function") == 1 180 | 181 | 182 | # Call edges 183 | def test_programl_call_edges_exist_from_ret_instructions_to_root_node( 184 | llvm_programl_graph, 185 | ): 186 | for ret_instr in get_nodes_with_attr(llvm_programl_graph, "ret"): 187 | assert llvm_programl_graph.G.has_edge(ret_instr, ret_instr.function) == True 188 | 189 | 190 | def test_programl_call_edges_exist_from_call_instructions_to_entry_instructions( 191 | llvm_programl_graph, 192 | ): 193 | function_nodes_by_name = get_function_nodes_by_name(llvm_programl_graph) 194 | 195 | for call_instr in get_nodes_with_attr(llvm_programl_graph, "call"): 196 | called_function_node = function_nodes_by_name[call_instr.callTarget] 197 | 198 | assert llvm_programl_graph.G.has_edge( 199 | call_instr, called_function_node.entryInstruction 200 | ) 201 | 202 | 203 | def test_programl_call_edges_exist_from_exit_instructions_to_their_callsite_instructions( 204 | llvm_programl_graph, 205 | ): 206 | function_nodes_by_name = get_function_nodes_by_name(llvm_programl_graph) 207 | 208 | for call_instr in get_nodes_with_attr(llvm_programl_graph, "call"): 209 | called_function_node = function_nodes_by_name[call_instr.callTarget] 210 | 211 | for exit_instruction in called_function_node.exitInstructions: 212 | assert llvm_programl_graph.G.has_edge(exit_instruction, call_instr) 213 | 214 | 215 | # CFG edges 216 | def test_programl_cfg_edges_reach_all_nodes(llvm_programl_graph): 217 | first_instruction = get_first_instruction(llvm_programl_graph) 218 | explored = explore_cfg_with_dfs(llvm_programl_graph, first_instruction) 219 | 220 | assert set(explored) == set(get_all_instructions(llvm_programl_graph)) 221 | -------------------------------------------------------------------------------- /compy/representations/llvm_seq.py: -------------------------------------------------------------------------------- 1 | from compy.representations.extractors import clang_driver_scoped_options 2 | from compy.representations.extractors.extractors import Visitor 3 | from compy.representations.extractors.extractors import ClangDriver 4 | from compy.representations.extractors.extractors import LLVMIRExtractor 5 | from compy.representations.extractors.extractors import llvm 6 | from compy.representations import common 7 | 8 | 9 | def merge_after_element_on_condition(elements, element_conditions): 10 | """ 11 | Ex.: If merged on conditions ['a'], ['a', 'b', 'c', 'a', 'e'] becomes ['ab', 'c', 'ae'] 12 | """ 13 | for i in range(len(elements) - 2, -1, -1): 14 | if elements[i] in element_conditions: 15 | elements[i] = elements[i] + elements.pop(i + 1) 16 | 17 | return elements 18 | 19 | 20 | def filer_elements(elements, element_filter): 21 | """ 22 | Ex.: If filtered on elements [' '], ['a', ' ', 'c'] becomes ['a', 'c'] 23 | """ 24 | return [element for element in elements if element not in element_filter] 25 | 26 | 27 | def strip_elements(elements, element_filters): 28 | """ 29 | Ex.: If stripped on elments [' '], ['a', ' b', 'c'] becomes ['a', 'b', 'c'] 30 | """ 31 | ret = [] 32 | for element in elements: 33 | for element_filter in element_filters: 34 | element = element.strip(element_filter) 35 | ret.append(element) 36 | 37 | return ret 38 | 39 | 40 | def strip_function_name(elements): 41 | for i in range(len(elements) - 1): 42 | if elements[i] == "@": 43 | elements[i + 1] = "fn_0" 44 | 45 | return elements 46 | 47 | 48 | def transform_elements(elements): 49 | elements = merge_after_element_on_condition(elements, ["%", "i"]) 50 | elements = strip_elements(elements, ["\n", " "]) 51 | elements = filer_elements(elements, ["", " ", "local_unnamed_addr"]) 52 | 53 | return elements 54 | 55 | 56 | class LLVMSeqVisitor(Visitor): 57 | def __init__(self): 58 | Visitor.__init__(self) 59 | self.S = [] 60 | 61 | def visit(self, v): 62 | if isinstance(v, llvm.seq.FunctionInfo): 63 | self.S += strip_function_name(transform_elements(v.signature)) 64 | 65 | if isinstance(v, llvm.seq.BasicBlockInfo): 66 | self.S += [v.name + ":"] 67 | 68 | if isinstance(v, llvm.seq.InstructionInfo): 69 | self.S += transform_elements(v.tokens) 70 | 71 | 72 | class LLVMSeqBuilder(common.RepresentationBuilder): 73 | def __init__(self, clang_driver=None): 74 | common.RepresentationBuilder.__init__(self) 75 | 76 | if clang_driver: 77 | self.__clang_driver = clang_driver 78 | else: 79 | self.__clang_driver = ClangDriver( 80 | ClangDriver.ProgrammingLanguage.C, 81 | ClangDriver.OptimizationLevel.O3, 82 | [], 83 | ["-Wall"], 84 | ) 85 | self.__extractor = LLVMIRExtractor(self.__clang_driver) 86 | 87 | def string_to_info(self, src, additional_include_dir=None, filename=None): 88 | with clang_driver_scoped_options(self.__clang_driver, additional_include_dir=additional_include_dir, filename=filename): 89 | return self.__extractor.SeqFromString(src) 90 | 91 | def info_to_representation(self, info, visitor=LLVMSeqVisitor): 92 | vis = visitor() 93 | info.accept(vis) 94 | 95 | for token in vis.S: 96 | if token not in self._tokens: 97 | self._tokens[token] = 1 98 | self._tokens[token] += 1 99 | 100 | return common.Sequence(vis.S, self.get_tokens()) 101 | -------------------------------------------------------------------------------- /compy/representations/llvm_seq_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from compy.representations.extractors.extractors import Visitor 4 | from compy.representations.extractors.extractors import llvm 5 | from compy.representations.llvm_seq import LLVMSeqBuilder 6 | from compy.representations.llvm_seq import LLVMSeqVisitor 7 | 8 | 9 | program_1fn_2 = """ 10 | int bar(int a) { 11 | if (a > 10) 12 | return a; 13 | return -1; 14 | } 15 | """ 16 | 17 | 18 | program_fib = """ 19 | int fib(int x) { 20 | switch(x) { 21 | case 0: 22 | return 0; 23 | case 1: 24 | return 1; 25 | default: 26 | return fib(x-1) + fib(x-2); 27 | } 28 | } 29 | """ 30 | 31 | 32 | # Construction 33 | def test_construct_with_custom_visitor(): 34 | builder = LLVMSeqBuilder() 35 | info = builder.string_to_info(program_1fn_2) 36 | seq = builder.info_to_representation(info) 37 | 38 | 39 | # General tests: Plot 40 | def test_plot(tmpdir): 41 | builder = LLVMSeqBuilder() 42 | info = builder.string_to_info(program_1fn_2) 43 | seq = builder.info_to_representation(info) 44 | 45 | outfile = os.path.join(tmpdir, "syntax_seq.png") 46 | seq.draw(path=outfile, width=8) 47 | 48 | assert os.path.isfile(outfile) 49 | 50 | # os.system('xdg-open ' + str(outfile)) 51 | -------------------------------------------------------------------------------- /compy/representations/syntax_seq.py: -------------------------------------------------------------------------------- 1 | from compy.representations.extractors import clang_driver_scoped_options 2 | from compy.representations.extractors.extractors import Visitor 3 | from compy.representations.extractors.extractors import ClangDriver 4 | from compy.representations.extractors.extractors import ClangExtractor 5 | from compy.representations.extractors.extractors import clang 6 | from compy.representations import common 7 | 8 | 9 | class SyntaxSeqVisitor(Visitor): 10 | def __init__(self): 11 | Visitor.__init__(self) 12 | self.S = [] 13 | 14 | def visit(self, v): 15 | if isinstance(v, clang.seq.TokenInfo): 16 | self.S.append(v.name) 17 | 18 | 19 | class SyntaxTokenkindVisitor(Visitor): 20 | def __init__(self): 21 | Visitor.__init__(self) 22 | self.S = [] 23 | 24 | def visit(self, v): 25 | if isinstance(v, clang.seq.TokenInfo): 26 | self.S.append(v.kind) 27 | 28 | 29 | class SyntaxTokenkindVariableVisitor(Visitor): 30 | def __init__(self): 31 | Visitor.__init__(self) 32 | self.S = [] 33 | 34 | def visit(self, v): 35 | if isinstance(v, clang.seq.TokenInfo): 36 | if v.kind == "raw_identifier" and "var" in v.name: 37 | self.S.append(v.name) 38 | elif ( 39 | v.name in ["for", "while", "do", "if", "else", "return"] 40 | or v.name in ["fn_0"] 41 | or v.name.startswith("int") 42 | or v.name.startswith("float") 43 | ): 44 | self.S.append(v.name) 45 | else: 46 | self.S.append(v.kind) 47 | 48 | 49 | class SyntaxSeqBuilder(common.RepresentationBuilder): 50 | def __init__(self, clang_driver=None): 51 | common.RepresentationBuilder.__init__(self) 52 | 53 | if clang_driver: 54 | self.__clang_driver = clang_driver 55 | else: 56 | self.__clang_driver = ClangDriver( 57 | ClangDriver.ProgrammingLanguage.C, 58 | ClangDriver.OptimizationLevel.O3, 59 | [], 60 | ["-Wall"], 61 | ) 62 | self.__extractor = ClangExtractor(self.__clang_driver) 63 | 64 | def string_to_info(self, src, additional_include_dir=None, filename=None): 65 | with clang_driver_scoped_options(self.__clang_driver, additional_include_dir=additional_include_dir, filename=filename): 66 | return self.__extractor.SeqFromString(src) 67 | 68 | def info_to_representation(self, info, visitor=SyntaxTokenkindVariableVisitor): 69 | vis = visitor() 70 | info.accept(vis) 71 | 72 | for token in vis.S: 73 | if token not in self._tokens: 74 | self._tokens[token] = 1 75 | self._tokens[token] += 1 76 | 77 | return common.Sequence(vis.S, self.get_tokens()) 78 | -------------------------------------------------------------------------------- /compy/representations/syntax_seq_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from compy.representations.extractors.extractors import Visitor 4 | from compy.representations.extractors.extractors import clang 5 | from compy.representations.syntax_seq import SyntaxSeqBuilder 6 | from compy.representations.syntax_seq import SyntaxSeqVisitor 7 | from compy.representations.syntax_seq import SyntaxTokenkindVisitor 8 | from compy.representations.syntax_seq import SyntaxTokenkindVariableVisitor 9 | 10 | 11 | program_1fn_2 = """ 12 | int bar(int a) { 13 | if (a > 10) 14 | return a; 15 | return -1; 16 | } 17 | """ 18 | 19 | 20 | program_fib = """ 21 | int fib(int x) { 22 | switch(x) { 23 | case 0: 24 | return 0; 25 | case 1: 26 | return 1; 27 | default: 28 | return fib(x-1) + fib(x-2); 29 | } 30 | } 31 | """ 32 | 33 | 34 | # Construction 35 | def test_construct_with_custom_visitor(): 36 | builder = SyntaxSeqBuilder() 37 | info = builder.string_to_info(program_1fn_2) 38 | seq = builder.info_to_representation(info, SyntaxTokenkindVariableVisitor) 39 | 40 | 41 | # General tests: Plot 42 | def test_plot(tmpdir): 43 | builder = SyntaxSeqBuilder() 44 | info = builder.string_to_info(program_1fn_2) 45 | seq = builder.info_to_representation(info) 46 | 47 | outfile = os.path.join(tmpdir, "syntax_seq.png") 48 | seq.draw(path=outfile, width=8) 49 | 50 | assert os.path.isfile(outfile) 51 | 52 | # os.system('xdg-open ' + str(outfile)) 53 | 54 | 55 | # All visitors 56 | def test_all_visitors(): 57 | for visitor in [ 58 | SyntaxSeqVisitor, 59 | SyntaxTokenkindVisitor, 60 | SyntaxTokenkindVariableVisitor, 61 | ]: 62 | builder = SyntaxSeqBuilder() 63 | info = builder.string_to_info(program_1fn_2) 64 | ast = builder.info_to_representation(info, visitor) 65 | 66 | assert ast 67 | -------------------------------------------------------------------------------- /docs/img/flow-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tud-ccc/compy-learn/24a394bd1ddc109a37b348c3de440389fa9a1f23/docs/img/flow-overview.png -------------------------------------------------------------------------------- /docs/img/representation-examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tud-ccc/compy-learn/24a394bd1ddc109a37b348c3de440389fa9a1f23/docs/img/representation-examples.png -------------------------------------------------------------------------------- /examples/devmap_exploration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.model_selection import StratifiedKFold 4 | 5 | from compy import datasets as D 6 | from compy import models as M 7 | from compy import representations as R 8 | from compy.representations.extractors import ClangDriver 9 | 10 | 11 | # Load dataset 12 | dataset = D.OpenCLDevmapDataset() 13 | 14 | # Explore combinations 15 | combinations = [ 16 | # CGO 20: AST+DF, CDFG 17 | (R.ASTGraphBuilder, R.ASTDataVisitor, M.GnnPytorchGeomModel), 18 | (R.LLVMGraphBuilder, R.LLVMCDFGVisitor, M.GnnPytorchGeomModel), 19 | # Arxiv 20: ProGraML 20 | (R.LLVMGraphBuilder, R.LLVMProGraMLVisitor, M.GnnPytorchGeomModel), 21 | # PACT 17: DeepTune 22 | (R.SyntaxSeqBuilder, R.SyntaxTokenkindVariableVisitor, M.RnnTfModel), 23 | # Extra 24 | (R.ASTGraphBuilder, R.ASTDataCFGVisitor, M.GnnPytorchGeomModel), 25 | (R.LLVMGraphBuilder, R.LLVMCDFGCallVisitor, M.GnnPytorchGeomModel), 26 | (R.LLVMGraphBuilder, R.LLVMCDFGPlusVisitor, M.GnnPytorchGeomModel), 27 | ] 28 | 29 | for builder, visitor, model in combinations: 30 | print("Processing %s-%s-%s" % (builder.__name__, visitor.__name__, model.__name__)) 31 | 32 | # Build representation 33 | clang_driver = ClangDriver( 34 | ClangDriver.ProgrammingLanguage.OpenCL, 35 | ClangDriver.OptimizationLevel.O3, 36 | [(x, ClangDriver.IncludeDirType.User) for x in dataset.additional_include_dirs], 37 | ["-xcl", "-target", "x86_64-pc-linux-gnu"], 38 | ) 39 | data = dataset.preprocess(builder(clang_driver), visitor) 40 | 41 | # Train and test 42 | kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=204) 43 | split = kf.split(data["samples"], [sample["info"][5] for sample in data["samples"]]) 44 | for train_idx, test_idx in split: 45 | model = model(num_types=data["num_types"]) 46 | train_summary = model.train( 47 | list(np.array(data["samples"])[train_idx]), 48 | list(np.array(data["samples"])[test_idx]), 49 | ) 50 | print(train_summary) 51 | 52 | break 53 | -------------------------------------------------------------------------------- /install_deps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | function add_llvm_10_apt_source { 4 | curl https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - 5 | if [[ $1 == "16.04" ]]; then 6 | echo "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-10 main" | sudo tee -a /etc/apt/sources.list 7 | elif [[ $1 == "18.04" ]]; then 8 | echo "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-10 main" | sudo tee -a /etc/apt/sources.list 9 | fi 10 | sudo apt-get -qq update 11 | } 12 | 13 | function install_system_packages { 14 | sudo apt install -y graphviz libgraphviz-dev 15 | sudo apt install -y libllvm10 llvm-10-dev 16 | sudo apt install -y clang-10 libclang1-10 libclang-10-dev libclang-common-10-dev 17 | } 18 | 19 | function install_python_packages { 20 | CUDA=$1 21 | 22 | python3 -m pip install torch==1.5.0+${CUDA} -f https://download.pytorch.org/whl/torch_stable.html 23 | python3 -m pip install torchvision==0.6.0 24 | python3 -m pip install torch-scatter==2.0.5 -f https://pytorch-geometric.com/whl/torch-1.5.0+${CUDA}.html 25 | python3 -m pip install torch-sparse==0.6.7 -f https://pytorch-geometric.com/whl/torch-1.5.0+${CUDA}.html 26 | python3 -m pip install torch-cluster==1.5.7 -f https://pytorch-geometric.com/whl/torch-1.5.0+${CUDA}.html 27 | python3 -m pip install torch-spline-conv==1.2.0 -f https://pytorch-geometric.com/whl/torch-1.5.0+${CUDA}.html 28 | if [[ "$CUDA" != "cpu" ]]; then 29 | python3 -m pip install dgl-$CUDA 30 | else 31 | python3 -m pip install dgl 32 | fi 33 | python3 -m pip install tensorflow==2.2.2 34 | } 35 | 36 | 37 | if [ $# -eq 0 ] 38 | then 39 | echo "Usage: install_deps {cpu|cu92|cu100|cu101}" 40 | exit 1 41 | fi 42 | 43 | if [[ $(lsb_release -rs) == "16.04" ]] || [[ $(lsb_release -rs) == "18.04" ]]; then 44 | echo "OS supported." 45 | if ! grep -q 'llvm-toolchain-.*-10' /etc/apt/sources.list; then 46 | add_llvm_10_apt_source $(lsb_release -rs) 47 | fi 48 | elif [[ $(lsb_release -rs) == "20.04" ]]; then 49 | echo "OS supported." 50 | else 51 | echo "Non-supported OS. You have to install the packages manually." 52 | exit 1 53 | fi 54 | 55 | install_system_packages 56 | install_python_packages $1 57 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | 3 | import os 4 | import re 5 | import sys 6 | import platform 7 | import subprocess 8 | 9 | from distutils.version import LooseVersion 10 | from setuptools import setup, Extension, find_packages 11 | from setuptools.command.build_ext import build_ext 12 | from shutil import copyfile, copymode 13 | from pathlib import Path 14 | 15 | install_requires = [ 16 | "appdirs", 17 | "gitpython", 18 | "numpy", 19 | "torch", 20 | "torch-geometric", 21 | "tensorflow", 22 | "pygraphviz", 23 | "appdirs", 24 | "pandas", 25 | "tqdm", 26 | ] 27 | tests_require = ["pytest", "pytest-cov"] 28 | 29 | 30 | class CMakeExtension(Extension): 31 | def __init__(self, name, sourcedir=""): 32 | Extension.__init__(self, name, sources=[]) 33 | self.sourcedir = os.path.abspath(sourcedir) 34 | 35 | 36 | class CMakeBuild(build_ext): 37 | def run(self): 38 | try: 39 | out = subprocess.check_output(["cmake", "--version"]) 40 | except OSError: 41 | raise RuntimeError( 42 | "CMake must be installed to build the following extensions: " 43 | + ", ".join(e.name for e in self.extensions) 44 | ) 45 | 46 | if platform.system() == "Windows": 47 | cmake_version = LooseVersion( 48 | re.search(r"version\s*([\d.]+)", out.decode()).group(1) 49 | ) 50 | if cmake_version < "3.1.0": 51 | raise RuntimeError("CMake >= 3.1.0 is required on Windows") 52 | 53 | for ext in self.extensions: 54 | self.build_extension(ext) 55 | 56 | def build_extension(self, ext): 57 | self.dist_folder = Path(self.get_ext_fullpath(ext.name)).parent.absolute() 58 | 59 | cmake_args = ["-DPYTHON_EXECUTABLE=" + sys.executable] 60 | 61 | cfg = "Debug" if self.debug else "Release" 62 | build_args = ["--config", cfg] 63 | 64 | cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] 65 | build_args += ["--", "-j" + os.environ.get("COMPY_BUILD_JOBS", '8')] 66 | 67 | print(" ".join(cmake_args)) 68 | 69 | env = os.environ.copy() 70 | env["CXXFLAGS"] = '{} -DVERSION_INFO=\\"{}\\"'.format( 71 | env.get("CXXFLAGS", ""), self.distribution.get_version() 72 | ) 73 | if not os.path.exists(self.build_temp): 74 | os.makedirs(self.build_temp) 75 | subprocess.check_call( 76 | ["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env 77 | ) 78 | subprocess.check_call( 79 | ["cmake", "--build", "."] + build_args, cwd=self.build_temp 80 | ) 81 | 82 | # Deploy shared library files to python folder structure. 83 | all_files = [ 84 | os.path.join(dp, f) 85 | for dp, dn, filenames in os.walk(self.build_temp) 86 | for f in filenames 87 | ] 88 | for filename in all_files: 89 | if filename.endswith(".so"): 90 | self.copy_pybind11_file(filename) 91 | 92 | # Copy *_test file to tests directory 93 | all_files = [ 94 | os.path.join(dp, f) 95 | for dp, dn, filenames in os.walk(self.build_temp) 96 | for f in filenames 97 | ] 98 | for filename in all_files: 99 | if filename.endswith("_tests"): 100 | self.copy_test_file(filename) 101 | print() 102 | 103 | def copy_pybind11_file(self, src_file): 104 | dest_file = os.path.relpath(src_file, self.build_temp) 105 | # dest_file = os.path.join(os.path.dirname(dest_file), '..', os.path.basename(dest_file)) 106 | print("copying {} -> {}".format(src_file, dest_file)) 107 | copyfile(src_file, dest_file) 108 | copymode(src_file, dest_file) 109 | 110 | dest_dist = os.path.join(self.dist_folder, dest_file) 111 | os.makedirs(os.path.dirname(dest_dist), exist_ok=True) 112 | copyfile(src_file, dest_dist) 113 | copymode(src_file, dest_dist) 114 | 115 | def copy_test_file(self, src_file): 116 | """ 117 | Copy ``src_file`` to ``dest_file`` ensuring parent directory exists. 118 | """ 119 | # Create directory if needed 120 | dest_dir = os.path.join( 121 | os.path.dirname(os.path.abspath(__file__)), "tests", "bin" 122 | ) 123 | if dest_dir != "" and not os.path.exists(dest_dir): 124 | print("creating directory {}".format(dest_dir)) 125 | os.makedirs(dest_dir) 126 | 127 | # Copy file 128 | dest_file = os.path.join(dest_dir, os.path.basename(src_file)) 129 | print("copying {} -> {}".format(src_file, dest_file)) 130 | copyfile(src_file, dest_file) 131 | copymode(src_file, dest_file) 132 | 133 | 134 | setup( 135 | name="ComPy", 136 | version="0.1b0", 137 | description="", 138 | author="ComPy-Learn authors", 139 | author_email="", 140 | long_description="", 141 | install_requires=install_requires, 142 | tests_require=tests_require, 143 | packages=find_packages("."), 144 | ext_modules=[CMakeExtension("")], 145 | cmdclass=dict(build_ext=CMakeBuild), 146 | test_suite="tests", 147 | zip_safe=False, 148 | ) 149 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tud-ccc/compy-learn/24a394bd1ddc109a37b348c3de440389fa9a1f23/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_runner.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import subprocess 3 | import os 4 | import pytest 5 | import sys 6 | 7 | 8 | class MainTest(unittest.TestCase): 9 | def test_cpp(self): 10 | print("\n\nRunning C++ tests...") 11 | all_files = [ 12 | os.path.join(dp, f) 13 | for dp, dn, filenames in os.walk( 14 | os.path.join(os.path.dirname(os.path.relpath(__file__)), "bin") 15 | ) 16 | for f in filenames 17 | ] 18 | for filename in all_files: 19 | print(filename) 20 | subprocess.check_call(filename) 21 | 22 | def test_python(self): 23 | print("\n\nRunning Python tests...") 24 | sys.path.append("src") 25 | 26 | args = [os.path.join(os.path.basename(__file__), "..", "compy")] 27 | 28 | # Verbose 29 | args += ["-v"] 30 | 31 | # Show outputs 32 | args += ["-s"] 33 | 34 | # Coverage check 35 | args += ["--cov=compy"] 36 | 37 | # # HTML report for coverage check 38 | # args = ['--cov-report', 'html'] + args 39 | 40 | # XML report for codecov 41 | args = ["--cov-report=xml"] + args 42 | 43 | assert pytest.main(args) == pytest.ExitCode.OK 44 | 45 | 46 | if __name__ == "__main__": 47 | unittest.main() 48 | --------------------------------------------------------------------------------