├── tests ├── __init__.py ├── test_files │ ├── empty_test.c │ ├── struct_test.c │ └── patch_test.c ├── conftest.py ├── test_patch_lang.py └── test_tourniquet.py ├── tourniquet ├── version.py ├── __init__.py ├── extractor.pyi ├── error.py ├── target.py ├── location.py ├── tourniquet.py ├── models.py └── patch_lang.py ├── .dockerignore ├── mypy.ini ├── .flake8 ├── MANIFEST.in ├── dev-requirements.txt ├── .gitignore ├── .github ├── dependabot.yml └── workflows │ ├── release.yml │ ├── docs-deploy.yml │ └── ci.yml ├── pyproject.toml ├── Dockerfile ├── Makefile ├── CMakeLists.txt ├── transformer ├── ASTPatch.h ├── ASTExporter.h ├── extractor.cpp └── ASTExporter.cpp ├── setup.py ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_files/empty_test.c: -------------------------------------------------------------------------------- 1 | int main() {} 2 | -------------------------------------------------------------------------------- /tourniquet/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.5" 2 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | *.so 2 | *.egg-info 3 | build/ 4 | env/ 5 | dist/ 6 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.7 3 | plugins = sqlmypy 4 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | ignore = E203, W503 4 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include transformer/*.cpp 2 | include transformer/*.h 3 | include CMakeLists.txt 4 | include dev-requirements.txt 5 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | ipython 2 | pytest 3 | pytest-cov 4 | flake8 5 | black 6 | mypy 7 | isort 8 | sqlalchemy-stubs 9 | pdoc3 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | pip-wheel-metadata/ 2 | *.egg-info/ 3 | build/ 4 | env/ 5 | .mypy_cache/ 6 | __pycache__/ 7 | extractor.so 8 | .coverage 9 | html/ 10 | dist/ 11 | -------------------------------------------------------------------------------- /tourniquet/__init__.py: -------------------------------------------------------------------------------- 1 | from .tourniquet import * # noqa: F401, F403 2 | from .version import __version__ # noqa: F401 3 | 4 | __pdoc__ = {"extractor": False, "version": False} 5 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | time: "10:00" 8 | open-pull-requests-limit: 10 9 | -------------------------------------------------------------------------------- /tests/test_files/struct_test.c: -------------------------------------------------------------------------------- 1 | /* 2 | * Typedef Struct 3 | */ 4 | typedef struct _test_t { 5 | char buff[20]; 6 | int num_items; 7 | } test_t; 8 | 9 | int main() { 10 | 11 | return 0; 12 | } 13 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | line_length = 100 3 | multi_line_output = 3 4 | known_first_party = "tourniquet" 5 | include_trailing_comma = true 6 | force_grid_wrap = 0 7 | use_parentheses = true 8 | ensure_newline_before_comments = true 9 | 10 | 11 | [tool.black] 12 | line-length = 100 13 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | TEST_DIR = Path(__file__).resolve().parent 6 | TEST_FILE_DIR = TEST_DIR / "test_files" 7 | 8 | 9 | @pytest.fixture 10 | def test_files() -> Path: 11 | return TEST_FILE_DIR 12 | 13 | 14 | @pytest.fixture 15 | def tmp_db(tmp_path) -> Path: 16 | return tmp_path.with_suffix(".db") 17 | -------------------------------------------------------------------------------- /tourniquet/extractor.pyi: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Any, Dict 3 | 4 | def extract_ast(filename: PathLike, is_cxx: bool) -> Dict[str, Any]: ... 5 | def transform( 6 | filename: PathLike, 7 | is_cxx: bool, 8 | replacement: str, 9 | start_line: int, 10 | start_col: int, 11 | end_line: int, 12 | end_col: int, 13 | ): ... 14 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 as base 2 | 3 | MAINTAINER Carson Harmon 4 | 5 | RUN apt-get update 6 | 7 | RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y \ 8 | clang-9 \ 9 | git \ 10 | build-essential \ 11 | python3.7-dev \ 12 | python3-pip \ 13 | llvm-9-dev \ 14 | libclang-9-dev \ 15 | apt-transport-https \ 16 | ca-certificates \ 17 | gnupg \ 18 | software-properties-common \ 19 | wget 20 | 21 | RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null \ 22 | | gpg --dearmor - \ 23 | | tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null && \ 24 | apt-add-repository 'deb https://apt.kitware.com/ubuntu/ bionic main' && \ 25 | apt-get update && \ 26 | apt-get install -y cmake 27 | 28 | RUN python3.7 -m pip install --upgrade pip 29 | 30 | WORKDIR / 31 | COPY . /tourniquet 32 | 33 | WORKDIR /tourniquet 34 | 35 | RUN pip3 install -e .[dev] 36 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | tags: 4 | - 'v*' 5 | 6 | name: release 7 | 8 | jobs: 9 | pypi: 10 | name: upload release to PyPI 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: actions/setup-python@v1 15 | with: 16 | python-version: 3.8 17 | - name: create release 18 | id: create_release 19 | uses: actions/create-release@v1 20 | env: 21 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 22 | with: 23 | tag_name: ${{ github.ref }} 24 | release_name: Release ${{ github.ref }} 25 | draft: false 26 | prerelease: ${{ contains(github.ref, 'pre') || contains(github.ref, 'rc') }} 27 | - name: sdist 28 | run: python3 setup.py sdist 29 | - name: publish 30 | uses: pypa/gh-action-pypi-publish@master 31 | with: 32 | user: __token__ 33 | password: ${{ secrets.PYPI_TOKEN }} 34 | -------------------------------------------------------------------------------- /tourniquet/error.py: -------------------------------------------------------------------------------- 1 | class Error(Exception): 2 | """ 3 | A base error for all tourniquet exceptions. 4 | """ 5 | 6 | pass 7 | 8 | 9 | class TemplateError(Error): 10 | """ 11 | A base error for template-related tourniquet exceptions. 12 | """ 13 | 14 | pass 15 | 16 | 17 | class TemplateNameError(TemplateError): 18 | """ 19 | Raised whenever a template name conflict or name lookup failure occurs. 20 | """ 21 | 22 | pass 23 | 24 | 25 | class ASTError(Error): 26 | """ 27 | A base error for AST-related tourniquet exceptions. 28 | """ 29 | 30 | pass 31 | 32 | 33 | class PatchSituationError(ASTError): 34 | """ 35 | Raised whenever a patch can't be situated at the requested location. 36 | """ 37 | 38 | 39 | class PatchConcretizationError(ASTError): 40 | """ 41 | Raised whenever a PatchLang expression can't be concretized against the AST. 42 | """ 43 | 44 | pass 45 | 46 | 47 | # TODO(ww): Think about errors for location concretization. 48 | -------------------------------------------------------------------------------- /.github/workflows/docs-deploy.yml: -------------------------------------------------------------------------------- 1 | name: docs-deploy 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | docs: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | with: 14 | submodules: true 15 | 16 | - uses: actions/setup-python@v1 17 | with: 18 | python-version: "3.8" 19 | 20 | - name: system deps 21 | run: | 22 | sudo apt-get update -y 23 | sudo apt install -y \ 24 | clang-9 \ 25 | llvm-9-dev \ 26 | libclang-9-dev 27 | 28 | - name: docs 29 | run: | 30 | export LLVM_DIR=/usr/lib/llvm-9/lib/cmake/llvm 31 | export Clang_DIR=/usr/lib/llvm-9/lib/cmake/clang 32 | make doc 33 | 34 | - name: deploy 35 | uses: peaceiris/actions-gh-pages@v3.7.3 36 | with: 37 | github_token: ${{ secrets.GITHUB_TOKEN }} 38 | publish_dir: ./html/tourniquet 39 | publish_branch: gh-pages 40 | force_orphan: true 41 | -------------------------------------------------------------------------------- /tourniquet/target.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from typing import List, Tuple 4 | 5 | 6 | class Target: 7 | """ 8 | This class represents the candidate program to repair 9 | The path to the program, the tests (as a list of tuples for now) 10 | and the build/test commands are required 11 | 12 | We might be able to automatically get some info from blight 13 | """ 14 | 15 | def __init__( 16 | self, 17 | filepath: str, 18 | tests: List[Tuple[str, int]], 19 | build_cmd: List[str], 20 | executable_path: str, 21 | ): 22 | self.file_path = filepath 23 | if not os.path.exists(self.file_path): 24 | raise FileNotFoundError 25 | self.tests = tests 26 | self.build_cmd = build_cmd 27 | self.bin_path = executable_path 28 | 29 | def build(self) -> bool: 30 | ret_code = subprocess.call(self.build_cmd) 31 | return ret_code == 0 32 | 33 | # This runs the bin specified by bin path with the tests as arguments 34 | def run_tests(self) -> bool: 35 | return False 36 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ALL_CXX_SRCS := $(shell find transformer -name '*.cpp' -o -name '*.h') 2 | 3 | ALL_PY_SRCS := setup.py \ 4 | $(shell find tourniquet -name '*.py') \ 5 | $(shell find tests -name '*.py') 6 | 7 | .PHONY: all 8 | all: 9 | @echo "Run my targets individually!" 10 | 11 | env: 12 | test -d env || python3 -m venv env && \ 13 | . env/bin/activate && \ 14 | pip install --upgrade pip 15 | 16 | .PHONY: dev 17 | dev: env 18 | . env/bin/activate && \ 19 | pip install -r dev-requirements.txt 20 | 21 | .PHONY: build 22 | build: env 23 | . env/bin/activate && \ 24 | pip install -e . 25 | 26 | .PHONY: py-lint 27 | py-lint: dev 28 | . env/bin/activate && \ 29 | black $(ALL_PY_SRCS) && \ 30 | isort $(ALL_PY_SRCS) && \ 31 | flake8 $(ALL_PY_SRCS) 32 | 33 | 34 | .PHONY: cxx-lint 35 | cxx-lint: 36 | clang-format -i $(ALL_CXX_SRCS) 37 | 38 | .PHONY: lint 39 | lint: py-lint cxx-lint 40 | git diff --exit-code 41 | 42 | .PHONY: typecheck 43 | typecheck: 44 | . env/bin/activate && \ 45 | mypy tourniquet 46 | 47 | .PHONY: test 48 | test: dev build 49 | . env/bin/activate && \ 50 | pytest tests/ 51 | 52 | .PHONY: test-cov 53 | test-cov: dev 54 | . env/bin/activate && \ 55 | pytest --cov=tourniquet/ tests/ 56 | 57 | .PHONY: doc 58 | doc: dev build 59 | . env/bin/activate && \ 60 | PYTHONWARNINGS='error::UserWarning' pdoc --force --html tourniquet 61 | 62 | .PHONY: edit 63 | edit: 64 | $(EDITOR) $(ALL_PY_SRCS) 65 | -------------------------------------------------------------------------------- /tests/test_files/patch_test.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | /* 4 | * Worlds easiest crackme 5 | * Copies the arg into the buff, and checks to see if the password is password. 6 | * You can either exploit the challenge or solve it by reversing the password 7 | * 8 | * This is a PoC of the automated patcher prototype 9 | * 10 | * It uses the MATE CPG to collect all globals, parameters, and local variables 11 | * to use in patch templates. It then tries to automatically fill in templates, 12 | * and pass a test suite. 13 | * 14 | * To prevent the exploit/crash, a plausible patch is to just delete the strcpy 15 | * statement. But deleting the strcpy statement ruins the rest of the program, 16 | * as you can no longer copy passwords into the password buffer. Having an 17 | * additional test with input "password" which normally should pass, and fails 18 | * when removing strcpy, prunes these plausible but incorrect patches. 19 | */ 20 | char *pass = "password"; 21 | 22 | int main(int argc, char *argv[]) { 23 | char buff[10]; 24 | int buff_len = sizeof(buff); 25 | char *pov = argv[1]; 26 | int len = strlen(argv[1]); 27 | // Just use the variables so the build system builds our challenge 28 | if (buff_len == len) { 29 | printf("Buffer sizes are of similar length!\n"); 30 | } 31 | // Possible patch, if (len < buff_len) { 32 | strcpy(buff, pov); 33 | if (strcmp(buff, pass) == 0) { 34 | return 0; 35 | } 36 | 37 | return 2; 38 | } 39 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: push 4 | 5 | jobs: 6 | docker: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | 11 | - name: docker build 12 | run: docker build -t tourniquet . 13 | 14 | tests-linux: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.7, 3.8] 19 | llvm-version: [9, 10, 11] 20 | steps: 21 | - uses: actions/checkout@v2 22 | 23 | - uses: actions/setup-python@v1 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: llvm deps 28 | run: | 29 | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - 30 | sudo add-apt-repository "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-${{ matrix.llvm-version }} main" 31 | sudo apt update 32 | sudo apt install \ 33 | llvm-${{ matrix.llvm-version }} \ 34 | llvm-${{ matrix.llvm-version }}-dev \ 35 | clang-${{ matrix.llvm-version }} \ 36 | libclang-${{ matrix.llvm-version }}-dev 37 | 38 | - name: lint 39 | run: make lint 40 | 41 | - name: typecheck 42 | run: make typecheck 43 | 44 | - name: test 45 | run: | 46 | export LLVM_DIR=/usr/lib/llvm-${{ matrix.llvm-version }}/lib/cmake/llvm 47 | export Clang_DIR=/usr/lib/llvm-${{ matrix.llvm-version }}/lib/cmake/clang 48 | make test 49 | 50 | tests-macos: 51 | runs-on: macos-latest 52 | strategy: 53 | matrix: 54 | python-version: [3.7, 3.8] 55 | llvm-version: [9] 56 | steps: 57 | - uses: actions/checkout@v2 58 | 59 | - uses: actions/setup-python@v1 60 | with: 61 | python-version: ${{ matrix.python-version }} 62 | 63 | - name: system deps 64 | run: brew install llvm@${{ matrix.llvm-version }} 65 | 66 | - name: python deps 67 | run: make dev 68 | 69 | - name: build 70 | run: make build 71 | 72 | - name: test 73 | run: make test 74 | -------------------------------------------------------------------------------- /tourniquet/location.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Iterator 5 | 6 | 7 | @dataclass(frozen=True) 8 | class SourceCoordinate: 9 | """ 10 | Encapsulates the bare amount of state required to uniquely locate 11 | a source feature within *some* unspecified source file. 12 | """ 13 | 14 | line: int 15 | """ 16 | The line that the feature occurs on. 17 | """ 18 | 19 | column: int 20 | """ 21 | The column that the feature occurs on. 22 | """ 23 | 24 | 25 | @dataclass(frozen=True) 26 | class Span: 27 | """ 28 | Encapsulates a "span" of a source feature, i.e. its start and end lines 29 | and columns. 30 | """ 31 | 32 | filename: Path 33 | """ 34 | The path to the file that the span occurs in. 35 | """ 36 | 37 | start: SourceCoordinate 38 | """ 39 | The coordinates (line and column) that the span starts on. 40 | """ 41 | 42 | end: SourceCoordinate 43 | """ 44 | The coordinates (line and column) that the span ends on. 45 | """ 46 | 47 | 48 | @dataclass(frozen=True) 49 | class Location: 50 | """ 51 | Encapsulates the bare amount of state required to uniquely locate a source 52 | feature within a program's source code. 53 | 54 | Observe that `Location` does not represent "spans," i.e. the start and end 55 | lines and columns for a source feature. This is intentional. 56 | """ 57 | 58 | filename: Path 59 | """ 60 | The path to the file that the feature occurs in. 61 | """ 62 | 63 | coordinates: SourceCoordinate 64 | """ 65 | The coordinates (line and column) of the feature. 66 | 67 | See also `line` and `column`. 68 | """ 69 | 70 | @property 71 | def line(self) -> int: 72 | """ 73 | Returns the line that the feature occurs on. 74 | """ 75 | return self.coordinates.line 76 | 77 | @property 78 | def column(self) -> int: 79 | """ 80 | Returns the column that the feature occurs on. 81 | """ 82 | return self.coordinates.column 83 | 84 | 85 | class Locator(ABC): 86 | """ 87 | Represents an abstract "locator," which can be concretized into a 88 | iterator of unique source locations. 89 | """ 90 | 91 | @abstractmethod 92 | def concretize(self) -> Iterator[Location]: 93 | ... 94 | 95 | 96 | class TrivialLocator(Locator): 97 | """ 98 | A trivial locator that just forwards a single unique source location. 99 | """ 100 | 101 | def __init__(self, filename, line, column): 102 | self._filename = Path(filename) 103 | self._line = line 104 | self._column = column 105 | 106 | def concretize(self) -> Iterator[Location]: 107 | yield Location(self._filename, SourceCoordinate(self._line, self._column)) 108 | -------------------------------------------------------------------------------- /tests/test_patch_lang.py: -------------------------------------------------------------------------------- 1 | from tourniquet import Tourniquet 2 | from tourniquet.location import Location as L 3 | from tourniquet.location import SourceCoordinate as SC 4 | from tourniquet.patch_lang import ( 5 | BinaryBoolOperator, 6 | BinaryMathOperator, 7 | ElseStmt, 8 | FixPattern, 9 | IfStmt, 10 | LessThanExpr, 11 | Lit, 12 | NodeStmt, 13 | ReturnStmt, 14 | StaticBufferSize, 15 | Variable, 16 | ) 17 | 18 | 19 | def test_concretize_lit(): 20 | lit = Lit("1") 21 | assert set(lit.concretize(None, None)) == {"1"} 22 | 23 | 24 | def test_concretize_variable(test_files, tmp_db): 25 | tourniquet = Tourniquet(tmp_db) 26 | test_file = test_files / "patch_test.c" 27 | tourniquet.collect_info(test_file) 28 | 29 | variable = Variable() 30 | location = L(test_file, SC(23, 3)) 31 | concretized = set(variable.concretize(tourniquet.db, location)) 32 | assert len(concretized) == 6 33 | assert concretized == {"argc", "argv", "buff", "buff_len", "pov", "len"} 34 | 35 | 36 | def test_concretize_staticbuffersize(test_files, tmp_db): 37 | tourniquet = Tourniquet(tmp_db) 38 | test_file = test_files / "patch_test.c" 39 | tourniquet.collect_info(test_file) 40 | 41 | sbs = StaticBufferSize() 42 | location = L(test_file, SC(23, 3)) 43 | concretized = set(sbs.concretize(tourniquet.db, location)) 44 | assert concretized == {"sizeof(buff)"} 45 | 46 | 47 | def test_concretize_binarymathoperator(): 48 | bmo = BinaryMathOperator(Lit("1"), Lit("2")) 49 | concretized = set(bmo.concretize(None, None)) 50 | assert concretized == {"1 + 2", "1 - 2", "1 / 2", "1 * 2", "1 << 2"} 51 | 52 | 53 | def test_concretize_binarybooloperator(): 54 | bbo = BinaryBoolOperator(Lit("1"), Lit("2")) 55 | concretized = set(bbo.concretize(None, None)) 56 | assert concretized == {"1 == 2", "1 != 2", "1 <= 2", "1 < 2", "1 >= 2", "1 > 2"} 57 | 58 | 59 | def test_concretize_lessthanexpr(): 60 | lte = LessThanExpr(Lit("1"), Lit("2")) 61 | concretized = set(lte.concretize(None, None)) 62 | assert concretized == {"1 < 2"} 63 | 64 | 65 | def test_concretize_nodestmt(test_files, tmp_db): 66 | tourniquet = Tourniquet(tmp_db) 67 | test_file = test_files / "patch_test.c" 68 | tourniquet.collect_info(test_file) 69 | 70 | node = NodeStmt() 71 | location = L(test_file, SC(23, 3)) 72 | concretized = list(node.concretize(tourniquet.db, location)) 73 | assert len(concretized) == 1 74 | assert concretized[0] == "char buff[10];" 75 | 76 | 77 | def test_concretize_ifstmt(): 78 | ifs = IfStmt(Lit("1"), Lit("bark;")) 79 | assert set(ifs.concretize(None, None)) == {"if (1) {\nbark;\n}\n"} 80 | 81 | 82 | def test_concretize_elsestmt(): 83 | elses = ElseStmt(Lit("bark;")) 84 | assert set(elses.concretize(None, None)) == {"else {\nbark;\n}\n"} 85 | 86 | 87 | def test_concretize_returnstmt(): 88 | rets = ReturnStmt(Lit("foo")) 89 | assert set(rets.concretize(None, None)) == {"return foo;"} 90 | 91 | 92 | # TODO(ww): Tests for: 93 | # * Statement 94 | # * StatementList 95 | # * PatchTemplate 96 | 97 | 98 | def test_fixpattern(): 99 | fp = FixPattern(IfStmt(Lit("1"), Lit("exit(1);")), ElseStmt(Lit("exit(2);"))) 100 | concretized = list(fp.concretize(None, None)) 101 | assert len(concretized) == 1 102 | assert concretized[0] == "if (1) {\nexit(1);\n}\n\nelse {\nexit(2);\n}\n" 103 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10.2) 2 | project(transformer) 3 | 4 | set(CMAKE_FIND_PACKAGE_SORT_ORDER NATURAL) 5 | set(CMAKE_FIND_PACKAGE_SORT_DIRECTION DEC) 6 | 7 | set(SOURCE_DIR "transformer") 8 | include_directories(${SOURCE_DIR}) 9 | 10 | if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") 11 | set(MACOSX TRUE) 12 | endif() 13 | set(CMAKE_CXX_STANDARD 17) 14 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 15 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 16 | find_package( 17 | Python3 18 | COMPONENTS Development 19 | REQUIRED) 20 | include_directories(${Python3_INCLUDE_DIRS}) 21 | link_directories(${Python3_LIBRARY_DIRS}) 22 | add_link_options(${Python3_LINK_OPTIONS}) 23 | 24 | find_package(LLVM REQUIRED CONFIG HINTS "${LLVM_DIR}" ENV LLVM_DIR) 25 | message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") 26 | message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") 27 | 28 | if(("${LLVM_PACKAGE_VERSION}" VERSION_LESS 9.0.0) 29 | OR ("${LLVM_PACKAGE_VERSION}" VERSION_GREATER_EQUAL 12)) 30 | message(FATAL_ERROR "Unsupported LLVM version: ${LLVM_PACKAGE_VERSION}") 31 | endif() 32 | 33 | include_directories(${LLVM_INCLUDE_DIRS}) 34 | link_directories(${LLVM_LIBRARY_DIRS}) 35 | add_definitions(${LLVM_DEFINITIONS}) 36 | 37 | find_package(Clang REQUIRED CONFIG HINTS "${Clang_DIR}" ENV Clang_DIR) 38 | message(STATUS "Using ClangConfig.cmake in: ${Clang_DIR}") 39 | 40 | include_directories(${CLANG_INCLUDE_DIRS}) 41 | add_definitions(${CLANG_DEFINITIONS}) 42 | 43 | set(LIBRARY_LIST 44 | clangAST 45 | clangASTMatchers 46 | clangAnalysis 47 | clangBasic 48 | clangDriver 49 | clangEdit 50 | clangFrontend 51 | clangFrontendTool 52 | clangLex 53 | clangParse 54 | clangSema 55 | clangEdit 56 | clangRewrite 57 | clangRewriteFrontend 58 | clangStaticAnalyzerFrontend 59 | clangStaticAnalyzerCheckers 60 | clangStaticAnalyzerCore 61 | clangCrossTU 62 | clangIndex 63 | clangSerialization 64 | clangToolingCore 65 | clangToolingInclusions 66 | clangTooling 67 | clangFormat) 68 | 69 | if("${LLVM_PACKAGE_VERSION}" VERSION_LESS 10.0.0) 70 | list(APPEND LIBRARY_LIST clangToolingRefactoring) 71 | elseif("${LLVM_PACKAGE_VERSION}" VERSION_GREATER_EQUAL 10.0.0) 72 | list(APPEND LIBRARY_LIST clangTransformer) 73 | endif() 74 | 75 | if(NOT MACOSX) 76 | # NOTE(ww): For reasons that are unclear to me, macOS does *not* want Python 77 | # extensions to link to the Python development libraries directly. Instead, it 78 | # expects to resolve them dynamically via dynamic_lookup. See 79 | # target_link_options below. 80 | set(LIBRARY_LIST ${LIBRARY_LIST} ${Python3_LIBRARIES}) 81 | endif() 82 | 83 | add_definitions(-D__STDC_LIMIT_MACROS -D__STDC_CONSTANT_MACROS) 84 | 85 | add_library(${PROJECT_NAME} SHARED ${SOURCE_DIR}/ASTExporter.cpp 86 | ${SOURCE_DIR}/extractor.cpp) 87 | set_target_properties(${PROJECT_NAME} 88 | PROPERTIES PREFIX "" OUTPUT_NAME ${PYTHON_MODULE_NAME}) 89 | 90 | if(MACOSX) 91 | target_link_libraries(${PROJECT_NAME} ${LIBRARY_LIST}) 92 | # NOTE(ww): See the conditional inclusion above. 93 | target_link_options(${PROJECT_NAME} PUBLIC -undefined dynamic_lookup) 94 | # NOTE(ww): Python extensions on macOS don't use .dylib. 95 | set_target_properties(${PROJECT_NAME} PROPERTIES SUFFIX ".so") 96 | else() 97 | target_link_libraries(${PROJECT_NAME} -Wl,--start-group ${LIBRARY_LIST} 98 | -Wl,--end-group) 99 | endif() 100 | target_link_libraries(${PROJECT_NAME} LLVM pthread m) 101 | -------------------------------------------------------------------------------- /transformer/ASTPatch.h: -------------------------------------------------------------------------------- 1 | /* 2 | * ASTTransformer.cpp 3 | * 4 | * Created on: Aug 18, 2020 5 | * Author: carson 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "clang/AST/RecursiveASTVisitor.h" 11 | #include "clang/Frontend/CompilerInstance.h" 12 | #include "clang/Lex/Lexer.h" 13 | #include "clang/Rewrite/Core/Rewriter.h" 14 | 15 | #include "clang/Lex/Lexer.h" 16 | #if LLVM_VERSION_MAJOR <= 9 17 | #include "clang/Tooling/Refactoring/SourceCode.h" 18 | #else 19 | #include "clang/Tooling/Transformer/SourceCode.h" 20 | #endif 21 | 22 | #include "llvm/Support/JSON.h" 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | #include 36 | 37 | using namespace clang; 38 | using namespace llvm; 39 | 40 | class ASTPatchConsumer : public clang::ASTConsumer { 41 | public: 42 | explicit ASTPatchConsumer(clang::ASTContext *Context, Rewriter *rewriter, 43 | int start_line, int start_col, int end_line, 44 | int end_col, std::string replacement) { 45 | SourceManager &srcMgr = rewriter->getSourceMgr(); 46 | FileID id = rewriter->getSourceMgr().getMainFileID(); 47 | const FileEntry *entry = rewriter->getSourceMgr().getFileEntryForID(id); 48 | SourceLocation start_loc = 49 | srcMgr.translateFileLineCol(entry, start_line, start_col); 50 | SourceLocation end_loc = 51 | srcMgr.translateFileLineCol(entry, end_line, end_col); 52 | SourceRange src_range(start_loc, end_loc); 53 | rewriter->ReplaceText(src_range, replacement.c_str()); 54 | } 55 | }; 56 | 57 | class ASTPatchAction : public ASTFrontendAction { 58 | public: 59 | explicit ASTPatchAction(int start_line, int start_col, int end_line, 60 | int end_col, std::string replacement, 61 | std::string filepath) 62 | : start_line(start_line), start_col(start_col), end_line(end_line), 63 | end_col(end_col), replacement(replacement), filepath(filepath) {} 64 | 65 | ASTPatchAction(const ASTPatchAction &) = delete; 66 | ASTPatchAction &operator=(const ASTPatchAction &) = delete; 67 | 68 | // TODO There is probably a better place to do this, HandleTranslationUnit 69 | // maybe? 70 | // TODO Best way to handle errors? 71 | void EndSourceFileAction() override { 72 | FileID id = rewriter.getSourceMgr().getMainFileID(); 73 | const FileEntry *Entry = rewriter.getSourceMgr().getFileEntryForID(id); 74 | std::string string_buffer; 75 | llvm::raw_string_ostream output_stream(string_buffer); 76 | rewriter.getRewriteBufferFor(id)->write(output_stream); 77 | std::ofstream ofs(filepath, std::ofstream::trunc); 78 | if (ofs.is_open()) { 79 | ofs << output_stream.str(); 80 | } else { 81 | PyErr_SetString(PyExc_IOError, "Failed to open file for patching"); 82 | } 83 | } 84 | 85 | std::unique_ptr CreateASTConsumer(CompilerInstance &Compiler, 86 | StringRef file) override { 87 | rewriter.setSourceMgr(Compiler.getSourceManager(), Compiler.getLangOpts()); 88 | return std::unique_ptr( 89 | new ASTPatchConsumer(&Compiler.getASTContext(), &rewriter, start_line, 90 | start_col, end_line, end_col, replacement)); 91 | } 92 | 93 | private: 94 | Rewriter rewriter; 95 | int start_line, start_col, end_line, end_col; 96 | std::string replacement; 97 | std::string filepath; 98 | }; 99 | -------------------------------------------------------------------------------- /transformer/ASTExporter.h: -------------------------------------------------------------------------------- 1 | /* 2 | * ASTExporter.h 3 | * 4 | * Created on: Aug 6, 2020 5 | * Author: carson 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "clang/AST/RecursiveASTVisitor.h" 11 | #include "clang/Frontend/CompilerInstance.h" 12 | 13 | #include "clang/Lex/Lexer.h" 14 | #if LLVM_VERSION_MAJOR <= 9 15 | #include "clang/Tooling/Refactoring/SourceCode.h" 16 | #else 17 | #include "clang/Tooling/Transformer/SourceCode.h" 18 | #endif 19 | 20 | #include "llvm/Support/JSON.h" 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | #include 33 | 34 | using namespace clang::tooling; 35 | using namespace clang; 36 | 37 | /* 38 | * This is a simpler AST visitor that collects some information from nodes it 39 | * vists and records some information about the nodes. For search based repair, 40 | * we are not doing deep analysis of individual statements. Instead, we want to 41 | * be able to access functions, parameters, local/global variables, statements, 42 | * and arguments passed to call exprs. There might be some more info we need 43 | * when dealing with templates etc. 44 | */ 45 | class ASTExporterVisitor 46 | : public clang::RecursiveASTVisitor { 47 | public: 48 | ASTExporterVisitor(ASTContext *Context, PyObject *info); 49 | bool VisitDeclStmt(Stmt *stmt); 50 | bool VisitVarDecl(VarDecl *vdecl); 51 | bool VisitCallExpr(CallExpr *call_expr); 52 | bool VisitFunctionDecl(FunctionDecl *func_decl); 53 | 54 | private: 55 | void PyDictUpdateEntry(PyObject *dict, const char *key, PyObject *new_item); 56 | PyObject *BuildStmtEntry(Stmt *stmt); 57 | void AddGlobalEntry(PyObject *entry); 58 | void AddFunctionEntry(const char *func_name, PyObject *entry); 59 | 60 | ASTContext *Context; 61 | PyObject *tree_info; 62 | // Clang doesn't store parental relationships for statements (it does for 63 | // decls) Meaning from a CallExpr you cant find the Caller Function with any 64 | // get method etc. Just keep track of our current function as we traverse 65 | FunctionDecl *current_func; 66 | }; 67 | 68 | /* 69 | * The Consumer/FrontendAction/Factory are all clang boilerplate 70 | * which allow us to trigger an action on the AST and pass variables up/down 71 | * the class hiearchy. 72 | */ 73 | 74 | class ASTExporterConsumer : public clang::ASTConsumer { 75 | public: 76 | explicit ASTExporterConsumer(clang::ASTContext *Context, PyObject *info) 77 | : Visitor(Context, info) {} 78 | 79 | virtual void HandleTranslationUnit(clang::ASTContext &Context) { 80 | Visitor.TraverseDecl(Context.getTranslationUnitDecl()); 81 | } 82 | 83 | private: 84 | ASTExporterVisitor Visitor; 85 | }; 86 | 87 | class ASTExporterFrontendAction : public clang::ASTFrontendAction { 88 | public: 89 | std::unique_ptr 90 | CreateASTConsumer(clang::CompilerInstance &Compiler, llvm::StringRef InFile) { 91 | return std::unique_ptr( 92 | new ASTExporterConsumer(&Compiler.getASTContext(), extract_results_)); 93 | } 94 | 95 | explicit ASTExporterFrontendAction(PyObject *extract_results) 96 | : extract_results_{extract_results} {} 97 | 98 | ASTExporterFrontendAction(const ASTExporterFrontendAction &) = delete; 99 | ASTExporterFrontendAction & 100 | operator=(const ASTExporterFrontendAction &) = delete; 101 | 102 | private: 103 | PyObject *extract_results_; 104 | }; 105 | -------------------------------------------------------------------------------- /transformer/extractor.cpp: -------------------------------------------------------------------------------- 1 | #define PY_SSIZE_T_CLEAN 2 | #include "ASTExporter.h" 3 | #include "ASTPatch.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | static bool read_file_to_string(const std::string &filename, 11 | std::string &data) { 12 | std::ifstream file(filename); 13 | if (!file.is_open()) { 14 | return false; 15 | } 16 | 17 | std::stringstream buf; 18 | buf << file.rdbuf(); 19 | data = buf.str(); 20 | 21 | return true; 22 | } 23 | 24 | template 25 | static void run_clang_tool(std::string &data, int is_cxx, 26 | ToolArgs &&...tool_args) { 27 | std::vector args{"-x"}; 28 | if (is_cxx) { 29 | args.push_back("c++"); 30 | } else { 31 | args.push_back("c"); 32 | } 33 | 34 | #if LLVM_VERSION_MAJOR <= 9 35 | runToolOnCodeWithArgs(new Tool(std::forward(tool_args)...), data, 36 | args); 37 | #else 38 | runToolOnCodeWithArgs(std::make_unique(tool_args...), data, args); 39 | #endif 40 | } 41 | 42 | static PyObject *extract_ast(PyObject *self, PyObject *args) { 43 | PyObject *filename_bytes; 44 | int is_cxx; 45 | if (!PyArg_ParseTuple(args, "O&p", PyUnicode_FSConverter, &filename_bytes, 46 | &is_cxx)) { 47 | return nullptr; 48 | } 49 | 50 | std::string filename = PyBytes_AsString(filename_bytes); 51 | Py_DECREF(filename_bytes); 52 | 53 | std::string data; 54 | if (!read_file_to_string(filename, data)) { 55 | PyErr_SetString(PyExc_IOError, "Failed to open file for extraction"); 56 | return nullptr; 57 | } 58 | 59 | // Allocate dictionary to return to Python 60 | PyObject *extract_results = PyDict_New(); 61 | if (!extract_results) { 62 | PyErr_SetString(PyExc_MemoryError, "Allocation failed for dict"); 63 | return nullptr; 64 | } 65 | PyDict_SetItem(extract_results, PyUnicode_FromString("module_name"), 66 | PyUnicode_FromString(filename.c_str())); 67 | 68 | run_clang_tool(data, is_cxx, extract_results); 69 | 70 | // Return the python dictionary back to the python code. 71 | return extract_results; 72 | } 73 | 74 | static PyObject *transform(PyObject *self, PyObject *args) { 75 | PyObject *filename_bytes; 76 | char *replacement; 77 | int is_cxx; 78 | int start_line, start_col, end_line, end_col; 79 | if (!PyArg_ParseTuple(args, "O&psiiii", PyUnicode_FSConverter, 80 | &filename_bytes, &is_cxx, &replacement, &start_line, 81 | &start_col, &end_line, &end_col)) { 82 | return nullptr; 83 | } 84 | 85 | std::string filename = PyBytes_AsString(filename_bytes); 86 | Py_DECREF(filename_bytes); 87 | 88 | std::string data; 89 | if (!read_file_to_string(filename, data)) { 90 | PyErr_SetString(PyExc_IOError, "Failed to open file for patching"); 91 | return nullptr; 92 | } 93 | 94 | run_clang_tool(data, is_cxx, start_line, start_col, end_line, 95 | end_col, std::string(replacement), filename); 96 | 97 | // The patching action might have failed (and set an appropriate Python 98 | // exception) on an I/O error. If so, return nullptr and allow the exception 99 | // to propagate. 100 | if (PyErr_Occurred()) { 101 | return nullptr; 102 | } 103 | 104 | Py_RETURN_TRUE; 105 | } 106 | 107 | PyMethodDef extractor_methods[] = { 108 | {"extract_ast", extract_ast, METH_VARARGS, 109 | "Returns a dictionary containing AST info for a file"}, 110 | {"transform", transform, METH_VARARGS, 111 | "Transforms the target program with a replacement"}, 112 | {nullptr, nullptr, 0, nullptr}, 113 | }; 114 | 115 | static struct PyModuleDef extractor_definition = { 116 | PyModuleDef_HEAD_INIT, 117 | "extractor", 118 | "The extractor extension uses clang to extract AST information and perform " 119 | "transformations", 120 | -1, 121 | extractor_methods, 122 | }; 123 | 124 | PyMODINIT_FUNC PyInit_extractor(void) { 125 | PyObject *m = PyModule_Create(&extractor_definition); 126 | return m; 127 | } 128 | -------------------------------------------------------------------------------- /tests/test_tourniquet.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tourniquet import Tourniquet 4 | from tourniquet.location import Location as L 5 | from tourniquet.location import SourceCoordinate as SC 6 | from tourniquet.models import Function, Global, Module 7 | from tourniquet.patch_lang import ( 8 | ElseStmt, 9 | Expression, 10 | FixPattern, 11 | IfStmt, 12 | LessThanExpr, 13 | Lit, 14 | NodeStmt, 15 | PatchTemplate, 16 | ReturnStmt, 17 | ) 18 | 19 | 20 | def test_tourniquet_extract_ast(test_files, tmp_db): 21 | # No DB name for now 22 | test_extractor = Tourniquet(tmp_db) 23 | test_file = test_files / "patch_test.c" 24 | ast_dict: dict = test_extractor._extract_ast(test_file) 25 | 26 | # The dictionary produced from the AST has three top-level keys: 27 | # module_name, globals, and functions. 28 | assert "module_name" in ast_dict 29 | assert "globals" in ast_dict 30 | assert "functions" in ast_dict 31 | 32 | # The module name is our source file. 33 | assert ast_dict["module_name"] == str(test_file) 34 | 35 | # Everything in globals is a "var_type". 36 | assert all(global_[0] == "var_type" for global_ in ast_dict["globals"]) 37 | 38 | # There's at least one global named "pass". 39 | assert any(global_[5] == "pass" for global_ in ast_dict["globals"]) 40 | 41 | # There's a "main" function in the "functions" dictionary. 42 | assert "main" in ast_dict["functions"] 43 | main = ast_dict["functions"]["main"] 44 | 45 | # There are 4 variables in "main", plus "argc" and "argv". 46 | main_vars = [var_decl[5] for var_decl in main if var_decl[0] == "var_type"] 47 | assert set(main_vars) == {"argc", "argv", "buff", "buff_len", "pov", "len"} 48 | 49 | 50 | def test_tourniquet_db(test_files, tmp_db): 51 | tourniquet = Tourniquet(tmp_db) 52 | test_file = test_files / "patch_test.c" 53 | tourniquet.collect_info(test_file) 54 | 55 | module = tourniquet.db.query(Module).filter_by(name=str(test_file)).one() 56 | assert len(module.functions) >= 1 57 | 58 | pass_ = tourniquet.db.query(Global).filter_by(name="pass").one() 59 | assert pass_.name == "pass" 60 | assert pass_.type_ == "char *" 61 | assert pass_.start_line == 20 62 | assert pass_.start_column == 1 63 | assert pass_.end_line == 20 64 | assert pass_.end_column == 14 65 | assert not pass_.is_array 66 | assert pass_.size == 8 67 | 68 | main = tourniquet.db.query(Function).filter_by(name="main").one() 69 | assert main.name == "main" 70 | assert main.start_line == 22 71 | assert main.start_column == 1 72 | assert main.end_line == 38 73 | assert main.end_column == 1 74 | assert len(main.var_decls) == 6 75 | assert len(main.calls) == 4 76 | assert len(main.statements) == 8 77 | 78 | # TODO(ww): Test main.{var_decls,calls,statements} 79 | 80 | 81 | def test_tourniquet_extract_ast_invalid_file(test_files, tmp_db): 82 | test_extractor = Tourniquet(tmp_db) 83 | test_file = test_files / "does-not-exist" 84 | with pytest.raises(FileNotFoundError): 85 | test_extractor._extract_ast(test_file) 86 | 87 | 88 | def test_register_template(tmp_db): 89 | test_extractor = Tourniquet(tmp_db) 90 | new_template = PatchTemplate(FixPattern(NodeStmt()), lambda x, y: True) 91 | test_extractor.register_template("testme", new_template) 92 | assert len(test_extractor.patch_templates) == 1 93 | 94 | 95 | def test_auto_patch(test_files, tmp_db): 96 | tourniquet = Tourniquet(tmp_db) 97 | test_file = test_files / "patch_test.c" 98 | tourniquet.collect_info(test_file) 99 | 100 | class DummyErrorAnalysis(Expression): 101 | def __init__(self, expr): 102 | self.lit = Lit(expr) 103 | 104 | def concretize(self, db, location): 105 | yield from self.lit.concretize(db, location) 106 | 107 | class DummyCallable: 108 | def __init__(self, unused): 109 | self._unused = unused 110 | 111 | def __call__(self, line, col): 112 | return line == 32 and col == 3 113 | 114 | location = L(test_file, SC(32, 3)) 115 | template = PatchTemplate( 116 | FixPattern( 117 | IfStmt(LessThanExpr(Lit("len"), Lit("buff_len")), NodeStmt()), 118 | ElseStmt(ReturnStmt(DummyErrorAnalysis("1"))), 119 | ), 120 | DummyCallable("unused"), 121 | ) 122 | tourniquet.register_template("buffer_guard", template) 123 | 124 | assert ( 125 | tourniquet.auto_patch( 126 | "buffer_guard", 127 | [("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", 1), ("password", 0)], 128 | location, 129 | ) 130 | is not None 131 | ) 132 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import subprocess 4 | import sys 5 | import sysconfig 6 | from pathlib import Path 7 | 8 | from setuptools import Extension, find_packages, setup 9 | from setuptools.command.build_ext import build_ext 10 | 11 | module_name = "extractor" 12 | 13 | version = {} 14 | with open("./tourniquet/version.py") as f: 15 | exec(f.read(), version) 16 | 17 | with open("README.md") as f: 18 | long_description = f.read() 19 | 20 | with open("dev-requirements.txt") as f: 21 | dev_requirements = f.readlines() 22 | 23 | 24 | def get_ext_filename_without_platform_suffix(filename): 25 | name, ext = os.path.splitext(filename) 26 | ext_suffix = sysconfig.get_config_var("EXT_SUFFIX") 27 | 28 | if ext_suffix == ext: 29 | return filename 30 | 31 | ext_suffix = ext_suffix.replace(ext, "") 32 | idx = name.find(ext_suffix) 33 | 34 | if idx == -1: 35 | return filename 36 | else: 37 | return name[:idx] + ext 38 | 39 | 40 | class CMakeExtension(Extension): 41 | """ 42 | This class defines a setup.py extension that stores the directory 43 | of the root CMake file 44 | 45 | """ 46 | 47 | def __init__(self, name, cmake_lists_dir=None, **kwargs): 48 | Extension.__init__(self, name, sources=[], **kwargs) 49 | if cmake_lists_dir is None: 50 | self.sourcedir = os.path.dirname(os.path.abspath(__file__)) 51 | else: 52 | self.sourcedir = os.path.dirname(os.path.abspath(cmake_lists_dir)) 53 | 54 | 55 | class CMakeBuild(build_ext): 56 | """ 57 | This class defines a build extension 58 | get_ext_filename determines the expected output name of the library 59 | build_extension sets the appropriate cmake flags and invokes cmake to build the extension 60 | 61 | """ 62 | 63 | def get_ext_filename(self, ext_name): 64 | filename = super().get_ext_filename(ext_name) 65 | return get_ext_filename_without_platform_suffix(filename) 66 | 67 | def _platform_cmake_args(self): 68 | args = [] 69 | if platform.system() == "Darwin": 70 | deps = { 71 | "LLVM 9": "/usr/local/opt/llvm@9", 72 | } 73 | 74 | for name, dir_ in deps.items(): 75 | if not Path(dir_).is_dir(): 76 | print(f"Error: Couldn't find a {name} installation at {dir_}", file=sys.stderr) 77 | sys.exit(1) 78 | 79 | prefix_path = ":".join(deps.values()) 80 | args.append(f"-DCMAKE_PREFIX_PATH={prefix_path}") 81 | return args 82 | 83 | def build_extension(self, ext): 84 | extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) 85 | 86 | cmake_args = [] 87 | build_type = "Release" 88 | if os.getenv("DEBUG", None) is not None: 89 | cmake_args.append("-DCMAKE_VERBOSE_MAKEFILE=ON") 90 | build_type = "Debug" 91 | 92 | llvm_dir = os.getenv("LLVM_DIR") 93 | if llvm_dir is not None: 94 | cmake_args.append(f"-DLLVM_DIR={llvm_dir}") 95 | 96 | clang_dir = os.getenv("Clang_DIR") 97 | if clang_dir is not None: 98 | cmake_args.append(f"-DClang_DIR={clang_dir}") 99 | 100 | cmake_args += [ 101 | f"-DCMAKE_BUILD_TYPE={build_type}", 102 | f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", 103 | f"-DPYTHON_MODULE_NAME={module_name}", 104 | ] 105 | cmake_args += self._platform_cmake_args() 106 | print(cmake_args) 107 | if not os.path.exists(self.build_temp): 108 | os.makedirs(self.build_temp) 109 | subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp) 110 | subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp) 111 | 112 | 113 | setup( 114 | name="tourniquet", 115 | author="Carson Harmon", 116 | author_email="carson.harmon@trailofbits.com", 117 | python_requires=">=3.7", 118 | packages=find_packages(), 119 | version=version["__version__"], 120 | description="Syntax Guided Repair/Transformation Package", 121 | long_description=long_description, 122 | long_description_content_type="text/markdown", 123 | license="Apache-2.0", 124 | url="https://github.com/trailofbits/tourniquet", 125 | project_urls={"Documentation": "https://trailofbits.github.io/tourniquet/"}, 126 | ext_package="tourniquet", 127 | ext_modules=[CMakeExtension(module_name)], 128 | cmdclass={"build_ext": CMakeBuild}, 129 | install_requires=["sqlalchemy ~= 1.3"], 130 | extras_require={"dev": dev_requirements}, 131 | classifiers=[ 132 | "Programming Language :: Python :: 3", 133 | "License :: OSI Approved :: Apache Software License", 134 | "Development Status :: 3 - Alpha", 135 | "Intended Audience :: Science/Research", 136 | "Topic :: Software Development :: Code Generators", 137 | ], 138 | ) 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tourniquet 2 | 3 | ![CI](https://github.com/trailofbits/tourniquet/workflows/CI/badge.svg) 4 | 5 | A Python library for easy C/C++ syntax guided program transformation/repair. 6 | This is still very early in development. 7 | 8 | ## Installation 9 | 10 | ### PyPI 11 | 12 | Tourniquet is available via PyPI: 13 | 14 | ```bash 15 | $ pip install tourniquet 16 | ``` 17 | 18 | You'll need a relatively recent CMake and LLVM (9, 10, and 11 are supported). 19 | 20 | ### Docker 21 | 22 | Tourniquet can also be built and run within a Docker container: 23 | 24 | ```bash 25 | $ docker build -t trailofbits/tourniquet . 26 | $ docker run -it trailofbits/tourniquet 27 | ``` 28 | 29 | Enter an `ipython` instance and `import tourniquet` 30 | 31 | ## Syntax Guided Program Repair TL;DR 32 | 33 | Fixing programs is hard, and one of the most popular techniques for automated repair is search. If you're familiar with 34 | fuzzing, it is basically the reverse. In fuzzing you are mutating inputs to cause a crash, but in search based program 35 | repair you are mutating the program until you pass some specific tests. 36 | 37 | One huge drawback of this search based approach is that the search space is huge and most mutations in it 38 | are useless. The idea behind syntax guided program repair is to come up with some generally useful syntax patterns and 39 | to use those for your search space. This means your search space is smaller (restricted by the syntax patterns) and 40 | you are focusing on patch candidates that might actually fix whatever bug is in front of you. 41 | 42 | ## So What Even Is Tourniquet? 43 | 44 | Tourniquet is a library and domain specific language for syntax guided program repair. Current tools have 45 | hard coded fix patterns within them, making it hard for humans to interact and tweak them. The goal of Tourniquet is to 46 | make it quick and easy to create repair templates that can immediately be used to try and repair bugs. We plan on using 47 | Tourniquet alongside program analysis tools to allow humans to create fix patterns that have semantic meaning. 48 | 49 | ## Domain Specific Language 50 | 51 | Rather than writing individual tree transform passes or other types of source manipulation, Tourniquet makes it easy to 52 | describe part of the syntax and part of the semantics of a repair and lets the computer do the rest. Here is a simple 53 | example template: 54 | 55 | ```python 56 | class YourSemanticAnalysis(Expression): 57 | def concretize(self, _db, _location): 58 | yield "SOME_ERROR_CONSTANT" 59 | 60 | 61 | def your_matcher_func(line, col): 62 | return True 63 | 64 | 65 | demo_template = PatchTemplate( 66 | FixPattern( # Location 1 67 | IfStmt( 68 | LessThanExpr(Variable(), Variable()), # Location 2 69 | NodeStmt() # Location 3 70 | ), 71 | ElseStmt( 72 | ReturnStmt(YourSemanticAnalysis()) # Location 4 73 | ) 74 | ), 75 | your_matcher_func # Location 5 76 | ) 77 | ``` 78 | 79 | *Location 1* is the beginning of the `FixPattern`. The `FixPattern` describes the overall shape of 80 | the repair. This means the human provides part of the syntax, and part of the semantics of the 81 | repair. 82 | 83 | *Location 2* shows some of the different types in the DSL. What this line is describing is a less 84 | than statement with two variables, all the variable information is automatically extracted from the 85 | Clang AST. 86 | 87 | *Location 3* is whatever source was matched by your matcher function, also extracted from the 88 | Clang AST. 89 | 90 | *Location 4* is an example of how you could integrate program analysis tools with Tourniquet. 91 | The `FixPattern` is trying to do a basic `if...else` statement where the `else` case returns some 92 | value. Return values have semantic properties, returning some arbitrary integer isn't usually a 93 | good idea. This means you can use some program analysis technique to 94 | infer what an appropriate return code might actually be, or simply ask a human to intervene. 95 | 96 | *Location 5* is for a matcher. The matcher is a callable that is supposed to 97 | take source line and column information and return `True` or `False` if the `FixPatern` is 98 | applicable to that source location. The idea here is that we couple specific types of fixes with 99 | specific types of bugs. We intend to use some other tools 100 | (such as [Manticore](https://github.com/trailofbits/manticore)) to help determine bug classes. 101 | 102 | ## Using Tourniquet 103 | 104 | ```python 105 | # Create a new Tourniquet instance 106 | demo = Tourniquet("test.db") 107 | 108 | # Extract info from its AST into the database 109 | demo.collect_info("demo_prog.c") 110 | 111 | # Create a new patch template 112 | demo_template = PatchTemplate( 113 | FixPattern( 114 | IfStmt( 115 | LessThanExpr(Variable(), Variable()), 116 | NodeStmt() 117 | ) 118 | ), 119 | lambda x, y: True, 120 | ) 121 | 122 | # Add the template to the tourniquet instance 123 | demo.register_template("demo_template", demo_template) 124 | 125 | # Tell Tourniquet you want to see results from this program, with this template, 126 | # matching against some location 127 | location = Location("demo_prog.c", SourceCoordinate(44, 3)) 128 | samples = demo.concretize_template("demo_template", location) 129 | 130 | # Look at all the patch candidates! 131 | print(list(samples)) 132 | 133 | # Attempt to automatically repair the program using that template 134 | # Specify the file, some testcases, and the location information again 135 | demo.auto_patch( 136 | "demo_template" 137 | [ 138 | ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", 1), 139 | ("password", 0) 140 | ], 141 | location 142 | ) 143 | ``` 144 | 145 | Auto patch will return `True` or `False` depending on if you successfully found a patch to fix all testcases. Eventually 146 | we will support having a test case directory etc, this is still early in development. 147 | 148 | Check out tourniquet's [API documentation](https://trailofbits.github.io/tourniquet) for more details. 149 | 150 | ## Development 151 | 152 | Install venv to be able to run `make` commands 153 | 154 | ```bash 155 | $ docker build -t trailofbits/tourniquet . 156 | $ docker run -it trailofbits/tourniquet 157 | root@b9f3a28655b6:/tourniquet# apt-get install -y python3-venv 158 | root@b9f3a28655b6:/tourniquet# python3 -m venv env 159 | root@b9f3a28655b6:/tourniquet# make test 160 | ``` 161 | 162 | ## Contributors 163 | 164 | * Carson Harmon (carson.harmon@trailofbits.com) 165 | * Evan Sultanik (evan.sultanik@trailofbits.com) 166 | * William Woodruff (william@trailofbits.com) 167 | 168 | ## Acknowledgements 169 | 170 | The project or effort depicted is sponsored by the Air Force Research Laboratory (AFRL) and 171 | DARPA under contract FA8750-19-C-0004. Any opinions, findings and conclusions or recommendations 172 | expressed in this material are those of the author(s) and do not necessarily reflect the views of 173 | the Air Force Research Laboratory (AFRL) and DARPA. 174 | 175 | Distribution Statement: Approved for Public Release, Distribution Unlimited 176 | -------------------------------------------------------------------------------- /transformer/ASTExporter.cpp: -------------------------------------------------------------------------------- 1 | #include "ASTExporter.h" 2 | 3 | ASTExporterVisitor::ASTExporterVisitor(ASTContext *Context, PyObject *info) 4 | : Context(Context), tree_info(info), current_func(nullptr) {} 5 | 6 | void ASTExporterVisitor::PyDictUpdateEntry(PyObject *dict, const char *key, 7 | PyObject *new_item) { 8 | if (auto old_item = PyDict_GetItemString(dict, key)) { 9 | PyList_Append(old_item, new_item); 10 | } else { 11 | // Always init to be [item] 12 | PyObject *arr = PyList_New(0); 13 | PyList_Append(arr, new_item); 14 | PyDict_SetItemString(dict, key, arr); 15 | } 16 | } 17 | 18 | PyObject *ASTExporterVisitor::BuildStmtEntry(Stmt *stmt) { 19 | const auto expr = getText(*stmt, *Context).str(); 20 | 21 | unsigned int start_line = 22 | Context->getSourceManager().getExpansionLineNumber(stmt->getBeginLoc()); 23 | unsigned int start_col = 24 | Context->getSourceManager().getExpansionColumnNumber(stmt->getBeginLoc()); 25 | unsigned int end_line = 26 | Context->getSourceManager().getExpansionLineNumber(stmt->getEndLoc()); 27 | unsigned int end_col = 28 | Context->getSourceManager().getExpansionColumnNumber(stmt->getEndLoc()); 29 | 30 | PyObject *new_arr = PyList_New(0); 31 | PyList_Append(new_arr, PyUnicode_FromString("stmt_type")); 32 | PyList_Append(new_arr, PyLong_FromUnsignedLong(start_line)); 33 | PyList_Append(new_arr, PyLong_FromUnsignedLong(start_col)); 34 | PyList_Append(new_arr, PyLong_FromUnsignedLong(end_line)); 35 | PyList_Append(new_arr, PyLong_FromUnsignedLong(end_col)); 36 | PyList_Append(new_arr, PyUnicode_FromString(expr.c_str())); 37 | 38 | return new_arr; 39 | } 40 | 41 | void ASTExporterVisitor::AddGlobalEntry(PyObject *entry) { 42 | PyDictUpdateEntry(tree_info, "globals", entry); 43 | } 44 | 45 | void ASTExporterVisitor::AddFunctionEntry(const char *func_name, 46 | PyObject *entry) { 47 | if (auto functions = PyDict_GetItemString(tree_info, "functions")) { 48 | PyDictUpdateEntry(functions, func_name, entry); 49 | } else { 50 | functions = PyDict_New(); 51 | PyDict_SetItemString(tree_info, "functions", functions); 52 | PyDictUpdateEntry(functions, func_name, entry); 53 | } 54 | } 55 | 56 | bool ASTExporterVisitor::VisitDeclStmt(Stmt *stmt) { 57 | AddFunctionEntry(current_func->getNameAsString().c_str(), 58 | BuildStmtEntry(stmt)); 59 | return true; 60 | } 61 | 62 | // Current func, var name, var type, qual types, string 63 | // Down cast to other types of param decl? 64 | // TODO test non canonical types 65 | bool ASTExporterVisitor::VisitVarDecl(VarDecl *vdecl) { 66 | // Ignore externs, and parameter declarations. 67 | if (vdecl->getStorageClass() == SC_Extern) { 68 | return true; 69 | } 70 | 71 | unsigned int start_line = 72 | Context->getSourceManager().getExpansionLineNumber(vdecl->getBeginLoc()); 73 | unsigned int start_col = Context->getSourceManager().getExpansionColumnNumber( 74 | vdecl->getBeginLoc()); 75 | unsigned int end_line = 76 | Context->getSourceManager().getExpansionLineNumber(vdecl->getEndLoc()); 77 | unsigned int end_col = 78 | Context->getSourceManager().getExpansionColumnNumber(vdecl->getEndLoc()); 79 | 80 | // This Python list object contains our variable declaration state, 81 | // with the following layout: 82 | // [ 83 | // "var_type", start_line, start_col, end_line, end_col, 84 | // var_name, var_type, is_array, size 85 | // ] 86 | // The variable declaration is either added to the list under the "globals" 87 | // key or to its enclosing function, depending on whether it's in a function. 88 | 89 | PyObject *new_arr = PyList_New(0); 90 | PyList_Append(new_arr, PyUnicode_FromString("var_type")); 91 | PyList_Append(new_arr, PyLong_FromUnsignedLong(start_line)); 92 | PyList_Append(new_arr, PyLong_FromUnsignedLong(start_col)); 93 | PyList_Append(new_arr, PyLong_FromUnsignedLong(end_line)); 94 | PyList_Append(new_arr, PyLong_FromUnsignedLong(end_col)); 95 | PyList_Append(new_arr, 96 | PyUnicode_FromString(vdecl->getNameAsString().c_str())); 97 | 98 | auto qt = vdecl->getType(); 99 | if (auto arr_type = llvm::dyn_cast(qt.getTypePtr())) { 100 | auto size = arr_type->getSize().getZExtValue(); 101 | std::string type = arr_type->getElementType().getAsString(); 102 | PyList_Append(new_arr, PyUnicode_FromString(type.c_str())); 103 | PyList_Append(new_arr, PyLong_FromUnsignedLong(1)); 104 | PyList_Append(new_arr, PyLong_FromUnsignedLong(size)); 105 | } else { 106 | PyList_Append(new_arr, PyUnicode_FromString(qt.getAsString().c_str())); 107 | PyList_Append(new_arr, PyLong_FromUnsignedLong(0)); 108 | auto type_info = Context->getTypeInfo(qt); 109 | PyList_Append(new_arr, PyLong_FromUnsignedLong(type_info.Width / 8)); 110 | } 111 | 112 | auto parent_func = vdecl->getParentFunctionOrMethod(); 113 | if (parent_func == nullptr) { 114 | AddGlobalEntry(new_arr); 115 | } else { 116 | FunctionDecl *fdecl = llvm::dyn_cast(parent_func); 117 | if (fdecl->isFileContext()) { 118 | return true; 119 | } 120 | AddFunctionEntry(fdecl->getNameAsString().c_str(), new_arr); 121 | } 122 | 123 | return true; 124 | } 125 | 126 | // Current func, Callee, args, arg types, string 127 | bool ASTExporterVisitor::VisitCallExpr(CallExpr *call_expr) { 128 | FunctionDecl *func = call_expr->getDirectCallee(); 129 | if (func == nullptr) { 130 | return true; 131 | } 132 | 133 | auto decl_name = func->getNameInfo().getName(); 134 | if (decl_name.isEmpty()) { 135 | return true; 136 | } 137 | 138 | // NOTE(ww): This can happen for compiler inserted calls, e.g. a call 139 | // to __builtin_object_size. These don't appear anywhere in the source 140 | // so we skip them. There might be a better way to check for these, 141 | // like func->getBuiltinID() != 0. 142 | const auto expr = getText(*call_expr, *Context).str(); 143 | if (expr.empty()) { 144 | return true; 145 | } 146 | 147 | auto func_name = current_func->getNameAsString(); 148 | // Every CallExpr is a Stmt, so also record it as a Stmt. 149 | AddFunctionEntry(func_name.c_str(), BuildStmtEntry(call_expr)); 150 | 151 | std::string callee = decl_name.getAsString(); 152 | unsigned int start_line = Context->getSourceManager().getExpansionLineNumber( 153 | call_expr->getBeginLoc()); 154 | unsigned int start_col = Context->getSourceManager().getExpansionColumnNumber( 155 | call_expr->getBeginLoc()); 156 | unsigned int end_line = Context->getSourceManager().getExpansionLineNumber( 157 | call_expr->getEndLoc()); 158 | unsigned int end_col = Context->getSourceManager().getExpansionColumnNumber( 159 | call_expr->getEndLoc()); 160 | 161 | PyObject *new_arr = PyList_New(0); 162 | PyList_Append(new_arr, PyUnicode_FromString("call_type")); 163 | PyList_Append(new_arr, PyLong_FromUnsignedLong(start_line)); 164 | PyList_Append(new_arr, PyLong_FromUnsignedLong(start_col)); 165 | PyList_Append(new_arr, PyLong_FromUnsignedLong(end_line)); 166 | PyList_Append(new_arr, PyLong_FromUnsignedLong(end_col)); 167 | PyList_Append(new_arr, PyUnicode_FromString(expr.c_str())); 168 | PyList_Append(new_arr, PyUnicode_FromString(callee.c_str())); 169 | 170 | for (auto arg : call_expr->arguments()) { 171 | auto arg_arr = PyList_New(0); 172 | PyList_Append(arg_arr, 173 | PyUnicode_FromString(getText(*arg, *Context).str().c_str())); 174 | PyList_Append(arg_arr, 175 | PyUnicode_FromString(arg->getType().getAsString().c_str())); 176 | PyList_Append(new_arr, arg_arr); 177 | } 178 | AddFunctionEntry(func_name.c_str(), new_arr); 179 | return true; 180 | } 181 | 182 | // Name, Parameters, Parameter Types? 183 | bool ASTExporterVisitor::VisitFunctionDecl(FunctionDecl *func_decl) { 184 | if (func_decl->getStorageClass() == SC_Extern) { 185 | return true; 186 | } 187 | 188 | unsigned int start_line = Context->getSourceManager().getExpansionLineNumber( 189 | func_decl->getBeginLoc()); 190 | unsigned int start_col = Context->getSourceManager().getExpansionColumnNumber( 191 | func_decl->getBeginLoc()); 192 | unsigned int end_line = Context->getSourceManager().getExpansionLineNumber( 193 | func_decl->getEndLoc()); 194 | unsigned int end_col = Context->getSourceManager().getExpansionColumnNumber( 195 | func_decl->getEndLoc()); 196 | 197 | PyObject *new_arr = PyList_New(0); 198 | PyList_Append(new_arr, PyUnicode_FromString("func_decl")); 199 | PyList_Append(new_arr, PyLong_FromUnsignedLong(start_line)); 200 | PyList_Append(new_arr, PyLong_FromUnsignedLong(start_col)); 201 | PyList_Append(new_arr, PyLong_FromUnsignedLong(end_line)); 202 | PyList_Append(new_arr, PyLong_FromUnsignedLong(end_col)); 203 | 204 | AddFunctionEntry(func_decl->getNameAsString().c_str(), new_arr); 205 | 206 | // NOTE(ww) Subsequent visitor methods use this member to determine which 207 | // function they're in. 208 | current_func = func_decl; 209 | 210 | return true; 211 | } 212 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | -------------------------------------------------------------------------------- /tourniquet/tourniquet.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import subprocess 3 | import tempfile 4 | from contextlib import contextmanager 5 | from pathlib import Path 6 | from typing import Any, Dict, Iterator, Optional 7 | 8 | from . import extractor, models 9 | from .error import PatchSituationError, TemplateNameError 10 | from .location import Location, SourceCoordinate 11 | from .patch_lang import PatchTemplate 12 | 13 | 14 | class Tourniquet: 15 | def __init__(self, database_name): 16 | self.db_name = database_name 17 | self.db = models.DB.create(database_name) 18 | self.patch_templates: Dict[str, PatchTemplate] = {} 19 | 20 | def _path_looks_like_cxx(self, source_path: Path): 21 | return source_path.suffix in [".cpp", ".cc", ".cxx"] 22 | 23 | def _extract_ast(self, source_path: Path, is_cxx: bool = True) -> Dict[str, Any]: 24 | if not source_path.is_file(): 25 | raise FileNotFoundError(f"{source_path} is not a file") 26 | 27 | return extractor.extract_ast(source_path, is_cxx) 28 | 29 | def _store_ast(self, ast_info: Dict[str, Any]): 30 | module = models.Module(name=ast_info["module_name"]) 31 | for global_ in ast_info["globals"]: 32 | assert global_[0] == "var_type", f"{global_[0]} != var_type" 33 | global_ = models.Global( 34 | module=module, 35 | name=global_[5], 36 | type_=global_[6], 37 | start_line=global_[1], 38 | start_column=global_[2], 39 | end_line=global_[3], 40 | end_column=global_[4], 41 | is_array=bool(global_[7]), 42 | size=global_[8], 43 | ) 44 | self.db.session.add(global_) 45 | 46 | # Every subsequent member 47 | for func_name, exprs in ast_info["functions"].items(): 48 | # NOTE(ww): We expected the first member of each function's list to be 49 | # a "func_decl" list, containing information about the function declaration 50 | # itself. We use this to construct the initial Function model. 51 | # If a list doesn't begin with "func_decl," then it was external and we 52 | # skip it. 53 | # TODO(ww): Think more about the above. 54 | exprs = iter(exprs) 55 | func_decl = next(exprs) 56 | if func_decl[0] != "func_decl": 57 | continue 58 | 59 | function = models.Function( 60 | module=module, 61 | name=func_name, 62 | start_line=func_decl[1], 63 | start_column=func_decl[2], 64 | end_line=func_decl[3], 65 | end_column=func_decl[4], 66 | ) 67 | self.db.session.add(function) 68 | 69 | for expr in exprs: 70 | # From here, the exprs we know are "var_type" (models.VarDecl), 71 | # "call_type" (models.Call), and "stmt_type" (models.Statement). 72 | # "call_type" lists contain, in turn, a list of arguments, 73 | # which we promote to models.Argument objects. 74 | if expr[0] == "var_type": 75 | var_decl = models.VarDecl( 76 | function=function, 77 | name=expr[5], 78 | type_=expr[6], 79 | start_line=expr[1], 80 | start_column=expr[2], 81 | end_line=expr[3], 82 | end_column=expr[4], 83 | is_array=bool(expr[7]), 84 | size=expr[8], 85 | ) 86 | self.db.session.add(var_decl) 87 | elif expr[0] == "call_type": 88 | call = models.Call( 89 | module=module, 90 | function=function, 91 | expr=expr[5], 92 | name=expr[6], 93 | start_line=expr[1], 94 | start_column=expr[2], 95 | end_line=expr[3], 96 | end_column=expr[4], 97 | ) 98 | self.db.session.add(call) 99 | 100 | for name, type_ in expr[7:]: 101 | argument = models.Argument(call=call, name=name, type_=type_) 102 | self.db.session.add(argument) 103 | elif expr[0] == "stmt_type": 104 | stmt = models.Statement( 105 | module=module, 106 | function=function, 107 | expr=expr[5], 108 | start_line=expr[1], 109 | start_column=expr[2], 110 | end_line=expr[3], 111 | end_column=expr[4], 112 | ) 113 | self.db.session.add(stmt) 114 | else: 115 | assert False, expr[0] 116 | 117 | self.db.session.commit() 118 | 119 | # TODO Should take a target 120 | def collect_info(self, source_path: Path): 121 | """ 122 | Collect information about the given source file and add it to the backing database. 123 | """ 124 | ast_info = self._extract_ast(source_path, is_cxx=self._path_looks_like_cxx(source_path)) 125 | self._store_ast(ast_info) 126 | 127 | def register_template(self, name: str, template: PatchTemplate): 128 | """ 129 | Register a patching template with the given name. 130 | """ 131 | if name in self.patch_templates: 132 | raise TemplateNameError(f"a template has already been registered as {name}") 133 | self.patch_templates[name] = template 134 | 135 | # TODO Should take target 136 | # TODO(ww): Consider rehoming this? 137 | def view_template(self, template_name, location: Location) -> Optional[str]: 138 | """ 139 | Pretty-print the given template, partially concretized to the given 140 | module and source location. 141 | """ 142 | template = self.patch_templates.get(template_name) 143 | if template is None: 144 | return None 145 | 146 | print("=" * 10, template_name, "=" * 10) 147 | view_str = template.view(self.db, location) 148 | print(view_str) 149 | print("=" * 10, "END", "=" * 10) 150 | return view_str 151 | 152 | # TODO Should take a target 153 | def concretize_template(self, template_name: str, location: Location) -> Iterator[str]: 154 | """ 155 | Concretize the given registered template to the given 156 | module and source location, yielding each candidate patch. 157 | 158 | Args: 159 | template_name: The name of the template to concretize. This name 160 | must have been previously registered with `register_template`. 161 | location: The `Location` to concretize the template at. 162 | 163 | Returns: 164 | A generator of strings, each one representing a concrete patch suitable 165 | for placement at the supplied location. 166 | 167 | Raises: 168 | TemplateNameError: If the supplied template name isn't registered. 169 | """ 170 | template = self.patch_templates.get(template_name) 171 | if template is None: 172 | raise TemplateNameError(f"no template registed with name {template_name}") 173 | 174 | yield from template.concretize(self.db, location) 175 | 176 | # TODO Should take a target 177 | # TODO(ww): This should take a span instead of a location, so that it doesn't have 178 | # to depend on the patch location being a statement. 179 | @contextmanager 180 | def patch(self, replacement: str, location: Location) -> Iterator[bool]: 181 | """ 182 | Applies the given replacement to the given location. 183 | 184 | Rolls back the replacement after context closure. 185 | 186 | Args: 187 | replacement: The patch to insert. 188 | location: The `Location` to insert at, including the source file. 189 | 190 | Returns: 191 | A generator whose single yield is the status of the replacement operation. 192 | 193 | Raises: 194 | PatchSituationError: If the supplied location can't be used for a patch. 195 | """ 196 | try: 197 | temp_file = tempfile.NamedTemporaryFile() 198 | shutil.copyfile(location.filename, temp_file.name) 199 | 200 | statement = self.db.statement_at(location) 201 | if statement is None: 202 | raise PatchSituationError(f"no statement at ({location.line}, {location.column})") 203 | 204 | yield self.transform( 205 | location.filename, 206 | replacement, 207 | statement.start_coordinate, 208 | statement.end_coordinate, 209 | ) 210 | finally: 211 | shutil.copyfile(temp_file.name, location.filename) 212 | temp_file.close() 213 | 214 | # TODO Should take a target 215 | def auto_patch(self, template_name, tests, location: Location) -> Optional[str]: 216 | # TODO(ww): This should be a NamedTempFile, at the absolute minimum. 217 | EXEC_FILE = Path("/tmp/target") 218 | 219 | # Collect replacements 220 | replacements = self.concretize_template(template_name, location) 221 | 222 | # Patch 223 | for replacement in replacements: 224 | with self.patch(replacement, location): 225 | # Just compile with clang for now 226 | ret = subprocess.call(["clang", "-g", "-o", EXEC_FILE, location.filename]) 227 | if ret != 0: 228 | print("Error, build failed?") 229 | continue 230 | 231 | # Run the test suite 232 | failed_test = False 233 | for test_case in tests: 234 | (input_, output) = test_case 235 | ret = subprocess.call([EXEC_FILE, input_]) 236 | if output != ret: 237 | failed_test = True 238 | break 239 | if not failed_test: 240 | # This means that its fixed :) 241 | return replacement 242 | 243 | return None 244 | 245 | def transform( 246 | self, filename: Path, replacement: str, start: SourceCoordinate, end: SourceCoordinate 247 | ): 248 | res = extractor.transform( 249 | filename, 250 | self._path_looks_like_cxx(filename), 251 | replacement, 252 | start.line, 253 | start.column, 254 | end.line, 255 | end.column, 256 | ) 257 | return res 258 | -------------------------------------------------------------------------------- /tourniquet/models.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, create_engine 5 | from sqlalchemy.ext.declarative import declarative_base 6 | from sqlalchemy.orm import relationship, sessionmaker 7 | 8 | from .location import Location, SourceCoordinate 9 | 10 | Base = declarative_base() 11 | 12 | 13 | class Module(Base): 14 | """ 15 | Represents a C or C++ file, or "module" in Clang's terminology. 16 | """ 17 | 18 | __tablename__ = "modules" 19 | 20 | id = Column(Integer, primary_key=True) 21 | """ 22 | This module's database ID. 23 | """ 24 | 25 | name = Column(String, unique=True, nullable=False) 26 | """ 27 | The name (i.e. source file) of this module. 28 | """ 29 | 30 | functions = relationship("Function", uselist=True) 31 | """ 32 | The `Function`s present in this module. 33 | """ 34 | 35 | global_variables = relationship("Global", uselist=True) 36 | """ 37 | The `Global`s present in this module. 38 | """ 39 | 40 | statements = relationship("Statement", uselist=True) 41 | """ 42 | The `Statement`s present in this module. 43 | """ 44 | 45 | calls = relationship("Call", uselist=True) 46 | """ 47 | The `Call`s present in this module. 48 | """ 49 | 50 | # TODO(ww): Module -> [VarDecl] relationship 51 | 52 | def __repr__(self): 53 | return f"" 54 | 55 | 56 | class Function(Base): 57 | """ 58 | Represents a single C or C++ function. 59 | """ 60 | 61 | __tablename__ = "functions" 62 | 63 | id = Column(Integer, primary_key=True) 64 | """ 65 | This function's database ID. 66 | """ 67 | 68 | module_name = Column(String, ForeignKey("modules.name")) 69 | """ 70 | The name of the `Module` that this function belongs to. 71 | """ 72 | 73 | module = relationship("Module", back_populates="functions") 74 | """ 75 | The `Module` that this function belongs to. 76 | """ 77 | 78 | name = Column(String, nullable=False) 79 | """ 80 | The name of this function. 81 | """ 82 | 83 | start_line = Column(Integer, nullable=False) 84 | """ 85 | The line that this function begins on. 86 | """ 87 | 88 | start_column = Column(Integer, nullable=False) 89 | """ 90 | The column that this function begins on. 91 | """ 92 | 93 | end_line = Column(Integer, nullable=False) 94 | """ 95 | The line that this function ends on. 96 | """ 97 | 98 | end_column = Column(Integer, nullable=False) 99 | """ 100 | The column that this function ends on. 101 | """ 102 | 103 | var_decls = relationship("VarDecl", uselist=True) 104 | """ 105 | The `VarDecl`s present in this function. 106 | """ 107 | 108 | calls = relationship("Call", uselist=True) 109 | """ 110 | The `Call`s present in this function. 111 | """ 112 | 113 | statements = relationship("Statement", uselist=True) 114 | """ 115 | The `Statement`s present in this function. 116 | """ 117 | 118 | @property 119 | def start_coordinate(self): 120 | """ 121 | Returns a `SourceCoordinate` representing where this function 122 | begins in its source file. 123 | """ 124 | return SourceCoordinate(self.start_line, self.start_column) 125 | 126 | @property 127 | def end_coordinate(self): 128 | """ 129 | Returns a `SourceCoordinate` representing where this function ends 130 | in its source file. 131 | """ 132 | return SourceCoordinate(self.end_line, self.end_column) 133 | 134 | @property 135 | def location(self): 136 | """ 137 | Returns a `Location` representing this function's source file and start 138 | coordinate. 139 | """ 140 | return Location(Path(self.module_name), self.start_coordinate) 141 | 142 | def __repr__(self): 143 | return f"" 144 | 145 | 146 | class Global(Base): 147 | """ 148 | Represents a global variable declaration. 149 | """ 150 | 151 | __tablename__ = "globals" 152 | 153 | id = Column(Integer, primary_key=True) 154 | """ 155 | This global's database ID. 156 | """ 157 | 158 | module_name = Column(String, ForeignKey("modules.name")) 159 | """ 160 | The name of the `Module` that this global belongs to. 161 | """ 162 | 163 | module = relationship("Module", back_populates="global_variables") 164 | """ 165 | The `Module` that this global belongs to. 166 | """ 167 | 168 | name = Column(String, nullable=False) 169 | """ 170 | The declared name of this global. 171 | """ 172 | 173 | type_ = Column(String, nullable=False) 174 | """ 175 | The declared type of this global. 176 | """ 177 | 178 | start_line = Column(Integer, nullable=False) 179 | """ 180 | The line that this global begins on. 181 | """ 182 | 183 | start_column = Column(Integer, nullable=False) 184 | """ 185 | The column that this global begins on. 186 | """ 187 | 188 | end_line = Column(Integer, nullable=False) 189 | """ 190 | The line that this global ends on. 191 | """ 192 | 193 | end_column = Column(Integer, nullable=False) 194 | """ 195 | The column that this global ends on. 196 | """ 197 | 198 | is_array = Column(Boolean, nullable=False) 199 | """ 200 | Whether or not this global's declaration is for an array type. 201 | """ 202 | 203 | size = Column(Integer, nullable=False) 204 | """ 205 | The size of this global, in bytes. 206 | """ 207 | 208 | def __repr__(self): 209 | return f"" 210 | 211 | 212 | class VarDecl(Base): 213 | """ 214 | Represents a local variable or function parameter declaration. 215 | """ 216 | 217 | __tablename__ = "var_decls" 218 | 219 | id = Column(Integer, primary_key=True) 220 | """ 221 | This declaration's database ID. 222 | """ 223 | 224 | function_id = Column(Integer, ForeignKey("functions.id")) 225 | """ 226 | The ID of the `Function` that this declaration is present in. 227 | """ 228 | 229 | function = relationship("Function", back_populates="var_decls") 230 | """ 231 | The `Function` that this declaration is present in. 232 | """ 233 | 234 | name = Column(String, nullable=False) 235 | """ 236 | The name of this declared variable. 237 | """ 238 | 239 | type_ = Column(String, nullable=False) 240 | """ 241 | The type of this declared variable. 242 | """ 243 | 244 | start_line = Column(Integer, nullable=False) 245 | """ 246 | The line that this declaration begins on. 247 | """ 248 | 249 | start_column = Column(Integer, nullable=False) 250 | """ 251 | The column that this declaration begins on. 252 | """ 253 | 254 | end_line = Column(Integer, nullable=False) 255 | """ 256 | The line that this declaration ends on. 257 | """ 258 | 259 | end_column = Column(Integer, nullable=False) 260 | """ 261 | The column that this declaration ends on. 262 | """ 263 | 264 | is_array = Column(Boolean, nullable=False) 265 | """ 266 | Whether or not this declaration is for an array type. 267 | """ 268 | 269 | size = Column(Integer, nullable=False) 270 | """ 271 | The size of the declared variable, in bytes. 272 | """ 273 | 274 | def __repr__(self): 275 | return f"" 276 | 277 | 278 | class Call(Base): 279 | """ 280 | Represents a function call. 281 | """ 282 | 283 | __tablename__ = "calls" 284 | 285 | id = Column(Integer, primary_key=True) 286 | """ 287 | This call's database ID. 288 | """ 289 | 290 | module_name = Column(String, ForeignKey("modules.name")) 291 | """ 292 | The name of the `Module` that this statement is in. 293 | """ 294 | 295 | module = relationship("Module", back_populates="calls") 296 | """ 297 | The `Module` that this statement is in. 298 | """ 299 | 300 | function_id = Column(Integer, ForeignKey("functions.id")) 301 | """ 302 | The ID of the `Function` that this call is present in. 303 | """ 304 | 305 | function = relationship("Function", back_populates="calls") 306 | """ 307 | The `Function` that this call is present in. 308 | """ 309 | 310 | expr = Column(String, nullable=False) 311 | """ 312 | The call expression itself. 313 | """ 314 | 315 | name = Column(String, nullable=False) 316 | """ 317 | The name of the callee. 318 | """ 319 | 320 | start_line = Column(Integer, nullable=False) 321 | """ 322 | The line that this call begins on. 323 | """ 324 | 325 | start_column = Column(Integer, nullable=False) 326 | """ 327 | The column that this call begins on. 328 | """ 329 | 330 | end_line = Column(Integer, nullable=False) 331 | """ 332 | The line that this call ends on. 333 | """ 334 | 335 | end_column = Column(Integer, nullable=False) 336 | """ 337 | The column that this call ends on. 338 | """ 339 | 340 | arguments = relationship("Argument", uselist=True) 341 | """ 342 | The `Argument`s associated with this call. 343 | """ 344 | 345 | @property 346 | def start_coordinate(self): 347 | """ 348 | Returns a `SourceCoordinate` representing where this call 349 | begins in its source file. 350 | """ 351 | return SourceCoordinate(self.start_line, self.start_column) 352 | 353 | @property 354 | def end_coordinate(self): 355 | """ 356 | Returns a `SourceCoordinate` representing where this call ends 357 | in its source file. 358 | """ 359 | return SourceCoordinate(self.end_line, self.end_column) 360 | 361 | @property 362 | def location(self): 363 | """ 364 | Returns a `Location` representing this call's source file and start 365 | coordinate. 366 | """ 367 | return Location(Path(self.module_name), self.start_coordinate) 368 | 369 | def __repr__(self): 370 | return f"" 371 | 372 | 373 | class Argument(Base): 374 | """ 375 | Represents an argument to a function call. 376 | """ 377 | 378 | __tablename__ = "call_arguments" 379 | 380 | id = Column(Integer, primary_key=True) 381 | """ 382 | This argument's database ID. 383 | """ 384 | 385 | call_id = Column(Integer, ForeignKey("calls.id")) 386 | """ 387 | The ID of the `Call` that this argument is in. 388 | """ 389 | 390 | call = relationship("Call", back_populates="arguments") 391 | """ 392 | The `Call` that this argument is in. 393 | """ 394 | 395 | name = Column(String, nullable=False) 396 | """ 397 | The name of the argument. 398 | """ 399 | 400 | type_ = Column(String, nullable=False) 401 | """ 402 | The type of the argument. 403 | """ 404 | 405 | def __repr__(self): 406 | return f"" 407 | 408 | 409 | class Statement(Base): 410 | """ 411 | Represents a primitive (i.e., non-compound) C or C++ statement. 412 | """ 413 | 414 | __tablename__ = "statements" 415 | 416 | id = Column(Integer, primary_key=True) 417 | """ 418 | This statement's database ID. 419 | """ 420 | 421 | module_name = Column(String, ForeignKey("modules.name")) 422 | """ 423 | The name of the `Module` that this statement is in. 424 | """ 425 | 426 | module = relationship("Module", back_populates="statements") 427 | """ 428 | The `Module` that this statement is in. 429 | """ 430 | 431 | function_id = Column(Integer, ForeignKey("functions.id")) 432 | """ 433 | The ID of the `Function` that this statement is in. 434 | """ 435 | 436 | function = relationship("Function", back_populates="statements") 437 | """ 438 | The `Function` that this statement is in. 439 | """ 440 | 441 | start_line = Column(Integer, nullable=False) 442 | """ 443 | The line that this statement begins on. 444 | """ 445 | 446 | start_column = Column(Integer, nullable=False) 447 | """ 448 | The column that this statement begins on. 449 | """ 450 | 451 | end_line = Column(Integer, nullable=False) 452 | """ 453 | The line that this statement ends on. 454 | """ 455 | 456 | end_column = Column(Integer, nullable=False) 457 | """ 458 | The column that this statement ends on. 459 | """ 460 | 461 | expr = Column(String, nullable=False) 462 | """ 463 | The expression text of this statement. 464 | """ 465 | 466 | @property 467 | def start_coordinate(self): 468 | """ 469 | Returns a `SourceCoordinate` representing where this statement 470 | begins in its source file. 471 | """ 472 | return SourceCoordinate(self.start_line, self.start_column) 473 | 474 | @property 475 | def end_coordinate(self): 476 | """ 477 | Returns a `SourceCoordinate` representing where this statement ends 478 | in its source file. 479 | """ 480 | return SourceCoordinate(self.end_line, self.end_column) 481 | 482 | @property 483 | def location(self): 484 | """ 485 | Returns a `Location` representing this statement's source file and start 486 | coordinate. 487 | """ 488 | return Location(Path(self.module_name), self.start_coordinate) 489 | 490 | def __repr__(self): 491 | return f"" 492 | 493 | 494 | class DB: 495 | """ 496 | A convenience class for querying the database. 497 | """ 498 | 499 | @classmethod 500 | def create(cls, db_path, echo=False): 501 | """ 502 | Creates a new database at the given path. 503 | """ 504 | 505 | engine = create_engine(f"sqlite:///{db_path}", echo=echo) 506 | 507 | session = sessionmaker(bind=engine)() 508 | Base.metadata.create_all(engine) 509 | 510 | return cls(session, db_path) 511 | 512 | def __init__(self, session, db_path): 513 | self.session = session 514 | self.db_path = db_path 515 | 516 | def query(self, *args, **kwargs): 517 | """ 518 | Forwards a SQLAlchemy-style `query` to the database. 519 | """ 520 | 521 | return self.session.query(*args, **kwargs) 522 | 523 | def function_at(self, location: Location) -> Optional[Function]: 524 | return ( 525 | self.query(Function) 526 | .filter( 527 | (str(location.filename) == Function.module_name) 528 | & (location.line >= Function.start_line) 529 | & (location.column >= Function.start_column) 530 | & ( 531 | ( 532 | (location.line == Function.end_line) 533 | & (location.column <= Function.end_column) 534 | ) 535 | | ((location.line < Function.end_line)) 536 | ) 537 | ) 538 | .one_or_none() 539 | ) 540 | 541 | def statement_at(self, location: Location) -> Optional[Statement]: 542 | statement = ( 543 | self.query(Statement) 544 | .filter( 545 | (Statement.module_name == str(location.filename)) 546 | & (Statement.start_line == location.line) 547 | & (Statement.start_column == location.column) 548 | ) 549 | .one_or_none() 550 | ) 551 | 552 | return statement 553 | -------------------------------------------------------------------------------- /tourniquet/patch_lang.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from abc import ABC, abstractmethod 3 | from typing import Callable, Iterator, List, Optional 4 | 5 | from . import models 6 | from .error import PatchConcretizationError 7 | from .location import Location 8 | 9 | 10 | class Expression(ABC): 11 | """ 12 | Represents an abstract source "expression". 13 | 14 | `Expression` is an abstract base class whose interfaces should be consumed via 15 | specific derived classes. 16 | """ 17 | 18 | @abstractmethod 19 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 20 | yield from () 21 | 22 | def view(self, db: models.DB, location: Location) -> str: 23 | return "Expression()" 24 | 25 | 26 | class Lit(Expression): 27 | """ 28 | Represents a literal source string. 29 | 30 | Template authors can use this to insert pre-formed source code, or code that doesn't need 31 | to be concretized by location. 32 | """ 33 | 34 | def __init__(self, expr: str): 35 | """ 36 | Create a new `Lit` with the given source string. 37 | 38 | Args: 39 | expr: The source literal 40 | """ 41 | self.expr = expr 42 | 43 | def concretize(self, _db, _location) -> Iterator[str]: 44 | """ 45 | Concretize this literal into its underlying source string. 46 | """ 47 | yield self.expr 48 | 49 | def view(self, _db, _location): 50 | return f"Lit({self.expr})" 51 | 52 | 53 | class Variable(Expression): 54 | """ 55 | Represents an abstract source variable. 56 | """ 57 | 58 | # NOTE(ww): This is probably slightly unsound, in terms of concretizations 59 | # produced: the set of variables in a function is not the set of variables 60 | # guaranteed to be in scope at a line number. 61 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 62 | """ 63 | Concretize this `Variable` into its potential names. 64 | 65 | Args: 66 | db: The AST database to concretize against 67 | location: The location to concretize at 68 | 69 | Returns: 70 | A generator of strings, each of which is a concrete variable name 71 | 72 | Raises: 73 | PatchConcretizationError: If the scope of the variable can't be resolved 74 | """ 75 | 76 | # TODO(ww): We should also concretize with the available globals. 77 | function = db.function_at(location) 78 | if function is None: 79 | raise PatchConcretizationError( 80 | f"no function contains ({location.line}, {location.column})" 81 | ) 82 | 83 | for var_decl in function.var_decls: 84 | yield var_decl.name 85 | 86 | def view(self, _db, _location) -> str: 87 | return "Variable()" 88 | 89 | 90 | class StaticBufferSize(Expression): 91 | """ 92 | Represents an abstract "sizeof(...)" expression. 93 | """ 94 | 95 | # Query for all static array types and return list of sizeof().. 96 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 97 | """ 98 | Concretize this `StaticBufferSize` into its potential sizes. 99 | 100 | Args: 101 | db: The AST database to concretize against 102 | location: The location to concretize at 103 | 104 | Returns: 105 | A generator of strings, each of which is a concrete `sizeof(...)` expression 106 | 107 | Raises: 108 | PatchConcretizationError: If the scope of the location is not a function 109 | """ 110 | 111 | # TODO(ww): We should also concretize with available globals. 112 | function = db.function_at(location) 113 | 114 | if function is None: 115 | raise PatchConcretizationError( 116 | f"no function contains ({location.line}, {location.column})" 117 | ) 118 | 119 | for var_decl in function.var_decls: 120 | if not var_decl.is_array: 121 | continue 122 | yield f"sizeof({var_decl.name})" 123 | 124 | def view(self, _db, _location) -> str: 125 | return "StaticBufferSize()" 126 | 127 | 128 | class BinaryMathOperator(Expression): 129 | """ 130 | Represents the possible binary math operators, along with their lhs and rhs expressions. 131 | """ 132 | 133 | def __init__(self, lhs: Expression, rhs: Expression): 134 | """ 135 | Create a new `BinaryMathOperator` with the given `lhs` and `rhs`. 136 | 137 | Args: 138 | lhs: An `Expression` to use for the left-hand side 139 | rhs: An `Expression` to use for the right-hand side 140 | """ 141 | self.lhs = lhs 142 | self.rhs = rhs 143 | 144 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 145 | """ 146 | Concretize this `BinaryMathOperator` into its possible operator expressions. 147 | 148 | Args: 149 | db: The AST database to concretize against 150 | location: The location to concretize at 151 | 152 | Returns: 153 | A generator of strings, each of which is a concrete binary math expression 154 | """ 155 | lhs_exprs = self.lhs.concretize(db, location) 156 | rhs_exprs = self.rhs.concretize(db, location) 157 | for (lhs, rhs) in itertools.product(lhs_exprs, rhs_exprs): 158 | yield f"{lhs} + {rhs}" 159 | yield f"{lhs} - {rhs}" 160 | yield f"{lhs} / {rhs}" 161 | yield f"{lhs} * {rhs}" 162 | yield f"{lhs} << {rhs}" 163 | # TODO(ww): Missing for unknown reasons: >>, &, |, ^, % 164 | 165 | def view(self, db: models.DB, location: Location) -> str: 166 | return f"BinaryMathOperator({self.lhs.view(db, location)}, {self.lhs.view(db, location)})" 167 | 168 | 169 | class BinaryBoolOperator(Expression): 170 | """ 171 | Represents the possible binary boolean operators, along with their lhs and rhs expressions. 172 | """ 173 | 174 | def __init__(self, lhs: Expression, rhs: Expression): 175 | """ 176 | Create a new `BinaryBoolOperator` with the given `lhs` and `rhs`. 177 | 178 | Args: 179 | lhs: An `Expression` to use for the left-hand side 180 | rhs: An `Expression` to use for the right-hand side 181 | """ 182 | self.lhs = lhs 183 | self.rhs = rhs 184 | 185 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 186 | """ 187 | Concretize this `BinaryBoolOperator` into its possible operator expressions. 188 | 189 | Args: 190 | db: The AST database to concretize against 191 | location: The location to concretize at 192 | 193 | Returns: 194 | A generator of strings, each of which is a concrete binary boolean expression 195 | """ 196 | lhs_exprs = self.lhs.concretize(db, location) 197 | rhs_exprs = self.rhs.concretize(db, location) 198 | for (lhs, rhs) in itertools.product(lhs_exprs, rhs_exprs): 199 | yield f"{lhs} == {rhs}" 200 | yield f"{lhs} != {rhs}" 201 | yield f"{lhs} <= {rhs}" 202 | yield f"{lhs} < {rhs}" 203 | yield f"{lhs} >= {rhs}" 204 | yield f"{lhs} > {rhs}" 205 | # TODO(ww): If location is in a C++ source file, maybe add <=> 206 | 207 | def view(self, db: models.DB, location: Location) -> str: 208 | return f"BinaryBoolOperator({self.lhs.view(db, location)}, {self.rhs.view(db, location)})" 209 | 210 | 211 | class LessThanExpr(Expression): 212 | """ 213 | Represents the less-than boolean operator, along with its lhs and rhs expressions. 214 | """ 215 | 216 | def __init__(self, lhs: Expression, rhs: Expression): 217 | """ 218 | Create a new `LessThanExpr` with the given `lhs` and `rhs`. 219 | 220 | Args: 221 | lhs: An `Expression` to use for the left-hand side. 222 | rhs: An `Expression` to use for the right-hand side. 223 | """ 224 | self.lhs = lhs 225 | self.rhs = rhs 226 | 227 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 228 | """ 229 | Concretize this `LessThanExpr` into its possible operator expressions. 230 | 231 | Args: 232 | db: The AST database to concretize against 233 | location: The location to concretize at 234 | 235 | Returns: 236 | A generator of strings, each of which is a concrete less-than boolean expression 237 | """ 238 | 239 | lhs_exprs = self.lhs.concretize(db, location) 240 | rhs_exprs = self.rhs.concretize(db, location) 241 | for (lhs, rhs) in itertools.product(lhs_exprs, rhs_exprs): 242 | yield f"{lhs} < {rhs}" 243 | 244 | def view(self, db: models.DB, location: Location): 245 | self.lhs.view(db, location) + " < " + self.rhs.view(db, location) 246 | 247 | 248 | class Statement(ABC): 249 | """ 250 | Represents an source "statement". 251 | 252 | `Statement` is an abstract base whose interfaces should be consumed via specific 253 | derived classes. 254 | """ 255 | 256 | @abstractmethod 257 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 258 | pass 259 | 260 | @abstractmethod 261 | def view(self, db: models.DB, location: Location): 262 | pass 263 | 264 | 265 | class StatementList: 266 | """ 267 | Represents a sequence of statements. 268 | """ 269 | 270 | def __init__(self, *args): 271 | self.statements: List[Statement] = args 272 | 273 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 274 | """ 275 | Concretize this `StatementList` into its possible statement seqences. 276 | 277 | This involves concretizing each statement and taking their unique 278 | product, such that every possible permutation of statements in the list 279 | is produced. 280 | 281 | Args: 282 | db: The AST database to concretize against 283 | locaton: The location to concretize at 284 | 285 | Returns: 286 | A generator of strings, each of which is a concreqte sequence of statements 287 | """ 288 | concretized = [stmt.concretize(db, location) for stmt in self.statements] 289 | for items in set(itertools.product(*concretized)): 290 | yield "\n".join(items) 291 | 292 | def view(self, db: models.DB, location: Location) -> str: 293 | final_str = "" 294 | for stmt in self.statements: 295 | final_str += stmt.view(db, location) + "\n" 296 | return final_str 297 | 298 | 299 | class IfStmt(Statement): 300 | """ 301 | Represents an `if(...) { ... }` statement. 302 | """ 303 | 304 | def __init__(self, cond_expr: Expression, *args: List[Statement]): 305 | """ 306 | Creates a new `IfStmt`. 307 | 308 | Args: 309 | cond_expr: The expression to evaluate 310 | args: The statements to use in the body of the `if` 311 | """ 312 | self.cond_expr = cond_expr 313 | self.statement_list = StatementList(*args) 314 | 315 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 316 | """ 317 | Concretize this `IfStmt` into its possible statements. 318 | 319 | Args: 320 | db: The AST database to concretize against 321 | location: The location to concretize at 322 | 323 | Returns: 324 | A generator of strings, each of which is an `if` statement 325 | """ 326 | cond_list = self.cond_expr.concretize(db, location) 327 | stmt_list = self.statement_list.concretize(db, location) 328 | for (cond, stmt) in itertools.product(cond_list, stmt_list): 329 | cand_str = "if (" + cond + ") {\n" + stmt + "\n}\n" 330 | yield cand_str 331 | 332 | def view(self, db: models.DB, location: Location) -> str: 333 | if_str = "if (" + self.cond_expr.view(db, location) + ") {\n" 334 | if_str += self.statement_list.view(db, location) 335 | if_str += "\n}\n" 336 | return if_str 337 | 338 | 339 | class ElseStmt(Statement): 340 | """ 341 | Represents an `else { ... }` statement. 342 | """ 343 | 344 | def __init__(self, *args: List[Statement]): 345 | """ 346 | Creates a new `ElseStmt`. 347 | 348 | Args: 349 | args: The statements to use in the body of the `else` 350 | """ 351 | 352 | self.statement_list = StatementList(*args) 353 | 354 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 355 | """ 356 | Concretize this `ElseStmt` into its possible statements. 357 | 358 | Args: 359 | db: The AST database to concretize against 360 | location: The location to concretize at 361 | 362 | Returns: 363 | A generator of strings, each of which is an `else` statement 364 | """ 365 | stmt_list = self.statement_list.concretize(db, location) 366 | for stmt in stmt_list: 367 | cand_str = "else {\n" + stmt + "\n}\n" 368 | yield cand_str 369 | 370 | def view(self, db: models.DB, location: Location) -> str: 371 | return "else {\n" + self.statement_list.view(db, location) + "\n}\n" 372 | 373 | 374 | class ReturnStmt(Statement): 375 | """ 376 | Represents a `return ...;` statement. 377 | """ 378 | 379 | def __init__(self, expr: Expression): 380 | """ 381 | Creates a new `ReturnStmt`. 382 | 383 | Args: 384 | expr: The expression to concretize for the `return` 385 | """ 386 | self.expr = expr 387 | 388 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 389 | """ 390 | Concretize this `ReturnStmt` into its possible statements. 391 | 392 | Args: 393 | db: The AST database to concretize against 394 | location: The location to concretize at 395 | 396 | Returns: 397 | A generator of strings, each of which is an `return` statement 398 | """ 399 | expr_list = self.expr.concretize(db, location) 400 | for exp in expr_list: 401 | candidate_str = f"return {exp};" 402 | yield candidate_str 403 | 404 | def view(self, db: models.DB, location: Location): 405 | return f"return {self.expr.view(db, location)};" 406 | 407 | 408 | # TODO This should take a Node, which is something from... the DB? Maybe? 409 | # TODO(ww): This should probably be renamed to ASTStmt or similar. 410 | class NodeStmt(Statement): 411 | """ 412 | Represents a statement from the AST database, identified by location. 413 | """ 414 | 415 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 416 | """ 417 | Concretize this `NodeStmt` into its concrete statement. 418 | 419 | Args: 420 | db: The AST database to concretize against 421 | location: The statement's location 422 | 423 | Returns: 424 | A generator of strings, whose single item is the extracted AST statement 425 | """ 426 | source_info = self.fetch_node(db, location) 427 | yield source_info 428 | 429 | def view(self, db, location) -> str: 430 | # Fetch and create a node. 431 | source_info = self.fetch_node(db, location) 432 | return source_info 433 | 434 | def fetch_node(self, db, location) -> str: 435 | """ 436 | Fetches the statement at the given location. 437 | 438 | Args: 439 | db: The AST database to query against. 440 | location: The location of the statement. 441 | 442 | Returns: 443 | The statement, including trailing semicolon. 444 | 445 | Raises: 446 | PatchConcretizationError: If there is no statement at the supplied location. 447 | """ 448 | statement = db.statement_at(location) 449 | if statement is None: 450 | raise PatchConcretizationError(f"no statement at ({location.line}, {location.column})") 451 | 452 | if statement.expr.endswith(";"): 453 | return statement.expr 454 | else: 455 | return f"{statement.expr};" 456 | 457 | 458 | class FixPattern: 459 | """ 460 | Represents a program "fix" that can be concretized into a sequence of candidate patches. 461 | """ 462 | 463 | def __init__(self, *args: List[Statement]): 464 | """ 465 | Create a new `FixPattern`. 466 | 467 | Args: 468 | args: The statements to use in the fix 469 | """ 470 | self.statement_list = StatementList(*args) 471 | 472 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 473 | """ 474 | Concretize this `FixPattern` into a sequence of candidate patches. 475 | 476 | Args: 477 | db: The AST database to concretize against 478 | location: The location to concretize at 479 | 480 | Returns: 481 | A generator of strings, each of which is a candidate patch 482 | """ 483 | yield from self.statement_list.concretize(db, location) 484 | 485 | def view(self, db: models.DB, location: Location) -> str: 486 | return self.statement_list.view(db, location) 487 | 488 | 489 | class PatchTemplate: 490 | """ 491 | Represents a templatized patch, including a matcher callable. 492 | """ 493 | 494 | def __init__( 495 | self, fix_pattern: FixPattern, matcher_func: Optional[Callable[[int, int], bool]] = None 496 | ): 497 | """ 498 | Create a new `PatchTemplate`. 499 | 500 | Args: 501 | fix_pattern: The fix pattern to concretize into patches 502 | matcher_func: The callable to filter patch locations with, if any 503 | """ 504 | self.matcher_func = matcher_func 505 | self.fix_pattern = fix_pattern 506 | 507 | def matches(self, line: int, col: int) -> bool: 508 | """ 509 | Returns whether the given line and column is suitable for patch situation. 510 | 511 | Calls `self.matcher_func` internally, if supplied. 512 | 513 | Args: 514 | line: The line to patch on 515 | column: The column to patch on 516 | 517 | Returns: 518 | `True` if `matcher_func` was not supplied or returns `True`, `False` otherwise 519 | """ 520 | if self.matcher_func is None: 521 | return True 522 | return self.matcher_func(line, col) 523 | 524 | def concretize(self, db: models.DB, location: Location) -> Iterator[str]: 525 | """ 526 | Concretize the inner `FixPattern` into a sequence of patch candidates. 527 | 528 | Args: 529 | db: The AST database to concretize against 530 | location: The location to concretize at 531 | 532 | Returns: 533 | A generator of strings, each of which is a candidate patch 534 | """ 535 | 536 | yield from self.fix_pattern.concretize(db, location) 537 | 538 | def view(self, db: models.DB, location: Location) -> str: 539 | return self.fix_pattern.view(db, location) 540 | --------------------------------------------------------------------------------