├── frontend ├── __init__.py ├── parser │ ├── __init__.py │ └── ast_util.py ├── python │ ├── collect_all.py │ ├── collect_source.py │ ├── navigate.py │ ├── collect_focal.py │ └── collect_test.py ├── cpp │ ├── collect_focal.py │ └── collect_all.py ├── go │ ├── collect_focal.py │ └── collect_all.py ├── java │ ├── collect_focal.py │ └── collect_all.py ├── javascript │ ├── js_util.py │ └── collect_all.py ├── rust │ ├── rust_util.py │ ├── collect_all.py │ └── collect_fuzz.py └── util.py ├── scripts ├── __init__.py ├── env.sh ├── common.py ├── decompress_repos.py ├── download_repos.py └── check_repo_stats.py ├── unitsyncer ├── __init__.py ├── util.py ├── extract_def.py ├── rust_syncer.py └── source_code.py ├── data └── repos │ ├── rust_example │ ├── .gitignore │ ├── src │ │ └── lib.rs │ ├── tests │ │ └── test_add.rs │ ├── Cargo.lock │ └── Cargo.toml │ ├── c_example │ ├── lib.h │ └── main.c │ ├── cpp_example │ ├── lib.hpp │ └── main.cpp │ ├── py_example │ ├── src │ │ ├── add.py │ │ └── classes.py │ ├── main.py │ └── tests │ │ └── test_python_example.py │ └── java_example │ └── Add.java ├── evaluation ├── dockerfiles │ ├── Dockerfile.python │ └── Dockerfile.eval ├── rust │ ├── rust_test_coverage.sh │ ├── coverage.py │ └── compile.py ├── README.md ├── exec_docker.py ├── data_quality.py └── execution.py ├── mypy.ini ├── requirements.txt ├── .github └── workflows │ ├── mypy.yml │ └── pylint.yml ├── tests ├── evaluation │ ├── test_rust_coverage.py │ ├── test_rust_compile.py │ └── test_eval.py ├── test_source_code.py ├── frontend │ ├── test_go.py │ ├── test_cpp.py │ ├── test_rust.py │ └── test_java.py └── test_prompt_header.py ├── .gitignore ├── README.md └── main.py /frontend/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /unitsyncer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/repos/rust_example/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /data/repos/c_example/lib.h: -------------------------------------------------------------------------------- 1 | int add(int a, int b) { return a + b; } -------------------------------------------------------------------------------- /data/repos/cpp_example/lib.hpp: -------------------------------------------------------------------------------- 1 | int add(int a, int b) { return a + b; } -------------------------------------------------------------------------------- /data/repos/py_example/src/add.py: -------------------------------------------------------------------------------- 1 | def add(x: int, y: int) -> int: 2 | return x + y 3 | -------------------------------------------------------------------------------- /data/repos/c_example/main.c: -------------------------------------------------------------------------------- 1 | #include "lib.h" 2 | 3 | int main() { 4 | int c = add(1, 2); 5 | return 0; 6 | } -------------------------------------------------------------------------------- /data/repos/cpp_example/main.cpp: -------------------------------------------------------------------------------- 1 | #include "lib.hpp" 2 | 3 | int main() { 4 | int c = add(1, 2); 5 | return 0; 6 | } -------------------------------------------------------------------------------- /data/repos/py_example/main.py: -------------------------------------------------------------------------------- 1 | from src.add import add 2 | 3 | if __name__ == "__main__": 4 | z = add(1, 2) 5 | -------------------------------------------------------------------------------- /data/repos/rust_example/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub fn add(left: usize, right: usize) -> usize { 2 | left + right 3 | } 4 | -------------------------------------------------------------------------------- /data/repos/rust_example/tests/test_add.rs: -------------------------------------------------------------------------------- 1 | use rust_example::add; 2 | 3 | #[test] 4 | fn it_adds_two() { 5 | assert_eq!(4, add(1, 3)); 6 | } 7 | -------------------------------------------------------------------------------- /scripts/env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Environment setup for non-docker user. 3 | export UNITSYNCER_HOME=`pwd` 4 | export CORES=`nproc` 5 | 6 | export PYTHONPATH=$PYTHONPATH:$UNITSYNCER_HOME -------------------------------------------------------------------------------- /data/repos/py_example/src/classes.py: -------------------------------------------------------------------------------- 1 | class Person: 2 | def __init__(self, name) -> None: 3 | self.name = name 4 | 5 | def greet(self): 6 | return f"Hello, {self.name}" 7 | -------------------------------------------------------------------------------- /data/repos/rust_example/Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "rust_example" 7 | version = "0.1.0" 8 | -------------------------------------------------------------------------------- /data/repos/rust_example/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust_example" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | -------------------------------------------------------------------------------- /data/repos/py_example/tests/test_python_example.py: -------------------------------------------------------------------------------- 1 | from src.add import add 2 | from src.classes import Person 3 | 4 | 5 | def test_add(): 6 | assert add(1, 2) == 3 7 | 8 | 9 | def test_greet(): 10 | p = Person("John") 11 | assert p.greet() == "Hello, John" 12 | -------------------------------------------------------------------------------- /evaluation/dockerfiles/Dockerfile.python: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | ENV DEBIAN_FRONTEND noninteractive 4 | RUN apt-get update && \ 5 | apt-get -y upgrade 6 | RUN apt-get install -y -q git build-essential wget python3 python3-pip && \ 7 | apt-get clean 8 | 9 | RUN pip install pytest pytest-cov 10 | 11 | WORKDIR /home 12 | -------------------------------------------------------------------------------- /evaluation/rust/rust_test_coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TARGET=$1 4 | RUSTFLAGS="-C instrument-coverage" LLVM_PROFILE_FILE="$TARGET-%p-%m.profraw" cargo test $TARGET &> /dev/null 5 | grcov $TARGET*.profraw -s . --binary-path ./target/debug/ -t html --branch --ignore-not-existing -o ./target/debug/coverage/$TARGET 6 | rm $TARGET*.profraw 7 | 8 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.10 3 | warn_return_any = True 4 | warn_unused_configs = True 5 | ignore_missing_imports = True 6 | check_untyped_defs = True 7 | allow_redefinition = false 8 | ignore_errors = false 9 | implicit_reexport = false 10 | local_partial_types = true 11 | no_implicit_optional = true 12 | strict_equality = true 13 | strict_optional = true 14 | warn_no_return = true 15 | warn_redundant_casts = true 16 | ; warn_unreachable = true 17 | warn_unused_ignores = true 18 | plugins = 19 | returns.contrib.mypy.returns_plugin -------------------------------------------------------------------------------- /frontend/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from tree_sitter import Language 2 | import tree_sitter_rust as tsrust 3 | import tree_sitter_java as tsjava 4 | import tree_sitter_javascript as tsjavascript 5 | import tree_sitter_go as tsgo 6 | import tree_sitter_python as python 7 | import tree_sitter_cpp as tscpp 8 | 9 | JAVA_LANGUAGE = Language(tsjava.language(), "java") 10 | JAVASCRIPT_LANGUAGE = Language(tsjavascript.language(), "javascript") 11 | RUST_LANGUAGE = Language(tsrust.language(), "rust") 12 | GO_LANGUAGE = Language(tsgo.language(), "go") 13 | CPP_LANGUAGE = Language(tscpp.language(), "cpp") 14 | PYTHON_LANGUAGE = Language(python.language(), "python") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/yeger00/pylspclient.git@bd58ae7ee5773fd3b4df8c28480a8df9279cd59a # use older version to avoid conflict with sansio-lsp-server 2 | python-lsp-server>=1.8.2 3 | jedi>=0.19.1 4 | returns[compatible-mypy]>=0.23.0 5 | jsonlines>=4.0.0 6 | tqdm>=4.64.1 7 | lambdas>=0.2.0 8 | sansio-lsp-client==0.10.0 9 | pygithub==2.3.0 10 | pytest==8.3.2 11 | astor==0.8.1 12 | jedi==0.19.1 13 | fire==0.6.0 14 | python-Levenshtein==0.25.1 15 | fuzzywuzzy==0.18.0 16 | cytoolz==0.12.3 17 | orjson==3.10.6 18 | pathos==0.3.2 19 | funcy==2.0 20 | funcy_chain>=0.2.0 21 | matplotlib==3.8.3 22 | pandas==2.2.2 23 | 24 | # tree-sitter dependencies 25 | tree-sitter==0.21.3 26 | tree-sitter-go==0.21.0 27 | tree-sitter-cpp==0.22.3 28 | tree-sitter-java==0.21.0 29 | tree-sitter-javascript==0.21.4 30 | tree-sitter-python==0.21.0 31 | tree-sitter-rust==0.21.2 -------------------------------------------------------------------------------- /data/repos/java_example/Add.java: -------------------------------------------------------------------------------- 1 | 2 | import org.junit.jupiter.api.AfterAll; 3 | import org.junit.jupiter.api.Assertions; 4 | import org.junit.jupiter.api.BeforeAll; 5 | import org.junit.jupiter.api.Test; 6 | 7 | /** 8 | * Add 9 | */ 10 | public class Add { 11 | public static void main(String[] args) { 12 | int a = 10; 13 | int b = 20; 14 | int c = add(a, b); 15 | System.out.println("Sum of a and b is: " + c); 16 | } 17 | 18 | public static int add(int a, int b) { 19 | return a + b; 20 | } 21 | 22 | @Deprecated 23 | public static int sub(int a, int b) { 24 | return a - b; 25 | } 26 | 27 | 28 | @Test 29 | public void testAdd() { 30 | Assertions.assertEquals(30, add(10, 20)); 31 | } 32 | 33 | public void testSub() { 34 | Assertions.assertEquals(10, sub(20, 10)); 35 | } 36 | } -------------------------------------------------------------------------------- /evaluation/dockerfiles/Dockerfile.eval: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | ENV DEBIAN_FRONTEND noninteractive 4 | RUN apt-get update && \ 5 | apt-get -y upgrade 6 | 7 | RUN apt-get update 8 | RUN apt-get install -y -q git build-essential wget unzip && \ 9 | apt-get clean 10 | 11 | WORKDIR /home 12 | 13 | # python dependency 14 | RUN apt-get install -y python3 python3-pip 15 | RUN python3 -m pip install coverage 16 | 17 | # C++ dependency 18 | RUN apt-get install -y llvm clang 19 | 20 | # Java dependency 21 | RUN apt-get install -y openjdk-18-jdk openjdk-18-jre 22 | RUN wget https://github.com/jacoco/jacoco/releases/download/v0.8.11/jacoco-0.8.11.zip && unzip jacoco-0.8.11.zip 23 | RUN rm -r coverage doc test index.html 24 | 25 | # JS dependency 26 | RUN apt-get install -y nodejs npm 27 | RUN npm install -g nyc 28 | 29 | # Go dependency 30 | RUN apt-get install -y golang-go 31 | 32 | 33 | RUN pip install fire tqdm 34 | COPY execution.py . 35 | 36 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: MyPy Type Checking 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | type-check: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - uses: actions/checkout@v2 11 | with: 12 | fetch-depth: 1 13 | 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: "3.10.12" # Replace with the version you need 18 | 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt 23 | pip install mypy>=1.7.1 24 | - name: Type Check Frontend 25 | run: mypy frontend/*.py frontend/parser frontend/java frontend/go frontend/javascript frontend/rust frontend/cpp 26 | - name: Type Check Backend 27 | run: mypy unitsyncer main.py 28 | - name: Type Check Evaluation Code 29 | run: mypy evaluation 30 | # - name: Type Check Scripts 31 | # run: mypy scripts 32 | -------------------------------------------------------------------------------- /tests/evaluation/test_rust_coverage.py: -------------------------------------------------------------------------------- 1 | """tests for rust coverage module""" 2 | 3 | import json 4 | from typing import Iterable 5 | import fire 6 | import os 7 | from tree_sitter import Node 8 | from frontend.parser import RUST_LANGUAGE 9 | from frontend.parser.ast_util import ASTUtil 10 | from unitsyncer.util import replace_tabs 11 | from evaluation.rust.coverage import get_testcase_coverages 12 | import unittest 13 | import logging 14 | 15 | 16 | class TestRustCoverage(unittest.TestCase): 17 | def test_base64(self): 18 | workspace_dir = os.path.abspath( 19 | "data/rust_repos//marshallpierce-rust-base64/marshallpierce-rust-base64-4ef33cc" 20 | ) 21 | 22 | cov_map = get_testcase_coverages(workspace_dir) 23 | self.assertTrue(len(cov_map) >= 13) # there are 13 hand-written tests 24 | 25 | self.assertEqual(cov_map["encode_all_ascii"], 3.94) 26 | self.assertEqual(cov_map["encode_all_bytes"], 4.95) 27 | 28 | 29 | if __name__ == "__main__": 30 | logging.basicConfig(level=logging.INFO) 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /frontend/python/collect_all.py: -------------------------------------------------------------------------------- 1 | """main script for Python frontend""" 2 | import frontend.python.collect_test as collect_test 3 | import frontend.python.collect_focal as collect_focal_new 4 | import frontend.python.collect_focal_org as collect_focal_org 5 | import fire 6 | 7 | 8 | def main( 9 | repo_id: str = "ageitgey/face_recognition", 10 | test_root: str = "data/tests", 11 | repo_root: str = "data/repos", 12 | focal_root: str = "data/focal", 13 | timeout: int = 300, 14 | nprocs: int = 0, 15 | original_collect_focal: bool = False, 16 | limits: int = -1, 17 | ): 18 | collect_test.main( 19 | repo_id=repo_id, 20 | test_root=test_root, 21 | repo_root=repo_root, 22 | timeout=timeout, 23 | nprocs=nprocs, 24 | limits=limits, 25 | ) 26 | collect_focal = collect_focal_org if original_collect_focal else collect_focal_new 27 | collect_focal.main( 28 | repo_id=repo_id, 29 | test_root=test_root, 30 | repo_root=repo_root, 31 | focal_root=focal_root, 32 | timeout=timeout, 33 | nprocs=nprocs, 34 | limits=limits, 35 | ) 36 | 37 | 38 | if __name__ == "__main__": 39 | fire.Fire(main) 40 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.10.12"] 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install pylint 21 | pip install -r requirements.txt 22 | # - name: Analysing the code with pylint 23 | # run: | 24 | # pylint $(git ls-files '*.py') 25 | - name: Run Pylint 26 | run: | 27 | PYLINT_OUTPUT=$(pylint $(git ls-files '*.py') || true) 28 | PYLINT_SCORE=$(echo "$PYLINT_OUTPUT" | grep 'rated at' | sed 's/.*rated at \([0-9.]*\)\/10.*/\1/') 29 | echo "PYLINT_SCORE=$PYLINT_SCORE" >> $GITHUB_ENV 30 | echo "$PYLINT_OUTPUT" 31 | 32 | - name: Fail if below threshold (8) 33 | run: | 34 | if (( $(echo "$PYLINT_SCORE < 8.0" |bc -l) )); then 35 | exit 1 36 | fi 37 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # UniTSyncer Evaluation 2 | 3 | ## Requirements 4 | 5 | ### Python 6 | 7 | Coverage report is provided by [coverage.py](https://coverage.readthedocs.io/en/7.4.0/), 8 | please install it via 9 | 10 | ```sh 11 | python3 -m pip install coverage 12 | ``` 13 | 14 | ### C++ 15 | 16 | Coverage report is provided by [LLVM toolchain](https://github.com/llvm/llvm-project), 17 | please install it via 18 | 19 | ```sh 20 | bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)" 21 | ``` 22 | 23 | or from GitHub release. 24 | 25 | ### Java 26 | 27 | Coverage report is provided by [jacoco](https://github.com/jacoco/jacoco), 28 | please download [jacoco-0.8.11.zip] from [its website](https://www.jacoco.org/jacoco/), 29 | 30 | ### Javascript 31 | 32 | We use [istanbuljs/nyc](https://github.com/istanbuljs/nyc) to compute coverage. 33 | Please download it to the system via 34 | 35 | ```sh 36 | npm install -g nyc 37 | ``` 38 | 39 | We also require [nodejs](https://nodejs.org/en/download/current), 40 | so please also down it. 41 | 42 | ### Golang 43 | 44 | Golang's coverage is build in the compiler, so no need to install additional dependencies. 45 | However, for go, it only supports **statement coverage**. 46 | 47 | ## Running Evaluation 48 | 49 | ```sh 50 | python3 execution.py -j test_evaluation.jsonl 51 | ``` 52 | 53 | ## Docker 54 | 55 | In `UniTSyncer/evaluation`, run 56 | 57 | ```sh 58 | docker build --tag unitsyncer-eval . -f dockerfiles/Dockerfile.eval 59 | ``` 60 | -------------------------------------------------------------------------------- /tests/test_source_code.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import logging 4 | from unitsyncer.source_code import get_function_code 5 | from returns.maybe import Nothing, Some 6 | from pylspclient.lsp_structs import Location, Position, Range 7 | 8 | from unitsyncer.util import path2uri 9 | 10 | 11 | class TestSourceCode(unittest.TestCase): 12 | def test_python_get_function_code(self): 13 | uri = f"file://{os.getcwd()}/data/repos/py_example/src/add.py" 14 | range_ = Range(Position(0, 4), Position(0, 7)) 15 | loc = Location(uri, range_) 16 | add_src = "def add(x: int, y: int) -> int:\n return x + y" 17 | 18 | self.assertEqual(get_function_code(loc, "python").unwrap()[0], add_src) 19 | 20 | def test_python_get_function_code_not_found(self): 21 | uri = f"file://{os.getcwd()}/data/repos/py_example/src/add.py" 22 | range_ = Range(Position(1, 4), Position(1, 7)) 23 | loc = Location(uri, range_) 24 | 25 | self.assertEqual(get_function_code(loc, "python"), Nothing) 26 | 27 | def test_java_get_function_code(self): 28 | uri = f"file://{os.getcwd()}/data/repos/java_example/Add.java" 29 | # range_ = Range(Position(29, 36), Position(29, 37)) 30 | range_ = Range(Position(17, 22), Position(17, 23)) 31 | loc = Location(uri, range_) 32 | 33 | add_src = "public static int add(int a, int b) {\n return a + b;\n}" 34 | 35 | self.assertEqual(get_function_code(loc, "java").unwrap()[0], add_src) 36 | 37 | def test_java_get_function_code_w_annotation(self): 38 | uri = f"file://{os.getcwd()}/data/repos/java_example/Add.java" 39 | range_ = Range(Position(22, 3), Position(22, 4)) 40 | loc = Location(uri, range_) 41 | sub_src = ( 42 | "@Deprecated\npublic static int sub(int a, int b) {\n return a - b;\n}" 43 | ) 44 | 45 | self.assertEqual(get_function_code(loc, "java").unwrap()[0], sub_src) 46 | 47 | def test_real_java_method_w_annotation(self): 48 | uri = f"file://{os.getcwd()}/data/java_repos/spring-cloud-spring-cloud-netflix/spring-cloud-spring-cloud-netflix-630151f/spring-cloud-netflix-eureka-client/src/main/java/org/springframework/cloud/netflix/eureka/EurekaInstanceConfigBean.java" 49 | 50 | range_ = Range(Position(297, 22), Position(297, 23)) 51 | loc = Location(uri, range_) 52 | self.assertIsNotNone(get_function_code(loc, "java").value_or(None)) 53 | 54 | 55 | if __name__ == "__main__": 56 | logging.basicConfig(level=logging.INFO) 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /tests/frontend/test_go.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import logging 4 | from returns.maybe import Nothing, Some 5 | from frontend.go.collect_focal import get_focal_call, is_test_fn 6 | from frontend.parser.ast_util import ASTUtil 7 | from frontend.parser import GO_LANGUAGE 8 | 9 | 10 | class TestGoFrontend(unittest.TestCase): 11 | def test_is_test_fn(self): 12 | code = """ 13 | func TestDatasets(t *testing.T) { 14 | defer setupZPool(t).cleanUp() 15 | 16 | _, err := zfs.Datasets("") 17 | ok(t, err) 18 | 19 | ds, err := zfs.GetDataset("test") 20 | ok(t, err) 21 | equals(t, zfs.DatasetFilesystem, ds.Type) 22 | equals(t, "", ds.Origin) 23 | if runtime.GOOS != "solaris" { 24 | assert(t, ds.Logicalused != 0, "Logicalused is not greater than 0") 25 | } 26 | }""" 27 | 28 | ast_util = ASTUtil(code) 29 | tree = ast_util.tree(GO_LANGUAGE) 30 | root_node = tree.root_node 31 | fn = ast_util.get_all_nodes_of_type(root_node, "function_declaration")[0] 32 | self.assertTrue(is_test_fn(fn, ast_util)) 33 | 34 | code = """ 35 | func GetDataset(name string) (*Dataset, error) { 36 | out, err := zfsOutput("list", "-Hp", "-o", dsPropListOptions, name) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | ds := &Dataset{Name: name} 42 | for _, line := range out { 43 | if err := ds.parseLine(line); err != nil { 44 | return nil, err 45 | } 46 | } 47 | 48 | return ds, nil 49 | } 50 | """ 51 | ast_util = ASTUtil(code) 52 | tree = ast_util.tree(GO_LANGUAGE) 53 | root_node = tree.root_node 54 | fn = ast_util.get_all_nodes_of_type(root_node, "function_declaration")[0] 55 | self.assertFalse(is_test_fn(fn, ast_util)) 56 | 57 | def test_focal(self): 58 | code = """func TestDatasets(t *testing.T) { 59 | defer setupZPool(t).cleanUp() 60 | 61 | _, err := zfs.Datasets("") 62 | ok(t, err) 63 | 64 | ds, err := zfs.GetDataset("test") 65 | ok(t, err) 66 | equals(t, zfs.DatasetFilesystem, ds.Type) 67 | equals(t, "", ds.Origin) 68 | if runtime.GOOS != "solaris" { 69 | assert(t, ds.Logicalused != 0, "Logicalused is not greater than 0") 70 | } 71 | }""" 72 | 73 | ast_util = ASTUtil(code) 74 | tree = ast_util.tree(GO_LANGUAGE) 75 | root_node = tree.root_node 76 | 77 | fn = ast_util.get_all_nodes_of_type(root_node, "function_declaration")[0] 78 | 79 | name, loc = get_focal_call(ast_util, fn).value_or((None, None)) 80 | self.assertEqual(name, "Datasets") 81 | self.assertEqual(loc, (3, 15)) 82 | 83 | 84 | if __name__ == "__main__": 85 | logging.basicConfig(level=logging.INFO) 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /scripts/common.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import time 3 | 4 | """ 5 | Functions from github ranking repo: 6 | https://github.com/EvanLi/Github-Ranking/blob/master/source/ 7 | """ 8 | 9 | 10 | def check_metadata_decorator(func): 11 | """Decorator function to check metadata keys and value types returned by Github GraphQL API""" 12 | 13 | def check_metadata(*args, **kwargs): 14 | required_keys = [ 15 | ("id", str), 16 | ("owner", dict), 17 | ("name", str), 18 | ("url", str), 19 | ("isArchived", bool), 20 | ("isFork", bool), 21 | ("isMirror", bool), 22 | ("primaryLanguage", dict), 23 | ("pushedAt", str), 24 | ("stargazerCount", int), 25 | ("object", dict), 26 | ] 27 | improper_fields = [] 28 | for key, t in required_keys: 29 | if not (key in args[0] and type(args[0][key]) is t): 30 | improper_fields.append(key) 31 | 32 | if len(improper_fields) > 0: 33 | print(f"Metadata JSON does not contain the proper keys: {improper_fields}") 34 | return False 35 | return func(*args, **kwargs) 36 | 37 | return check_metadata 38 | 39 | 40 | def get_access_token(): 41 | with open("./oauth", "r") as f: 42 | access_token = f.read().strip() 43 | return access_token 44 | 45 | 46 | def get_graphql_data(GQL: str) -> dict: 47 | """ 48 | use graphql to get data 49 | """ 50 | access_token = get_access_token() 51 | headers = { 52 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/81.0.4044.113 Safari/537.36", 53 | "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", 54 | "Accept-Language": "zh-CN,zh;q=0.9", 55 | "Authorization": "bearer {}".format(access_token), 56 | } 57 | s = requests.session() 58 | s.keep_alive = False # don't keep the session 59 | graphql_api = "https://api.github.com/graphql" 60 | for _ in range(5): 61 | time.sleep(2) # not get so fast 62 | try: 63 | # requests.packages.urllib3.disable_warnings() # disable InsecureRequestWarning of verify=False, 64 | r = requests.post( 65 | url=graphql_api, json={"query": GQL}, headers=headers, timeout=30 66 | ) 67 | if r.status_code != 200: 68 | print( 69 | f"Can not retrieve from {GQL}. Response status is {r.status_code}, content is {r.content}." 70 | ) 71 | else: 72 | return r.json() 73 | except Exception as e: 74 | print(e) 75 | time.sleep(5) 76 | 77 | -------------------------------------------------------------------------------- /scripts/decompress_repos.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collects all the test functions from projects following 3 | "Conventions for Python test discovery" in 4 | https://docs.pytest.org/en/7.4.x/explanation/goodpractices.html#test-discovery 5 | """ 6 | 7 | import os 8 | import re 9 | import sys 10 | import ast 11 | import fire 12 | import tarfile 13 | from tqdm import tqdm 14 | from multiprocessing import Pool 15 | 16 | from frontend.util import wrap_repo 17 | 18 | 19 | def decompress(task: tuple[str, str], optimize_storage: bool = True): 20 | """decompress tar.gz file 21 | 22 | Args: 23 | task (tuple[str, str]): input tar path, output path 24 | optimize_storage (bool, optional): ignore some file extensions. Defaults to True. 25 | NOTE: this might affect efficiency 26 | 27 | Returns: 28 | int: extract status, 0 if success, 1 if input file not found, 2 if error 29 | """ 30 | ipath, opath = task 31 | if not os.path.exists(ipath): 32 | return 1 33 | # if os.path.exists(opath): return 2 34 | try: 35 | if optimize_storage: 36 | exclude_extensions = {".png", ".jpg", ".JPEG", ".jpeg", ".bin", ".pkl"} 37 | with tarfile.open(ipath, "r:gz") as tar: 38 | for member in tar.getmembers(): 39 | if not any(member.name.endswith(ext) for ext in exclude_extensions): 40 | tar.extract(member, path=opath) 41 | else: 42 | tarfile.open(ipath).extractall(opath) 43 | except: 44 | return 2 45 | return 0 46 | 47 | 48 | def main( 49 | repo_id_list: str = "ageitgey/face_recognition", 50 | timeout: int = -1, 51 | iroot: str = "data/repos_tarball/", 52 | oroot: str = "data/repos/", 53 | ): 54 | # if repo_id_list is a file then load lines 55 | # otherwise it is the id of a specific repo 56 | try: 57 | repo_id_list = [l.strip() for l in open(repo_id_list, "r").readlines()] 58 | except: 59 | repo_id_list = [repo_id_list] 60 | print(f"Loaded {len(repo_id_list)} repos to be processed") 61 | 62 | tasks = [ 63 | ( 64 | os.path.join(iroot, wrap_repo(repo_id)) + ".tar.gz", 65 | os.path.join(oroot, wrap_repo(repo_id)), 66 | ) 67 | for repo_id in repo_id_list 68 | ] 69 | results = [] 70 | with Pool(16) as p: 71 | with tqdm(total=len(tasks)) as pbar: 72 | for status in p.imap_unordered(decompress, tasks): 73 | results.append(status) 74 | pbar.update() 75 | failed = {"input": 0, "output": 0} 76 | failed["input"] = sum([i == 1 for i in results]) 77 | failed["output"] = sum([i == 2 for i in results]) 78 | if sum(failed.values()): 79 | print("Failed:", {key: val for key, val in failed.items() if val}) 80 | print("Done!") 81 | 82 | 83 | if __name__ == "__main__": 84 | fire.Fire(main) 85 | -------------------------------------------------------------------------------- /tests/frontend/test_cpp.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import logging 4 | from returns.maybe import Nothing, Some 5 | from frontend.cpp.collect_focal import get_focal_call 6 | from frontend.parser.ast_util import ASTUtil 7 | from frontend.parser import CPP_LANGUAGE, JAVA_LANGUAGE 8 | 9 | 10 | class TestCppFrontend(unittest.TestCase): 11 | """testing C++ frontend for gtest""" 12 | 13 | def __test_focal_helper(self, code: str): 14 | ast_util = ASTUtil(code) 15 | tree = ast_util.tree(CPP_LANGUAGE) 16 | root_node = tree.root_node 17 | 18 | fn = ast_util.get_all_nodes_of_type(root_node, "function_definition")[0] 19 | return get_focal_call(ast_util, fn) 20 | 21 | def test_focal(self): 22 | """ 23 | For a regular @Test function, with function call in `EXPECT`, 24 | that function call should be the focal 25 | """ 26 | code = """ 27 | TEST(OpenACCTest, DirectiveHelpers) { 28 | EXPECT_EQ(getOpenACCDirectiveKind(""), ACCD_unknown); 29 | EXPECT_EQ(getOpenACCDirectiveKind("dummy"), ACCD_unknown); 30 | EXPECT_EQ(getOpenACCDirectiveKind("atomic"), ACCD_atomic); 31 | EXPECT_EQ(getOpenACCDirectiveKind("cache"), ACCD_cache); 32 | EXPECT_EQ(getOpenACCDirectiveKind("data"), ACCD_data); 33 | EXPECT_EQ(getOpenACCDirectiveKind("declare"), ACCD_declare); 34 | } 35 | """ 36 | 37 | name, loc = self.__test_focal_helper(code).unwrap() 38 | self.assertEqual(name, "getOpenACCDirectiveKind") 39 | self.assertEqual(loc, (2, 12)) 40 | 41 | def test_focal_not_in_assert(self): 42 | """ 43 | if there is no function call in the first `assertThat`, 44 | the the last call before first `assertThat` is the focal 45 | """ 46 | code = """ 47 | TEST(AsmWriterTest, DebugPrintDetachedArgument) { 48 | LLVMContext Ctx; 49 | auto Ty = Type::getInt32Ty(Ctx); 50 | auto Arg = new Argument(Ty); 51 | 52 | std::string S; 53 | raw_string_ostream OS(S); 54 | Arg->print(OS); 55 | EXPECT_EQ(S, "i32 "); 56 | delete Arg; 57 | }""" 58 | 59 | name, loc = self.__test_focal_helper(code).unwrap() 60 | self.assertEqual(name, "print") 61 | self.assertEqual(loc, (8, 7)) 62 | 63 | code = """ 64 | TEST(BFSTest, InstantiateGraphFromEdges) 65 | { 66 | Graph g({ {1, 2}, {1, 3}, {2, 3} }); 67 | 68 | std::vector bfs = g.BFS(1); 69 | std::vector expected{ 1, 2, 3 }; 70 | 71 | ASSERT_EQ(bfs, expected); 72 | } 73 | """ 74 | name, loc = self.__test_focal_helper(code).unwrap() 75 | self.assertEqual(name, "BFS") 76 | self.assertEqual(loc, (5, 29)) 77 | 78 | def test_focal_not_assert(self): 79 | """If no assert in test function, then fail""" 80 | 81 | code = """ 82 | TEST(TestNothing, SomeTest) { 83 | int a = 1 + 2; 84 | int b = 2 + 3; 85 | } 86 | """ 87 | self.assertEqual(self.__test_focal_helper(code), Nothing) 88 | 89 | 90 | if __name__ == "__main__": 91 | logging.basicConfig(level=logging.INFO) 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /evaluation/exec_docker.py: -------------------------------------------------------------------------------- 1 | """Execution Runtime Metrics for Python, Java, Go, C++, and JS""" 2 | 3 | import logging 4 | import fire 5 | import json 6 | import docker 7 | import contextlib 8 | import os 9 | import tempfile 10 | import subprocess 11 | import random 12 | from multiprocessing import Pool 13 | from tqdm import tqdm 14 | 15 | 16 | @contextlib.contextmanager 17 | def build_image(repo_id: str): 18 | _, workdir = repo_id.split("/") 19 | dockerfile = f""" 20 | FROM unitsyncer-eval:python 21 | ENV DEBIAN_FRONTEND noninteractive 22 | RUN git clone https://github.com/{repo_id} 23 | WORKDIR {workdir} 24 | RUN pip install -r requirements.txt 25 | """ 26 | temp_file = tempfile.NamedTemporaryFile(prefix="unitsyncer_") 27 | with open(temp_file.name, "w") as f: 28 | f.write(dockerfile) 29 | try: 30 | subprocess.run( 31 | ["docker", "build", "--tag", repo_id, ".", "-f", temp_file.name], 32 | check=True, 33 | stdout=subprocess.DEVNULL, 34 | stderr=subprocess.DEVNULL, 35 | ) 36 | yield 37 | finally: 38 | subprocess.run( 39 | ["docker", "rmi", repo_id, "-f"], 40 | stdout=subprocess.DEVNULL, 41 | stderr=subprocess.DEVNULL, 42 | ) 43 | 44 | 45 | def parse_pytest_output_coverage(stdout: str) -> float | None: 46 | lines = stdout.splitlines() 47 | for line in reversed(lines): 48 | if "TOTAL" in line: 49 | elems = line.split(" ") 50 | return float(elems[-1].strip("%")) 51 | return None 52 | 53 | 54 | def get_py_coverage(repo_id: str): 55 | try: 56 | with build_image(repo_id): 57 | client = docker.from_env() 58 | res = client.containers.run(repo_id, "pytest --cov=. tests") 59 | if isinstance(res, bytes): 60 | stdout = res.decode("utf-8") 61 | return parse_pytest_output_coverage(stdout) 62 | except: 63 | with open("coverage.log", "a") as fp: 64 | fp.write(repo_id + "\n") 65 | return None 66 | 67 | return None 68 | 69 | 70 | def main(repo_list_path: str, lang: str, nproc: int = 20, seed: int = 0): 71 | with open(repo_list_path, "r") as fp: 72 | repo_list = fp.read().splitlines()[:10000] 73 | 74 | with open("coverage.log", "r") as fp: 75 | skip_list = fp.read().splitlines() 76 | 77 | repo_list = [repo for repo in repo_list if repo not in skip_list] 78 | 79 | # with Pool(nproc) as pool: 80 | # covs = list(tqdm(pool.imap(get_py_coverage, repo_list), total=len(repo_list))) 81 | 82 | with open(f"{lang}_coverage.jsonl", "a") as writer: 83 | for repo_id in tqdm(repo_list): 84 | cov = get_py_coverage(repo_id) 85 | if cov is not None: 86 | writer.write(json.dumps({"repo_id": repo_id, "coverage": cov}) + "\n") 87 | else: 88 | logging.warning(f"{repo_id} no coverage") 89 | 90 | 91 | if __name__ == "__main__": 92 | logging.basicConfig(level=logging.INFO) 93 | fire.Fire(main) 94 | -------------------------------------------------------------------------------- /unitsyncer/util.py: -------------------------------------------------------------------------------- 1 | """util functions for UniTSyncer backend""" 2 | import threading 3 | from returns.maybe import Maybe, Nothing, Some 4 | from pathos.multiprocessing import ProcessPool 5 | import sys 6 | import io 7 | from itertools import chain 8 | from typing import Callable, Iterable, TypeVar, overload 9 | from functools import reduce 10 | from operator import add 11 | 12 | from frontend.parser.ast_util import ASTUtil 13 | from tree_sitter import Node 14 | 15 | 16 | class ReadPipe(threading.Thread): 17 | """source: 18 | https://github.com/yeger00/pylspclient/blob/master/examples/python-language-server.py#L10 19 | """ 20 | 21 | def __init__(self, pipe): 22 | threading.Thread.__init__(self) 23 | self.pipe = pipe 24 | 25 | def run(self): 26 | line = self.pipe.readline().decode("utf-8") 27 | while line: 28 | line = self.pipe.readline().decode("utf-8") 29 | 30 | 31 | def uri2path(uri: str) -> Maybe[str]: 32 | if uri.startswith("file://"): 33 | return Some(uri[7:]) 34 | return Nothing 35 | 36 | 37 | def path2uri(path: str) -> str: 38 | """ 39 | Args: 40 | path (str): absolute path to file 41 | 42 | Returns: 43 | str: uri format of path 44 | """ 45 | return "file://" + path 46 | 47 | 48 | def parallel_starmap(f, args, jobs=1): 49 | with ProcessPool(jobs) as p: 50 | rnt = p.map(lambda x: f(*x), args) 51 | return rnt 52 | 53 | 54 | def replace_tabs(text: str, n_space=4) -> str: 55 | """replace each tab with 4 spaces""" 56 | return text.replace("\t", " " * n_space) 57 | 58 | 59 | def silence(func): 60 | """Execute a function with suppressed stdout.""" 61 | 62 | def wrapper(*args, **kwargs): 63 | original_stdout = sys.stdout 64 | try: 65 | # Redirect stdout to a dummy file-like object 66 | sys.stdout = io.StringIO() 67 | return func(*args, **kwargs) 68 | finally: 69 | # Restore original stdout 70 | sys.stdout = original_stdout 71 | 72 | return wrapper 73 | 74 | 75 | def convert_to_seconds(s: str) -> int: 76 | seconds_per_unit = {"s": 1, "m": 60, "h": 3600, "d": 86400, "w": 604800} 77 | return int(s[:-1]) * seconds_per_unit[s[-1]] 78 | 79 | 80 | def get_cpp_func_name(ast_util: ASTUtil, node: Node) -> Maybe[str]: 81 | """extract function name from function_definition node""" 82 | for child in node.children: 83 | if child.type == "function_declarator": 84 | declarator = ast_util.get_source_from_node(child) 85 | func_name = declarator.split("(")[0] 86 | return Some(func_name) 87 | 88 | return Nothing 89 | 90 | 91 | T = TypeVar("T") 92 | U = TypeVar("U") 93 | 94 | 95 | def concatMap(func: Callable[[T], Iterable[U]], iterable: Iterable[T]) -> Iterable[U]: 96 | """creates a list from a list generating function by application of this function 97 | on all elements in a list passed as the second argument 98 | 99 | 100 | Args: 101 | func: T -> [U] 102 | iterable: [T] 103 | 104 | Returns: [U] 105 | """ 106 | return reduce(add, map(func, iterable)) 107 | -------------------------------------------------------------------------------- /evaluation/rust/coverage.py: -------------------------------------------------------------------------------- 1 | """Functions for running cargo-fuzz and get coverage for test cases""" 2 | 3 | import json 4 | from typing import Iterable, Iterator 5 | import fire 6 | import os 7 | from tree_sitter import Node 8 | from frontend.parser import RUST_LANGUAGE 9 | from frontend.parser.ast_util import ASTUtil 10 | from unitsyncer.util import replace_tabs 11 | import subprocess 12 | from unitsyncer.common import UNITSYNCER_HOME 13 | from returns.maybe import Maybe, Some, Nothing 14 | from frontend.rust.rust_util import get_test_functions 15 | from frontend.rust.collect_all import collect_test_files 16 | from unitsyncer.util import concatMap 17 | 18 | 19 | def clean_workspace(workspace_dir: str): 20 | subprocess.run(["rm", "rust_test_coverage.sh"], cwd=workspace_dir) 21 | subprocess.run(["rm", "-r", "target"], cwd=workspace_dir) 22 | 23 | 24 | def init_workspace(workspace_dir: str): 25 | cov_script_path = f"{UNITSYNCER_HOME}/evaluation/rust/rust_test_coverage.sh" 26 | subprocess.run(["cp", cov_script_path, workspace_dir]) 27 | 28 | 29 | def get_coverage( 30 | workspace_dir: str, test_target: str, clean_run: bool = False 31 | ) -> Maybe[float]: 32 | if clean_run: 33 | clean_workspace(workspace_dir) 34 | 35 | init_workspace(workspace_dir) 36 | 37 | subprocess.run( 38 | ["./rust_test_coverage.sh", test_target], 39 | cwd=workspace_dir, 40 | stdout=subprocess.DEVNULL, 41 | stderr=subprocess.DEVNULL, 42 | ) 43 | 44 | try: 45 | cov_path = f"{workspace_dir}/target/debug/coverage/{test_target}/coverage.json" 46 | with open(cov_path) as f: 47 | cov_obj = json.load(f) 48 | except FileNotFoundError: 49 | return Nothing 50 | 51 | if cov_obj["label"] != "coverage" or "message" not in cov_obj: 52 | return Nothing 53 | return Some(float(cov_obj["message"][:-1])) 54 | 55 | 56 | def get_tests(workspace_dir: str) -> Iterable[str]: 57 | """get all test targets from project by looking for fn with #[test] in tests dir""" 58 | 59 | def get_tests_from_file(fpath: str) -> list[str]: 60 | with open(fpath) as f: 61 | code = f.read() 62 | ast_util = ASTUtil(code) 63 | tree = ast_util.tree(RUST_LANGUAGE) 64 | test_nodes = get_test_functions(ast_util, tree.root_node) 65 | return [t.unwrap() for t in map(ast_util.get_name, test_nodes) if t != Nothing] 66 | 67 | return concatMap( 68 | get_tests_from_file, 69 | collect_test_files(os.path.join(workspace_dir, "tests"), False), 70 | ) 71 | 72 | 73 | def get_testcase_coverages(workspace_dir: str) -> dict[str, float]: 74 | """get coverage of each individual testcase in the tests sub-directory 75 | 76 | Args: 77 | workspace_dir (str): root of the project workspace 78 | 79 | Returns: 80 | dict[str, float]: {testcase_name: its coverage} 81 | """ 82 | coverages = {} 83 | for test_name in get_tests(workspace_dir): 84 | cov = get_coverage(workspace_dir, test_name).unwrap() 85 | coverages[test_name] = cov 86 | return coverages 87 | 88 | 89 | def main(): 90 | workspace_dir = os.path.abspath( 91 | "data/repos/marshallpierce-rust-base64/marshallpierce-rust-base64-4ef33cc" 92 | ) 93 | 94 | print(get_testcase_coverages(workspace_dir)) 95 | 96 | 97 | if __name__ == "__main__": 98 | fire.Fire(main) 99 | -------------------------------------------------------------------------------- /frontend/cpp/collect_focal.py: -------------------------------------------------------------------------------- 1 | """find focal call in Java test functions""" 2 | 3 | from typing import Optional 4 | import fire 5 | import re 6 | from tree_sitter import Node 7 | from frontend.parser import CPP_LANGUAGE 8 | from frontend.parser.ast_util import ASTUtil, ASTLoc, flatten_postorder 9 | from returns.maybe import Maybe, Nothing, Some, maybe 10 | from unitsyncer.util import get_cpp_func_name 11 | 12 | 13 | def is_test_fn(n: Node, ast_util: ASTUtil): 14 | def is_gtest_testcase(node: Node): 15 | return ( 16 | get_cpp_func_name(ast_util, node) 17 | .map(lambda name: name == "TEST") 18 | .value_or(False) 19 | ) 20 | 21 | return is_gtest_testcase(n) 22 | 23 | 24 | def get_focal_call(ast_util: ASTUtil, func: Node) -> Maybe[tuple[str, ASTLoc]]: 25 | """Find the focal call in the given function 26 | 27 | Args: 28 | ast_util (ASTUtil): ASTUtil for the file 29 | func (Node): a method_declaration node 30 | 31 | Returns: 32 | Maybe[tuple[str, ASTLoc]]: focal call and its location 33 | """ 34 | 35 | calls = flatten_postorder(func, "call_expression") 36 | 37 | # reverse for postorder 38 | func_calls = [ast_util.get_source_from_node(call) for call in calls] 39 | calls_before_assert: list[str] = [] 40 | has_assert = False 41 | for call in func_calls: 42 | if "EXPECT" in call or "ASSERT" in call: 43 | has_assert = True 44 | break 45 | calls_before_assert.append(call) 46 | 47 | if not has_assert or not calls_before_assert: 48 | return Nothing 49 | 50 | def get_loc(call: str) -> Maybe[tuple[str, ASTLoc]]: 51 | """add offset to nested function calls 52 | 53 | Args: 54 | call (str): code str of a method_invocation 55 | 56 | Returns: 57 | Maybe[tuple[str, ASTLoc]]: (method_name, its location) 58 | """ 59 | idx = func_calls.index(call) 60 | node = calls[idx] 61 | lineno, col = node.start_point 62 | 63 | # regular expression to find method names 64 | pattern = r"(\w+)\s*\(" 65 | matches = list(re.finditer(pattern, call)) 66 | 67 | if matches: 68 | last_match = matches[-1] 69 | method_name = last_match.group(1) 70 | offset = last_match.start(1) 71 | loc = (lineno, col + offset) 72 | return Some((method_name, loc)) 73 | 74 | return Nothing 75 | 76 | return get_loc(calls_before_assert[-1]) 77 | 78 | 79 | def main(): 80 | code = """ 81 | TEST(OpenACCTest, DirectiveHelpers) { 82 | EXPECT_EQ(getOpenACCDirectiveKind(""), ACCD_unknown); 83 | EXPECT_EQ(getOpenACCDirectiveKind("dummy"), ACCD_unknown); 84 | EXPECT_EQ(getOpenACCDirectiveKind("atomic"), ACCD_atomic); 85 | EXPECT_EQ(getOpenACCDirectiveKind("cache"), ACCD_cache); 86 | EXPECT_EQ(getOpenACCDirectiveKind("data"), ACCD_data); 87 | EXPECT_EQ(getOpenACCDirectiveKind("declare"), ACCD_declare); 88 | } 89 | """ 90 | 91 | ast_util = ASTUtil(code) 92 | tree = ast_util.tree(CPP_LANGUAGE) 93 | root_node = tree.root_node 94 | 95 | func_delcs = ast_util.get_all_nodes_of_type(root_node, "function_definition") 96 | func = func_delcs[0] 97 | 98 | print(get_focal_call(ast_util, func)) 99 | 100 | 101 | if __name__ == "__main__": 102 | fire.Fire(main) 103 | -------------------------------------------------------------------------------- /frontend/go/collect_focal.py: -------------------------------------------------------------------------------- 1 | """find focal call in Golang test function""" 2 | 3 | import fire 4 | import re 5 | from tree_sitter import Node 6 | from frontend.parser import GO_LANGUAGE 7 | from frontend.parser.ast_util import ASTUtil, ASTLoc, flatten_postorder 8 | from returns.maybe import Maybe, Nothing, Some, maybe 9 | 10 | 11 | def get_focal_call(ast_util: ASTUtil, func: Node) -> Maybe[tuple[str, ASTLoc]]: 12 | """Find the focal call in the given function 13 | 14 | Args: 15 | ast_util (ASTUtil): ASTUtil for the file 16 | func (Node): a method_declaration node 17 | 18 | Returns: 19 | Maybe[tuple[str, ASTLoc]]: focal call and its location 20 | """ 21 | 22 | # todo: find better heuristic to match object on imports 23 | # get the focal call from the given function 24 | 25 | calls = flatten_postorder(func, "call_expression") 26 | 27 | def get_basename(call: Node) -> str: 28 | func_name = ast_util.get_all_nodes_of_type(call, "identifier")[0] 29 | return ast_util.get_source_from_node(func_name) 30 | 31 | calls_before_assert: list[tuple[str, Node]] = [] 32 | has_assert = False 33 | for call in calls: 34 | base_name = get_basename(call) 35 | if base_name in {"ok", "equals", "assert"}: 36 | has_assert = True 37 | break 38 | 39 | full_name = ast_util.get_source_from_node(call).split("(")[0] 40 | calls_before_assert.append((full_name, call)) 41 | 42 | if not has_assert or not calls_before_assert: 43 | return Nothing 44 | 45 | # todo: check if focal is imported from workdir package 46 | def get_loc(n: tuple[str, Node]) -> tuple[str, ASTLoc]: 47 | full_name, node = n 48 | lineno, col = node.start_point 49 | match full_name.split("."): 50 | case [obj_name, *_, method_name]: # pylint: disable=unused-variable 51 | offset = len(full_name) - len(method_name) 52 | return method_name, (lineno, col + offset) 53 | case _: 54 | return full_name, (lineno, col) 55 | 56 | return Some(get_loc(calls_before_assert[-1])) 57 | 58 | 59 | def is_test_fn(node: Node, ast_util: ASTUtil) -> bool: 60 | """check if the function is a Golang test function 61 | 62 | Golang requires a test function to have `t *testing.T` input parameter 63 | 64 | Args: 65 | node (Node): a function_delc node 66 | ast_util (ASTUtil): node's util object 67 | 68 | Returns: 69 | bool: true if node is a unit test function, otherwise not 70 | """ 71 | # func delc always have parameter_list 72 | params = ast_util.get_all_nodes_of_type(node, "parameter_list")[0] 73 | param_types = ast_util.get_all_nodes_of_type(params, "qualified_type") 74 | 75 | return "testing.T" in map(ast_util.get_source_from_node, param_types) 76 | 77 | 78 | def main(): 79 | code = """func TestDatasets(t *testing.T) { 80 | defer setupZPool(t).cleanUp() 81 | 82 | _, err := zfs.Datasets("") 83 | ok(t, err) 84 | 85 | ds, err := zfs.GetDataset("test") 86 | ok(t, err) 87 | equals(t, zfs.DatasetFilesystem, ds.Type) 88 | equals(t, "", ds.Origin) 89 | if runtime.GOOS != "solaris" { 90 | assert(t, ds.Logicalused != 0, "Logicalused is not greater than 0") 91 | } 92 | }""" 93 | 94 | ast_util = ASTUtil(code) 95 | tree = ast_util.tree(GO_LANGUAGE) 96 | root_node = tree.root_node 97 | 98 | fn = ast_util.get_all_nodes_of_type(root_node, "function_declaration")[0] 99 | print(get_focal_call(ast_util, fn)) 100 | 101 | 102 | if __name__ == "__main__": 103 | fire.Fire(main) 104 | -------------------------------------------------------------------------------- /frontend/java/collect_focal.py: -------------------------------------------------------------------------------- 1 | """find focal call in Java test functions""" 2 | 3 | from typing import Optional 4 | import fire 5 | import re 6 | from tree_sitter import Node 7 | from frontend.parser import JAVA_LANGUAGE 8 | from frontend.parser.ast_util import ASTUtil, ASTLoc, flatten_postorder 9 | from returns.maybe import Maybe, Nothing, Some, maybe 10 | 11 | 12 | def is_test_fn(node: Node, ast_util: ASTUtil): 13 | def has_test_modifier(node: Node): 14 | modifiers = ast_util.get_method_modifiers(node) 15 | return modifiers.map(lambda x: "@Test" in x).value_or(False) 16 | 17 | return has_test_modifier(node) 18 | 19 | 20 | def fuzzy_focal_name(test_func_name: str) -> str: 21 | patterns = [r"test_(\w+)", r"(\w+)_test", r"Test(\w+)", r"(\w+)Test"] 22 | 23 | for pattern in patterns: 24 | match = re.search(pattern, test_func_name, re.IGNORECASE) 25 | if match: 26 | return match.group(1) 27 | 28 | return test_func_name 29 | 30 | 31 | def get_focal_call(ast_util: ASTUtil, func: Node) -> Maybe[tuple[str, ASTLoc]]: 32 | """Find the focal call in the given function 33 | 34 | Args: 35 | ast_util (ASTUtil): ASTUtil for the file 36 | func (Node): a method_declaration node 37 | 38 | Returns: 39 | Maybe[tuple[str, ASTLoc]]: focal call and its location 40 | """ 41 | 42 | # todo: find better heuristic to match object on imports 43 | calls = flatten_postorder(func, "method_invocation") 44 | 45 | # reverse for postorder 46 | func_calls = [ast_util.get_source_from_node(call) for call in calls] 47 | calls_before_assert: list[str] = [] 48 | has_assert = False 49 | for call in func_calls: 50 | if "assert" in call: 51 | has_assert = True 52 | break 53 | calls_before_assert.append(call) 54 | 55 | if not has_assert or not calls_before_assert: 56 | return Nothing 57 | 58 | def get_loc(call: str) -> Maybe[tuple[str, ASTLoc]]: 59 | """add offset to nested function calls 60 | 61 | Args: 62 | call (str): code str of a method_invocation 63 | 64 | Returns: 65 | Maybe[tuple[str, ASTLoc]]: (method_name, its location) 66 | """ 67 | idx = func_calls.index(call) 68 | node = calls[idx] 69 | lineno, col = node.start_point 70 | 71 | # regular expression to find method names 72 | pattern = r"(\w+)\s*\(" 73 | matches = list(re.finditer(pattern, call)) 74 | 75 | if matches: 76 | last_match = matches[-1] 77 | method_name = last_match.group(1) 78 | offset = last_match.start(1) 79 | loc = (lineno, col + offset) 80 | return Some((method_name, loc)) 81 | else: 82 | return Nothing 83 | 84 | return get_loc(calls_before_assert[-1]) 85 | 86 | 87 | def main(): 88 | code = """ 89 | @Test 90 | void catalogLoads() { 91 | @SuppressWarnings("rawtypes") 92 | ResponseEntity entity = new TestRestTemplate() 93 | .getForEntity("http://localhost:" + this.port + "/context/eureka/apps", Map.class); 94 | String computedPath = entity.getHeaders().getFirst("X-Version-Filter-Computed-Path"); 95 | assertThat(computedPath).isEqualTo("/context/eureka/v2/apps"); 96 | }""" 97 | ast_util = ASTUtil(code) 98 | tree = ast_util.tree(JAVA_LANGUAGE) 99 | root_node = tree.root_node 100 | 101 | func_delcs = ast_util.get_all_nodes_of_type(root_node, "method_declaration") 102 | func = func_delcs[0] 103 | 104 | print(get_focal_call(ast_util, func)) 105 | 106 | 107 | if __name__ == "__main__": 108 | fire.Fire(main) 109 | -------------------------------------------------------------------------------- /frontend/python/collect_source.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import ast 4 | import fire 5 | import json 6 | import astor 7 | import traceback 8 | import astunparse 9 | from tqdm import tqdm 10 | 11 | from frontend.util import wrap_repo, timestamp 12 | from navigate import load_ast_func 13 | 14 | 15 | def ast2source(node: ast.AST): 16 | """convert ast node into its source using astor 17 | we are not using astunparse is because there's something wrong when it parsing docstring 18 | to_source is not used here to avoid calling pretty_source which will make the line shorter if too long 19 | """ 20 | generator = astor.SourceGenerator(" " * 4) 21 | generator.visit(node) 22 | generator.result.append("\n") 23 | if set(generator.result[0]) == set("\n"): 24 | generator.result[0] = "" 25 | return "".join(generator.result) 26 | 27 | 28 | def collect_source(func_id: str, repo_root: str): 29 | """collect the source code given its function id""" 30 | func, mod = None, None 31 | try: 32 | func, mod = load_ast_func(os.path.join(repo_root, func_id), return_nav=True) 33 | except: 34 | pass 35 | if func is None: 36 | return False, None 37 | 38 | try: 39 | source = ast2source(func) 40 | except: 41 | source = "" 42 | 43 | try: 44 | docstr = ast.get_docstring(func) 45 | except: 46 | docstr = None 47 | if docstr is None: 48 | docstr = "" 49 | return True, (source, docstr) 50 | 51 | 52 | def main( 53 | repo_id_list: str = "ageitgey/face_recognition", 54 | repo_root: str = "data/repos", 55 | focal_root: str = "data/focal", 56 | source_path: str = "data/source/all.jsonl", 57 | timeout: int = 5, 58 | nprocs: int = 0, 59 | limits: int = -1, 60 | ): 61 | try: 62 | repo_id_list = [l.strip() for l in open(repo_id_list, "r").readlines()] 63 | except FileNotFoundError: 64 | repo_id_list = [repo_id_list] 65 | if limits > 0: 66 | repo_id_list = repo_id_list[:limits] 67 | print(f"Loaded {len(repo_id_list)} repos to be processed") 68 | if os.path.exists(source_path): 69 | os.remove(source_path) 70 | total, failed = 0, {"repo": 0, "func": 0} 71 | for repo_id in (pbar := tqdm(repo_id_list)): 72 | pbar.set_description(f"{timestamp()} Processing {repo_id}") 73 | # load test-focal pairs to be processed 74 | path = os.path.join(focal_root, wrap_repo(repo_id) + ".jsonl") 75 | if not os.path.exists(path): 76 | failed["repo"] += 1 77 | continue 78 | data = [json.loads(l.strip()) for l in open(path, "r").readlines()] 79 | total += len(data) 80 | for item in data: 81 | test_id, focal_id = item["test"], item["focal"] 82 | test_status, test = collect_source(test_id, repo_root) 83 | focal_status, focal = collect_source(focal_id, repo_root) 84 | if not test_status or not focal_status: 85 | failed["func"] += 1 86 | continue 87 | with open(source_path, "a") as ofile: 88 | dict2write = { 89 | "test_id": test_id, 90 | "test": test[0], 91 | "code_id": focal_id, 92 | "code": focal[0], 93 | "docstring": focal[1], 94 | } 95 | ofile.write(json.dumps(dict2write) + "\n") 96 | print(f'Processed {len(repo_id_list) - failed["repo"]}/{len(repo_id_list)} repos') 97 | print(f'Collected {total - failed["func"]}/{total} code-test pairs') 98 | 99 | 100 | if __name__ == "__main__": 101 | fire.Fire(main) 102 | -------------------------------------------------------------------------------- /evaluation/data_quality.py: -------------------------------------------------------------------------------- 1 | """script to reproduce plot in data quality analysis""" 2 | 3 | import json 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from multiprocessing import Pool, cpu_count 7 | import random 8 | import pandas as pd 9 | from tqdm import tqdm 10 | import dataclasses 11 | from funcy import lmap, lfilter 12 | import fire 13 | import os 14 | 15 | 16 | plt.style.use("_mpl-gallery") 17 | plt.rcParams["pdf.fonttype"] = 42 18 | plt.rcParams["ps.fonttype"] = 42 19 | 20 | 21 | def repo_id(test_id: str) -> str: 22 | return test_id.split("/")[0] 23 | 24 | 25 | @dataclasses.dataclass 26 | class ProjStat: 27 | repo_id: str 28 | n_test_lines: int 29 | n_code_lines: int 30 | n_assert: int 31 | 32 | 33 | def analyze(objs: list[dict]) -> pd.DataFrame: 34 | proj_map: dict[str, ProjStat] = {} 35 | for obj in tqdm(objs): 36 | repo = repo_id(obj["test_id"]) 37 | test: str = obj["test"] 38 | code: str = obj["code"] 39 | n_test_lines = len(test.splitlines()) 40 | n_code_lines = len(code.splitlines()) 41 | 42 | lower_test = test.lower() 43 | n_assert = lower_test.count("assert") 44 | if obj["lang"] in ("js", "cpp"): 45 | n_expect = lower_test.count("expect") 46 | n_assert += n_expect 47 | if obj["lang"] == "js": 48 | n_test = lower_test.count("test") 49 | n_assert += n_test 50 | 51 | if repo in proj_map: 52 | proj_map[repo].n_test_lines += n_test_lines 53 | proj_map[repo].n_code_lines += n_code_lines 54 | proj_map[repo].n_assert += n_assert 55 | else: 56 | proj_map[repo] = ProjStat(repo, n_test_lines, n_code_lines, n_assert) 57 | return pd.DataFrame([o.__dict__ for o in proj_map.values()]) 58 | 59 | 60 | def test_to_code_ratio(df: pd.DataFrame): 61 | return df["n_test_lines"] / df["n_code_lines"] 62 | 63 | 64 | def get_density(df: pd.DataFrame): 65 | return df["n_assert"] / df["n_test_lines"] 66 | 67 | 68 | def main( 69 | input_dataset_path: str = "data/source/all.jsonl", 70 | nproc: int = cpu_count(), 71 | alpha: float = 1, 72 | fontsize: int = 18, 73 | ): 74 | 75 | with open(input_dataset_path, "r") as fp: 76 | lines = fp.readlines() 77 | 78 | with Pool(nproc) as p: 79 | objs = p.map(json.loads, lines) 80 | 81 | langs = ["python", "java", "go", "cpp", "js"] 82 | 83 | objss = lmap(lambda l: lfilter(lambda o: o["lang"] == l, objs), langs) 84 | 85 | dfs = lmap(analyze, objss) 86 | 87 | np.random.seed(0) 88 | bins = np.linspace(0, 10, 100) 89 | 90 | os.makedirs("ratio", exist_ok=True) 91 | ratios = lmap(test_to_code_ratio, dfs) 92 | ticks = list(range(10)) 93 | for name, ratio in zip(langs, ratios): 94 | plt.figure(figsize=(12, 6)) 95 | plt.hist(ratio, bins, alpha=alpha) # type: ignore 96 | plt.xticks(ticks) 97 | plt.xlabel("Test-to-code Ratio", fontsize=fontsize) 98 | plt.ylabel("Per-project Frequency", fontsize=fontsize) 99 | plt.rc("axes", labelsize=18) 100 | plt.savefig(f"ratio/{name}.pdf", dpi=500, bbox_inches="tight") 101 | 102 | os.makedirs("density", exist_ok=True) 103 | ds = lmap(get_density, dfs) 104 | bins = np.linspace(0, 1, 100) 105 | for name, density in zip(langs, ds): 106 | print(name) 107 | plt.figure(figsize=(12, 6)) 108 | plt.hist(density, bins, alpha=alpha) # type: ignore 109 | # plt.legend(loc='upper right', fontsize=18) 110 | # plt.yscale("log") 111 | plt.xlabel("Assertion Density", fontsize=fontsize) 112 | plt.ylabel("Per-project Frequency", fontsize=fontsize) 113 | plt.rc("axes", labelsize=18) 114 | plt.savefig(f"density/{name}.pdf", dpi=500, bbox_inches="tight") 115 | 116 | 117 | if __name__ == "__main__": 118 | fire.Fire(main) 119 | -------------------------------------------------------------------------------- /unitsyncer/extract_def.py: -------------------------------------------------------------------------------- 1 | """extract test def header for LLM to generate the body""" 2 | import fire 3 | import ast 4 | from frontend.parser.ast_util import ASTUtil 5 | from frontend.parser import ( 6 | GO_LANGUAGE, 7 | JAVASCRIPT_LANGUAGE, 8 | CPP_LANGUAGE, 9 | JAVA_LANGUAGE, 10 | ) 11 | from itertools import takewhile 12 | from tqdm import tqdm 13 | import json 14 | import os 15 | import logging 16 | 17 | 18 | def py_get_def(code: str) -> str | None: 19 | try: 20 | tree = ast.parse(code) 21 | except SyntaxError: 22 | return code.split(":\n")[0] 23 | 24 | for node in ast.walk(tree): 25 | if isinstance(node, ast.FunctionDef): 26 | func_name = node.name 27 | args = [arg.arg for arg in node.args.args] 28 | return f"def {func_name}({', '.join(args)}):\n" 29 | 30 | return None 31 | 32 | 33 | def go_get_def(code: str) -> str | None: 34 | ast_util = ASTUtil(code) 35 | tree = ast_util.tree(GO_LANGUAGE) 36 | root = tree.root_node 37 | func_delcs = ast_util.get_all_nodes_of_type(root, "function_declaration") 38 | if not func_delcs: 39 | return None 40 | 41 | test_delc = func_delcs[0] 42 | test_name_node = ast_util.get_all_nodes_of_type(test_delc, "identifier")[0] 43 | test_name = ast_util.get_source_from_node(test_name_node) 44 | return f"func {test_name}(t *testing.T) {{\n" 45 | 46 | 47 | def js_get_def(code: str) -> str | None: 48 | ast_util = ASTUtil(code) 49 | tree = ast_util.tree(JAVASCRIPT_LANGUAGE) 50 | root = tree.root_node 51 | func_delcs = ast_util.get_all_nodes_of_type(root, "lexical_declaration") 52 | if not func_delcs: 53 | return None 54 | 55 | test_delc = func_delcs[0] 56 | test_name_node = ast_util.get_all_nodes_of_type(test_delc, "identifier")[0] 57 | test_name = ast_util.get_source_from_node(test_name_node) 58 | return f"const {test_name} = () => {{\n" 59 | 60 | 61 | def cpp_get_def(code: str) -> str | None: 62 | ast_util = ASTUtil(code) 63 | tree = ast_util.tree(CPP_LANGUAGE) 64 | root = tree.root_node 65 | func_delcs = ast_util.get_all_nodes_of_type(root, "function_definition") 66 | if not func_delcs: 67 | return None 68 | 69 | test_delc = func_delcs[0] 70 | test_name_node = ast_util.get_all_nodes_of_type(test_delc, "identifier")[0] 71 | test_name = ast_util.get_source_from_node(test_name_node) 72 | test_params = ast_util.get_all_nodes_of_type(test_delc, "parameter_declaration")[:2] 73 | return f"{test_name}({', '.join(map(ast_util.get_source_from_node, test_params))}) {{\n" 74 | 75 | 76 | def java_get_def(code: str) -> str | None: 77 | return "".join(takewhile(lambda c: c != "{", code)) + "{\n" 78 | 79 | 80 | def get_def_header(code: str, lang: str) -> str | None: 81 | header: str | None = None 82 | if lang == "python": 83 | header = py_get_def(code) 84 | elif lang == "cpp": 85 | header = cpp_get_def(code) 86 | elif lang == "java": 87 | header = java_get_def(code) 88 | elif lang == "go": 89 | header = go_get_def(code) 90 | elif lang in ("js", "javascript"): 91 | header = js_get_def(code) 92 | 93 | return header 94 | 95 | 96 | def main(in_path: str, out_path: str): 97 | with tqdm(total=os.path.getsize(in_path)) as p_bar: 98 | with open(in_path, "r") as in_f, open(out_path, "a") as out_f: 99 | for j_line in in_f: 100 | j = json.loads(j_line) 101 | test = j["test"] 102 | lang = j["lang"] 103 | 104 | try: 105 | header = get_def_header(test, lang) 106 | j["test_header"] = header 107 | out_f.write(json.dumps(j) + "\n") 108 | except Exception as e: 109 | logging.error(e) 110 | logging.error(j) 111 | 112 | p_bar.update(len(j_line)) 113 | 114 | 115 | if __name__ == "__main__": 116 | fire.Fire(main) 117 | -------------------------------------------------------------------------------- /frontend/javascript/js_util.py: -------------------------------------------------------------------------------- 1 | """util functions for JS frontend""" 2 | 3 | from typing import Iterable 4 | from tree_sitter import Node 5 | from frontend.parser import JAVASCRIPT_LANGUAGE 6 | from frontend.parser.ast_util import ASTUtil, ASTLoc, flatten_postorder 7 | from returns.maybe import Maybe, Nothing, Some 8 | from unitsyncer.util import replace_tabs 9 | 10 | 11 | def is_test_fn(n: Node, ast_util: ASTUtil): 12 | 13 | def is_call_to_test(node: Node): 14 | return ast_util.get_name(node).map(lambda n: n == "describe").value_or(False) 15 | 16 | return is_call_to_test(n) 17 | 18 | 19 | def js_get_test_args( 20 | ast_util: ASTUtil, test_call_expr: Node 21 | ) -> Maybe[tuple[str, Node]]: 22 | """extract the test name and the test function from a call to `describe` 23 | 24 | Returns: 25 | Maybe[tuple[str, Node]]: test name/description, and the test function 26 | """ 27 | if test_call_expr.type != "call_expression": 28 | return Nothing 29 | 30 | # get the first argument of the call expression 31 | args_node = None 32 | for child in test_call_expr.children: 33 | if child.type == "arguments": 34 | args_node = child 35 | break 36 | if args_node is None: 37 | return Nothing 38 | 39 | # a call to `describe` has the following structure: 40 | # describe(test_name, test_func) 41 | # args_node.children should be a list of 5 nodes: 42 | # "(", "test_name", ",", "test_func", ") 43 | args = args_node.children 44 | if len(args) != 5 or args[1].type != "string" or args[3].type != "function": 45 | return Nothing 46 | 47 | return Some((ast_util.get_source_from_node(args[1]), args[3])) 48 | 49 | 50 | def get_focal_call(ast_util: ASTUtil, test_func: Node) -> Maybe[tuple[str, ASTLoc]]: 51 | calls = flatten_postorder(test_func, "call_expression") 52 | 53 | func_calls = [ast_util.get_source_from_node(call) for call in calls] 54 | calls_before_expect = [] 55 | has_expect = False 56 | for call in func_calls: 57 | if "expect" in call or "test" in call: 58 | has_expect = True 59 | break 60 | calls_before_expect.append(call) 61 | 62 | if not has_expect: 63 | return Nothing 64 | if len(calls_before_expect) == 0: 65 | return Nothing 66 | 67 | def get_loc(call: str) -> tuple[str, ASTLoc]: 68 | idx = func_calls.index(call) 69 | node = calls[idx] 70 | lineno, col = node.start_point 71 | match call.split("."): 72 | case [*_, method_name]: 73 | offset = len(call) - len(method_name) 74 | method_name = method_name.split("(")[0] 75 | return method_name, (lineno, col + offset) 76 | case _: 77 | return call, (lineno, col) 78 | 79 | return Some(get_loc(calls_before_expect[-1])) 80 | 81 | 82 | def main(): 83 | code = """ 84 | describe('loading a non-existent value (from memory and disk)', function () { 85 | fixtureUtils.mkdir({folderName: 'non-existent'}); 86 | storeUtils.init(); 87 | storeUtils.get('nothing'); 88 | 89 | it('calls back with `null`', function () { 90 | expect(this.err).to.equal(null); 91 | expect(this.val).to.equal(null); 92 | }); 93 | }); 94 | """ 95 | ast_util = ASTUtil(replace_tabs(code)) 96 | tree = ast_util.tree(JAVASCRIPT_LANGUAGE) 97 | root_node = tree.root_node 98 | 99 | # print(root_node.sexp()) 100 | 101 | # js test function is a higher order function that takes a function as input 102 | call_exprs = ast_util.get_all_nodes_of_type(root_node, "call_expression") 103 | 104 | def is_call_to_test(node: Node): 105 | return ast_util.get_name(node).map(lambda n: n == "describe").value_or(False) 106 | 107 | focal = list(filter(is_call_to_test, call_exprs))[0] 108 | test_name, test_func = js_get_test_args(ast_util, focal).unwrap() 109 | print(test_name) 110 | print(ast_util.get_source_from_node(test_func)) 111 | 112 | print(get_focal_call(ast_util, test_func)) 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # VScode 163 | .vscode/ 164 | 165 | 166 | # repository files 167 | /data/* 168 | /data/repos/* 169 | !/data/repos/*_example 170 | !/data/repo_meta 171 | 172 | java-language-server/ 173 | /java-language-server/* 174 | 175 | oauth 176 | debug/ 177 | **/target/ 178 | 179 | stat 180 | *.jsonl 181 | *.csv -------------------------------------------------------------------------------- /tests/evaluation/test_rust_compile.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Iterable 3 | import fire 4 | import os 5 | from tree_sitter import Node 6 | from frontend.parser import RUST_LANGUAGE 7 | from frontend.parser.ast_util import ASTUtil 8 | from unitsyncer.util import replace_tabs 9 | import json 10 | from evaluation.rust.compile import flatten_use_delc, construct_use_delcs 11 | import unittest 12 | import os 13 | import logging 14 | 15 | 16 | class TestRustCompile(unittest.TestCase): 17 | def test_flatten_base64(self): 18 | code = "use rand::{Rng, SeedableRng};" 19 | expected = ["use rand::Rng;", "use rand::SeedableRng;"] 20 | self.assertEqual(flatten_use_delc(code), expected) 21 | 22 | code = "use base64::engine::{general_purpose::STANDARD, Engine};" 23 | expected = [ 24 | "use base64::engine::general_purpose::STANDARD;", 25 | "use base64::engine::Engine;", 26 | ] 27 | self.assertEqual(flatten_use_delc(code), expected) 28 | 29 | code = "use base64::engine::general_purpose::{GeneralPurpose, NO_PAD};" 30 | expected = [ 31 | "use base64::engine::general_purpose::GeneralPurpose;", 32 | "use base64::engine::general_purpose::NO_PAD;", 33 | ] 34 | self.assertEqual(flatten_use_delc(code), expected) 35 | 36 | # use wildcard only 37 | code = "use base64::*;" 38 | expected = ["use base64::*;"] 39 | self.assertEqual(flatten_use_delc(code), expected) 40 | 41 | # use wildcard in list 42 | code = "use base64::{alphabet::URL_SAFE, engine::general_purpose::PAD, engine::general_purpose::STANDARD, *,};" 43 | expected = [ 44 | "use base64::alphabet::URL_SAFE;", 45 | "use base64::engine::general_purpose::PAD;", 46 | "use base64::engine::general_purpose::STANDARD;", 47 | "use base64::*;", 48 | ] 49 | self.assertEqual(flatten_use_delc(code), expected) 50 | 51 | # use_as clause in list 52 | code = "use base64::{engine::general_purpose::STANDARD, Engine as _};" 53 | expected = [ 54 | "use base64::engine::general_purpose::STANDARD;", 55 | "use base64::Engine as _;", 56 | ] 57 | self.assertEqual(flatten_use_delc(code), expected) 58 | 59 | # with indent 60 | code = """ 61 | use base64::{ 62 | alphabet::URL_SAFE, engine::general_purpose::PAD, engine::general_purpose::STANDARD, *, 63 | };""" 64 | expected = [ 65 | "use base64::alphabet::URL_SAFE;", 66 | "use base64::engine::general_purpose::PAD;", 67 | "use base64::engine::general_purpose::STANDARD;", 68 | "use base64::*;", 69 | ] 70 | self.assertEqual(flatten_use_delc(code), expected) 71 | 72 | def test_construct_use_lists(self): 73 | workspace_dir = os.path.abspath( 74 | "data/repos/marshallpierce-rust-base64/marshallpierce-rust-base64-4ef33cc" 75 | ) 76 | 77 | tests_expected = { 78 | "use base64::alphabet::URL_SAFE;", 79 | "use base64::engine::general_purpose::PAD;", 80 | "use base64::engine::general_purpose::STANDARD;", 81 | "use base64::*;", 82 | "use rand::Rng;", 83 | "use rand::SeedableRng;", 84 | "use base64::engine::Engine;", 85 | "use base64::engine::general_purpose::GeneralPurpose;", 86 | "use base64::engine::general_purpose::NO_PAD;", 87 | } 88 | 89 | self.assertEqual(construct_use_delcs(workspace_dir, "tests"), tests_expected) 90 | 91 | fuzz_expected = { 92 | "use base64::Engine as _;", 93 | "use base64::engine::general_purpose::STANDARD;", 94 | "use self::rand::SeedableRng;", 95 | "use self::rand::Rng;", 96 | "use base64::*;", 97 | "use base64::alphabet;", 98 | } 99 | self.assertEqual(construct_use_delcs(workspace_dir, "fuzz"), fuzz_expected) 100 | 101 | 102 | if __name__ == "__main__": 103 | logging.basicConfig(level=logging.INFO) 104 | unittest.main() 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UniTSyn 2 | 3 | Multilingual **Uni**t **T**est and Function Source **Syn**hronization for CodeLLM. 4 | Code for our ISSTA 2024 paper https://arxiv.org/abs/2402.03396. 5 | 6 | ## Requirements 7 | 8 | - Python 3.10+ 9 | - `requirements.txt` 10 | - [rustfmt](https://github.com/rust-lang/rustfmt) to use `frontend/rust/collect_fuzz.py` 11 | 12 | ### Language Server 13 | 14 | To run this script on a new project, you need to install the corresponding language server: 15 | 16 | | Language | Language Server | Frontend | Backend | 17 | | ---------- | ------------------------------------------------------------------------------------------------------ | -------- | -------- | 18 | | Python | [pylsp](https://github.com/python-lsp/python-lsp-server) | ✔ | ✔ | 19 | | Java | [java-language-server](https://github.com/georgewfraser/java-language-server)\* | ✔ | ✔ | 20 | | JavaScript | [typescript-language-server](https://github.com/typescript-language-server/typescript-language-server) | ✔ | ✔ | 21 | | Go | [gopls](https://pkg.go.dev/golang.org/x/tools/gopls) | ✔ | ✔ | 22 | | C/C++ | [clangd](https://clangd.llvm.org/installation.html) | ✔ | ✔ | 23 | 24 | \*NOTE: you need git clone the repo to workdir of this project, then follow the instructions in the repo to install the language server. 25 | 26 | You can find language servers for other languages at 27 | [language-server-protocol/implementors/servers](https://microsoft.github.io/language-server-protocol/implementors/servers/). 28 | Other languages are not supported yet, but will be as the research progresses. 29 | To support a new language, you need a frontend to do the following: 30 | 31 | 1. Collect the unit test locations and focal functions locations in the repo (see `scripts/collect_test.py` and `scripts/collect_focal.py` for Python frontend). 32 | 2. Given a `Location` of function declaration, extract the function source code (see `unitsyncer/source_code.py`). 33 | 34 | ## Setup 35 | 36 | ```bash 37 | mkdir -p data/focal data/repos data/repos_tarball data/tests 38 | source ./scripts/env.sh 39 | ``` 40 | 41 | ## Run 42 | 43 | ```bash 44 | python3 scripts/download_repos.py 45 | python3 scripts/decompress_repos.py 46 | 47 | python3 frontend//collect_all.py 48 | python3 main.py 49 | ``` 50 | 51 | ## Automated Repo Mining 52 | 53 | Automatic repo mining is supported through `scripts/find_repos.py`. 54 | Note: Please run `source ./scripts/env.sh` from the root of the repo before mining 55 | 56 | Current checks that are supported are: 57 | 58 | - "stars" 59 | - "latest commit" 60 | - "language" 61 | - "fuzzers" 62 | 63 | The corresponding value in `reqs` to check against should be at the same index as the check in `checks_list`. 64 | 65 | ```bash 66 | # Command template 67 | python3 scripts/find_repos.py --language='' --checks_list='[]' --reqs='[]' --num_searches='' 68 | 69 | # Rust example 70 | python3 scripts/find_repos.py --language='Rust' --checks_list='["stars", "latest commit", "language", "fuzzers"]' --reqs='["10", "2020-1-1", "Rust", None]' --num_searches='1' 71 | 72 | # Python example 73 | python3 scripts/find_repos.py --language='Python' --checks_list='["stars", "latest commit", "language"]' --reqs='["10", "2020-1-1", "Python"]' --num_searches='1' 74 | ``` 75 | 76 | Cursors representing where the search left off are saved to `data/repo_cursors/_cursor.txt`. `find_repos.py` will automatically use and update this cursor to avoid mining duplicate repos. 77 | 78 | # Reference 79 | 80 | Please cite our work in your publications if it helps your research: 81 | 82 | ```biblatex 83 | @inproceedings{he2024unitsyn, 84 | author = {He, Yifeng and Huang, Jiabo and Rong, Yuyang and Guo, Yiwen and Wang, Ethan and Chen, Hao}, 85 | title = {UniTSyn: A Large-Scale Dataset Capable of Enhancing the Prowess of Large Language Models for Program Testing}, 86 | booktitle = {International Symposium on Software Testing and Analysis (ISSTA)}, 87 | date = {2024-09-16/2024-09-20}, 88 | address = {Vienna, Austria}, 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /tests/frontend/test_rust.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import logging 4 | from returns.maybe import Nothing, Some 5 | from frontend.parser.ast_util import ASTUtil 6 | from frontend.parser import RUST_LANGUAGE 7 | from frontend.rust.rust_util import get_focal_call, get_test_functions 8 | from unitsyncer.util import replace_tabs 9 | 10 | 11 | class TestRustFrontend(unittest.TestCase): 12 | def test_get_focal_wo_unwrap(self): 13 | code = """#[test] 14 | fn encode_all_bytes_url() { 15 | let bytes: Vec = (0..=255).collect(); 16 | 17 | assert_eq!( 18 | "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0\ 19 | -P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn\ 20 | -AgYKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq\ 21 | -wsbKztLW2t7i5uru8vb6_wMHCw8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t_g4eLj5OXm5-jp6uvs7e7v8PHy\ 22 | 8_T19vf4-fr7_P3-_w==", 23 | &engine::GeneralPurpose::new(&URL_SAFE, PAD).encode(bytes) 24 | ); 25 | } 26 | """ 27 | ast_util = ASTUtil(replace_tabs(code)) 28 | tree = ast_util.tree(RUST_LANGUAGE) 29 | root_node = tree.root_node 30 | 31 | test_func = get_test_functions(ast_util, root_node)[0] 32 | focal_call = get_focal_call(ast_util, test_func) 33 | name, _ = focal_call.unwrap() 34 | self.assertEqual( 35 | name, "engine::GeneralPurpose::new(&URL_SAFE, PAD).encode(bytes)" 36 | ) 37 | 38 | def test_get_focal_w_unwrap(self): 39 | code = '#[test]\nfn encode_engine_slice_error_when_buffer_too_small() {\n for num_triples in 1..100 {\n let input = "AAA".repeat(num_triples);\n let mut vec = vec![0; (num_triples - 1) * 4];\n assert_eq!(\n EncodeSliceError::OutputSliceTooSmall,\n STANDARD.encode_slice(&input, &mut vec).unwrap_err()\n );\n vec.push(0);\n assert_eq!(\n EncodeSliceError::OutputSliceTooSmall,\n STANDARD.encode_slice(&input, &mut vec).unwrap_err()\n );\n vec.push(0);\n assert_eq!(\n EncodeSliceError::OutputSliceTooSmall,\n STANDARD.encode_slice(&input, &mut vec).unwrap_err()\n );\n vec.push(0);\n assert_eq!(\n EncodeSliceError::OutputSliceTooSmall,\n STANDARD.encode_slice(&input, &mut vec).unwrap_err()\n );\n vec.push(0);\n assert_eq!(\n num_triples * 4,\n STANDARD.encode_slice(&input, &mut vec).unwrap()\n );\n }\n}' 40 | 41 | ast_util = ASTUtil(replace_tabs(code)) 42 | tree = ast_util.tree(RUST_LANGUAGE) 43 | root_node = tree.root_node 44 | 45 | test_func = get_test_functions(ast_util, root_node)[0] 46 | focal_call = get_focal_call(ast_util, test_func) 47 | name, _ = focal_call.unwrap() 48 | self.assertEqual(name, "STANDARD.encode_slice(&input, &mut vec)") 49 | 50 | def test_no_focal_in_assert(self): 51 | code = """ 52 | #[test] 53 | fn test_1() { 54 | let data = []; 55 | let engine = utils::random_engine(data); 56 | let encoded = engine.encode(data); 57 | let decoded = engine.decode(&encoded).unwrap(); 58 | assert_eq!(data, decoded.as_slice()); 59 | } 60 | """ 61 | 62 | ast_util = ASTUtil(replace_tabs(code)) 63 | tree = ast_util.tree(RUST_LANGUAGE) 64 | root_node = tree.root_node 65 | 66 | test_func = get_test_functions(ast_util, root_node)[0] 67 | 68 | focal_call = get_focal_call(ast_util, test_func) 69 | name, _ = focal_call.unwrap() 70 | self.assertEqual(name, "engine.decode(&encoded)") 71 | 72 | def test_focal_no_assert(self): 73 | code = """ 74 | #[test] 75 | fn test_23() { 76 | let data = []; 77 | let engine = utils::random_engine(data); 78 | let _ = engine.decode(data); 79 | } 80 | """ 81 | ast_util = ASTUtil(replace_tabs(code)) 82 | tree = ast_util.tree(RUST_LANGUAGE) 83 | root_node = tree.root_node 84 | 85 | test_func = get_test_functions(ast_util, root_node)[0] 86 | 87 | focal_call = get_focal_call(ast_util, test_func) 88 | name, _ = focal_call.unwrap() 89 | self.assertEqual(name, "engine.decode(data)") 90 | 91 | 92 | if __name__ == "__main__": 93 | logging.basicConfig(level=logging.INFO) 94 | unittest.main() 95 | -------------------------------------------------------------------------------- /frontend/parser/ast_util.py: -------------------------------------------------------------------------------- 1 | """Helper Class and Functions for tree-sitter AST""" 2 | 3 | from tree_sitter import Language, Parser, Tree 4 | from tree_sitter import Node 5 | from returns.maybe import Maybe, Nothing, Some 6 | from unitsyncer.common import UNITSYNCER_HOME 7 | from typing import Optional, Tuple 8 | 9 | ASTLoc = tuple[int, int] 10 | 11 | 12 | class ASTUtil: 13 | """Helper Class to build/read/manipulate AST with tree-sitter""" 14 | 15 | def __init__(self, source_code: str) -> None: 16 | self.src = source_code 17 | 18 | def tree(self, lang: Language) -> Tree: 19 | parser = Parser() 20 | parser.set_language(lang) 21 | return parser.parse(bytes(self.src, "utf8")) 22 | 23 | def get_source_from_node(self, node: Node) -> str: 24 | match node.type: 25 | case "method_declaration": 26 | start = node.start_point[0] 27 | end = node.end_point[0] 28 | src_lines = self.src.splitlines()[start : end + 1] 29 | src_lines = remove_leading_spaces(src_lines) 30 | return "\n".join(src_lines) 31 | case _: 32 | start = node.start_byte 33 | end = node.end_byte 34 | return self.src[start:end] 35 | 36 | def get_method_name(self, method_node: Node) -> Maybe[str]: 37 | if method_node.type != "method_declaration": 38 | return Nothing 39 | 40 | # for a method decl, its name is the first identifier 41 | for child in method_node.children: 42 | if child.type == "identifier": 43 | return Some(self.get_source_from_node(child)) 44 | 45 | return Nothing 46 | 47 | def get_name(self, node: Node) -> Maybe[str]: 48 | for child in node.children: 49 | if child.type == "identifier": 50 | return Some(self.get_source_from_node(child)) 51 | 52 | return Nothing 53 | 54 | def get_method_modifiers(self, method_node: Node) -> Maybe[list[str]]: 55 | if method_node.type != "method_declaration": 56 | return Nothing 57 | 58 | modifiers = [] 59 | for child in method_node.children: 60 | if child.type == "modifiers": 61 | for modifier_child in child.children: 62 | modifiers.append(self.get_source_from_node(modifier_child)) 63 | return Some(modifiers) 64 | 65 | def get_all_nodes_of_type( 66 | self, root: Node, node_type: str | None, max_level=50 67 | ) -> list[Node]: 68 | """walk on AST and collect all nodes of the given type 69 | 70 | Args: 71 | root (Node): root node of tree 72 | node_type (str | None): type of node to collect, if None collect all Node 73 | max_level (int, optional): maximum recursion level. Defaults to 50. 74 | 75 | Returns: 76 | list[Node]: collected nodes 77 | """ 78 | nodes: list[Node] = [] 79 | if max_level == 0: 80 | return nodes 81 | 82 | for child in root.children: 83 | if node_type is None or child.type == node_type: 84 | nodes.append(child) 85 | nodes += self.get_all_nodes_of_type( 86 | child, node_type, max_level=max_level - 1 87 | ) 88 | return nodes 89 | 90 | 91 | def remove_leading_spaces(lines: list[str]) -> list[str]: 92 | """remove leading spaces from each line""" 93 | space_idx = len(lines[0]) - len(lines[0].lstrip()) 94 | return [s[space_idx:] for s in lines] 95 | 96 | 97 | def flatten_postorder( 98 | root: Node, node_type: Optional[str] = None, max_level=50 99 | ) -> list[Node]: 100 | """flatten a tree in postorder 101 | 102 | Args: 103 | root (Node): root of tree 104 | node_type (str | None): type of node to collect, if None collect all Node 105 | max_level (int, optional): maximum recursion level. Defaults to 50. 106 | 107 | Returns: 108 | list[Node]: flattened tree 109 | """ 110 | nodes: list[Node] = [] 111 | if max_level == 0: 112 | return nodes 113 | 114 | for child in root.children: 115 | nodes += flatten_postorder(child, node_type, max_level - 1) 116 | 117 | if node_type is None or root.type == node_type: 118 | nodes.append(root) 119 | return nodes 120 | -------------------------------------------------------------------------------- /frontend/rust/rust_util.py: -------------------------------------------------------------------------------- 1 | """Util functions for rust frontend""" 2 | 3 | from typing import Iterable, Optional 4 | from tree_sitter import Node 5 | from frontend.parser import RUST_LANGUAGE 6 | from frontend.parser.ast_util import ASTUtil, ASTLoc, flatten_postorder 7 | from returns.maybe import Maybe, Nothing, Some, maybe 8 | from unitsyncer.util import replace_tabs 9 | 10 | 11 | def get_test_functions(ast_util: ASTUtil, root_node: Node) -> list[Node]: 12 | if root_node.type != "source_file": 13 | return [] 14 | 15 | def has_test_annotation(idx: int): 16 | if root_node.children[idx].type != "attribute_item": 17 | return False 18 | 19 | # check if the attributes have #[test] 20 | for child in root_node.children[idx].children: 21 | for node in ast_util.get_all_nodes_of_type(child, "identifier"): 22 | if ast_util.get_source_from_node(node) == "test": 23 | return True 24 | return False 25 | 26 | function_items = [] 27 | for idx, child in enumerate(root_node.children): 28 | if child.type == "function_item" and has_test_annotation(idx - 1): 29 | function_items.append(child) 30 | return function_items 31 | 32 | 33 | def get_first_assert(ast_util: ASTUtil, test_func: Node) -> Maybe[Node]: 34 | macro_invocations = ast_util.get_all_nodes_of_type(test_func, "macro_invocation") 35 | 36 | for macro_invocation in macro_invocations: 37 | if ( 38 | ast_util.get_name(macro_invocation) 39 | .map(lambda name: "assert" in name) 40 | .value_or(False) 41 | ): 42 | return Some(macro_invocation) 43 | return Nothing 44 | 45 | 46 | @maybe 47 | def get_first_valid_call(calls: list[Node], ast_util: ASTUtil) -> Optional[Node]: 48 | """find first valid call not for focal 49 | 50 | Args: 51 | calls (list[Node]): a list of candidate call nodes 52 | ast_util (ASTUtil): ast_util build with the source code 53 | 54 | Returns: 55 | Optional[Node]: first node that should not be skipped, check `do_skip` for detail 56 | """ 57 | 58 | def do_skip(call_node: Node) -> bool: 59 | skip_list = ["unwrap", "len", "as_slice", "into_iter"] 60 | call_node_name = ast_util.get_source_from_node(call_node) 61 | return any(skip_str in call_node_name for skip_str in skip_list) 62 | 63 | return next( 64 | (call for call in calls if not do_skip(call)), 65 | None, 66 | ) 67 | 68 | 69 | def get_focal_call(ast_util: ASTUtil, test_func: Node) -> Maybe[tuple[str, ASTLoc]]: 70 | """Get the focal call from the given test function 71 | 72 | Heuristic: 73 | 1. find the first assert macro 74 | 2. expand the macro and find the first call expression in the macro 75 | 3. if no call expression in the macro, back track to find the last call before assert 76 | """ 77 | 78 | def expand_assert_and_get_call(assert_macro: Node) -> Maybe[tuple[str, ASTLoc]]: 79 | token_tree = ast_util.get_all_nodes_of_type(assert_macro, "token_tree")[0] 80 | code = ast_util.get_source_from_node(token_tree) 81 | assert_ast_util = ASTUtil(code) 82 | assert_ast = assert_ast_util.tree(RUST_LANGUAGE) 83 | assert_root = assert_ast.root_node 84 | 85 | match assert_ast_util.get_all_nodes_of_type(assert_root, "call_expression"): 86 | case []: 87 | # todo: no call expression in assert macro, 88 | # back track to find the last call before assert 89 | return Nothing 90 | case calls: 91 | 92 | def to_result(node: Node) -> tuple[str, ASTLoc]: 93 | name = assert_ast_util.get_source_from_node(node) 94 | lineno = node.start_point[0] + assert_macro.start_point[0] 95 | col = node.start_point[1] 96 | return name, (lineno, col) 97 | 98 | return get_first_valid_call(calls, assert_ast_util).map(to_result) 99 | 100 | return Nothing 101 | 102 | focal_in_assert = get_first_assert(ast_util, test_func).bind( 103 | expand_assert_and_get_call 104 | ) 105 | if focal_in_assert != Nothing: 106 | return focal_in_assert 107 | 108 | match flatten_postorder(test_func, "call_expression"): 109 | case []: 110 | return Nothing 111 | case calls: 112 | 113 | def to_name_loc(n: Node): 114 | return (ast_util.get_source_from_node(n), n.start_point) 115 | 116 | return get_first_valid_call(calls[::-1], ast_util).map(to_name_loc) 117 | return Nothing 118 | 119 | 120 | def main(): 121 | code = """ 122 | #[test] 123 | fn test_1() { 124 | let data = []; 125 | let engine = utils::random_engine(data); 126 | let encoded = engine.encode(data); 127 | let decoded = engine.decode(&encoded).unwrap(); 128 | assert_eq!(data, decoded.as_slice()); 129 | } 130 | """ 131 | ast_util = ASTUtil(replace_tabs(code)) 132 | tree = ast_util.tree(RUST_LANGUAGE) 133 | root_node = tree.root_node 134 | 135 | test_funcs = get_test_functions(ast_util, root_node) 136 | func = test_funcs[0] 137 | print(ast_util.get_source_from_node(func)) 138 | 139 | focal_call = get_focal_call(ast_util, func) 140 | print(focal_call) 141 | 142 | 143 | if __name__ == "__main__": 144 | main() 145 | -------------------------------------------------------------------------------- /evaluation/rust/compile.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Iterable 3 | import fire 4 | import os 5 | from tree_sitter import Node 6 | from frontend.parser import RUST_LANGUAGE 7 | from frontend.parser.ast_util import ASTUtil 8 | import json 9 | import subprocess 10 | from unitsyncer.common import UNITSYNCER_HOME 11 | from returns.maybe import Maybe, Some, Nothing 12 | from functools import reduce 13 | from operator import add 14 | from unitsyncer.util import concatMap 15 | 16 | 17 | def flatten_use_delc(use_delc_code: str) -> list[str]: 18 | """flatten a nested use declaration line to multiple use declarations 19 | 20 | e.g. 21 | 22 | from: 23 | use rand::{Rng, SeedableRng}; 24 | to: 25 | use rand::Rng;\n 26 | use rand::SeedableRng; 27 | 28 | 29 | Args: 30 | use_delc_code (str): code of a use delc 31 | 32 | Returns: 33 | list[str]: flattened use declarations 34 | """ 35 | ast_util = ASTUtil(use_delc_code) 36 | tree = ast_util.tree(RUST_LANGUAGE) 37 | root = tree.root_node 38 | use_delc_nodes = ast_util.get_all_nodes_of_type(root, "use_declaration") 39 | if len(use_delc_nodes) != 1: 40 | return [] 41 | 42 | delc_node = use_delc_nodes[0] 43 | scoped_use_list_nodes = ast_util.get_all_nodes_of_type(delc_node, "scoped_use_list") 44 | if len(scoped_use_list_nodes) == 0: 45 | # example: use base64::*; 46 | wildcard_nodes = ast_util.get_all_nodes_of_type(delc_node, "use_wildcard") 47 | if len(wildcard_nodes) >= 1: 48 | return [use_delc_code] 49 | return [] 50 | 51 | def fold_nodes(nodes: list[Node]) -> str: 52 | # NOTE: ignore type since `str` is a Iterable, but not compatible with `Iterable[U]` type 53 | # `type String = List[Char]` is not valid in Python since there is no `Char` 54 | # the type checker would then infer it as `Iterable[str]`, which != `str` 55 | return concatMap(ast_util.get_source_from_node, nodes) # type: ignore 56 | 57 | def get_use_src(node: Node) -> str | None: 58 | match node.type: 59 | case "identifier" | "use_wildcard" | "use_as_clause": 60 | return ast_util.get_source_from_node(node) 61 | case "scoped_identifier": 62 | return fold_nodes(node.children) 63 | case _: 64 | return None 65 | 66 | node = scoped_use_list_nodes[0] 67 | match node.children: 68 | case []: 69 | return [] 70 | case [*base_nodes, use_list_node]: 71 | if use_list_node.type != "use_list": 72 | return [] 73 | 74 | base = fold_nodes(base_nodes) 75 | use_list = map(get_use_src, use_list_node.children) 76 | return [f"use {base + u};" for u in use_list if u] 77 | case _: 78 | return [] 79 | 80 | 81 | def collect_rs_files(root: str): 82 | """Get all files end with .rs in the given root directory 83 | 84 | Args: 85 | root (str): path to repo root 86 | """ 87 | for dirpath, _, filenames in os.walk(root): 88 | for filename in filenames: 89 | if filename.endswith(".rs"): 90 | yield os.path.join(dirpath, filename) 91 | 92 | 93 | def construct_use_delcs(workspace_dir: str, type: str) -> set[str]: 94 | """construct a set of unique use_list for a project from all use declarations in 95 | a subdirectory to 96 | 97 | 1. solve generated tests' dependency error 98 | 2. avoid compile error caused by duplicated imports 99 | 100 | Args: 101 | workspace_dir (str): path to project's workdir 102 | type (str): tests or fuzz to collect use_delcs. 103 | 104 | Returns: 105 | set[str]: set of use declarations to write to generated test files 106 | """ 107 | subdir = os.path.join(workspace_dir, type) 108 | 109 | def get_use_list_from_file(fpath: str) -> Iterable[str]: 110 | with open(fpath) as f: 111 | code = f.read() 112 | ast_util = ASTUtil(code) 113 | tree = ast_util.tree(RUST_LANGUAGE) 114 | use_list_nodes = ast_util.get_all_nodes_of_type( 115 | tree.root_node, "use_declaration" 116 | ) 117 | return list(map(ast_util.get_source_from_node, use_list_nodes)) 118 | 119 | use_lists = concatMap(get_use_list_from_file, collect_rs_files(subdir)) 120 | 121 | # flatten and remove duplicate 122 | return set(concatMap(flatten_use_delc, use_lists)) 123 | 124 | 125 | def write_tests_to_workspace(workspace_dir: str, tests: list[str], test_type: str): 126 | """get compile rate of generated testcases 127 | 128 | Args: 129 | workspace_dir (str): path to the project/crate's workspace dir 130 | tests (list[str]): list of generated test functions 131 | test_type (str): tests (unittest) or fuzz 132 | """ 133 | use_delc = "\n".join(construct_use_delcs(workspace_dir, test_type)) 134 | for i, test in enumerate(tests): 135 | p = os.path.join(workspace_dir, "tests", f"generated_{test_type}_{i}.rs") 136 | with open(p, "w") as f: 137 | f.write(use_delc + "\n\n" + "#[test]\n" + test) 138 | 139 | 140 | def main(): 141 | workspace_dir = os.path.abspath( 142 | "data/repos/marshallpierce-rust-base64/marshallpierce-rust-base64-4ef33cc" 143 | ) 144 | 145 | print(construct_use_delcs(workspace_dir, "tests")) 146 | 147 | 148 | if __name__ == "__main__": 149 | fire.Fire(main) 150 | -------------------------------------------------------------------------------- /frontend/go/collect_all.py: -------------------------------------------------------------------------------- 1 | """main script of Golang frontend""" 2 | 3 | from typing import Iterable 4 | import fire 5 | import os 6 | from tree_sitter import Node 7 | from frontend.parser import GO_LANGUAGE 8 | from frontend.parser.ast_util import ASTUtil 9 | from returns.maybe import Maybe, Nothing, Some 10 | from unitsyncer.util import replace_tabs 11 | import json 12 | from frontend.util import mp_map_repos, wrap_repo, run_with_timeout 13 | from collections import Counter 14 | from frontend.go.collect_focal import get_focal_call, is_test_fn 15 | 16 | 17 | def has_test(file_path): 18 | with open(file_path, "r", errors="replace") as f: 19 | code = f.read() 20 | 21 | return '"testing"' in code 22 | 23 | 24 | def collect_test_files(root: str): 25 | """Get all files end with .java in the given root directory 26 | 27 | Args: 28 | root (str): path to repo root 29 | """ 30 | for dirpath, _, filenames in os.walk(root): 31 | for filename in filenames: 32 | if filename.endswith(".go") and "test" in filename: 33 | if has_test(p := os.path.join(dirpath, filename)): 34 | yield p 35 | 36 | 37 | def collect_test_funcs(ast_util: ASTUtil) -> Iterable[Node]: 38 | """collect testing functions from the target file""" 39 | 40 | tree = ast_util.tree(GO_LANGUAGE) 41 | root_node = tree.root_node 42 | 43 | decls = ast_util.get_all_nodes_of_type(root_node, "function_declaration") 44 | 45 | return filter(lambda n: is_test_fn(n, ast_util), decls) 46 | 47 | 48 | def collect_test_n_focal(file_path: str): 49 | with open(file_path, "r", errors="replace") as f: 50 | ast_util = ASTUtil(replace_tabs(f.read())) 51 | 52 | def get_focal_for_test(test_func: Node): 53 | test_name = ast_util.get_method_name(test_func).value_or(None) 54 | focal, focal_loc = get_focal_call(ast_util, test_func).value_or((None, None)) 55 | return { 56 | "test_id": test_name, 57 | "test_loc": test_func.start_point, 58 | "test": ast_util.get_source_from_node(test_func), 59 | "focal_id": focal, 60 | "focal_loc": focal_loc, 61 | } 62 | 63 | return map(get_focal_for_test, collect_test_funcs(ast_util)) 64 | 65 | 66 | @run_with_timeout 67 | def collect_from_repo( 68 | repo_id: str, repo_root: str, test_root: str, focal_root: str 69 | ): # pylint: disable=unused-argument 70 | """collect all test functions in the given project 71 | return (status, nfile, ntest) 72 | status can be 0: success, 1: repo not found, 2: test not found, 3: skip when output file existed 73 | """ 74 | repo_path = os.path.join(repo_root, wrap_repo(repo_id)) 75 | if not os.path.exists(repo_path) or not os.path.isdir(repo_path): 76 | return 1, 0, 0 77 | focal_path = os.path.join(focal_root, wrap_repo(repo_id) + ".jsonl") 78 | # skip if exist 79 | if os.path.exists(focal_path): 80 | return 3, 0, 0 81 | # collect potential testing modules 82 | all_files = collect_test_files(repo_path) 83 | all_files = list(all_files) 84 | tests = {} 85 | for f in all_files: 86 | funcs = collect_test_n_focal(f) 87 | tests[f] = funcs 88 | 89 | if len(tests.keys()) == 0: 90 | return 2, 0, sum(len(list(v)) for v in tests.values()) 91 | # save to disk 92 | n_test_func = 0 93 | n_focal_func = 0 94 | with open(focal_path, "w") as outfile: 95 | for k, ds in tests.items(): 96 | for d in ds: 97 | test_id = f"{k.removeprefix(repo_root)}::{d['test_id']}" 98 | d["test_id"] = test_id[1:] if test_id[0] == "/" else test_id 99 | if d["focal_loc"] is None: 100 | continue 101 | outfile.write(json.dumps(d) + "\n") 102 | n_test_func += int(d["test_loc"] is not None) 103 | n_focal_func += int(d["focal_loc"] is not None) 104 | return 0, n_test_func, n_focal_func 105 | 106 | 107 | def main( 108 | repo_id: str = "mistifyio/go-zfs", 109 | repo_root: str = "data/repos/", 110 | test_root: str = "data/tests/", 111 | focal_root: str = "data/focal/", 112 | timeout: int = 120, 113 | nprocs: int = 0, 114 | limits: int = -1, 115 | ): 116 | try: 117 | repo_id_list = [l.strip() for l in open(repo_id, "r").readlines()] 118 | except FileNotFoundError: 119 | repo_id_list = [repo_id] 120 | if limits > 0: 121 | repo_id_list = repo_id_list[:limits] 122 | print(f"Loaded {len(repo_id_list)} repos to be processed") 123 | 124 | # collect focal function from each repo 125 | status_ntest_nfocal = mp_map_repos( 126 | collect_from_repo, 127 | repo_id_list=repo_id_list, 128 | nprocs=nprocs, 129 | timeout=timeout, 130 | repo_root=repo_root, 131 | test_root=test_root, 132 | focal_root=focal_root, 133 | ) 134 | 135 | filtered_results = [i for i in status_ntest_nfocal if i is not None] 136 | if len(filtered_results) < len(status_ntest_nfocal): 137 | print(f"{len(status_ntest_nfocal) - len(filtered_results)} repos timeout") 138 | status, ntest, nfocal = zip(*filtered_results) 139 | status_counter: Counter[int] = Counter(status) 140 | print( 141 | f"Processed {sum(status_counter.values())} repos with", 142 | f"{status_counter[3]} skipped, {status_counter[1]} not found,", 143 | f"and {status_counter[2]} failed to locate any focal functions", 144 | ) 145 | print(f"Collected {sum(nfocal)} focal functions for {sum(ntest)} tests") 146 | print("Done!") 147 | 148 | 149 | if __name__ == "__main__": 150 | fire.Fire(main) 151 | -------------------------------------------------------------------------------- /tests/frontend/test_java.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import logging 4 | from returns.maybe import Nothing, Some 5 | from frontend.java.collect_focal import get_focal_call 6 | from frontend.parser.ast_util import ASTUtil 7 | from frontend.parser import JAVA_LANGUAGE 8 | 9 | 10 | class TestJavaFrontend(unittest.TestCase): 11 | def __test_focal_helper(self, code: str): 12 | ast_util = ASTUtil(code) 13 | tree = ast_util.tree(JAVA_LANGUAGE) 14 | root_node = tree.root_node 15 | 16 | fn = ast_util.get_all_nodes_of_type(root_node, "method_declaration")[0] 17 | return get_focal_call(ast_util, fn) 18 | 19 | def test_focal(self): 20 | """ 21 | For a regular @Test function, with function call in `assertThat`, 22 | that function call should be the focal 23 | """ 24 | code = """ 25 | @Test 26 | void catalogLoads() { 27 | @SuppressWarnings("rawtypes") 28 | ResponseEntity entity = new TestRestTemplate() 29 | .getForEntity("http://localhost:" + this.port + "/context/eureka/apps", Map.class); 30 | assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK); 31 | String computedPath = entity.getHeaders().getFirst("X-Version-Filter-Computed-Path"); 32 | assertThat(computedPath).isEqualTo("/context/eureka/v2/apps"); 33 | }""" 34 | 35 | name, loc = self.__test_focal_helper(code).unwrap() 36 | self.assertEqual(name, "getStatusCode") 37 | self.assertEqual(loc, (6, 19)) 38 | 39 | code = """ 40 | @Test 41 | void testAdd() { 42 | assertThat(add(1, 2)).isEqualTo(3); 43 | }""" 44 | name, loc = self.__test_focal_helper(code).unwrap() 45 | self.assertEqual(name, "add") 46 | self.assertEqual(loc, (3, 15)) 47 | 48 | code = """ 49 | @Test 50 | void testCompareTo() { 51 | assertTrue(0 == Status.INITIAL.compareTo(Status.INITIAL)); 52 | assertTrue(0 > Status.INITIAL.compareTo(Status.TRANSLATED)); 53 | assertTrue(0 > Status.INITIAL.compareTo(Status.VERIFIED)); 54 | assertTrue(0 > Status.INITIAL.compareTo(Status.SPECIAL)); 55 | 56 | assertTrue(0 < Status.TRANSLATED.compareTo(Status.INITIAL)); 57 | assertTrue(0 == Status.TRANSLATED.compareTo(Status.TRANSLATED)); 58 | assertTrue(0 > Status.TRANSLATED.compareTo(Status.VERIFIED)); 59 | assertTrue(0 > Status.TRANSLATED.compareTo(Status.SPECIAL)); 60 | 61 | assertTrue(0 < Status.VERIFIED.compareTo(Status.INITIAL)); 62 | assertTrue(0 < Status.VERIFIED.compareTo(Status.TRANSLATED)); 63 | assertTrue(0 == Status.VERIFIED.compareTo(Status.VERIFIED)); 64 | assertTrue(0 > Status.VERIFIED.compareTo(Status.SPECIAL)); 65 | 66 | assertTrue(0 < Status.SPECIAL.compareTo(Status.INITIAL)); 67 | assertTrue(0 < Status.SPECIAL.compareTo(Status.TRANSLATED)); 68 | assertTrue(0 < Status.SPECIAL.compareTo(Status.VERIFIED)); 69 | assertTrue(0 == Status.SPECIAL.compareTo(Status.SPECIAL)); 70 | }""" 71 | name, loc = self.__test_focal_helper(code).unwrap() 72 | self.assertEqual(name, "compareTo") 73 | self.assertEqual(loc, (3, 35)) 74 | 75 | def test_focal_not_in_assert(self): 76 | """ 77 | if there is no function call in the first `assertThat`, 78 | the the last call before first `assertThat` is the focal 79 | """ 80 | code = """ 81 | @Test 82 | void catalogLoads() { 83 | @SuppressWarnings("rawtypes") 84 | ResponseEntity entity = new TestRestTemplate() 85 | .getForEntity("http://localhost:" + this.port + "/context/eureka/apps", Map.class); 86 | String computedPath = entity.getHeaders().getFirst("X-Version-Filter-Computed-Path"); 87 | assertThat(computedPath).isEqualTo("/context/eureka/v2/apps"); 88 | }""" 89 | 90 | name, loc = self.__test_focal_helper(code).unwrap() 91 | self.assertEqual(name, "getFirst") 92 | self.assertEqual(loc, (6, 43)) 93 | 94 | code = """ 95 | @Test 96 | void testAdd() { 97 | int z = add(1, 2); 98 | assertThat(z).isEqualTo(3); 99 | }""" 100 | name, loc = self.__test_focal_helper(code).unwrap() 101 | self.assertEqual(name, "add") 102 | self.assertEqual(loc, (3, 12)) 103 | 104 | def test_focal_not_assert(self): 105 | """If no assert in test function, then fail""" 106 | 107 | code = "@Test\nvoid testNothing() {\n}" 108 | self.assertEqual(self.__test_focal_helper(code), Nothing) 109 | 110 | def test_focal_in_branch(self): 111 | code = """ 112 | @Test 113 | public void testInputParts(ServiceTransformationEngine transformationEngine, @All ServiceManager serviceManager) throws Exception { 114 | 115 | //check and import services 116 | checkAndImportServices(transformationEngine, serviceManager); 117 | 118 | URI op = findServiceURI(serviceManager, "serv1323166560"); 119 | String[] expected = {"con241744282", "con1849951292", "con1653328292"}; 120 | if (op != null) { 121 | Set ops = serviceManager.listOperations(op); 122 | Set inputs = serviceManager.listInputs(ops.iterator().next()); 123 | Set parts = new HashSet(serviceManager.listMandatoryParts(inputs.iterator().next())); 124 | assertTrue(parts.size() == 3); 125 | for (URI part : parts) { 126 | boolean valid = false; 127 | for (String expectedInput : expected) { 128 | if (part.toASCIIString().contains(expectedInput)) { 129 | valid = true; 130 | break; 131 | } 132 | } 133 | assertTrue(valid); 134 | } 135 | } else { 136 | fail(); 137 | } 138 | 139 | serviceManager.shutdown(); 140 | }""" 141 | name, loc = self.__test_focal_helper(code).unwrap() 142 | self.assertEqual(name, "size") 143 | self.assertEqual(loc, (13, 25)) 144 | 145 | 146 | if __name__ == "__main__": 147 | logging.basicConfig(level=logging.INFO) 148 | unittest.main() 149 | -------------------------------------------------------------------------------- /frontend/cpp/collect_all.py: -------------------------------------------------------------------------------- 1 | """main script for C/C++ frontend""" 2 | 3 | from typing import Iterable 4 | import fire 5 | import os 6 | from tree_sitter import Node 7 | from frontend.parser import CPP_LANGUAGE 8 | from frontend.parser.ast_util import ASTUtil 9 | from returns.maybe import Maybe, Nothing, Some 10 | from frontend.cpp.collect_focal import get_focal_call, is_test_fn 11 | from unitsyncer.util import replace_tabs 12 | import json 13 | from frontend.util import mp_map_repos, wrap_repo, run_with_timeout 14 | from collections import Counter 15 | from unitsyncer.util import get_cpp_func_name 16 | 17 | 18 | def has_test(file_path): 19 | def has_google_test(code): 20 | return '#include "gtest/gtest.h"' in code 21 | 22 | try: 23 | with open(file_path, "r", errors="replace") as f: 24 | code = f.read() 25 | except FileNotFoundError: 26 | return False 27 | return has_google_test(code) 28 | 29 | 30 | def collect_test_files(root: str): 31 | """Get all files end with .java in the given root directory 32 | 33 | Args: 34 | root (str): path to repo root 35 | """ 36 | for dirpath, _, filenames in os.walk(root): 37 | for filename in filenames: 38 | if filename.endswith((".c", ".cpp", ".cxx", ".cc")): 39 | if has_test(p := os.path.join(dirpath, filename)): 40 | yield p 41 | 42 | 43 | def collect_test_funcs(ast_util: ASTUtil) -> Iterable[Node]: 44 | """collect testing functions from the target file""" 45 | 46 | tree = ast_util.tree(CPP_LANGUAGE) 47 | root_node = tree.root_node 48 | 49 | defns = ast_util.get_all_nodes_of_type(root_node, "function_definition") 50 | 51 | return filter(lambda n: is_test_fn(n, ast_util), defns) 52 | 53 | 54 | def collect_test_n_focal(file_path: str): 55 | with open(file_path, "r", errors="replace") as f: 56 | ast_util = ASTUtil(replace_tabs(f.read())) 57 | 58 | def get_focal_for_test(test_func: Node): 59 | focal, focal_loc = get_focal_call(ast_util, test_func).value_or((None, None)) 60 | 61 | test_params = ast_util.get_all_nodes_of_type(test_func, "parameter_declaration") 62 | test_name = ast_util.get_source_from_node(test_params[0]) 63 | return { 64 | "test_id": test_name, 65 | "test_loc": test_func.start_point, 66 | "test": ast_util.get_source_from_node(test_func), 67 | "focal_id": focal, 68 | "focal_loc": focal_loc, 69 | } 70 | 71 | return map(get_focal_for_test, collect_test_funcs(ast_util)) 72 | 73 | 74 | @run_with_timeout 75 | def collect_from_repo( 76 | repo_id: str, repo_root: str, test_root: str, focal_root: str 77 | ): # pylint: disable=unused-argument 78 | """collect all test functions in the given project 79 | return (status, nfile, ntest) 80 | status can be 0: success, 1: repo not found, 2: test not found, 3: skip when output file existed 81 | """ 82 | repo_path = os.path.join(repo_root, wrap_repo(repo_id)) 83 | if not os.path.exists(repo_path) or not os.path.isdir(repo_path): 84 | return 1, 0, 0 85 | focal_path = os.path.join(focal_root, wrap_repo(repo_id) + ".jsonl") 86 | # skip if exist 87 | if os.path.exists(focal_path): 88 | return 3, 0, 0 89 | # collect potential testing modules 90 | all_files = collect_test_files(repo_path) 91 | tests = {} 92 | for f in all_files: 93 | funcs = collect_test_n_focal(f) 94 | tests[f] = funcs 95 | if len(tests.keys()) == 0: 96 | return 2, 0, sum(len(list(v)) for v in tests.values()) 97 | # save to disk 98 | n_test_func = 0 99 | n_focal_func = 0 100 | with open(focal_path, "w") as outfile: 101 | for k, ds in tests.items(): 102 | for d in ds: 103 | test_id = f"{k.removeprefix(repo_root)}::{d['test_id']}" 104 | d["test_id"] = test_id[1:] if test_id[0] == "/" else test_id 105 | if d["focal_loc"] is None: 106 | continue 107 | outfile.write(json.dumps(d) + "\n") 108 | n_test_func += int(d["test_loc"] is not None) 109 | n_focal_func += int(d["focal_loc"] is not None) 110 | return 0, n_test_func, n_focal_func 111 | 112 | 113 | def main( 114 | repo_id: str = "llvm/llvm-project", 115 | repo_root: str = "data/repos/", 116 | test_root: str = "data/tests/", 117 | focal_root: str = "data/focal/", 118 | timeout: int = 120, 119 | nprocs: int = 0, 120 | limits: int = -1, 121 | ): 122 | try: 123 | repo_id_list = [l.strip() for l in open(repo_id, "r").readlines()] 124 | except FileNotFoundError: 125 | repo_id_list = [repo_id] 126 | if limits > 0: 127 | repo_id_list = repo_id_list[:limits] 128 | print(f"Loaded {len(repo_id_list)} repos to be processed") 129 | 130 | # collect focal function from each repo 131 | status_ntest_nfocal = mp_map_repos( 132 | collect_from_repo, 133 | repo_id_list=repo_id_list, 134 | nprocs=nprocs, 135 | timeout=timeout, 136 | repo_root=repo_root, 137 | test_root=test_root, 138 | focal_root=focal_root, 139 | ) 140 | 141 | filtered_results = [i for i in status_ntest_nfocal if i is not None] 142 | if len(filtered_results) < len(status_ntest_nfocal): 143 | print(f"{len(status_ntest_nfocal) - len(filtered_results)} repos timeout") 144 | status, ntest, nfocal = zip(*filtered_results) 145 | status_counter: Counter[int] = Counter(status) 146 | print( 147 | f"Processed {sum(status_counter.values())} repos with", 148 | f"{status_counter[3]} skipped, {status_counter[1]} not found,", 149 | f"and {status_counter[2]} failed to locate any focal functions", 150 | ) 151 | print(f"Collected {sum(nfocal)} focal functions for {sum(ntest)} tests") 152 | print("Done!") 153 | 154 | 155 | if __name__ == "__main__": 156 | fire.Fire(main) 157 | -------------------------------------------------------------------------------- /frontend/util.py: -------------------------------------------------------------------------------- 1 | """util functions for UniTSyncer frontend""" 2 | import os 3 | import sys 4 | import ast 5 | import json 6 | import time 7 | import signal 8 | import datetime 9 | import functools 10 | import contextlib 11 | from tqdm import tqdm 12 | from typing import ( 13 | Optional, 14 | Callable, 15 | List, 16 | Any, 17 | Dict, 18 | Iterable, 19 | Set, 20 | Tuple, 21 | TypeVar, 22 | ) 23 | from pathos.multiprocessing import ProcessPool 24 | import subprocess 25 | 26 | 27 | class Timing: 28 | """providing timing as context or functions""" 29 | 30 | queue: list = [] 31 | 32 | def __enter__(self): 33 | self.tic() 34 | return self 35 | 36 | def __exit__(self, *args): 37 | self.tac() 38 | 39 | @staticmethod 40 | def tic(): 41 | Timing.queue.append(time.time()) 42 | 43 | @staticmethod 44 | def tac(): 45 | assert Timing.queue, "Call Timing.tic before" 46 | start_at = Timing.queue.pop() 47 | print(f"Elapsed {datetime.timedelta(seconds=time.time() - start_at)}") 48 | 49 | 50 | def log_or_skip(path: Optional[str] = None, handler=json.dumps, **kwargs): 51 | """log kwargs if path is provided with handler for preprocessing""" 52 | if not path: 53 | return 54 | with open(path, "a") as outfile: 55 | to_log = kwargs 56 | if handler: 57 | to_log = handler(to_log) 58 | outfile.write(f"{to_log}\n") 59 | 60 | 61 | def wrap_repo(name: str): 62 | """wrap repo name from username/repo into username?repo""" 63 | return "-".join(name.split("/")) 64 | 65 | 66 | class TimeoutException(Exception): 67 | """Wrapper Exception for Frontend Timeout""" 68 | 69 | pass # pylint: disable=unnecessary-pass 70 | 71 | 72 | @contextlib.contextmanager 73 | def time_limit(seconds: float): 74 | def signal_handler(signum, frame): 75 | raise TimeoutException("Timed out!") 76 | 77 | signal.setitimer(signal.ITIMER_REAL, seconds) 78 | signal.signal(signal.SIGALRM, signal_handler) 79 | try: 80 | yield 81 | finally: 82 | signal.setitimer(signal.ITIMER_REAL, 0) 83 | 84 | 85 | def timestamp(frmt="%Y-%m-%d %H:%M:%S"): 86 | return datetime.datetime.now().strftime(frmt) 87 | 88 | 89 | def timeout_wrapper(handler: Callable, timeout: int = -1): 90 | """return None if timeout instead of raising error""" 91 | 92 | def inner(*args, **kwargs): 93 | if timeout <= 0: 94 | return handler(*args, **kwargs) 95 | try: 96 | with time_limit(timeout): 97 | return handler(*args, **kwargs) 98 | except TimeoutException: 99 | pass 100 | return None 101 | 102 | return inner 103 | 104 | 105 | def mp_map_repos(handler: Callable, repo_id_list: List[str], nprocs: int = 0, **kwargs): 106 | """conduct an unorder map at the level of repo using handler 107 | make sure handler take a string of repo_id as the first ordered arg 108 | other args can be passed as named args by kwargs 109 | """ 110 | results = [] 111 | if nprocs < 1: 112 | for repo_id in (pbar := tqdm(repo_id_list)): 113 | pbar.set_description(f"{timestamp()} Processing {repo_id}") 114 | results.append(handler(repo_id, **kwargs)) 115 | else: 116 | with ProcessPool(nprocs) as p: 117 | with tqdm(total=len(repo_id_list)) as pbar: 118 | for status in p.uimap( 119 | functools.partial(handler, **kwargs), repo_id_list 120 | ): 121 | results.append(status) 122 | pbar.set_description(timestamp()) 123 | pbar.update() 124 | return results 125 | 126 | 127 | def run_with_timeout(func: Callable): 128 | """run a function with timeout""" 129 | 130 | def wrapper(*args, timeout=-1, **kwargs): 131 | if timeout <= 0: 132 | return func(*args, **kwargs) 133 | try: 134 | with time_limit(timeout): 135 | return func(*args, **kwargs) 136 | except TimeoutException: 137 | pass 138 | return None 139 | 140 | return wrapper 141 | 142 | 143 | __T = TypeVar("__T") 144 | __R = TypeVar("__R") 145 | 146 | 147 | def parallel_subprocess( 148 | iterable: Iterable[__T], 149 | jobs: int, 150 | subprocess_creator: Callable[[__T], subprocess.Popen], 151 | on_exit: Optional[Callable[[subprocess.Popen], __R]] = None, 152 | use_tqdm=True, 153 | tqdm_leave=True, 154 | tqdm_msg="", 155 | ) -> Dict[__T, __R]: 156 | """ 157 | Creates `jobs` subprocesses that run in parallel. 158 | `iter` contains input that is send to each subprocess. 159 | `subprocess_creator` creates the subprocess and returns a `Popen`. 160 | After each subprocess ends, `on_exit` will go collect user defined input and return. 161 | The return valus is a dictionary of inputs and outputs. 162 | 163 | User has to guarantee elements in `iter` is unique, or the output may be incorrect. 164 | """ 165 | ret = {} 166 | processes: Set[Tuple[subprocess.Popen, __T]] = set() 167 | if use_tqdm: 168 | iterable = tqdm(iterable, leave=tqdm_leave, desc=tqdm_msg) 169 | for _input in iterable: 170 | processes.add((subprocess_creator(_input), _input)) 171 | if len(processes) >= jobs: 172 | # wait for a child process to exit 173 | os.wait() 174 | exited_processes = [(p, i) for p, i in processes if p.poll() is not None] 175 | for p, i in exited_processes: 176 | processes.remove((p, i)) 177 | if on_exit is not None: 178 | ret[i] = on_exit(p) 179 | # wait for remaining processes to exit 180 | for p, i in processes: 181 | p.wait() 182 | # let `on_exit` to decide wait for or kill the process 183 | if on_exit is not None: 184 | ret[i] = on_exit(p) 185 | return ret 186 | -------------------------------------------------------------------------------- /frontend/python/navigate.py: -------------------------------------------------------------------------------- 1 | """util functions to navigate in python repo""" 2 | 3 | import ast 4 | from typing import Optional, Callable, List, Union 5 | 6 | 7 | class ModuleNavigator: 8 | """provide utils function using ast""" 9 | 10 | def __init__(self, path: str): 11 | self.path = path 12 | with open(path, "r", errors="replace") as fp: 13 | self.ast = ast.parse(fp.read()) 14 | self.nodes, self.parents = flatten(self.ast) 15 | 16 | @staticmethod 17 | def build(path: str): 18 | try: 19 | nav = ModuleNavigator(path) 20 | return nav 21 | except SyntaxError: 22 | return None 23 | 24 | def find_all(self, ntype: Union[type, Callable], root: Optional[ast.AST] = None): 25 | if root is None: 26 | root, nodes = self.ast, self.nodes 27 | else: 28 | nodes = None 29 | return find_all(root, ntype, nodes=nodes) 30 | 31 | def find_by_name(self, name: str, root: Optional[ast.AST] = None): 32 | if root is None: 33 | root, nodes = self.ast, self.nodes 34 | else: 35 | nodes = None 36 | return find_by_name(root, name, nodes=nodes) 37 | 38 | def get_path_to(self, node: ast.AST): 39 | return get_path_to(node, self.nodes, self.parents) 40 | 41 | def postorder(self, root: Optional[ast.AST] = None): 42 | nodes = [] 43 | 44 | def walk(n): 45 | children: list[ast.AST] = [] 46 | for f in getattr(n, "_fields", []): 47 | field = getattr(n, f, []) 48 | if isinstance(field, (tuple, list)): 49 | children.extend(field) 50 | else: 51 | children.append(field) 52 | for child in children: 53 | walk(child) 54 | nodes.append(n) 55 | 56 | walk(root if root is not None else self.ast) 57 | return nodes 58 | 59 | @property 60 | def total_lines(self) -> int: 61 | line_numbers = { 62 | node.lineno for node in ast.walk(self.ast) if hasattr(node, "lineno") 63 | } 64 | return len(line_numbers) 65 | 66 | def __str__(self): 67 | return ast.dump(self.ast) 68 | 69 | 70 | def flatten(root: ast.AST): 71 | """flatten an ast pre-order""" 72 | nodes: list[ast.AST] = [] 73 | parents: list[int] = [] 74 | 75 | def walk(n, p=None): 76 | nidx = len(nodes) 77 | nodes.append(n) 78 | parents.append(p) 79 | children: list[ast.AST] = [] 80 | for f in getattr(n, "_fields", []): 81 | field = getattr(n, f, []) 82 | if isinstance(field, (tuple, list)): 83 | children.extend(field) 84 | else: 85 | children.append(field) 86 | for child in children: 87 | walk(child, nidx) 88 | 89 | walk(root) 90 | assert len(nodes) == len(parents) 91 | return nodes, parents 92 | 93 | 94 | def find_all( 95 | root: ast.AST, 96 | condition: Union[type, Callable], 97 | nodes: Optional[List[ast.AST]] = None, 98 | ): 99 | """return all nodes of the desired type""" 100 | if nodes is None: 101 | nodes, _ = flatten(root) 102 | if isinstance(condition, type): 103 | _filter = lambda x: isinstance(x, condition) 104 | else: 105 | _filter = condition 106 | return [node for node in nodes if _filter(node)] 107 | 108 | 109 | def find_by_name(root: ast.AST, name: str, nodes: Optional[List[ast.AST]] = None): 110 | """find node by name, return the first one if duplicated""" 111 | if nodes is None: 112 | nodes, _ = flatten(root) 113 | for node in nodes: 114 | if getattr(node, "name", None) == name: 115 | return node 116 | return None 117 | 118 | 119 | def get_path_to( 120 | target: ast.AST, 121 | nodes: list[ast.AST], 122 | parents: list[int], 123 | ): 124 | 125 | # find the path to target bottom-up 126 | try: 127 | target_idx = nodes.index(target) 128 | except ValueError: 129 | return None 130 | path = [] 131 | while target_idx is not None: 132 | path.append(nodes[target_idx]) 133 | target_idx = parents[target_idx] 134 | return path[::-1] 135 | 136 | 137 | def dump_ast_func( 138 | func: ast.FunctionDef, 139 | path: str, 140 | nav: Optional[ModuleNavigator] = None, 141 | ancestors: Optional[List[ast.AST]] = None, 142 | return_nav: Optional[bool] = False, 143 | ): 144 | """converts an ast node of function into string""" 145 | if nav is None: 146 | nav = ModuleNavigator(path) 147 | if ancestors is None: 148 | ancestors = nav.get_path_to(func) 149 | classes = [n.name for n in ancestors if isinstance(n, ast.ClassDef)] 150 | func_id = "::".join([path] + classes + [func.name]) 151 | if not return_nav: 152 | return func_id 153 | return func_id, nav 154 | 155 | 156 | def load_ast_func( 157 | func_id: str, 158 | nav: Optional[ModuleNavigator] = None, 159 | return_nav: Optional[bool] = False, 160 | ): 161 | """convert a string to an ast node of function""" 162 | ancestors, node = func_id.split("::"), None 163 | path = ancestors.pop(0) 164 | if nav is None: 165 | nav = ModuleNavigator(path) 166 | while ancestors: 167 | node_id = ancestors.pop(0) 168 | node = nav.find_by_name(node_id, root=node) 169 | if not return_nav: 170 | return node 171 | return node, nav 172 | 173 | 174 | def is_assert(node: ast.AST): 175 | """tell if a node is an assertion""" 176 | if isinstance(node, ast.Assert): 177 | return True 178 | if isinstance(node, ast.Call): 179 | func = node.func 180 | if isinstance(func, ast.Name) and func.id.startswith("assert"): 181 | return True 182 | if isinstance(func, ast.Attribute) and func.attr.startswith("assert"): 183 | return True 184 | return False 185 | -------------------------------------------------------------------------------- /frontend/java/collect_all.py: -------------------------------------------------------------------------------- 1 | """main script for Java frontend""" 2 | 3 | from typing import Iterable 4 | import fire 5 | import os 6 | from tree_sitter import Node 7 | from frontend.parser import JAVA_LANGUAGE 8 | from frontend.parser.ast_util import ASTUtil 9 | from returns.maybe import Maybe, Nothing, Some 10 | from frontend.java.collect_focal import get_focal_call, is_test_fn 11 | from unitsyncer.util import replace_tabs 12 | import json 13 | from frontend.util import mp_map_repos, wrap_repo, run_with_timeout 14 | from collections import Counter 15 | 16 | 17 | def has_test(file_path): 18 | # follow TeCo to check for JUnit4 and JUnit5 19 | # todo: support different usage as in google/closure-compiler 20 | def has_junit4(code): 21 | return "@Test" in code and "import org.junit.Test" in code 22 | 23 | def has_junit5(code): 24 | return "@Test" in code and "import org.junit.jupiter.api.Test" in code 25 | 26 | with open(file_path, "r", errors="replace") as f: 27 | code = f.read() 28 | return has_junit4(code) or has_junit5(code) 29 | 30 | 31 | def collect_test_files(root: str): 32 | """Get all files end with .java in the given root directory 33 | 34 | Args: 35 | root (str): path to repo root 36 | """ 37 | for dirpath, _, filenames in os.walk(root): 38 | for filename in filenames: 39 | if filename.endswith(".java"): 40 | if has_test(p := os.path.join(dirpath, filename)): 41 | yield p 42 | 43 | 44 | def collect_test_funcs(ast_util: ASTUtil) -> Iterable[Node]: 45 | """collect testing functions from the target file""" 46 | 47 | tree = ast_util.tree(JAVA_LANGUAGE) 48 | root_node = tree.root_node 49 | 50 | decls = ast_util.get_all_nodes_of_type(root_node, "method_declaration") 51 | 52 | return filter(lambda n: is_test_fn(n, ast_util), decls) 53 | 54 | 55 | def collect_test_n_focal(file_path: str): 56 | with open(file_path, "r", errors="replace") as f: 57 | ast_util = ASTUtil(replace_tabs(f.read())) 58 | 59 | def get_focal_for_test(test_func: Node): 60 | test_name = ast_util.get_method_name(test_func).value_or(None) 61 | focal, focal_loc = get_focal_call(ast_util, test_func).value_or((None, None)) 62 | return { 63 | "test_id": test_name, 64 | "test_loc": test_func.start_point, 65 | "test": ast_util.get_source_from_node(test_func), 66 | "focal_id": focal, 67 | "focal_loc": focal_loc, 68 | } 69 | 70 | return map(get_focal_for_test, collect_test_funcs(ast_util)) 71 | 72 | 73 | @run_with_timeout 74 | def collect_from_repo( 75 | repo_id: str, repo_root: str, test_root: str, focal_root: str 76 | ): # pylint: disable=unused-argument 77 | """collect all test functions in the given project 78 | return (status, nfile, ntest) 79 | status can be 0: success, 1: repo not found, 2: test not found, 3: skip when output file existed 80 | """ 81 | repo_path = os.path.join(repo_root, wrap_repo(repo_id)) 82 | if not os.path.exists(repo_path) or not os.path.isdir(repo_path): 83 | return 1, 0, 0 84 | focal_path = os.path.join(focal_root, wrap_repo(repo_id) + ".jsonl") 85 | # skip if exist 86 | if os.path.exists(focal_path): 87 | return 3, 0, 0 88 | # collect potential testing modules 89 | all_files = collect_test_files(repo_path) 90 | tests = {} 91 | for f in all_files: 92 | funcs = collect_test_n_focal(f) 93 | tests[f] = funcs 94 | if len(tests.keys()) == 0: 95 | return 2, 0, sum(len(list(v)) for v in tests.values()) 96 | # save to disk 97 | n_test_func = 0 98 | n_focal_func = 0 99 | with open(focal_path, "w") as outfile: 100 | for k, ds in tests.items(): 101 | for d in ds: 102 | test_id = f"{k.removeprefix(repo_root)}::{d['test_id']}" 103 | d["test_id"] = test_id[1:] if test_id[0] == "/" else test_id 104 | if d["focal_loc"] is None: 105 | continue 106 | outfile.write(json.dumps(d) + "\n") 107 | n_test_func += int(d["test_loc"] is not None) 108 | n_focal_func += int(d["focal_loc"] is not None) 109 | return 0, n_test_func, n_focal_func 110 | 111 | 112 | def main( 113 | repo_id: str = "spring-cloud/spring-cloud-netflix", 114 | repo_root: str = "data/repos/", 115 | test_root: str = "data/tests/", 116 | focal_root: str = "data/focal/", 117 | timeout: int = 120, 118 | nprocs: int = 0, 119 | limits: int = -1, 120 | ): 121 | try: 122 | repo_id_list = [l.strip() for l in open(repo_id, "r").readlines()] 123 | except FileNotFoundError: 124 | repo_id_list = [repo_id] 125 | if limits > 0: 126 | repo_id_list = repo_id_list[:limits] 127 | print(f"Loaded {len(repo_id_list)} repos to be processed") 128 | 129 | # collect focal function from each repo 130 | status_ntest_nfocal = mp_map_repos( 131 | collect_from_repo, 132 | repo_id_list=repo_id_list, 133 | nprocs=nprocs, 134 | timeout=timeout, 135 | repo_root=repo_root, 136 | test_root=test_root, 137 | focal_root=focal_root, 138 | ) 139 | 140 | filtered_results = [i for i in status_ntest_nfocal if i is not None] 141 | if len(filtered_results) < len(status_ntest_nfocal): 142 | print(f"{len(status_ntest_nfocal) - len(filtered_results)} repos timeout") 143 | status, ntest, nfocal = zip(*filtered_results) 144 | status_counter: Counter[int] = Counter(status) 145 | print( 146 | f"Processed {sum(status_counter.values())} repos with", 147 | f"{status_counter[3]} skipped, {status_counter[1]} not found,", 148 | f"and {status_counter[2]} failed to locate any focal functions", 149 | ) 150 | print(f"Collected {sum(nfocal)} focal functions for {sum(ntest)} tests") 151 | print("Done!") 152 | 153 | 154 | if __name__ == "__main__": 155 | fire.Fire(main) 156 | -------------------------------------------------------------------------------- /frontend/rust/collect_all.py: -------------------------------------------------------------------------------- 1 | """Main script for Rust frontend""" 2 | 3 | from typing import Iterable 4 | import fire 5 | import os 6 | from tree_sitter import Node 7 | from frontend.parser import RUST_LANGUAGE 8 | from frontend.parser.ast_util import ASTUtil 9 | from frontend.rust.rust_util import get_test_functions, get_focal_call 10 | from unitsyncer.util import replace_tabs 11 | import json 12 | from frontend.util import mp_map_repos, wrap_repo, run_with_timeout 13 | from collections import Counter 14 | 15 | 16 | def has_test(file_path): 17 | with open(file_path, "r", errors="replace") as f: 18 | code = f.read() 19 | return "#[test]" in code 20 | 21 | 22 | def is_fuzz_test(filepath): 23 | return "tests-gen" in filepath and ".inputs.rs" in filepath 24 | 25 | 26 | def collect_test_files(root: str, fuzz: bool): 27 | """Get all files end with .java in the given root directory 28 | 29 | Args: 30 | root (str): path to repo root 31 | """ 32 | for dirpath, _, filenames in os.walk(root): 33 | if "test" not in dirpath: 34 | continue 35 | for filename in filenames: 36 | if filename.endswith(".rs"): 37 | if has_test(p := os.path.join(dirpath, filename)): 38 | if fuzz: 39 | if is_fuzz_test(p): 40 | yield p 41 | else: 42 | if not is_fuzz_test(p): 43 | yield p 44 | 45 | 46 | def collect_test_funcs(ast_util: ASTUtil) -> Iterable[Node]: 47 | """collect testing functions from the target file""" 48 | 49 | tree = ast_util.tree(RUST_LANGUAGE) 50 | root_node = tree.root_node 51 | 52 | return get_test_functions(ast_util, root_node) 53 | 54 | 55 | def collect_test_n_focal(file_path: str): 56 | with open(file_path, "r", errors="replace") as f: 57 | ast_util = ASTUtil(replace_tabs(f.read())) 58 | 59 | def get_focal_for_test(test_func: Node): 60 | test_name = ast_util.get_name(test_func).value_or(None) 61 | focal, focal_loc = get_focal_call(ast_util, test_func).value_or((None, None)) 62 | return { 63 | "test_id": test_name, 64 | "test_loc": test_func.start_point, 65 | "test": ast_util.get_source_from_node(test_func), 66 | "focal_id": focal, 67 | "focal_loc": focal_loc, 68 | } 69 | 70 | return map(get_focal_for_test, collect_test_funcs(ast_util)) 71 | 72 | 73 | @run_with_timeout 74 | def collect_from_repo( 75 | repo_id: str, repo_root: str, test_root: str, focal_root: str, fuzz: bool 76 | ): # pylint: disable=unused-argument 77 | """collect all test functions in the given project 78 | return (status, nfile, ntest) 79 | status can be 0: success, 1: repo not found, 2: test not found, 3: skip when output file existed 80 | """ 81 | repo_path = os.path.join(repo_root, wrap_repo(repo_id)) 82 | if not os.path.exists(repo_path) or not os.path.isdir(repo_path): 83 | return 1, 0, 0 84 | focal_path = os.path.join(focal_root, wrap_repo(repo_id) + ".jsonl") 85 | # skip if exist 86 | if os.path.exists(focal_path): 87 | return 3, 0, 0 88 | # collect potential testing modules 89 | all_files = collect_test_files(repo_path, fuzz) 90 | tests = {} 91 | for f in all_files: 92 | funcs = collect_test_n_focal(f) 93 | tests[f] = funcs 94 | if len(tests.keys()) == 0: 95 | return 2, 0, sum(len(list(v)) for v in tests.values()) 96 | # save to disk 97 | n_test_func = 0 98 | n_focal_func = 0 99 | with open(focal_path, "w") as outfile: 100 | for k, ds in tests.items(): 101 | for d in ds: 102 | if d is None: 103 | continue 104 | test_id = f"{k.removeprefix(repo_root)}::{d['test_id']}" 105 | d["test_id"] = test_id[1:] if test_id[0] == "/" else test_id 106 | if d["focal_loc"] is None: 107 | continue 108 | outfile.write(json.dumps(d) + "\n") 109 | n_test_func += int(d["test_loc"] is not None) 110 | n_focal_func += int(d["focal_loc"] is not None) 111 | return 0, n_test_func, n_focal_func 112 | 113 | 114 | def main( 115 | repo_id: str = "marshallpierce/rust-base64", 116 | repo_root: str = "data/repos/", 117 | test_root: str = "data/tests/", 118 | focal_root: str = "data/focal/", 119 | timeout: int = 120, 120 | nprocs: int = 0, 121 | limits: int = -1, 122 | fuzz: bool = True, 123 | ): 124 | try: 125 | repo_id_list = [l.strip() for l in open(repo_id, "r").readlines()] 126 | except FileNotFoundError: 127 | repo_id_list = [repo_id] 128 | if limits > 0: 129 | repo_id_list = repo_id_list[:limits] 130 | print(f"Loaded {len(repo_id_list)} repos to be processed") 131 | 132 | # collect focal function from each repo 133 | status_ntest_nfocal = mp_map_repos( 134 | collect_from_repo, 135 | repo_id_list=repo_id_list, 136 | nprocs=nprocs, 137 | timeout=timeout, 138 | repo_root=repo_root, 139 | test_root=test_root, 140 | focal_root=focal_root, 141 | fuzz=fuzz, 142 | ) 143 | 144 | filtered_results = [i for i in status_ntest_nfocal if i is not None] 145 | if len(filtered_results) < len(status_ntest_nfocal): 146 | print(f"{len(status_ntest_nfocal) - len(filtered_results)} repos timeout") 147 | status, ntest, nfocal = zip(*filtered_results) 148 | status_counter: Counter[int] = Counter(status) 149 | print( 150 | f"Processed {sum(status_counter.values())} repos with", 151 | f"{status_counter[3]} skipped, {status_counter[1]} not found,", 152 | f"and {status_counter[2]} failed to locate any focal functions", 153 | ) 154 | print(f"Collected {sum(nfocal)} focal functions for {sum(ntest)} tests") 155 | print("Done!") 156 | 157 | 158 | if __name__ == "__main__": 159 | fire.Fire(main) 160 | -------------------------------------------------------------------------------- /tests/test_prompt_header.py: -------------------------------------------------------------------------------- 1 | """tests for get_def_header module in evaluation/extract_def.py""" 2 | import os 3 | from unitsyncer.extract_def import get_def_header 4 | import unittest 5 | import logging 6 | from unitsyncer.common import UNITSYNCER_HOME 7 | 8 | 9 | class TestGetDefHeader(unittest.TestCase): 10 | def test_py(self): 11 | test = "\n\nMETADATA = {\n 'author': 'jt',\n 'dataset': 'test'\n}\n\n\ndef check(has_close_elements):\n assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\n assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\n assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\n assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\n assert has_close_elements([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\n assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\n assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\n\ncheck(has_close_elements)" 12 | 13 | self.assertEqual( 14 | get_def_header(test, "python"), "def check(has_close_elements):\n" 15 | ) 16 | 17 | def test_go(self): 18 | test = "func TestHasCloseElements(t *testing.T) {\n assert := assert.New(t)\n assert.Equal(true, HasCloseElements([]float64{11.0, 2.0, 3.9, 4.0, 5.0, 2.2}, 0.3))\n assert.Equal(false, HasCloseElements([]float64{1.0, 2.0, 3.9, 4.0, 5.0, 2.2}, 0.05))\n assert.Equal(true, HasCloseElements([]float64{1.0, 2.0, 5.9, 4.0, 5.0}, 0.95))\n assert.Equal(false, HasCloseElements([]float64{1.0, 2.0, 5.9, 4.0, 5.0}, 0.8))\n assert.Equal(true, HasCloseElements([]float64{1.0, 2.0, 3.0, 4.0, 5.0, 2.0}, 0.1))\n assert.Equal(true, HasCloseElements([]float64{1.1, 2.2, 3.1, 4.1, 5.1}, 1.0))\n assert.Equal(false, HasCloseElements([]float64{1.1, 2.2, 3.1, 4.1, 5.1}, 0.5))\n}\n" 19 | self.assertEqual( 20 | get_def_header(test, "go"), 21 | "func TestHasCloseElements(t *testing.T) {\n", 22 | ) 23 | 24 | def test_js(self): 25 | test = "const testHasCloseElements = () => {\n console.assert(hasCloseElements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) === true)\n console.assert(\n hasCloseElements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) === false\n )\n console.assert(hasCloseElements([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) === true)\n console.assert(hasCloseElements([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) === false)\n console.assert(hasCloseElements([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) === true)\n console.assert(hasCloseElements([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) === true)\n console.assert(hasCloseElements([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) === false)\n}\n\ntestHasCloseElements()\n" 26 | 27 | self.assertEqual( 28 | get_def_header(test, "js"), 29 | "const testHasCloseElements = () => {\n", 30 | ) 31 | 32 | def test_cpp(self): 33 | test = """ 34 | TEST(AsmWriterTest, DebugPrintDetachedArgument) { 35 | LLVMContext Ctx; 36 | auto Ty = Type::getInt32Ty(Ctx); 37 | auto Arg = new Argument(Ty); 38 | 39 | std::string S; 40 | raw_string_ostream OS(S); 41 | Arg->print(OS); 42 | EXPECT_EQ(S, "i32 "); 43 | delete Arg; 44 | }""" 45 | self.assertEqual( 46 | get_def_header(test, "cpp"), 47 | "TEST(AsmWriterTest, DebugPrintDetachedArgument) {\n", 48 | ) 49 | 50 | code = """ 51 | TEST(BFSTest, InstantiateGraphFromEdges) 52 | { 53 | Graph g({ {1, 2}, {1, 3}, {2, 3} }); 54 | 55 | std::vector bfs = g.BFS(1); 56 | std::vector expected{ 1, 2, 3 }; 57 | 58 | ASSERT_EQ(bfs, expected); 59 | } 60 | """ 61 | self.assertEqual( 62 | get_def_header(code, "cpp"), 63 | "TEST(BFSTest, InstantiateGraphFromEdges) {\n", 64 | ) 65 | 66 | def test_java(self): 67 | code = """@Test 68 | void catalogLoads() { 69 | @SuppressWarnings("rawtypes") 70 | ResponseEntity entity = new TestRestTemplate() 71 | .getForEntity("http://localhost:" + this.port + "/context/eureka/apps", Map.class); 72 | assertThat(entity.getStatusCode()).isEqualTo(HttpStatus.OK); 73 | String computedPath = entity.getHeaders().getFirst("X-Version-Filter-Computed-Path"); 74 | assertThat(computedPath).isEqualTo("/context/eureka/v2/apps"); 75 | }""" 76 | 77 | self.assertEqual(get_def_header(code, "java"), "@Test\nvoid catalogLoads() {\n") 78 | 79 | code = """@Test 80 | void testAdd() { 81 | assertThat(add(1, 2)).isEqualTo(3); 82 | }""" 83 | self.assertEqual(get_def_header(code, "java"), "@Test\nvoid testAdd() {\n") 84 | 85 | code = """@Test 86 | public void testInputParts(ServiceTransformationEngine transformationEngine, @All ServiceManager serviceManager) throws Exception { 87 | 88 | //check and import services 89 | checkAndImportServices(transformationEngine, serviceManager); 90 | 91 | URI op = findServiceURI(serviceManager, "serv1323166560"); 92 | String[] expected = {"con241744282", "con1849951292", "con1653328292"}; 93 | if (op != null) { 94 | Set ops = serviceManager.listOperations(op); 95 | Set inputs = serviceManager.listInputs(ops.iterator().next()); 96 | Set parts = new HashSet(serviceManager.listMandatoryParts(inputs.iterator().next())); 97 | assertTrue(parts.size() == 3); 98 | for (URI part : parts) { 99 | boolean valid = false; 100 | for (String expectedInput : expected) { 101 | if (part.toASCIIString().contains(expectedInput)) { 102 | valid = true; 103 | break; 104 | } 105 | } 106 | assertTrue(valid); 107 | } 108 | } else { 109 | fail(); 110 | } 111 | 112 | serviceManager.shutdown(); 113 | }""" 114 | 115 | self.assertEqual( 116 | get_def_header(code, "java"), 117 | "@Test\npublic void testInputParts(ServiceTransformationEngine transformationEngine, @All ServiceManager serviceManager) throws Exception {\n", 118 | ) 119 | 120 | 121 | if __name__ == "__main__": 122 | logging.basicConfig(level=logging.INFO) 123 | unittest.main() 124 | -------------------------------------------------------------------------------- /frontend/python/collect_focal.py: -------------------------------------------------------------------------------- 1 | """script to collect focal with cursor location""" 2 | import os 3 | import sys 4 | import ast 5 | import fire 6 | import json 7 | import jedi 8 | import pathlib 9 | import traceback 10 | import astunparse 11 | from tqdm import tqdm 12 | from typing import Optional 13 | from collections import Counter 14 | 15 | from frontend.util import wrap_repo, mp_map_repos, run_with_timeout, TimeoutException 16 | from navigate import ModuleNavigator, load_ast_func, dump_ast_func, is_assert 17 | from frontend.python.collect_focal_org import ( 18 | NotFoundException, 19 | parse_func_name, 20 | is_subpath, 21 | jedi2ast, 22 | collect_from_repo, 23 | ) 24 | 25 | 26 | def parse_focal_call(test_func: ast.AST, module: ModuleNavigator, repo: jedi.Project): 27 | """guess target focal function from testing function according to 28 | 1. trace back all the function calls from the first assertion and 29 | return the first call that invoke a function definition in the repo 30 | 2. if found focal class by removing "Test" in name, 31 | """ 32 | script = jedi.Script(path=module.path, project=repo) 33 | last_call = None # pylint: disable=unused-variable 34 | calls_before_assert, found_assert = [], False 35 | for node in module.postorder(root=test_func): 36 | found_assert |= is_assert(node) 37 | if isinstance(node, ast.Call): 38 | if not found_assert: 39 | calls_before_assert.append(node) 40 | last_call = node 41 | while calls_before_assert: 42 | node = calls_before_assert.pop() 43 | # col_offset should be shifted to get the definition of the target function 44 | # say we have the function being called is api.load_image_file 45 | # and the lineno and col_offset of it is x and y 46 | # if we call script.got(x, y), 47 | # jedi will try to find the definition of api rather than load_image_file 48 | # so we need to split api.load_image_file into (api, load_image_file) 49 | # then drop the last item to get (api,) 50 | # after that, shifted_col_offset is computed as col_offset + len(api) 51 | # then we can get the definition of the function of interests 52 | node_name = parse_func_name(node.func) 53 | shift_col_offset = node.col_offset + ( 54 | len(node_name) - len(node_name.split(".")[-1]) 55 | ) 56 | defs = script.goto( 57 | node.lineno, 58 | shift_col_offset, 59 | follow_imports=True, 60 | follow_builtin_imports=False, 61 | ) 62 | if ( 63 | len(defs) > 0 # and defs[0].type == 'function' 64 | and not defs[0].in_builtin_module() 65 | and is_subpath(repo.path, defs[0].module_path) 66 | ): 67 | return node, defs[0], node.lineno, shift_col_offset 68 | return None 69 | 70 | 71 | def collect_focal_func( 72 | repo_id: str = "ageitgey/face_recognition", 73 | test_id: str = "ageitgey-face_recognition/ageitgey-face_recognition-59cff93/tests/test_face_recognition.py::Test_face_recognition::test_load_image_file", 74 | iroot: str = "data/repos", 75 | repo: Optional[jedi.Project] = None, 76 | ): 77 | # construct jedi project if it is not given 78 | if repo is None: 79 | repo = jedi.Project(os.path.join(iroot, wrap_repo(repo_id))) 80 | # load testing function from its id 81 | test_func, test_mod = load_ast_func(os.path.join(iroot, test_id), return_nav=True) 82 | # find call to to the potential focal function 83 | result = parse_focal_call(test_func, test_mod, repo) 84 | if result is not None: 85 | ( 86 | focal_call, # pylint: disable=unused-variable 87 | focal_func_jedi, 88 | line, 89 | col, 90 | ) = result 91 | else: 92 | raise NotFoundException(f"Failed to find potential focal call in {test_id}") 93 | # convert focal_func from jedi Name to ast object 94 | result = jedi2ast(focal_func_jedi) 95 | if result is not None: 96 | focal_func, focal_mod = result 97 | else: 98 | raise NotFoundException( 99 | f"Failed to locate focal function {focal_func_jedi.full_name} for {test_id}" 100 | ) 101 | # get focal path to dump focal func 102 | focal_path = str(focal_func_jedi.module_path.relative_to(os.path.abspath(iroot))) 103 | focal_id = dump_ast_func(focal_func, focal_path, focal_mod) 104 | return focal_id, (line, col), (test_func.lineno, test_func.col_offset) 105 | 106 | 107 | def main( 108 | repo_id: str = "ageitgey/face_recognition", 109 | test_root: str = "data/tests", 110 | repo_root: str = "data/repos", 111 | focal_root: str = "data/focal", 112 | timeout: int = 300, 113 | nprocs: int = 0, 114 | limits: int = -1, 115 | ): 116 | try: 117 | repo_id_list = [l.strip() for l in open(repo_id, "r").readlines()] 118 | except FileNotFoundError: 119 | repo_id_list = [repo_id] 120 | if limits > 0: 121 | repo_id_list = repo_id_list[:limits] 122 | print(f"Loaded {len(repo_id_list)} repos to be processed") 123 | 124 | # collect focal function from each repo 125 | status_ntest_nfocal = mp_map_repos( 126 | collect_from_repo, 127 | repo_id_list=repo_id_list, 128 | nprocs=nprocs, 129 | timeout=timeout, 130 | repo_root=repo_root, 131 | test_root=test_root, 132 | focal_root=focal_root, 133 | ) 134 | 135 | filtered_results = [i for i in status_ntest_nfocal if i is not None] 136 | if len(filtered_results) < len(status_ntest_nfocal): 137 | print(f"{len(status_ntest_nfocal) - len(filtered_results)} repos timeout") 138 | status, ntest, nfocal = zip(*filtered_results) 139 | status_counter: Counter[int] = Counter(status) 140 | print( 141 | f"Processed {sum(status_counter.values())} repos with", 142 | f"{status_counter[3]} skipped, {status_counter[1]} not found,", 143 | f"and {status_counter[2]} failed to locate any focal functions", 144 | ) 145 | print(f"Collected {sum(nfocal)} focal functions for {sum(ntest)} tests") 146 | print("Done!") 147 | 148 | 149 | if __name__ == "__main__": 150 | fire.Fire(main) 151 | -------------------------------------------------------------------------------- /unitsyncer/rust_syncer.py: -------------------------------------------------------------------------------- 1 | """Replacement Synchronizer for Rust""" 2 | from typing import Optional 3 | from pip._vendor import tomli 4 | from os.path import join as pjoin, isfile, isdir, abspath 5 | import os 6 | from returns.maybe import Maybe, Nothing, Some 7 | from returns.result import Result, Success, Failure 8 | from frontend.parser.ast_util import ASTLoc, ASTUtil 9 | from frontend.parser import RUST_LANGUAGE 10 | from tree_sitter import Language, Parser, Tree 11 | from tree_sitter import Node 12 | from pylspclient.lsp_structs import Location, LANGUAGE_IDENTIFIER, Range, Position 13 | from unitsyncer.source_code import get_function_code 14 | from unitsyncer.util import path2uri, uri2path 15 | from returns.converters import maybe_to_result 16 | from unitsyncer.sync import Synchronizer 17 | from fuzzywuzzy import process 18 | from functools import partial 19 | 20 | 21 | class RustSynchronizer(Synchronizer): 22 | def __init__(self, workspace_dir: str, language="rust") -> None: 23 | super().__init__(workspace_dir, LANGUAGE_IDENTIFIER.RUST) 24 | self.file_func_map: dict[str, list[tuple[str, Node]]] = {} 25 | 26 | def initialize(self, timeout: int = 10): 27 | """index all files and functions in the workdir/src""" 28 | for root, _, files in os.walk(self.workspace_dir): 29 | for file in files: 30 | if file.endswith(".rs"): 31 | file_path = pjoin(root, file) 32 | funcs = self._get_file_functions(file_path) 33 | self.file_func_map[file_path] = funcs 34 | 35 | def _get_file_functions(self, file_path: str) -> list[tuple[str, Node]]: 36 | """get all function items in the given file 37 | 38 | Args: 39 | file_path (str): path to source code file 40 | 41 | Returns: 42 | list[tuple[str, Node]]: [(function_name, function_node)] 43 | """ 44 | with open(file_path) as code_file: 45 | ast_util = ASTUtil(code_file.read()) 46 | tree = ast_util.tree(RUST_LANGUAGE) 47 | nodes = ast_util.get_all_nodes_of_type(tree.root_node, "function_item") 48 | names = [ast_util.get_name(node).value_or("") for node in nodes] 49 | return list(zip(names, nodes)) 50 | 51 | def get_source_of_call( 52 | self, 53 | focal_name: str, 54 | file_path: Optional[str] = None, 55 | line: Optional[int] = None, 56 | col: Optional[int] = None, 57 | verbose: bool = False, 58 | ) -> Result[tuple[str, str | None, str | None], str]: 59 | match self.goto_definition(focal_name): 60 | case [] | None: 61 | return Failure("Not Definition Found") 62 | case [(_, loc), *_]: 63 | # todo: find best match based on imports of test file 64 | 65 | def not_found_error(_): 66 | file_path = uri2path(loc.uri).value_or("") 67 | lineno = loc.range.start.line 68 | col_offset = loc.range.start.character 69 | return f"Source code not found: {file_path}:{lineno}:{col_offset}" 70 | 71 | return ( 72 | maybe_to_result(get_function_code(loc, LANGUAGE_IDENTIFIER.RUST)) 73 | .alt(not_found_error) 74 | .bind( 75 | lambda t: Failure("Empty Source Code") 76 | if t[0] == "" 77 | else Success(t) 78 | ) 79 | ) 80 | case _: 81 | return Failure("Unexpected Error") 82 | 83 | def goto_definition(self, focal_name: str) -> list[tuple[str, Location]]: 84 | """get the definition of the given function name 85 | 86 | Args: 87 | focal_name (str): name of the function 88 | 89 | Returns: 90 | list[tuple[str, Location]]: [(source_file_path, location))], 91 | source_file_path is used for sorting the results 92 | """ 93 | results: list[tuple[str, Location]] = [] # [(file_path, location)] 94 | include_name: str 95 | base_name: str 96 | 97 | match focal_name.split("."): 98 | case [obj_name, *xs, method_name]: 99 | include_name = obj_name 100 | 101 | # if method_name is unwrap, use the previous splitted name as method_name 102 | if "unwrap" in method_name: 103 | method_name = obj_name if len(xs) == 0 else xs[-1] 104 | base_name = method_name.split("(")[0] 105 | case _: 106 | temp_name = focal_name.split("(")[0] 107 | include_name = temp_name 108 | base_name = temp_name 109 | 110 | for file_path, funcs in self.file_func_map.items(): 111 | for name, node in funcs: 112 | if name == base_name: 113 | uri = path2uri(file_path) 114 | range_ = Range( 115 | Position(*node.start_point), 116 | Position(*node.end_point), 117 | ) 118 | results.append((file_path, Location(uri, range_))) 119 | 120 | # sort by fuzzy match with base name 121 | return sorted( 122 | results, 123 | key=partial(self.fuzzy_comparator, include_name), 124 | reverse=True, 125 | ) 126 | 127 | def fuzzy_comparator(self, include_name: str, x: tuple[str, Location]) -> float: 128 | """similarity score of file path with include_name 129 | 130 | Args: 131 | include_name (str): **engine::GeneralPurpose::new(&URL_SAFE, PAD)**.encode(bytes) 132 | x (tuple[str, Location]): (file_path, location) 133 | 134 | Returns: 135 | float: similarity score 136 | """ 137 | file_path = x[0].split(self.workspace_dir)[-1] 138 | match process.extractOne(file_path, [include_name]): 139 | case (_, score): 140 | return score 141 | case _: 142 | return 0 143 | 144 | def stop(self): 145 | pass 146 | 147 | 148 | def main(): 149 | workdir = "data/repos/marshallpierce-rust-base64/marshallpierce-rust-base64-4ef33cc" 150 | lsp = RustSynchronizer(workdir) 151 | lsp.initialize() 152 | 153 | print( 154 | lsp.get_source_of_call( 155 | "engine::GeneralPurpose::new(&URL_SAFE, PAD).encode(bytes)" 156 | ) 157 | ) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /scripts/download_repos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import fire 4 | import json 5 | import time 6 | import random 7 | import tarfile 8 | import requests 9 | import calendar 10 | from tqdm import tqdm 11 | from github import Github, Repository, Auth 12 | from typing import Tuple, Union, Optional 13 | 14 | from frontend.util import log_or_skip, wrap_repo, time_limit, TimeoutException 15 | 16 | 17 | def fetch_repo(repo_id: str, timeout: int, hub: Optional[Github]): 18 | """fetch a repo""" 19 | hub = hub if hub is not None else Github() 20 | try: 21 | with time_limit(timeout): 22 | repo = hub.get_repo(repo_id) 23 | return True, repo 24 | except Exception as e: 25 | return False, f"Fetch repo failed: {e}" 26 | 27 | 28 | def fetch_archive(repo: Repository.Repository): 29 | """fetch the archive of a repo 30 | latest release -> latest tag -> latest commit 31 | """ 32 | # try latest release 33 | try: 34 | return True, repo.get_latest_release() 35 | except: 36 | pass 37 | # try latest tag 38 | try: 39 | return True, next(iter(repo.get_tags())) 40 | except: 41 | pass 42 | # try latest commit 43 | try: 44 | commit = next(iter(repo.get_commits())) 45 | commit.tarball_url = ( 46 | f"https://api.github.com/repos/{repo_id}/tarball/{commit.sha}" 47 | ) 48 | return True, commit 49 | except Exception as e: 50 | return False, f"Fetch archive failed: {e}" 51 | 52 | 53 | def download_archive(path: str, url: str, timeout: int): 54 | """download file from url to path with timeout limit""" 55 | try: 56 | with time_limit(timeout): 57 | resp = requests.get(url) 58 | resp.raise_for_status() 59 | except Exception as e: 60 | return False, f"Download failed: {e}" 61 | with open(path, "wb") as outfile: 62 | outfile.write(resp.content) 63 | return True, "" 64 | 65 | 66 | def download_repo( 67 | hub: Github, repo_id: str, path: str, fetch_timeout: int, download_timeout: int 68 | ): 69 | """return status: 70 | 0. successed 71 | 1. fetch repo failed 72 | 2. fetch archive failed 73 | 3. download archive failed 74 | """ 75 | # fetch the repo 76 | status, repo = fetch_repo(repo_id, timeout=fetch_timeout, hub=hub) 77 | if not status: 78 | return 1, (repo,) 79 | # fetch archive 80 | status, archive = fetch_archive(repo) 81 | if not status: 82 | return 2, (repo, archive) 83 | # download archive, skip downloading if target path existed 84 | if os.path.exists(path): 85 | status = True 86 | else: 87 | status, msg = download_archive(path, archive.tarball_url, download_timeout) 88 | if status: 89 | return 0, (repo, archive, path) 90 | return 3, (repo, archive, msg) 91 | 92 | 93 | def main( 94 | repo_id_list: str = "ageitgey/face_recognition", 95 | fetch_timeout: int = 30, 96 | download_timeout: int = 300, 97 | delay: Union[Tuple[int], int] = -1, 98 | oroot: str = "data/repos_tarball/", 99 | log: Optional[str] = "meta.jsonl", 100 | limits: int = -1, 101 | decompress: bool = False, 102 | oauth: str = None, 103 | ): 104 | if log: 105 | log = os.path.join(oroot, log) 106 | # declare github object 107 | # oauth is provided for rate limit: https://docs.github.com/en/rest/overview/resources-in-the-rest-api?apiVersion=2022-11-28#rate-limiting 108 | # 5k calls per hours if authorised, otherwise, 60 calls or some 109 | hub = None 110 | if oauth: 111 | try: 112 | hub = Github(auth=Auth.Token(oauth)) 113 | except: 114 | pass 115 | if not hub: 116 | hub = Github() 117 | # if repo_id_list is a file then load lines 118 | # otherwise it is the id of a specific repo 119 | try: 120 | repo_id_list = [l.strip() for l in open(repo_id_list, "r").readlines()] 121 | except: 122 | repo_id_list = [repo_id_list] 123 | if limits >= 0: 124 | repo_id_list = repo_id_list[:limits] 125 | print(f"Loaded {len(repo_id_list)} repos to be downloaded") 126 | failed = {"repo": 0, "archive": 0, "download": 0} 127 | for repo_id in (pbar := tqdm(repo_id_list)): 128 | # log repo_id and rate limits 129 | rate = hub.get_rate_limit() 130 | pbar.set_description( 131 | f"Downloading {repo_id}, Rate: {rate.core.remaining}/{rate.core.limit}" 132 | ) 133 | # download repo 134 | path = os.path.join(oroot, wrap_repo(repo_id)) + ".tar.gz" 135 | status, results = download_repo( 136 | hub, repo_id, path, fetch_timeout, download_timeout 137 | ) 138 | if status == 0: 139 | repo, archive, path = results 140 | err_msg = "" 141 | log_or_skip( 142 | log, 143 | repo_id=repo_id, 144 | repo=repo.clone_url, 145 | archive=archive.tarball_url, 146 | download=path, 147 | ) 148 | if decompress: 149 | try: 150 | tarfile.open(path).extractall(".".join(path.split(".")[:-2])) 151 | except: 152 | pass 153 | elif status == 1: 154 | failed["repo"] += 1 155 | err_msg = results[0] 156 | log_or_skip(log, repo_id=repo_id, repo=err_msg) 157 | elif status == 2: 158 | failed["archive"] += 1 159 | repo, err_msg = results 160 | log_or_skip(log, repo_id=repo_id, repo=repo.clone_url, archive=err_msg) 161 | elif status == 3: 162 | failed["download"] += 1 163 | repo, archive, err_msg = results 164 | log_or_skip( 165 | log, 166 | repo_id=repo_id, 167 | repo=repo.clone_url, 168 | archive=archive.tarball_url, 169 | download=err_msg, 170 | ) 171 | # delay 172 | sleep_time = delay if isinstance(delay, int) else random.randint(*delay) 173 | if "rate limit exceeded" in err_msg: 174 | reset_at = calendar.timegm(rate.core.reset.timetuple()) 175 | sleep_time = reset_at - calendar.timegm(time.gmtime()) + 5 176 | print(f"Rate limits exceeded, sleep for {sleep_time} seconds") 177 | if sleep_time > 0: 178 | time.sleep(sleep_time) 179 | 180 | if sum(failed.values()): 181 | print("Failed:", {key: val for key, val in failed.items() if val}) 182 | print("Done!") 183 | 184 | 185 | if __name__ == "__main__": 186 | fire.Fire(main) 187 | -------------------------------------------------------------------------------- /frontend/javascript/collect_all.py: -------------------------------------------------------------------------------- 1 | """main script for Javascript frontend""" 2 | 3 | from typing import Iterable 4 | import fire 5 | import os 6 | from tree_sitter import Node 7 | from frontend.parser import JAVASCRIPT_LANGUAGE 8 | from frontend.parser.ast_util import ASTUtil 9 | from returns.maybe import Maybe, Nothing, Some 10 | from unitsyncer.util import replace_tabs 11 | import json 12 | from frontend.util import mp_map_repos, wrap_repo, run_with_timeout 13 | from collections import Counter 14 | from frontend.javascript.js_util import js_get_test_args, get_focal_call, is_test_fn 15 | 16 | 17 | def has_test(file_path): 18 | # follow TeCo to check for JUnit4 and JUnit5 19 | # todo: support different usage as in google/closure-compiler 20 | def has_chai(code): 21 | return "require('chai')" in code or 'require("chai")' in code 22 | 23 | def has_jest(code): 24 | return "from '@jest/globals'" in code 25 | 26 | try: 27 | with open(file_path, "r", errors="replace") as f: 28 | code = f.read() 29 | except FileNotFoundError: 30 | return False 31 | return has_chai(code) or has_jest(code) or "describe(" in code 32 | 33 | 34 | def collect_test_files(root: str): 35 | """Get all files end with .java in the given root directory 36 | 37 | Args: 38 | root (str): path to repo root 39 | """ 40 | for dirpath, _, filenames in os.walk(root): 41 | # find js test dir 42 | if "test" not in os.path.relpath(dirpath, root): 43 | continue 44 | if "node_modules" in os.path.relpath(dirpath, root): 45 | continue 46 | 47 | for filename in filenames: 48 | if filename.endswith(".js"): 49 | if has_test(p := os.path.join(dirpath, filename)): 50 | yield p 51 | 52 | 53 | def collect_test_funcs(ast_util: ASTUtil) -> Iterable[Maybe[tuple[str, Node]]]: 54 | """collect testing functions from the target file""" 55 | 56 | tree = ast_util.tree(JAVASCRIPT_LANGUAGE) 57 | root_node = tree.root_node 58 | 59 | # js test function is a higher order function that takes a function as input 60 | call_exprs = ast_util.get_all_nodes_of_type(root_node, "call_expression") 61 | return map( 62 | lambda node: js_get_test_args(ast_util, node), 63 | filter(lambda n: is_test_fn(n, ast_util), call_exprs), 64 | ) 65 | 66 | 67 | def collect_test_n_focal(file_path: str): 68 | with open(file_path, "r", errors="replace") as f: 69 | ast_util = ASTUtil(replace_tabs(f.read())) 70 | 71 | def get_focal_for_test(test_func_pair: Maybe[tuple[str, Node]]): 72 | match test_func_pair: 73 | case Some((test_name, test_func)): 74 | focal, focal_loc = get_focal_call(ast_util, test_func).value_or( 75 | (None, None) 76 | ) 77 | return { 78 | "test_id": test_name, 79 | "test_loc": test_func.start_point, 80 | "test": ast_util.get_source_from_node(test_func), 81 | "focal_id": focal, 82 | "focal_loc": focal_loc, 83 | } 84 | case Nothing: 85 | return None 86 | 87 | return filter(None, map(get_focal_for_test, collect_test_funcs(ast_util))) 88 | 89 | 90 | @run_with_timeout 91 | def collect_from_repo( 92 | repo_id: str, repo_root: str, test_root: str, focal_root: str 93 | ): # pylint: disable=unused-argument 94 | """collect all test functions in the given project 95 | return (status, nfile, ntest) 96 | status can be 0: success, 1: repo not found, 2: test not found, 3: skip when output file existed 97 | """ 98 | repo_path = os.path.join(repo_root, wrap_repo(repo_id)) 99 | if not os.path.exists(repo_path) or not os.path.isdir(repo_path): 100 | return 1, 0, 0 101 | focal_path = os.path.join(focal_root, wrap_repo(repo_id) + ".jsonl") 102 | # skip if exist 103 | if os.path.exists(focal_path): 104 | return 3, 0, 0 105 | # collect potential testing modules 106 | all_files = collect_test_files(repo_path) 107 | tests = {} 108 | for f in all_files: 109 | funcs = collect_test_n_focal(f) 110 | tests[f] = funcs 111 | if len(tests.keys()) == 0: 112 | return 2, 0, sum(len(list(v)) for v in tests.values()) 113 | # save to disk 114 | n_test_func = 0 115 | n_focal_func = 0 116 | with open(focal_path, "w") as outfile: 117 | for k, ds in tests.items(): 118 | for d in ds: 119 | test_id = f"{k.removeprefix(repo_root)}::{d['test_id']}" 120 | d["test_id"] = test_id[1:] if test_id[0] == "/" else test_id 121 | if d["focal_loc"] is None: 122 | continue 123 | outfile.write(json.dumps(d) + "\n") 124 | n_test_func += int(d["test_loc"] is not None) 125 | n_focal_func += int(d["focal_loc"] is not None) 126 | return 0, n_test_func, n_focal_func 127 | 128 | 129 | def main( 130 | repo_id: str = "twolfson/fs-memory-store", 131 | repo_root: str = "data/repos/", 132 | test_root: str = "data/tests/", 133 | focal_root: str = "data/focal/", 134 | timeout: int = 120, 135 | nprocs: int = 0, 136 | limits: int = -1, 137 | ): 138 | try: 139 | repo_id_list = [l.strip() for l in open(repo_id, "r").readlines()] 140 | except FileExistsError: 141 | repo_id_list = [repo_id] 142 | if limits > 0: 143 | repo_id_list = repo_id_list[:limits] 144 | print(f"Loaded {len(repo_id_list)} repos to be processed") 145 | 146 | # collect focal function from each repo 147 | status_ntest_nfocal = mp_map_repos( 148 | collect_from_repo, 149 | repo_id_list=repo_id_list, 150 | nprocs=nprocs, 151 | timeout=timeout, 152 | repo_root=repo_root, 153 | test_root=test_root, 154 | focal_root=focal_root, 155 | ) 156 | 157 | filtered_results = [i for i in status_ntest_nfocal if i is not None] 158 | if len(filtered_results) < len(status_ntest_nfocal): 159 | print(f"{len(status_ntest_nfocal) - len(filtered_results)} repos timeout") 160 | status, ntest, nfocal = zip(*filtered_results) 161 | status_counter: Counter[int] = Counter(status) 162 | print( 163 | f"Processed {sum(status_counter.values())} repos with", 164 | f"{status_counter[3]} skipped, {status_counter[1]} not found,", 165 | f"and {status_counter[2]} failed to locate any focal functions", 166 | ) 167 | print(f"Collected {sum(nfocal)} focal functions for {sum(ntest)} tests") 168 | print("Done!") 169 | 170 | 171 | if __name__ == "__main__": 172 | fire.Fire(main) 173 | -------------------------------------------------------------------------------- /scripts/check_repo_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script checks if a repo fulfills certain requirements 3 | Can customize what requirements to check and values to check against 4 | 5 | 6 | References 7 | 8 | Parts 1 and 2 of a medium blog explaining the query: 9 | https://fabiomolinar.medium.com/using-githubs-graphql-to-retrieve-a-list-of-repositories-their-commits-and-some-other-stuff-ccbbb4e96d78 10 | https://fabiomolinar.medium.com/using-githubs-graphql-to-retrieve-a-list-of-repositories-their-commits-and-some-other-stuff-ce2f73432f7 11 | 12 | Repository Object: 13 | https://docs.github.com/en/graphql/reference/objects#repository 14 | 15 | """ 16 | from datetime import datetime 17 | import fire 18 | import orjson 19 | import os 20 | import sys 21 | from typing import Callable, Optional 22 | 23 | from scripts.common import get_graphql_data, check_metadata_decorator 24 | 25 | 26 | #### Requirement Callables #### 27 | @check_metadata_decorator 28 | def req_enough_stars(metadata: dict, req_stars: str = "10") -> bool: 29 | """Checks if Github repository has enough stars""" 30 | return metadata["stargazerCount"] >= int(req_stars) 31 | 32 | 33 | @check_metadata_decorator 34 | def req_latest_commit(metadata: dict, date_str: str = "2020-1-1") -> bool: 35 | """Checks if Github repository has a valid latest commit""" 36 | # latest_commit_date = datetime.fromisoformat(metadata["pushedAt"]) 37 | date_values = metadata["pushedAt"].split("T")[0].split("-") 38 | latest_commit_date = datetime( 39 | int(date_values[0]), int(date_values[1]), int(date_values[2]) 40 | ) 41 | date = date_str.split("-") 42 | req_date = datetime(int(date[0]), int(date[1]), int(date[2])) 43 | return latest_commit_date.date() > req_date.date() 44 | 45 | 46 | @check_metadata_decorator 47 | def req_language(metadata: dict, language: str = "java") -> bool: 48 | """Checks if Github repository has correct language""" 49 | return metadata["primaryLanguage"]["name"].lower() == language.lower() 50 | 51 | 52 | @check_metadata_decorator 53 | def req_fuzzers(metadata: dict) -> bool: 54 | """ "Checks if Github repository of Rust has fuzz file""" 55 | contents = metadata["object"]["entries"] 56 | for item in contents: 57 | if item["name"] == "fuzz" and item["type"] == "tree": 58 | return True 59 | return False 60 | 61 | 62 | #### End Requirement Callables #### 63 | 64 | 65 | def check_requirements( 66 | repo: str, 67 | requirements: list[Callable[[dict, Optional[str]], bool]], 68 | reqs: list[str], 69 | metadata: Optional[dict] = None, 70 | ) -> bool: 71 | """Checks if Github repository meets requirements 72 | 73 | Args: 74 | repo (str): Github repository in repo_owner/repo_name format 75 | requirements (list[callable]): List of requirement callables that check if repo meets requirement 76 | reqs (list[str]): List of values to check against for each callable 77 | - each req will be called with the callable of the same index in requirements 78 | 79 | Returns: 80 | bool: True if repo meets requirements, False otherwise 81 | """ 82 | 83 | repo_query = repo.split("/") 84 | 85 | # Get repo data 86 | gql_format = """ 87 | query { 88 | rateLimit { 89 | cost 90 | remaining 91 | resetAt 92 | } 93 | repository(name:"%s", owner:"%s"){ 94 | id 95 | owner { 96 | login 97 | } 98 | name 99 | url 100 | isArchived 101 | isFork 102 | isMirror 103 | primaryLanguage { 104 | name 105 | } 106 | pushedAt 107 | stargazerCount 108 | object(expression: "HEAD:") { 109 | ... on Tree { 110 | entries { 111 | name 112 | type 113 | } 114 | } 115 | } 116 | } 117 | } 118 | """ 119 | # Example of format of metadata: 120 | # { 121 | # 'data': { 122 | # 'rateLimit': {'cost': , 'remaining': , 'resetAt': }, 123 | # 'repository': {'primaryLanguage': , 'pushedAt': , 'stargazerCount': } 124 | # } 125 | # } 126 | data = metadata 127 | if metadata is None: # Query if metadata is not provided 128 | metadata = get_graphql_data(gql_format % (repo_query[1], repo_query[0])) 129 | data = metadata["data"]["repository"] 130 | if "errors" in metadata: 131 | sys.exit(f"Fetching repo metadata error: {metadata['errors']}") 132 | 133 | # print(f"{repo} metadata: {metadata}") 134 | 135 | # If repo is archived, mirror, or fork, automatic fail 136 | if data["isArchived"]: 137 | print("Repo is archived") 138 | return False 139 | elif data["isFork"]: 140 | print("Repo is fork") 141 | return False 142 | elif data["isMirror"]: 143 | print("Repo is mirror") 144 | return False 145 | 146 | # Check requirements 147 | for i in range(len(requirements)): 148 | if requirements[i] == req_fuzzers and not requirements[i](data): 149 | print(f"Error with req {requirements[i].__name__}") 150 | return False 151 | elif requirements[i] != req_fuzzers and not requirements[i](data, reqs[i]): 152 | print( 153 | f"Error with req {requirements[i].__name__} with requirement {reqs[i]}" 154 | ) 155 | return False 156 | 157 | # Save metadata to file to avoid repeat queries for repos that pass checks 158 | # print("Saving metadata") 159 | for key, value in data.items(): 160 | if not os.path.exists(f"./data/repo_metadata/{key}.json"): 161 | f = open(f"./data/repo_metadata/{key}.json", "x") 162 | f.close() 163 | 164 | with open(f"./data/repo_metadata/{key}.json", "rb") as f: 165 | try: 166 | dic = orjson.loads(f.read()) 167 | except ValueError: 168 | dic = {} 169 | dic[repo] = value 170 | with open(f"./data/repo_metadata/{key}.json", "wb") as f: 171 | f.write(orjson.dumps(dic)) 172 | 173 | return True 174 | 175 | 176 | # Map elements of checks_list to callables 177 | check_map = { 178 | "stars": req_enough_stars, 179 | "latest commit": req_latest_commit, 180 | "language": req_language, 181 | "fuzzers": req_fuzzers, 182 | } 183 | 184 | 185 | # Pass checks_list and reqs with this : --checks_list='' --reqs='' 186 | # Ex. --reqs='["0", "2020-1-1"]'template 187 | # If checking Rust fuzz path, put null in place of where the req should be in the reqs list 188 | def main( 189 | repo_id_list: str = "ethanbwang/test", 190 | checks_list: list[str] = ["stars", "latest commit"], 191 | reqs: list[str] = ["10", "2020-1-1"], # Year format should be -- 192 | ): 193 | # if repo_id_list is a file then load lines 194 | # otherwise it is the id of a specific repo 195 | try: 196 | repo_id_list = [l.strip() for l in open(repo_id_list, "r").readlines()] 197 | except: 198 | repo_id_list = [repo_id_list] 199 | 200 | checks = [check_map[check] for check in checks_list] 201 | 202 | for repo in repo_id_list: 203 | if check_requirements(repo, checks, reqs): 204 | print(f"{repo} meets the requirements\n") 205 | else: 206 | print(f"{repo} does not meet the requirements\n") 207 | 208 | 209 | if __name__ == "__main__": 210 | fire.Fire(main) 211 | 212 | -------------------------------------------------------------------------------- /tests/evaluation/test_eval.py: -------------------------------------------------------------------------------- 1 | """tests for get_coverage module in evaluation/execution.py""" 2 | import os 3 | from evaluation.execution import get_coverage 4 | import unittest 5 | import logging 6 | from unitsyncer.common import UNITSYNCER_HOME 7 | 8 | 9 | class TestEvaluationCoverage(unittest.TestCase): 10 | """tests for evaluation/execution.py""" 11 | 12 | def test_python(self): 13 | focal = 'from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n """ Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n """\n for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n\n return False\n' 14 | test = "\n\nMETADATA = {\n 'author': 'jt',\n 'dataset': 'test'\n}\n\n\ndef check(has_close_elements):\n assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\n assert has_close_elements([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\n assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\n assert has_close_elements([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\n assert has_close_elements([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\n assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\n assert has_close_elements([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\n\ncheck(has_close_elements)" 15 | self.assertEqual(get_coverage(focal, test, "python"), (None, 100, 100)) 16 | 17 | def test_cpp(self): 18 | focal = "/*\nCheck if in given vector of numbers, are any two numbers closer to each other than\ngiven threshold.\n>>> has_close_elements({1.0, 2.0, 3.0}, 0.5)\nfalse\n>>> has_close_elements({1.0, 2.8, 3.0, 4.0, 5.0, 2.0}, 0.3)\ntrue\n*/\n#include\n#include\n#include\nusing namespace std;\nbool has_close_elements(vector numbers, float threshold){\n int i,j;\n \n for (i=0;i Maybe[tuple[str, str | None, str | None]]: 25 | """Extract the source code of a function from a Location LSP response 26 | 27 | Args: 28 | func_location (Location): location of function responded by LS 29 | lang (str): language of the file as in LANGUAGE_IDENTIFIER 30 | 31 | Returns: 32 | Maybe[tuple[str, str | None, str | None]]: source code of function, its docstring, code_id 33 | """ 34 | lineno = func_location.range.start.line 35 | col_offset = func_location.range.start.character # pylint: disable=unused-variable 36 | 37 | def _get_function_code(file_path) -> Maybe[tuple[str, str | None, str | None]]: 38 | try: 39 | with open(file_path, "r", errors="replace") as file: 40 | code = file.read() 41 | except FileNotFoundError: 42 | return Nothing 43 | 44 | match lang: 45 | case LANGUAGE_IDENTIFIER.PYTHON: 46 | node = ast.parse(code, filename=file_path) 47 | return py_get_def(node, lineno).map( 48 | lambda node: (ast.unparse(node), ast.get_docstring(node), None) 49 | ) 50 | case LANGUAGE_IDENTIFIER.JAVA: 51 | ast_util = ASTUtil(replace_tabs(code)) 52 | tree = ast_util.tree(JAVA_LANGUAGE) 53 | return java_get_def(tree.root_node, lineno, ast_util).map( 54 | lambda node: ( 55 | ast_util.get_source_from_node(node), 56 | None, 57 | f"{file_path}::{ast_util.get_method_name(node).unwrap()}", 58 | ) 59 | ) 60 | case LANGUAGE_IDENTIFIER.JAVASCRIPT: 61 | ast_util = ASTUtil(replace_tabs(code)) 62 | tree = ast_util.tree(JAVASCRIPT_LANGUAGE) 63 | 64 | return js_get_def(tree.root_node, lineno, ast_util).map( 65 | lambda node: ( 66 | ast_util.get_source_from_node(node), 67 | None, 68 | # js focal may not be a function, so directly unwrap may fail 69 | f"{file_path}::{ast_util.get_name(node).value_or(None)}", 70 | ) 71 | ) 72 | case LANGUAGE_IDENTIFIER.RUST: 73 | ast_util = ASTUtil(replace_tabs(code)) 74 | tree = ast_util.tree(RUST_LANGUAGE) 75 | 76 | return rust_get_def(tree.root_node, lineno, ast_util).map( 77 | lambda node: ( 78 | ast_util.get_source_from_node(node), 79 | None, 80 | f"{file_path}::{ast_util.get_name(node).value_or(None)}", 81 | ) 82 | ) 83 | case LANGUAGE_IDENTIFIER.GO: 84 | ast_util = ASTUtil(replace_tabs(code)) 85 | tree = ast_util.tree(GO_LANGUAGE) 86 | 87 | return go_get_def(tree.root_node, lineno, ast_util).map( 88 | lambda node: ( 89 | ast_util.get_source_from_node(node), 90 | None, 91 | f"{file_path}::{ast_util.get_name(node).value_or(None)}", 92 | ) 93 | ) 94 | case LANGUAGE_IDENTIFIER.C | LANGUAGE_IDENTIFIER.CPP: 95 | ast_util = ASTUtil(replace_tabs(code)) 96 | tree = ast_util.tree(CPP_LANGUAGE) 97 | 98 | return cpp_get_def(tree.root_node, lineno, ast_util).map( 99 | lambda node: ( 100 | ast_util.get_source_from_node(node), 101 | None, 102 | f"{file_path}::{get_cpp_func_name(ast_util, node).value_or(None)}", 103 | ) 104 | ) 105 | 106 | case _: 107 | return Nothing 108 | 109 | return uri2path(func_location.uri).bind(_get_function_code) 110 | 111 | 112 | def py_get_def(node: ast.AST, lineno: int) -> Maybe[ast.FunctionDef]: 113 | for child in ast.iter_child_nodes(node): 114 | if ( 115 | isinstance(child, ast.FunctionDef) 116 | # AST is 1-indexed, LSP is 0-indexed 117 | and child.lineno == lineno + 1 118 | # AST count from def, LSP count from function name 119 | # and child.col_offset == col_offset - 4 120 | ): 121 | return Some(child) 122 | result = py_get_def(child, lineno) 123 | if result != Nothing: 124 | return result 125 | 126 | return Nothing 127 | 128 | 129 | def java_get_def(node: Node, lineno: int, ast_util: ASTUtil) -> Maybe[Node]: 130 | def in_modifier_range(method_node: Node, lineno: int) -> bool: 131 | n_modifier = ast_util.get_method_modifiers(method_node).map(len).value_or(0) 132 | defn_lineno: int = method_node.start_point[0] 133 | return defn_lineno + n_modifier >= lineno 134 | 135 | for defn in ast_util.get_all_nodes_of_type(node, "method_declaration"): 136 | # tree-sitter AST is 0-indexed 137 | defn_lineno = defn.start_point[0] 138 | if defn_lineno == lineno or in_modifier_range(defn, lineno): 139 | return Some(defn) 140 | return Nothing 141 | 142 | 143 | def js_get_def(node: Node, lineno: int, ast_util: ASTUtil) -> Maybe[Node]: 144 | for defn in ast_util.get_all_nodes_of_type(node, None): 145 | # tree-sitter AST is 0-indexed 146 | defn_lineno = defn.start_point[0] 147 | if defn_lineno == lineno: 148 | return Some(defn) 149 | return Nothing 150 | 151 | 152 | def rust_get_def(node: Node, lineno: int, ast_util: ASTUtil) -> Maybe[Node]: 153 | for defn in ast_util.get_all_nodes_of_type(node, "function_item"): 154 | # tree-sitter AST is 0-indexed 155 | defn_lineno = defn.start_point[0] 156 | if defn_lineno == lineno: 157 | return Some(defn) 158 | return Nothing 159 | 160 | 161 | def go_get_def(node: Node, lineno: int, ast_util: ASTUtil) -> Maybe[Node]: 162 | def find_in(node_type: str): 163 | for defn in ast_util.get_all_nodes_of_type(node, node_type): 164 | # tree-sitter AST is 0-indexed 165 | defn_lineno = defn.start_point[0] 166 | if defn_lineno == lineno: 167 | return defn 168 | return None 169 | 170 | # NOTE: in python 171 | # None or some_value === some_value 172 | # None or None === None 173 | return Maybe.from_optional( 174 | find_in("method_declaration") or find_in("function_declaration") 175 | ) 176 | 177 | 178 | def cpp_get_def(node: Node, lineno: int, ast_util: ASTUtil) -> Maybe[Node]: 179 | for defn in ast_util.get_all_nodes_of_type(node, "function_definition"): 180 | # tree-sitter AST is 0-indexed 181 | defn_lineno = defn.start_point[0] 182 | if defn_lineno == lineno: 183 | return Some(defn) 184 | return Nothing 185 | -------------------------------------------------------------------------------- /frontend/python/collect_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collects all the test functions from projects following 3 | "Conventions for Python test discovery" in 4 | https://docs.pytest.org/en/7.4.x/explanation/goodpractices.html#test-discovery 5 | """ 6 | 7 | import os 8 | import re 9 | import sys 10 | import ast 11 | import fire 12 | import traceback 13 | from tqdm import tqdm 14 | from pathlib import Path 15 | from collections import Counter 16 | from typing import List, Optional 17 | 18 | from frontend.util import run_with_timeout, wrap_repo, mp_map_repos, TimeoutException 19 | from navigate import ModuleNavigator, dump_ast_func 20 | 21 | 22 | def collect_test_files(root: str): 23 | """collect all files in the root folder recursively and filter to match the given patterns""" 24 | patterns = [ 25 | ".*_test\.py", # pylint: disable=anomalous-backslash-in-string 26 | "test_.*\.py", # pylint: disable=anomalous-backslash-in-string 27 | ] 28 | test_files = [] 29 | for parent, _, files in os.walk(root): 30 | for file in files: 31 | if any([re.match(ptn, file) for ptn in patterns]): 32 | test_files.append(os.path.join(parent, file)) 33 | return test_files 34 | 35 | 36 | def collect_test_funcs(module_path: str): 37 | """collect testing functions from the target file""" 38 | nav = ModuleNavigator(module_path) 39 | funcs = nav.find_all(ast.FunctionDef) 40 | # funcs = nav.find_all(lambda x:isinstance(x, (ast.FunctionDef, ast.AsyncFunctionDef))) 41 | 42 | def is_test_cls(node: ast.AST): 43 | """is a test class if 44 | 1.1 class name starts with Test 45 | 1.2 inherit from unittest.TestCase 46 | 2. a static class without a init function 47 | """ 48 | if not isinstance(node, ast.ClassDef): 49 | return False 50 | # if not node.name.startswith('Test'): return False 51 | test_prefix = node.name.startswith("Test") 52 | inherit_unittest_attr = any( 53 | [ 54 | isinstance(base, ast.Attribute) and base.attr == "TestCase" 55 | for base in node.bases 56 | ] 57 | ) 58 | inherit_unittest_name = any( 59 | [ 60 | isinstance(base, ast.Name) and base.id == "TestCase" 61 | for base in node.bases 62 | ] 63 | ) 64 | if not any([test_prefix, inherit_unittest_name, inherit_unittest_attr]): 65 | return False 66 | cls_funcs = nav.find_all(ast.FunctionDef, root=node) 67 | return not any(func.name == "__init__" for func in cls_funcs) 68 | 69 | def has_assert(func: ast.AST): 70 | # builtin assertion 71 | if len(nav.find_all(ast.Assert, root=func)) > 0: 72 | return True 73 | # Testcase in unittest, eg. self.assertEqual 74 | for call in nav.find_all(ast.Call, root=func): 75 | if isinstance(call.func, ast.Attribute) and call.func.attr.startswith( 76 | "assert" 77 | ): 78 | return True 79 | return False 80 | 81 | def is_test_outside_cls(func: ast.AST): 82 | """decide if the function is a testing function outside a class 83 | return true if its name starts with "test" 84 | """ 85 | return func.name.startswith("test") 86 | 87 | def is_test_inside_cls(func: ast.AST, path: List[ast.AST]): 88 | """decide if the function is a testing function inside a class 89 | return true if its class is prefixed by "Test" and either 90 | + it is prefixed by "test" 91 | + it is decorated with @staticmethod and @classmethods 92 | """ 93 | # keep only the node in path whose name is prefixed by "Test" 94 | cls_path = [n for n in path if is_test_cls(n)] 95 | if len(cls_path) == 0: 96 | return False 97 | if func.name.startswith("test"): 98 | return True 99 | decorators = getattr(func, "decorator_list", []) 100 | return any( 101 | isinstance(d, ast.Name) and d.id in ("staticmethod", "classmethods") 102 | for d in decorators 103 | ) 104 | 105 | test_funcs = [] 106 | for func in funcs: 107 | path = nav.get_path_to(func) 108 | is_cls = [isinstance(n, ast.ClassDef) for n in path] 109 | is_test = False 110 | is_test |= any(is_cls) and is_test_inside_cls(func, path) 111 | is_test |= not any(is_cls) and is_test_outside_cls(func) 112 | is_test &= has_assert(func) 113 | if not is_test: 114 | continue 115 | func_id = dump_ast_func(func, module_path, nav, path) 116 | test_funcs.append(func_id) 117 | 118 | return test_funcs 119 | 120 | 121 | @run_with_timeout 122 | def collect_from_repo(repo_id: str, repo_root: str, test_root: str): 123 | """collect all test functions in the given project 124 | return (status, nfile, ntest) 125 | status can be 0: success, 1: repo not found, 2: test not found, 3: skip when output file existed 126 | """ 127 | repo_path = os.path.join(repo_root, wrap_repo(repo_id)) 128 | if not os.path.exists(repo_path) or not os.path.isdir(repo_path): 129 | return 1, 0, 0 130 | test_path = os.path.join(test_root, wrap_repo(repo_id) + ".txt") 131 | # skip if exist 132 | if os.path.exists(test_path): 133 | return 3, 0, 0 134 | # collect potential testing modules 135 | all_files = collect_test_files(repo_path) 136 | test_files, test_funcs = [], [] 137 | for f in all_files: 138 | try: 139 | funcs = collect_test_funcs(f) 140 | except TimeoutException: 141 | raise 142 | except: # pylint: disable=bare-except 143 | funcs = None 144 | if funcs is None or len(funcs) == 0: 145 | continue 146 | test_files.append(f) 147 | test_funcs.extend(funcs) 148 | if len(test_funcs) == 0: 149 | return 2, len(test_files), len(test_funcs) 150 | # save to disk 151 | with open(test_path, "w") as outfile: 152 | for func_id in test_funcs: 153 | parts = func_id.split("::") 154 | parts[0] = str( 155 | Path(os.path.abspath(parts[0])).relative_to(os.path.abspath(repo_root)) 156 | ) 157 | func_id = "::".join(parts) 158 | outfile.write(f"{func_id}\n") 159 | return 0, len(test_files), len(test_funcs) 160 | 161 | 162 | def main( 163 | repo_id: str = "ageitgey/face_recognition", 164 | repo_root: str = "data/repos/", 165 | test_root: str = "data/tests", 166 | timeout: int = 120, 167 | nprocs: int = 0, 168 | limits: int = -1, 169 | ): 170 | # if repo_id_list is a file then load lines 171 | # otherwise it is the id of a specific repo 172 | try: 173 | repo_id_list = [l.strip() for l in open(repo_id, "r").readlines()] 174 | except FileNotFoundError: 175 | repo_id_list = [repo_id] 176 | if limits > 0: 177 | repo_id_list = repo_id_list[:limits] 178 | print(f"Loaded {len(repo_id_list)} repos to be processed") 179 | 180 | status_nfile_ntest = mp_map_repos( 181 | collect_from_repo, 182 | repo_id_list=repo_id_list, 183 | nprocs=nprocs, 184 | repo_root=repo_root, 185 | test_root=test_root, 186 | timeout=timeout, 187 | ) 188 | 189 | filtered_results = [i for i in status_nfile_ntest if i is not None] 190 | if len(filtered_results) < len(status_nfile_ntest): 191 | print(f"{len(status_nfile_ntest) - len(filtered_results)} repos timeout") 192 | status, nfile, ntest = zip(*filtered_results) 193 | status_counter: Counter[int] = Counter(status) 194 | print( 195 | f"Processed {sum(status.values())} repos with {status[3]} skipped, {status[1]} not found, and {status[2]} failed to mine any testing functions" 196 | ) 197 | print(f"Collected {sum(ntest)} tests from {sum(nfile)} files in total") 198 | 199 | 200 | if __name__ == "__main__": 201 | fire.Fire(main) 202 | -------------------------------------------------------------------------------- /frontend/rust/collect_fuzz.py: -------------------------------------------------------------------------------- 1 | """script for rust fuzzing and transforming test_template""" 2 | import logging 3 | from typing import Optional 4 | import fire 5 | import os 6 | from frontend.util import wrap_repo, parallel_subprocess 7 | import subprocess 8 | from os.path import join as pjoin, abspath 9 | from tqdm import tqdm 10 | from unitsyncer.common import CORES 11 | from pathos.multiprocessing import ProcessingPool 12 | 13 | 14 | def transform_repos(repos: list[str], jobs: int): 15 | def transform_one_repo(repo_path: str): 16 | return subprocess.Popen( 17 | ["rust-fuzzer-gen", repo_path], 18 | stdout=subprocess.PIPE, 19 | stderr=subprocess.PIPE, 20 | ) 21 | 22 | logging.info(f"Running rust-fuzz-gen on {len(repos)} repos") 23 | parallel_subprocess(repos, jobs, transform_one_repo, on_exit=None) 24 | 25 | 26 | def get_target_list(p: subprocess.Popen): 27 | match p.stdout: 28 | case None: 29 | return [] 30 | case _: 31 | return p.stdout.read().decode("utf-8").split("\n") 32 | 33 | 34 | def fuzz_one_target(target: tuple[str, str], timeout): 35 | repo_path, target_name = target 36 | with open(pjoin(repo_path, "fuzz_inputs", target_name), "w") as f: 37 | return subprocess.Popen( 38 | # todo: find out why -max_total_time doesn't work 39 | # ["cargo", "fuzz", "run", target_name, "--", f"-max_total_time={timeout}"], 40 | [ 41 | "bash", 42 | "-c", 43 | f"timeout {timeout} cargo fuzz run {target_name}", 44 | ], 45 | cwd=repo_path, 46 | stdout=f, 47 | stderr=subprocess.DEVNULL, 48 | ) 49 | 50 | 51 | def build(repos: list[str], jobs: int): 52 | logging.info(f"Building fuzzing targets in {len(repos)} repos") 53 | _ = parallel_subprocess( 54 | repos, 55 | jobs, 56 | lambda path: subprocess.Popen( 57 | ["cargo", "fuzz", "build"], 58 | cwd=path, 59 | stdout=subprocess.DEVNULL, 60 | stderr=subprocess.DEVNULL, 61 | ), 62 | on_exit=None, 63 | ) 64 | 65 | 66 | def fuzz_repos(repos: list[str], jobs: int, timeout: int = 60): 67 | logging.info("Collecting all fuzz targets") 68 | 69 | target_map = parallel_subprocess( 70 | repos, 71 | jobs, 72 | lambda path: subprocess.Popen( 73 | ["cargo", "fuzz", "list"], cwd=path, stdout=subprocess.PIPE 74 | ), 75 | on_exit=get_target_list, 76 | ) 77 | targets: list[tuple[str, str]] = [ 78 | (k, v) for k, vs in target_map.items() for v in vs if len(v) > 0 79 | ] 80 | for repo in repos: 81 | os.makedirs(pjoin(repo, "fuzz_inputs"), exist_ok=True) 82 | 83 | logging.info(f"Running cargo fuzz on {len(targets)} targets for {timeout} seconds") 84 | parallel_subprocess( 85 | targets, jobs, lambda p: fuzz_one_target(p, timeout), on_exit=None 86 | ) 87 | 88 | 89 | def substitute_input(template: str, input_data: str, idx: int) -> str: 90 | return template.replace( 91 | '[] ; # [doc = "This is a test template"]', f"{input_data} ; " 92 | ).replace("fn test_something ()", f"fn test_{idx} ()") 93 | 94 | 95 | def substitute_one_repo(repo: str, targets: list[str], n_fuzz): 96 | template_dir = pjoin(repo, "tests-gen") 97 | input_dir = pjoin(repo, "fuzz_inputs") 98 | for t in targets: 99 | if t == "": 100 | continue 101 | 102 | # format template before loading 103 | template_path = pjoin(template_dir, t + ".rs") 104 | try: 105 | with open(template_path) as f_template: 106 | template = f_template.read() 107 | with open(pjoin(input_dir, t), "r") as f_input: 108 | inputs = [i for i in f_input.read().splitlines() if i != "[]"][:n_fuzz] 109 | 110 | tests = [ 111 | substitute_input(template, input_data, i) 112 | for i, input_data in enumerate(inputs) 113 | ] 114 | generated_test_path = pjoin(template_dir, f"{t}.inputs.rs") 115 | with open(generated_test_path, "w") as f_template: 116 | f_template.write("\n".join(tests)) 117 | 118 | # format generated tests 119 | subprocess.run(["rustfmt", str(generated_test_path)], check=False) 120 | except FileNotFoundError: 121 | logging.debug(f"Template {template_path} not found") 122 | 123 | 124 | def testgen_repos(repos: list[str], jobs: int, n_fuzz: int = 100): 125 | """Generate tests from fuzz inputs 126 | 127 | Args: 128 | repos (list[str]): list of repo paths 129 | jobs (int): number of parallel jobs to use 130 | n_fuzz (int, optional): number of fuzz data to use. Defaults to 100. 131 | """ 132 | target_map = parallel_subprocess( 133 | repos, 134 | jobs, 135 | lambda path: subprocess.Popen( 136 | ["cargo", "fuzz", "list"], cwd=path, stdout=subprocess.PIPE 137 | ), 138 | on_exit=get_target_list, 139 | use_tqdm=False, 140 | ) 141 | logging.info("Substitute fuzz data to test templates") 142 | with ProcessingPool(jobs) as p: 143 | _ = list( 144 | tqdm( 145 | p.map( 146 | lambda item: substitute_one_repo(item[0], item[1], n_fuzz), 147 | target_map.items(), 148 | ) 149 | ) 150 | ) 151 | 152 | 153 | def main( 154 | repo_id: str = "image-rs/image-png", 155 | repo_root: str = "data/rust_repos/", 156 | timeout: int = 60, 157 | jobs: int = CORES, 158 | limits: Optional[int] = None, 159 | pipeline: str = "transform", 160 | n_fuzz=100, 161 | ): 162 | """collect fuzzing data from rust repos 163 | 164 | Args: 165 | repo_id (str, optional): repo id. Defaults to "marshallpierce/rust-base64". 166 | repo_root (str, optional): directory contains all the repos. Defaults to "data/rust_repos/". 167 | timeout (int, optional): max_total_time to fuzz. Defaults to 60. 168 | jobs (int, optional): number of parallel jobs to use. Defaults to CORES. 169 | limits (Optional[int], optional): number of repos to process, None if use all of them. 170 | pipeline (str, optional): what to do. Defaults to "transform". 171 | n_fuzz (int, optional): number of fuzz data to use. Defaults to 100. 172 | """ 173 | try: 174 | repo_id_list = [ 175 | ll for line in open(repo_id, "r").readlines() if len(ll := line.strip()) > 0 176 | ] 177 | except FileNotFoundError: 178 | repo_id_list = [repo_id] 179 | if limits is not None: 180 | repo_id_list = repo_id_list[:limits] 181 | logging.info(f"Loaded {len(repo_id_list)} repos to be processed") 182 | 183 | logging.info("Collecting all rust repos") 184 | repos = [] 185 | for repo_id in repo_id_list: 186 | repo_path = os.path.join(repo_root, wrap_repo(repo_id)) 187 | if os.path.exists(repo_path) and os.path.isdir(repo_path): 188 | subdirectories = [ 189 | os.path.join(repo_path, d) 190 | for d in os.listdir(repo_path) 191 | if os.path.isdir(os.path.join(repo_path, d)) 192 | ] 193 | repos.append(abspath(subdirectories[0])) 194 | 195 | match pipeline: 196 | case "transform": 197 | transform_repos(repos, jobs) 198 | case "build": 199 | build(repos, jobs) 200 | case "fuzz": 201 | fuzz_repos(repos, jobs, timeout=timeout) 202 | case "testgen": 203 | testgen_repos(repos, jobs, n_fuzz) 204 | case "all": 205 | transform_repos(repos, jobs) 206 | build(repos, jobs) 207 | fuzz_repos(repos, jobs, timeout=timeout) 208 | testgen_repos(repos, jobs, n_fuzz) 209 | case _: 210 | logging.error(f"Unknown pipeline {pipeline}") 211 | 212 | 213 | if __name__ == "__main__": 214 | logging.basicConfig(level=logging.INFO) 215 | fire.Fire(main) 216 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """main script for UniTSyncer backend""" 2 | from tqdm import tqdm 3 | from unitsyncer.sync import Synchronizer, LSPSynchronizer 4 | from unitsyncer.rust_syncer import RustSynchronizer 5 | from unitsyncer.sansio_lsp_syncer import SansioLSPSynchronizer 6 | from pylspclient.lsp_structs import LANGUAGE_IDENTIFIER, Location, Position, Range 7 | from returns.maybe import Maybe, Nothing, Some 8 | from returns.result import Result, Success, Failure 9 | from unitsyncer.util import parallel_starmap as starmap, path2uri, convert_to_seconds 10 | from unitsyncer.common import CORES 11 | from unitsyncer.extract_def import get_def_header 12 | import math 13 | from unitsyncer.source_code import get_function_code 14 | import json 15 | import jsonlines 16 | import os 17 | from pathos.multiprocessing import ProcessPool 18 | import logging 19 | import fire 20 | from itertools import groupby 21 | import pathlib 22 | 23 | 24 | def id2path(func_id: str) -> str: 25 | return func_id.split("::")[0] 26 | 27 | 28 | def java_workdir_dict(objs: list[dict]) -> dict[str, list[dict]]: 29 | """split a list of test ids into a dict of workdir to file path 30 | this solves the LSP TimeoutError for JAVA with too much subdirectories 31 | 32 | Args: 33 | objs (list[dict]): [focal_ids parsed into dict] 34 | 35 | Returns: 36 | dict[str, list[dict]]: {workdir: [corresponding focal objects, ...], ...} 37 | """ 38 | workdir_dict: dict[str, list[dict]] = {} 39 | for obj in objs: 40 | test_id = obj["test_id"] 41 | file_path = id2path(test_id) 42 | workdir = file_path.split("/test")[0] 43 | if workdir not in workdir_dict: 44 | workdir_dict[workdir] = [] 45 | workdir_dict[workdir].append(obj) 46 | return workdir_dict 47 | 48 | 49 | def focal2result(syncer: Synchronizer, repos_root, obj): 50 | p = id2path(obj["test_id"]) 51 | file_path = os.path.join(repos_root, p) 52 | src_lineno, src_col_offset = obj["focal_loc"] 53 | test_lineno, test_col_offset = obj["test_loc"] 54 | 55 | langID = syncer.langID # pylint: disable=invalid-name 56 | 57 | # only python ast is 1-indexed, tree-sitter and LSP are 0-indexed 58 | match langID: 59 | case LANGUAGE_IDENTIFIER.PYTHON: 60 | src_lineno -= 1 61 | test_lineno -= 1 62 | 63 | # since the test's delc node is already capture by frontend, it can store the test code 64 | if "test" in obj.keys(): 65 | test = obj["test"] 66 | else: 67 | fake_loc = Location( 68 | path2uri(file_path), 69 | Range( 70 | Position(test_lineno, test_col_offset), 71 | Position(test_lineno, test_col_offset + 1), 72 | ), 73 | ) 74 | test, _, _ = get_function_code(fake_loc, syncer.langID).unwrap() 75 | 76 | result = { 77 | "test_id": obj["test_id"], 78 | "test": test, 79 | } 80 | 81 | # todo: conform return format when Failure 82 | match syncer.get_source_of_call( 83 | obj["focal_id"], 84 | file_path, 85 | src_lineno, 86 | src_col_offset, 87 | ): 88 | case Success((code, docstring, code_id)): 89 | result["code_id"] = ( 90 | obj["focal_id"] 91 | if code_id is None 92 | else code_id.removeprefix(repos_root + "/") 93 | ) 94 | result["code"] = code 95 | result["docstring"] = docstring 96 | result["test_header"] = get_def_header(test, langID) 97 | case Failure(e): 98 | logging.debug(e) 99 | result["error"] = e 100 | 101 | return result 102 | 103 | 104 | def process_one_focal_file( 105 | focal_file="./data/focal/ageitgey-face_recognition.jsonl", 106 | repos_root="data/repos", 107 | language="python", 108 | skip_processed=True, 109 | ) -> tuple[int, int]: 110 | with open(focal_file) as f: 111 | objs = [json.loads(line) for line in f.readlines()] 112 | 113 | if len(objs) == 0: 114 | return 0, 0 115 | 116 | n_focal = len(objs) 117 | match language: 118 | case LANGUAGE_IDENTIFIER.JAVA: 119 | wd = java_workdir_dict(objs) 120 | case _: 121 | first_test_id = objs[0]["test_id"] 122 | workdir = "/".join(id2path(first_test_id).split("/")[:2]) 123 | wd = { 124 | workdir: objs, 125 | } 126 | 127 | success_results = [] 128 | failure_results = [] 129 | source_file = focal_file.replace("focal", "source") 130 | success_file = source_file.replace(".jsonl", ".success.jsonl") 131 | failure_file = source_file.replace(".jsonl", ".failure.jsonl") 132 | 133 | # check if this file is already processed 134 | if skip_processed and os.path.exists(success_file) and os.path.exists(failure_file): 135 | with open(success_file, "rb") as f: 136 | n_succ = sum(1 for _ in f) 137 | with open(failure_file, "rb") as f: 138 | n_fail = sum(1 for _ in f) 139 | 140 | if language == LANGUAGE_IDENTIFIER.JAVA: 141 | return n_focal, n_succ 142 | if n_succ + n_fail >= n_focal: 143 | return n_focal, n_succ 144 | 145 | pathlib.Path(success_file).touch() 146 | pathlib.Path(failure_file).touch() 147 | 148 | logging.debug(f"number of workdir_dict: {len(wd.keys())}") 149 | repos_root = os.path.abspath(repos_root) 150 | for workdir, workdir_objs in wd.items(): 151 | succ = [] 152 | fail = [] 153 | full_workdir = os.path.join(repos_root, workdir) 154 | logging.debug(f"workdir: {full_workdir}") 155 | syncer: Synchronizer 156 | 157 | match language: 158 | case LANGUAGE_IDENTIFIER.RUST: 159 | syncer = RustSynchronizer(full_workdir, language) 160 | case LANGUAGE_IDENTIFIER.GO: 161 | syncer = SansioLSPSynchronizer(full_workdir, language) 162 | case _: 163 | syncer = LSPSynchronizer(full_workdir, language) 164 | 165 | try: 166 | syncer.initialize(timeout=60) 167 | 168 | for obj in workdir_objs: 169 | result = focal2result(syncer, repos_root, obj) 170 | if "error" in result: 171 | fail.append(result) 172 | else: 173 | succ.append(result) 174 | 175 | syncer.stop() 176 | except Exception as e: # pylint: disable=broad-exception-caught 177 | logging.debug(e) 178 | syncer.stop() 179 | continue 180 | 181 | # append to source file in loop to avoid losing data 182 | with jsonlines.open(success_file, "a") as f: 183 | f.write_all(succ) 184 | with jsonlines.open(failure_file, "a") as f: 185 | f.write_all(fail) 186 | 187 | success_results.extend(succ) 188 | failure_results.extend(fail) 189 | 190 | return n_focal, len(success_results) 191 | 192 | 193 | def main( 194 | repos_root="data/repos", 195 | focal_path="data/focal", 196 | language="python", 197 | jobs=CORES, 198 | debug=False, 199 | timeout="30m", 200 | ): 201 | logging.basicConfig(level=logging.DEBUG if debug else logging.INFO) 202 | all_focal_files = [] 203 | if os.path.isdir(focal_path): 204 | focal_dir = focal_path 205 | for root, _, files in os.walk(os.path.abspath(focal_dir)): 206 | for file in files: 207 | if file.endswith(".jsonl"): 208 | all_focal_files.append(os.path.join(root, file)) 209 | elif os.path.isfile(focal_path): 210 | all_focal_files.append(focal_path) 211 | else: 212 | logging.error(f"{focal_path} is not a valid file or directory") 213 | exit(1) 214 | 215 | logging.info(f"Processing {len(all_focal_files)} focal files") 216 | os.makedirs("./data/source", exist_ok=True) 217 | 218 | # starting jobs / 2 since each job will spawn 2 processes (main and LSP) 219 | with ProcessPool(math.ceil(jobs / 2)) as pool: 220 | rnt = list( 221 | tqdm( 222 | pool.imap( 223 | lambda f: process_one_focal_file( 224 | f, repos_root=repos_root, language=language 225 | ), 226 | all_focal_files, 227 | ), 228 | total=len(all_focal_files), 229 | ) 230 | ) 231 | nfocal, ncode = zip(*rnt) 232 | logging.info( 233 | f"Processed {sum(ncode)} have source code in {sum(nfocal)} focal functions" 234 | ) 235 | 236 | 237 | if __name__ == "__main__": 238 | fire.Fire(main) 239 | -------------------------------------------------------------------------------- /evaluation/execution.py: -------------------------------------------------------------------------------- 1 | """ 2 | coverage evaluation script for LLM generated code-test pairs 3 | """ 4 | 5 | import tempfile 6 | import os 7 | import subprocess 8 | import json 9 | import csv 10 | import fire 11 | from tqdm import tqdm 12 | 13 | 14 | def get_ext(lang: str) -> str: 15 | ext: str 16 | if lang == "python": 17 | ext = ".py" 18 | elif lang == "java": 19 | ext = ".java" 20 | elif lang == "cpp": 21 | ext = ".cpp" 22 | elif lang == "js": 23 | ext = ".js" 24 | elif lang == "go": 25 | ext = ".go" 26 | else: 27 | ext = "" 28 | 29 | return ext 30 | 31 | 32 | def run_command_in(cwd: str): 33 | """Create a helper function to run shell command in `cwd` directory 34 | 35 | Args: 36 | cwd (str): path to a directory 37 | """ 38 | 39 | def subprocess_caller( 40 | command: str, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL 41 | ) -> subprocess.CompletedProcess: 42 | cmd_list = command.split(" ") 43 | return subprocess.run( 44 | cmd_list, cwd=cwd, stdout=stdout, stderr=stderr, check=True, text=True 45 | ) 46 | 47 | return subprocess_caller 48 | 49 | 50 | BranchCov = float | None 51 | StatCov = float | None 52 | LineCov = float | None 53 | 54 | 55 | def get_coverage( 56 | code: str, 57 | test: str, 58 | lang: str = "python", 59 | java_lib_path: str = os.path.join(os.getcwd(), "lib"), 60 | ) -> tuple[StatCov, LineCov, BranchCov] | None: 61 | """compute branch coverage of `test` on `code` 62 | 63 | Args: 64 | code (str): source code of focal function to be tested 65 | test (str): test function code 66 | lang (str, optional): language used. Defaults to "python". 67 | 68 | Returns: 69 | tuple[StatCov, LineCov, BranchCov] | None: tuple of coverage rates, 70 | None if compile failed 71 | """ 72 | line_cov: float | None = None 73 | branch_cov: float | None = None 74 | stat_cov: float | None = None 75 | 76 | lang = lang.lower() 77 | java_lib_path = os.path.abspath(java_lib_path) 78 | 79 | tmp_dir = tempfile.TemporaryDirectory() 80 | tmp_dir_path = tmp_dir.name 81 | run_cmd = run_command_in(tmp_dir_path) 82 | ext = get_ext(lang) 83 | 84 | focal_file_name = "focal" + ext 85 | test_file_name = "test" + ext 86 | test_file = os.path.join(tmp_dir_path, test_file_name) 87 | 88 | focal_file = os.path.join(tmp_dir_path, focal_file_name) 89 | with open(focal_file, "w") as f: 90 | f.write(code) 91 | 92 | if lang == "python": 93 | with open(test_file, "w") as fp: 94 | fp.write("from focal import *\n") 95 | fp.write(test) 96 | run_cmd("coverage run --branch test.py") 97 | run_cmd("coverage json") 98 | with open(os.path.join(tmp_dir_path, "coverage.json")) as cov_fp: 99 | j = json.load(cov_fp) 100 | try: 101 | exec_result = j["files"][focal_file_name]["summary"] 102 | except KeyError: 103 | return None 104 | 105 | covered_lines = exec_result["covered_lines"] 106 | num_statements = exec_result["num_statements"] 107 | line_cov = ( 108 | 100 * (covered_lines / num_statements) if num_statements != 0 else 100.0 109 | ) 110 | num_branches = exec_result["num_branches"] 111 | covered_branches = exec_result["covered_branches"] 112 | branch_cov = ( 113 | 100 * (covered_branches / num_branches) if num_branches != 0 else 100.0 114 | ) 115 | 116 | elif lang == "cpp": 117 | with open(test_file, "w") as fp: 118 | fp.write('#include "focal.cpp"\n') 119 | fp.write(test) 120 | compile_result = run_cmd( 121 | "clang++ -fprofile-instr-generate -fcoverage-mapping test.cpp -o test" 122 | ) 123 | if compile_result.returncode != 0: 124 | return None 125 | 126 | run_cmd("./test") 127 | run_cmd("llvm-profdata merge -sparse default.profraw -o test.profdata") 128 | llvm_cov_proc = run_cmd( 129 | "llvm-cov export ./test -instr-profile=test.profdata --format=text", 130 | stdout=subprocess.PIPE, 131 | ) 132 | j = json.loads(llvm_cov_proc.stdout) 133 | try: 134 | for d in j["data"]: 135 | for f in d["files"]: 136 | if f["filename"] == os.path.abspath(focal_file): # type: ignore 137 | branch_cnt = f["summary"]["branches"]["count"] # type: ignore 138 | percentage = f["summary"]["branches"]["percent"] # type: ignore 139 | branch_cov = 100.0 if branch_cnt == 0 else percentage 140 | 141 | lines_cnt = f["summary"]["lines"]["count"] # type: ignore 142 | percentage = f["summary"]["lines"]["percent"] # type: ignore 143 | line_cov = 100.0 if lines_cnt == 0 else percentage 144 | except KeyError: 145 | return None 146 | elif lang == "java": 147 | main_file = "Main.java" 148 | main_file_path = os.path.join(tmp_dir_path, main_file) 149 | with open(main_file_path, "w") as f: 150 | f.write(code) 151 | f.write(test) 152 | run_cmd("javac Main.java -d bin/") 153 | run_cmd( 154 | f"java -javaagent:{java_lib_path}/jacocoagent.jar=destfile=jacoco.exec -cp bin Main" 155 | ) 156 | run_cmd( 157 | f"java -jar {java_lib_path}/jacococli.jar report jacoco.exec" 158 | " --classfiles bin --sourcefiles Main.java --csv coverage.csv" 159 | ) 160 | coverage_file = os.path.join(tmp_dir_path, "coverage.csv") 161 | with open(coverage_file, newline="") as csvfile: 162 | reader = csv.DictReader(csvfile) 163 | for row in reader: 164 | if row["CLASS"] == "Solution": 165 | branch_covered = int(row["BRANCH_COVERED"]) 166 | branch_missed = int(row["BRANCH_MISSED"]) 167 | branch_num = branch_covered + branch_missed 168 | branch_cov = 100.0 * ( 169 | branch_covered / branch_num if branch_num != 0 else 1 170 | ) 171 | 172 | line_covered = int(row["LINE_COVERED"]) 173 | line_missed = int(row["LINE_MISSED"]) 174 | line_num = branch_covered + line_missed 175 | line_cov = 100.0 * (line_covered / line_num if line_num != 0 else 1) 176 | break 177 | elif lang == "js": 178 | with open(focal_file, "a") as f: 179 | f.write(test) 180 | run_cmd("nyc --reporter=json-summary node focal.js") 181 | coverage_file = os.path.join(tmp_dir_path, "coverage", "coverage-summary.json") 182 | with open(coverage_file) as cov_fp: 183 | j = json.load(cov_fp) 184 | try: 185 | branch_cov = j[focal_file]["branches"]["pct"] 186 | line_cov = j[focal_file]["lines"]["pct"] 187 | stat_cov = j[focal_file]["statements"]["pct"] 188 | except KeyError: 189 | return None 190 | 191 | elif lang == "go": 192 | mod_file = os.path.join(tmp_dir_path, "go.mod") 193 | with open(mod_file, "w") as mod_fp: 194 | mod_fp.write("module go_cov\ngo 1.16\n") 195 | 196 | test_file_name = "focal_test.go" 197 | test_file = os.path.join(tmp_dir_path, test_file_name) 198 | with open(test_file, "w") as test_fp: 199 | test_fp.write("package main\n") 200 | test_fp.write('import "testing"\n') 201 | test_fp.write('import "github.com/stretchr/testify/assert"\n') 202 | test_fp.write(test) 203 | with open(focal_file, "w") as focal_fp: 204 | focal_fp.write("package main\n") 205 | focal_fp.write(code) 206 | 207 | run_cmd("go get github.com/stretchr/testify/assert") 208 | run_cmd("go test -coverprofile=coverage.out") 209 | cov_result: str = run_cmd( 210 | "go tool cover -func=coverage.out", stdout=subprocess.PIPE 211 | ).stdout 212 | try: 213 | line = cov_result.splitlines()[0] 214 | elems = line.split("\t") 215 | stat_cov = float(elems[-1][:-1]) # str 100.0% -> float 100.0 216 | except IndexError: 217 | return None 218 | 219 | else: 220 | return None 221 | 222 | tmp_dir.cleanup() 223 | return stat_cov, line_cov, branch_cov 224 | 225 | 226 | def main( 227 | jsonl_path: str, 228 | change_go_proxy: bool = False, 229 | java_lib_path: str = os.path.join(os.getcwd(), "lib"), 230 | ): 231 | if change_go_proxy: 232 | subprocess.run(["go", "env", "-w", "GOPROXY=https://goproxy.cn"]) 233 | 234 | with open(jsonl_path, "r") as fp: 235 | j_lines = fp.readlines() 236 | 237 | with open(f"results_{jsonl_path}", "w") as fp: 238 | for j_line in tqdm(j_lines): 239 | j = json.loads(j_line) 240 | focal = j["focal"] 241 | test = j["test"] 242 | lang = j["lang"] 243 | cov = get_coverage(focal, test, lang=lang, java_lib_path=java_lib_path) 244 | j["coverage"] = cov 245 | fp.write(json.dumps(j) + "\n") 246 | 247 | 248 | if __name__ == "__main__": 249 | fire.Fire(main) 250 | --------------------------------------------------------------------------------