├── .gitattributes ├── .gitignore ├── MANIFEST.in ├── README.md ├── binaryai_bindiffmatch ├── __init__.py ├── __main__.py ├── bindiffmatch.py ├── main.py ├── metricsutils.py ├── models.py ├── py.typed ├── similarity_matrix │ ├── __init__.py │ ├── basic.py │ └── lowmem.py └── utils.py ├── pyproject.toml └── scripts ├── diaphora-3.0-b91a9e7abe03de45bf47d4619eda7f8b3f0357bb.patch └── metrics.py /.gitattributes: -------------------------------------------------------------------------------- 1 | data/** filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.pyc 3 | __pycache__ 4 | 5 | /dist 6 | /build 7 | *.egg-info 8 | 9 | *.idb 10 | *.i64 11 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune data -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BinaryAI BindiffMatch algorithm 2 | 3 | This repo contains [BinaryAI]( https://www.binaryai.cn/ ) file comparison algorithm implementation, along with datasets and metric scripts. 4 | 5 | ## Project 6 | 7 | The `binaryai_bindiffmatch` directory is BinaryAI BindiffMatch algorithm, not including BAI-2.0 model and embedding implementation. 8 | 9 | The `data` directory contains metric datasets. (You can download it from [release assets]( https://github.com/binaryai/bindiffmatch/releases ) ) 10 | 11 | `data/files` contains unstripped files and stripped files. 12 | We use binaries from `coreutils`, `diffutils` and `findutils` libraries as testcases. These binaries are experiment data in [DeepBinDiff]( https://github.com/yueduan/DeepBinDiff/tree/master/experiment_data ) project, go to origin project to get these binaries. 13 | We manually build some versions of `openssl` project and choose two files as example case. Here are the sources [openssl-1.1.1u]( https://www.openssl.org/source/openssl-1.1.1u.tar.gz ) [openssl-3.1.1]( https://www.openssl.org/source/openssl-3.1.1.tar.gz ) 14 | 15 | `data/labeleds` contains pre-generated infos of functions in each binary file. The basicinfo, pseudocode, callees, name are powered by [Ghidra]( https://github.com/NationalSecurityAgency/ghidra ), and feature embedding vectors are powered by BinaryAI BAI-2.0 model. Scripts to generate these file are not included in this project. 16 | 17 | `data/matchresults` contains pre-generated match results on testcases and example, powered by BinaryAI BindiffMatch algorithm and [Diaphora]( https://github.com/joxeankoret/diaphora/tree/3.0 ), as well as the groundtruth results. 18 | BinaryAI BindiffMatch results can be generated by `python -m binaryai_bindiffmatch -o ` on each pair of files. 19 | Diaphora results are generated by first applying [patch]( scripts/diaphora-3.0-b91a9e7abe03de45bf47d4619eda7f8b3f0357bb.patch ) on this [commit]( https://github.com/joxeankoret/diaphora/tree/3.0 ), then using IDA headless mode to export `.sqlite` database. After then, run offline Diaphora script to generate `.diaphora` results (with `relaxed_ratio` set to True, other options keep default), and finally convert to json as same format as BinaryAI results. Scripts for doing these are not included in this project. 20 | 21 | ## Install 22 | 23 | Require Python >= 3.10 24 | Run `pip install .[lowmem]` to install this package and its dependencies 25 | 26 | ## Metric 27 | 28 | `python scripts/metrics.py testcases binaryai`: get metric result on full testcases powered by BinaryAI BindiffMatch algorithm 29 | `python scripts/metrics.py testcases diaphora`: get metric result on full testcases powered by Diaphora 30 | `python scripts/metrics.py example binaryai`: get metric result on [example]( https://www.binaryai.cn/compare/eyJzaGEyNTYiOiJiNDQzYjRjMmNiMzlkYWNmMTkwNzA3NTI1NGE3MWJkYTg1ZjU2OTczNDk3YjgxNmUyZWRjNTNlZGQ2OTE4MTllIiwidGFyZ2V0Ijp7ImJpbmRpZmYiOnsic2hhMjU2IjoiZTMwZWRjOGQ2YjYyN2U5YmRjMTRmNWQyMTViNzZiYTUxYzFjMTNhODZjOWNjYzEzYzY1YmEyNGIzZTdmODRiMCJ9fX0= ) case powered by BinaryAI BindiffMatch algorithm 31 | `python scripts/metrics.py example diaphora`: get metric result on [example]( https://www.binaryai.cn/compare/eyJzaGEyNTYiOiJiNDQzYjRjMmNiMzlkYWNmMTkwNzA3NTI1NGE3MWJkYTg1ZjU2OTczNDk3YjgxNmUyZWRjNTNlZGQ2OTE4MTllIiwidGFyZ2V0Ijp7ImJpbmRpZmYiOnsic2hhMjU2IjoiZTMwZWRjOGQ2YjYyN2U5YmRjMTRmNWQyMTViNzZiYTUxYzFjMTNhODZjOWNjYzEzYzY1YmEyNGIzZTdmODRiMCJ9fX0= ) case powered by Diaphora 32 | -------------------------------------------------------------------------------- /binaryai_bindiffmatch/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import BindiffMatchConfig, bindiffmatch, defaultconfig 2 | 3 | __all__ = [ 4 | "bindiffmatch", 5 | "BindiffMatchConfig", 6 | "defaultconfig", 7 | ] 8 | -------------------------------------------------------------------------------- /binaryai_bindiffmatch/__main__.py: -------------------------------------------------------------------------------- 1 | from .main import main 2 | 3 | if __name__ == "__main__": 4 | main() 5 | -------------------------------------------------------------------------------- /binaryai_bindiffmatch/bindiffmatch.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections.abc import Iterable 3 | from typing import Any 4 | 5 | import networkx as nx # type: ignore[import] 6 | import numpy as np 7 | 8 | from .models import AlgorithmOutputMatchPair, BinaryFile, MatchResult 9 | from .utils import build_matchresult_with_matchpairs, is_valid_function 10 | 11 | try: 12 | from .similarity_matrix.lowmem import ( 13 | WrapMatrix, 14 | build_similarity_score_matrix, 15 | create_numpy_array, 16 | wrap_linear_sum_assignment, 17 | ) 18 | except ImportError: 19 | from .similarity_matrix.basic import ( # type: ignore 20 | WrapMatrix, 21 | build_similarity_score_matrix, 22 | create_numpy_array, 23 | wrap_linear_sum_assignment, 24 | ) 25 | 26 | 27 | def build_callgraph(doc: BinaryFile, *, with_attributes: bool = True) -> nx.DiGraph: 28 | assert doc.functions is not None 29 | g = nx.DiGraph() 30 | func_count = len(doc.functions) 31 | g.graph["func_count"] = func_count 32 | for i, func in enumerate(doc.functions): 33 | assert func.embedding is not None 34 | if with_attributes: 35 | g.add_node(func.addr, **{"id": i}) 36 | else: 37 | g.add_node(func.addr) 38 | for func in doc.functions: 39 | addr = func.addr 40 | if func.callees: 41 | for callee in func.callees: 42 | assert g.has_node(addr), addr 43 | assert g.has_node(callee), callee 44 | g.add_edge(addr, callee) 45 | assert g.number_of_nodes() == func_count 46 | return g 47 | 48 | 49 | def filter_mwm_match_results( 50 | doc1: BinaryFile, 51 | doc2: BinaryFile, 52 | similarity_score_matrix: WrapMatrix, 53 | row_ind: Iterable[int], 54 | col_ind: Iterable[int], 55 | *, 56 | threshold: float, 57 | ) -> list[AlgorithmOutputMatchPair]: 58 | assert doc1.functions is not None 59 | assert doc2.functions is not None 60 | matched = [] 61 | for id1, id2 in zip(row_ind, col_ind): 62 | score = float(similarity_score_matrix.get_value(id1, id2)) 63 | if score >= threshold: 64 | func1 = doc1.functions[id1] 65 | func2 = doc2.functions[id2] 66 | addr1 = func1.addr 67 | addr2 = func2.addr 68 | if is_valid_function(func1) and is_valid_function(func2): # XXX 69 | matched.append( 70 | AlgorithmOutputMatchPair( 71 | addr_1=addr1, 72 | addr_2=addr2, 73 | score=round(score, 2), 74 | ) 75 | ) 76 | return matched 77 | 78 | 79 | def do_mwm_on_full_matrix( 80 | doc1: BinaryFile, doc2: BinaryFile, similarity_score_matrix: WrapMatrix, *, threshold: float 81 | ) -> list[AlgorithmOutputMatchPair]: 82 | assert doc1.functions is not None 83 | assert doc2.functions is not None 84 | row_ind, col_ind = wrap_linear_sum_assignment(similarity_score_matrix, True) 85 | matched = filter_mwm_match_results(doc1, doc2, similarity_score_matrix, row_ind, col_ind, threshold=threshold) 86 | return matched 87 | 88 | 89 | def do_mwm_on_sub_rows_cols( 90 | doc1: BinaryFile, 91 | doc2: BinaryFile, 92 | similarity_score_matrix: WrapMatrix, 93 | subrowindexes: Iterable[int], 94 | subcolindexes: Iterable[int], 95 | *, 96 | threshold: float, 97 | ) -> list[AlgorithmOutputMatchPair]: 98 | assert doc1.functions is not None 99 | assert doc2.functions is not None 100 | 101 | subrowindexes = subrowindexes if isinstance(subrowindexes, list) else list(subrowindexes) 102 | subcolindexes = subcolindexes if isinstance(subcolindexes, list) else list(subcolindexes) 103 | if not subrowindexes or not subcolindexes: 104 | return [] 105 | row_ind, col_ind = wrap_linear_sum_assignment(similarity_score_matrix, True, subrowindexes, subcolindexes) 106 | 107 | matched = filter_mwm_match_results(doc1, doc2, similarity_score_matrix, row_ind, col_ind, threshold=threshold) 108 | return matched 109 | 110 | 111 | def do_mwm_on_sub_pairs( 112 | doc1: BinaryFile, 113 | doc2: BinaryFile, 114 | similarity_score_matrix: WrapMatrix, 115 | subnodeindexpairs: Iterable[tuple[int, int]], # if duplicated, only process once 116 | *, 117 | threshold: float, 118 | ) -> list[AlgorithmOutputMatchPair]: 119 | assert doc1.functions is not None 120 | assert doc2.functions is not None 121 | 122 | subnodeindexpairs = subnodeindexpairs if isinstance(subnodeindexpairs, list) else list(subnodeindexpairs) 123 | 124 | subrowindexes = list(set(id1 for id1, _ in subnodeindexpairs)) 125 | subcolindexes = list(set(id2 for _, id2 in subnodeindexpairs)) 126 | subrowindex_map = dict(zip(subrowindexes, itertools.count())) 127 | subcolindex_map = dict(zip(subcolindexes, itertools.count())) 128 | 129 | sub_similarity_score_matrix = create_numpy_array(-1, np.float32, (len(subrowindexes), len(subcolindexes))) 130 | for id1, id2 in subnodeindexpairs: 131 | id1_index, id2_index = subrowindex_map[id1], subcolindex_map[id2] 132 | sub_similarity_score_matrix[id1_index, id2_index] = similarity_score_matrix.get_value(id1, id2) 133 | 134 | row_index_ind, col_index_ind = wrap_linear_sum_assignment(WrapMatrix(sub_similarity_score_matrix), True) 135 | row_ind = (subrowindexes[i] for i in row_index_ind) 136 | col_ind = (subcolindexes[j] for j in col_index_ind) 137 | 138 | matched = filter_mwm_match_results(doc1, doc2, similarity_score_matrix, row_ind, col_ind, threshold=threshold) 139 | return matched 140 | 141 | 142 | def get_callrelation_neighbors(g: nx.DiGraph, root_node: Any, *, hop: int) -> tuple[Iterable[Any], Iterable[Any]]: 143 | def get_callrelation_neighbors_internal( 144 | g: nx.DiGraph, root_node: Any, *, hop: int, reverse: bool = False 145 | ) -> Iterable[Any]: 146 | nodes_set = set() 147 | seen_nodes = {root_node} 148 | layer_nodes = {root_node} 149 | for _ in range(hop): 150 | new_layer_nodes = set() 151 | for layer_node in layer_nodes: 152 | for e in g.in_edges(layer_node) if reverse else g.out_edges(layer_node): 153 | next_layer_node = e[0] if reverse else e[1] 154 | if next_layer_node not in seen_nodes: 155 | seen_nodes.add(next_layer_node) 156 | new_layer_nodes.add(next_layer_node) 157 | nodes_set.update(new_layer_nodes) 158 | layer_nodes = new_layer_nodes 159 | return nodes_set 160 | 161 | outedge_nodes_set = get_callrelation_neighbors_internal(g, root_node, hop=hop, reverse=False) 162 | inedge_nodes_set = get_callrelation_neighbors_internal(g, root_node, hop=hop, reverse=True) 163 | 164 | return outedge_nodes_set, inedge_nodes_set 165 | 166 | 167 | def do_spread( 168 | doc1: BinaryFile, 169 | doc2: BinaryFile, 170 | similarity_score_matrix: WrapMatrix, 171 | initial_matchpairs: list[AlgorithmOutputMatchPair], 172 | *, 173 | hop: int, 174 | threshold: float, 175 | ) -> list[AlgorithmOutputMatchPair]: 176 | assert doc1.functions is not None 177 | assert doc2.functions is not None 178 | 179 | g1 = build_callgraph(doc1, with_attributes=False) 180 | g2 = build_callgraph(doc2, with_attributes=False) 181 | addr_to_id_map1 = {f.addr: i for i, f in enumerate(doc1.functions)} 182 | addr_to_id_map2 = {f.addr: i for i, f in enumerate(doc2.functions)} 183 | 184 | processed_addrs_1 = set() 185 | processed_addrs_2 = set() 186 | for matchpair in initial_matchpairs: 187 | processed_addrs_1.add(matchpair.addr_1) 188 | processed_addrs_2.add(matchpair.addr_2) 189 | 190 | spread_matchpairs = list[AlgorithmOutputMatchPair]() 191 | 192 | turn = 0 193 | while True: 194 | turn += 1 195 | 196 | subnodepairs = [] 197 | for matchpair in initial_matchpairs: 198 | neighbors_1_out, neighbors_1_in = get_callrelation_neighbors(g1, matchpair.addr_1, hop=hop) 199 | neighbors_2_out, neighbors_2_in = get_callrelation_neighbors(g2, matchpair.addr_2, hop=hop) 200 | new_neighbors_1_out = [addr1 for addr1 in neighbors_1_out if addr1 not in processed_addrs_1] 201 | new_neighbors_1_in = [addr1 for addr1 in neighbors_1_in if addr1 not in processed_addrs_1] 202 | new_neighbors_2_out = [addr2 for addr2 in neighbors_2_out if addr2 not in processed_addrs_2] 203 | new_neighbors_2_in = [addr2 for addr2 in neighbors_2_in if addr2 not in processed_addrs_2] 204 | for addr1, addr2 in itertools.chain( 205 | itertools.product(new_neighbors_1_out, new_neighbors_2_out), 206 | itertools.product(new_neighbors_1_in, new_neighbors_2_in), 207 | ): 208 | id1 = addr_to_id_map1[addr1] 209 | id2 = addr_to_id_map2[addr2] 210 | subnodepairs.append((id1, id2)) 211 | 212 | processed_addrs_1.add(addr1) 213 | processed_addrs_2.add(addr2) 214 | 215 | if not subnodepairs: 216 | break 217 | 218 | new_matchpairs = do_mwm_on_sub_pairs(doc1, doc2, similarity_score_matrix, subnodepairs, threshold=threshold) 219 | 220 | spread_matchpairs.extend(new_matchpairs) 221 | initial_matchpairs = new_matchpairs 222 | 223 | return spread_matchpairs 224 | 225 | 226 | def get_remained_nodepair_indexes( 227 | doc1: BinaryFile, doc2: BinaryFile, matched_pairs: Iterable[AlgorithmOutputMatchPair] 228 | ) -> tuple[list[int], list[int]]: 229 | assert doc1.functions is not None 230 | assert doc2.functions is not None 231 | addr1_inpair = set() 232 | addr2_inpair = set() 233 | for m in matched_pairs: 234 | addr1_inpair.add(m.addr_1) 235 | addr2_inpair.add(m.addr_2) 236 | index1_notinpair = [i for i, f in enumerate(doc1.functions) if f.addr not in addr1_inpair] 237 | index2_notinpair = [i for i, f in enumerate(doc2.functions) if f.addr not in addr2_inpair] 238 | return index1_notinpair, index2_notinpair 239 | 240 | 241 | def do_bindiffmatch( 242 | doc1: BinaryFile, 243 | doc2: BinaryFile, 244 | *, 245 | threshold_high: float, 246 | threshold_low: float, 247 | threshold_remain: float, 248 | hop: int, 249 | ) -> MatchResult: 250 | # prepare 251 | similarity_score_matrix = build_similarity_score_matrix(doc1, doc2) 252 | 253 | # stage 1 254 | initial_matchpairs = do_mwm_on_full_matrix(doc1, doc2, similarity_score_matrix, threshold=threshold_high) 255 | 256 | # stage 2 257 | spread_matchpairs = do_spread( 258 | doc1, doc2, similarity_score_matrix, initial_matchpairs, hop=hop, threshold=threshold_low 259 | ) 260 | 261 | # stage 3 262 | remained_nodes1, remained_nodes2 = get_remained_nodepair_indexes( 263 | doc1, doc2, itertools.chain(initial_matchpairs, spread_matchpairs) 264 | ) 265 | remained_matchpairs = do_mwm_on_sub_rows_cols( 266 | doc1, doc2, similarity_score_matrix, remained_nodes1, remained_nodes2, threshold=threshold_remain 267 | ) 268 | 269 | # merge 270 | matchresult = build_matchresult_with_matchpairs( 271 | doc1, doc2, itertools.chain(initial_matchpairs, spread_matchpairs, remained_matchpairs) 272 | ) 273 | return matchresult 274 | -------------------------------------------------------------------------------- /binaryai_bindiffmatch/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | 4 | from .bindiffmatch import do_bindiffmatch 5 | from .models import BinaryFile, MatchResult 6 | from .utils import dump_matchresult, load_doc 7 | 8 | 9 | @dataclasses.dataclass 10 | class BindiffMatchConfig: 11 | threshold_high: float 12 | threshold_low: float 13 | threshold_remain: float 14 | hop: int 15 | 16 | 17 | defaultconfig = BindiffMatchConfig( 18 | threshold_high=0.74, 19 | threshold_low=0.4, 20 | threshold_remain=0.66, 21 | hop=1, 22 | ) 23 | 24 | 25 | def bindiffmatch(doc1: BinaryFile, doc2: BinaryFile, *, config: BindiffMatchConfig) -> MatchResult: 26 | return do_bindiffmatch(doc1, doc2, **dataclasses.asdict(config)) 27 | 28 | 29 | def parse_args() -> argparse.Namespace: 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("first") 32 | parser.add_argument("second") 33 | parser.add_argument("-o", "--output", required=True) 34 | return parser.parse_args() 35 | 36 | 37 | def main() -> None: 38 | args = parse_args() 39 | first = args.first 40 | second = args.second 41 | output = args.output 42 | doc1 = load_doc(first) 43 | doc2 = load_doc(second) 44 | matchresult = do_bindiffmatch(doc1, doc2, **dataclasses.asdict(defaultconfig)) 45 | dump_matchresult(matchresult, output) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /binaryai_bindiffmatch/metricsutils.py: -------------------------------------------------------------------------------- 1 | # import functools 2 | # import itertools 3 | import os 4 | from collections.abc import Generator 5 | 6 | """ 7 | stripped raw binaries: (filename ends with .strip) 8 | {data}/files/{library}/stripped_binaries/{library}-{version}-{optimazation}/{filename} 9 | doc with embedding and funcname label: (functions powered by Ghidra, embeddings powered by BinaryAI BAI-2.0 model) 10 | {data}/labeleds/{library}/stripped_binaries/{library}-{version}-{optimazation}/{filename}.json 11 | diaphora exported database: 12 | {data}/diaphora_sqlites/{library}/striiped_binaries/{library}-{version}-{optimazation}/{filename}.sqlite 13 | 14 | groundtruth matchresult: 15 | {data}/matchresults/groundtruth/{library}-{version}-{optimazation}__vs__{library}-{version}-{optimazation}/{filename}.json 16 | diaphora matchresult: 17 | {data}/matchresults/diaphora/{library}-{version}-{optimazation}__vs__{library}-{version}-{optimazation}/{filename}.json 18 | binaryai matchresult: 19 | {data}/matchresults/binaryai/{library}-{version}-{optimazation}__vs__{library}-{version}-{optimazation}/{filename}.json 20 | """ 21 | 22 | libraries = ["coreutils", "diffutils", "findutils"] 23 | versions = { 24 | "coreutils": ["5.93", "6.4", "7.6", "8.1", "8.30"], 25 | "diffutils": ["2.8", "3.1", "3.4", "3.6"], 26 | "findutils": ["4.233", "4.41", "4.6"], 27 | "openssl": ["1.1.1u", "3.1.1"], 28 | } 29 | optimazations = { 30 | "coreutils": ["O0", "O1", "O2", "O3"], 31 | "diffutils": ["O0", "O1", "O2", "O3"], 32 | "findutils": ["O0", "O1", "O2", "O3"], 33 | "openssl": ["gcc_x64_O3", "gcc_arm_O0"], 34 | } 35 | 36 | 37 | def build_testcase_cross_version_pairs_on_library( 38 | library: str, 39 | ) -> Generator[tuple[tuple[str, str, str], tuple[str, str, str]], None, None]: 40 | optimazation = "O1" 41 | to_compare_version = versions[library][-1] 42 | for version in versions[library][:-1]: 43 | assert version != to_compare_version, (version, to_compare_version) 44 | yield ((library, version, optimazation), (library, to_compare_version, optimazation)) 45 | 46 | 47 | def build_testcase_cross_optimization_pairs_on_library( 48 | library: str, 49 | ) -> Generator[tuple[tuple[str, str, str], tuple[str, str, str]], None, None]: 50 | for version in versions[library]: 51 | to_compare_optimazation = "O3" 52 | for optimazation in optimazations[library][:-1]: 53 | assert optimazation != to_compare_optimazation, (optimazation, to_compare_optimazation) 54 | yield ((library, version, optimazation), (library, version, to_compare_optimazation)) 55 | 56 | 57 | def get_stripped_filenames(datadir: str, library: str) -> list[str]: 58 | # assert library in libraries 59 | labeled_doc_relpath = get_labeled_doc_relpath( 60 | library, versions[library][-1], optimazations[library][-1], None 61 | ) 62 | result = [] 63 | for filename in os.listdir(os.path.join(datadir, labeled_doc_relpath)): 64 | assert filename.endswith(".strip.json") 65 | filename = filename.removesuffix(".json") 66 | result.append(filename) 67 | return result 68 | 69 | 70 | def get_stripped_binary_relpath(library: str, version: str, optimazation: str, filename: str | None) -> str: 71 | return f"files/{library}/stripped_binaries/{library}-{version}-{optimazation}/{filename if filename else ''}" 72 | 73 | 74 | def get_labeled_doc_relpath(library: str, version: str, optimazation: str, filename: str | None) -> str: 75 | return f"labeleds/{library}/stripped_binaries/{library}-{version}-{optimazation}/{filename+'.json' if filename else ''}" # noqa: E501 76 | 77 | 78 | def get_groundtruth_matchresult_relpath( 79 | library: str, version1: str, optimazation1: str, version2: str, optimazation2: str, filename: str | None 80 | ) -> str: 81 | return f"matchresults/groundtruth/{library}/{library}-{version1}-{optimazation1}__vs__{library}-{version2}-{optimazation2}/{filename+'.json' if filename else ''}" # noqa: E501 82 | 83 | 84 | def get_diaphora_matchresult_relpath( 85 | library: str, version1: str, optimazation1: str, version2: str, optimazation2: str, filename: str | None 86 | ) -> str: # noqa: E501 87 | return f"matchresults/diaphora/{library}/{library}-{version1}-{optimazation1}__vs__{library}-{version2}-{optimazation2}/{filename+'.json' if filename else ''}" # noqa: E501 88 | 89 | 90 | def get_algorithm_matchresult_relpath( 91 | algorithm: str, 92 | library1: str, 93 | version1: str, 94 | optimazation1: str, 95 | library2: str, 96 | version2: str, 97 | optimazation2: str, 98 | filename: str | None, 99 | ) -> str: 100 | library = library1 if library1 == library2 else "_other" 101 | return f"matchresults/{algorithm}/{library}/{library1}-{version1}-{optimazation1}__vs__{library2}-{version2}-{optimazation2}/{filename+'.json' if filename else ''}" # noqa: E501 102 | 103 | 104 | def get_matchresult_cross_vs_named_filepair( 105 | library: str, 106 | version1: str, 107 | optimazation1: str, 108 | version2: str, 109 | optimazation2: str, 110 | ) -> str: # noqa: E501 111 | return f"{library}-{version1}-{optimazation1}__vs__{library}-{version2}-{optimazation2}" 112 | 113 | 114 | coreutils_cross_version_pairs = list(build_testcase_cross_version_pairs_on_library("coreutils")) 115 | findutils_cross_version_pairs = list(build_testcase_cross_version_pairs_on_library("findutils")) 116 | diffutils_cross_version_pairs = list(build_testcase_cross_version_pairs_on_library("diffutils")) 117 | 118 | coreutils_cross_optimization_pairs = list(build_testcase_cross_optimization_pairs_on_library("coreutils")) 119 | findutils_cross_optimization_pairs = list(build_testcase_cross_optimization_pairs_on_library("findutils")) 120 | diffutils_cross_optimization_pairs = list(build_testcase_cross_optimization_pairs_on_library("diffutils")) 121 | 122 | 123 | testcase_pairs = ( 124 | coreutils_cross_version_pairs 125 | + findutils_cross_version_pairs 126 | + diffutils_cross_version_pairs 127 | + coreutils_cross_optimization_pairs 128 | + findutils_cross_optimization_pairs 129 | + diffutils_cross_optimization_pairs 130 | ) 131 | example_pair = (("openssl", "1.1.1u", "gcc_arm_O0"), ("openssl", "3.1.1", "gcc_x64_O3")) 132 | example_filename = "openssl.strip" 133 | -------------------------------------------------------------------------------- /binaryai_bindiffmatch/models.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Optional 3 | 4 | 5 | @dataclasses.dataclass(slots=True, kw_only=True) 6 | class BasicInfo: 7 | base_address: int 8 | file_type: Optional[str] = None 9 | machine_type: Optional[str] = None 10 | platform_type: Optional[str] = None 11 | endianness: Optional[str] = None 12 | loader: Optional[str] = None 13 | entrypoint: Optional[int] = None 14 | 15 | 16 | @dataclasses.dataclass(slots=True, kw_only=True) 17 | class Function: 18 | addr: int 19 | name: Optional[str] = None 20 | pseudocode: Optional[str] = None 21 | callees: Optional[list[int]] = None 22 | strings: Optional[list[str]] = None 23 | embedding: Optional[list[float]] = None 24 | linecount: Optional[int] = None 25 | 26 | 27 | @dataclasses.dataclass(slots=True, kw_only=True) 28 | class BinaryFile: 29 | sha256: str 30 | basic_info: BasicInfo 31 | functions: Optional[list[Function]] = None 32 | 33 | 34 | # -------------------------------------- 35 | 36 | 37 | @dataclasses.dataclass(slots=True, kw_only=True) 38 | class AlgorithmOutputMatchPair: 39 | addr_1: int 40 | addr_2: int 41 | score: Optional[float] = None 42 | 43 | 44 | @dataclasses.dataclass(slots=True, kw_only=True) 45 | class MatchPair: 46 | function_1: Function 47 | function_2: Function 48 | score: Optional[float] = None 49 | 50 | 51 | @dataclasses.dataclass(slots=True, kw_only=True) 52 | class MatchResult: 53 | file_1: BinaryFile 54 | file_2: BinaryFile 55 | matches: list[MatchPair] 56 | unmatches_1: list[Function] 57 | unmatches_2: list[Function] 58 | -------------------------------------------------------------------------------- /binaryai_bindiffmatch/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binaryai/bindiffmatch/68af1791392b19fd784200444ddca4563cef0d0a/binaryai_bindiffmatch/py.typed -------------------------------------------------------------------------------- /binaryai_bindiffmatch/similarity_matrix/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binaryai/bindiffmatch/68af1791392b19fd784200444ddca4563cef0d0a/binaryai_bindiffmatch/similarity_matrix/__init__.py -------------------------------------------------------------------------------- /binaryai_bindiffmatch/similarity_matrix/basic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | from scipy.optimize import linear_sum_assignment # type: ignore[import] 6 | 7 | from ..models import BinaryFile 8 | 9 | 10 | class WrapMatrix: 11 | __slots__ = ["_raw_matrix"] 12 | 13 | def __init__(self, raw_matrix: npt.ArrayLike) -> None: 14 | self._raw_matrix = np.asarray(raw_matrix) 15 | assert self._raw_matrix.ndim == 2 16 | 17 | def get_matrix(self) -> npt.NDArray[Any]: 18 | return self._raw_matrix 19 | 20 | def get_value(self, i: int, j: int) -> Any: 21 | return self._raw_matrix[i, j] 22 | 23 | def get_raw_matrix(self) -> npt.NDArray[Any]: 24 | return self._raw_matrix 25 | 26 | 27 | def wrap_linear_sum_assignment( 28 | similarity_score_matrix: WrapMatrix, 29 | maximize: bool = False, 30 | subrows: Optional[npt.ArrayLike] = None, 31 | subcols: Optional[npt.ArrayLike] = None, 32 | ) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: 33 | if not subrows or not subcols: 34 | return linear_sum_assignment(similarity_score_matrix.get_raw_matrix(), maximize) # type: ignore[no-any-return] 35 | sub_similarity_score_matrix = similarity_score_matrix.get_raw_matrix()[np.ix_(subrows, subcols)] # type: ignore[arg-type] # noqa: E501 36 | row_index_ind, col_index_ind = linear_sum_assignment(sub_similarity_score_matrix, True) 37 | row_ind = (subrows[i] for i in row_index_ind) # type: ignore[index] 38 | col_ind = (subcols[j] for j in col_index_ind) # type: ignore[index] 39 | return row_ind, col_ind # type: ignore[return-value] 40 | 41 | 42 | def create_numpy_array( 43 | data: npt.ArrayLike | None, dtype: npt.DTypeLike, shape: int | tuple[int, ...] 44 | ) -> npt.NDArray[Any]: 45 | a = np.empty(shape=shape, dtype=dtype) 46 | if data is not None: 47 | a[:] = data 48 | return a 49 | 50 | 51 | def build_similarity_score_matrix(doc1: BinaryFile, doc2: BinaryFile) -> WrapMatrix: 52 | def get_embeddings_list(doc: BinaryFile) -> list[npt.NDArray[Any]]: 53 | assert doc.functions is not None 54 | r = list[npt.NDArray[Any]]() 55 | for func in doc.functions: 56 | assert func.embedding is not None 57 | emb = np.asarray(func.embedding) 58 | assert emb.ndim == 1 59 | emb = emb / np.linalg.norm(emb, 2) # l2 normalize 60 | emb[emb < -1] = -1 61 | emb[emb > 1] = 1 62 | r.append(emb) 63 | return r 64 | 65 | assert doc1.functions is not None 66 | assert doc2.functions is not None 67 | 68 | func_count1 = len(doc1.functions) 69 | func_count2 = len(doc2.functions) 70 | embedding_len = len(doc1.functions[0].embedding) if func_count1 > 0 else 0 # type: ignore[arg-type] 71 | embedding_len_2 = len(doc2.functions[0].embedding) if func_count2 > 0 else 0 # type: ignore[arg-type] 72 | assert embedding_len == embedding_len_2, f"error: embedding lenth diff: {embedding_len} != {embedding_len_2}" 73 | 74 | embeddings_list1 = get_embeddings_list(doc1) 75 | embeddings1 = np.asarray(embeddings_list1) 76 | del embeddings_list1 77 | embeddings_list2 = get_embeddings_list(doc2) 78 | embeddings2 = np.asarray(embeddings_list2) 79 | del embeddings_list2 80 | 81 | similarity_score_matrix_raw = embeddings1 @ embeddings2.transpose() 82 | similarity_score_matrix = WrapMatrix(similarity_score_matrix_raw) 83 | return similarity_score_matrix 84 | -------------------------------------------------------------------------------- /binaryai_bindiffmatch/similarity_matrix/lowmem.py: -------------------------------------------------------------------------------- 1 | import math 2 | import platform 3 | import tempfile 4 | from typing import Any, Optional 5 | 6 | import numpy as np 7 | import numpy.typing as npt 8 | from nanolsap import linear_sum_assignment 9 | 10 | from ..models import BinaryFile 11 | 12 | USE_MEMMAP_NUMPY_ARRAY = False if platform.system() == "Windows" else True 13 | WRAP_MATRIX_SCALE: int = 127 14 | WRAP_MATRIX_DTYPE = np.int8 15 | BLOCK_MATMUL_UNIT_LIMIT = 5000 16 | 17 | 18 | class WrapMatrix: 19 | __slots__ = ["_raw_matrix", "_transpose", "_scale"] 20 | 21 | def __init__(self, raw_matrix: npt.ArrayLike, *, transpose: bool = False, scale: int = 1) -> None: 22 | self._raw_matrix = np.asarray(raw_matrix) 23 | self._transpose = transpose 24 | self._scale = scale 25 | assert self._raw_matrix.ndim == 2 26 | assert self._scale != 0 27 | 28 | def get_matrix(self) -> npt.NDArray[Any]: 29 | if self._transpose: 30 | r = self._raw_matrix.transpose() 31 | else: 32 | r = self._raw_matrix 33 | if self._scale != 1: 34 | r = r / self._scale 35 | return r 36 | 37 | def get_value(self, i: int, j: int) -> Any: 38 | if self._transpose: 39 | return self._raw_matrix[j, i] / self._scale 40 | return self._raw_matrix[i, j] / self._scale 41 | 42 | def get_raw_matrix(self) -> npt.NDArray[Any]: 43 | return self._raw_matrix 44 | 45 | def is_transpose(self) -> bool: 46 | return self._transpose 47 | 48 | 49 | def wrap_linear_sum_assignment( 50 | similarity_score_matrix: WrapMatrix, 51 | maximize: bool = False, 52 | subrows: Optional[npt.ArrayLike] = None, 53 | subcols: Optional[npt.ArrayLike] = None, 54 | ) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: 55 | # nanolsap.linear_sum_assignment works much slower when nr > nc, so we manually do a shadow transpose 56 | if similarity_score_matrix.is_transpose(): 57 | col_ind, row_ind = linear_sum_assignment(similarity_score_matrix.get_raw_matrix(), maximize, subcols, subrows) 58 | else: 59 | row_ind, col_ind = linear_sum_assignment(similarity_score_matrix.get_raw_matrix(), maximize, subrows, subcols) 60 | return row_ind, col_ind 61 | 62 | 63 | def create_numpy_array( 64 | data: npt.ArrayLike | None, dtype: npt.DTypeLike, shape: int | tuple[int, ...] 65 | ) -> npt.NDArray[Any]: 66 | def prod(d: int | tuple[int, ...]) -> int: 67 | if isinstance(d, int): 68 | return d 69 | r = 1 70 | for c in d: 71 | r *= c 72 | return r 73 | 74 | if not USE_MEMMAP_NUMPY_ARRAY: 75 | a = np.empty(shape=shape, dtype=dtype) 76 | else: 77 | with tempfile.NamedTemporaryFile(prefix="binaryai_bindiffmatch_tmp_") as fp: 78 | memsize = np.dtype(dtype).itemsize * prod(shape) 79 | fp.truncate(memsize) 80 | a = np.memmap(fp.name, dtype=dtype, shape=shape) 81 | 82 | if data is not None: 83 | a[:] = data 84 | return a 85 | 86 | 87 | def block_matmul( 88 | x1: npt.ArrayLike, 89 | x2: npt.ArrayLike, 90 | *, 91 | out: npt.NDArray[Any] | None = None, 92 | out_dtype: npt.DTypeLike = np.float64, 93 | unit_limit: int = 0, 94 | ) -> npt.NDArray[Any]: 95 | a = np.asarray(x1) 96 | b = np.asarray(x2) 97 | assert a.ndim == 2 98 | assert b.ndim == 2 99 | assert a.shape[1] == b.shape[0] 100 | n, m = a.shape[0], b.shape[1] 101 | u = max(n, m) if unit_limit == 0 else unit_limit 102 | r = np.empty(shape=(n, m), dtype=out_dtype) if out is None else out 103 | for i in range((n + u - 1) // u): 104 | for j in range((m + u - 1) // u): 105 | tmp_x = a[i * u : (i + 1) * u, :] 106 | tmp_y = b[:, j * u : (j + 1) * u] 107 | r[i * u : (i + 1) * u, j * u : (j + 1) * u] = tmp_x @ tmp_y 108 | return r 109 | 110 | 111 | def build_similarity_score_matrix(doc1: BinaryFile, doc2: BinaryFile) -> WrapMatrix: 112 | def get_embeddings_list_with_scale(doc: BinaryFile, scale: int = 1) -> list[npt.NDArray[Any]]: 113 | assert doc.functions is not None 114 | r = list[npt.NDArray[Any]]() 115 | for func in doc.functions: 116 | assert func.embedding is not None 117 | emb = np.asarray(func.embedding) 118 | assert emb.ndim == 1 119 | emb = emb / np.linalg.norm(emb, 2) # l2 normalize 120 | emb[emb < -1] = -1 121 | emb[emb > 1] = 1 122 | if scale != 1: 123 | sqrt_scale = math.sqrt(scale) 124 | # notice this sqrt, because later we will multiply the two embedding matrixes 125 | emb = emb * sqrt_scale 126 | r.append(emb) 127 | return r 128 | 129 | assert doc1.functions is not None 130 | assert doc2.functions is not None 131 | 132 | func_count1 = len(doc1.functions) 133 | func_count2 = len(doc2.functions) 134 | embedding_len = len(doc1.functions[0].embedding) if func_count1 > 0 else 0 # type: ignore[arg-type] 135 | embedding_len_2 = len(doc2.functions[0].embedding) if func_count2 > 0 else 0 # type: ignore[arg-type] 136 | assert embedding_len == embedding_len_2, f"error: embedding lenth diff: {embedding_len} != {embedding_len_2}" 137 | 138 | transpose = func_count1 > func_count2 139 | scale = WRAP_MATRIX_SCALE 140 | tmp_dtype = np.float32 141 | wrap_dtype = WRAP_MATRIX_DTYPE 142 | tmp_wrap_dtype = wrap_dtype 143 | 144 | scale_embeddings_list1 = get_embeddings_list_with_scale(doc1, scale) 145 | embeddings1 = create_numpy_array(scale_embeddings_list1, tmp_dtype, (func_count1, embedding_len)) 146 | del scale_embeddings_list1 147 | scale_embeddings_list2 = get_embeddings_list_with_scale(doc2, scale) 148 | embeddings2 = create_numpy_array(scale_embeddings_list2, tmp_dtype, (func_count2, embedding_len)) 149 | del scale_embeddings_list2 150 | 151 | if transpose: 152 | tmp_func_count1, tmp_func_count2 = func_count2, func_count1 153 | tmp_embeddings1, tmp_embeddings2 = embeddings2, embeddings1 154 | else: 155 | tmp_func_count1, tmp_func_count2 = func_count1, func_count2 156 | tmp_embeddings1, tmp_embeddings2 = embeddings1, embeddings2 157 | similarity_score_matrix_raw = create_numpy_array(None, tmp_wrap_dtype, (tmp_func_count1, tmp_func_count2)) 158 | # notice: here, the matmul may cause memory usage sudden increase then decrease 159 | # Only when two input matrixes and the output matrixes has same dtype, numpy can do in place calculate. 160 | # otherwise, if output dtype is not same as input dtype, numpy will do internal memory copy on input matrixes, 161 | # so it cannot benefit from memmap memory swap to saving memory, and may cause oom on huge input matrixes. 162 | # np.matmul(tmp_embeddings1, tmp_embeddings2.transpose(), out=similarity_score_matrix_raw, casting="unsafe") 163 | block_matmul( 164 | tmp_embeddings1, 165 | tmp_embeddings2.transpose(), 166 | out=similarity_score_matrix_raw, 167 | out_dtype=tmp_wrap_dtype, 168 | unit_limit=BLOCK_MATMUL_UNIT_LIMIT, 169 | ) 170 | 171 | del tmp_embeddings1 172 | del tmp_embeddings2 173 | del embeddings1 174 | del embeddings2 175 | 176 | similarity_score_matrix_raw = similarity_score_matrix_raw.astype(wrap_dtype, copy=False) 177 | similarity_score_matrix = WrapMatrix(similarity_score_matrix_raw, transpose=transpose, scale=scale) 178 | return similarity_score_matrix 179 | -------------------------------------------------------------------------------- /binaryai_bindiffmatch/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import dataclasses 3 | import json 4 | import os 5 | import typing 6 | from collections.abc import Generator, Iterable 7 | from pathlib import Path 8 | from typing import Any 9 | 10 | from .models import AlgorithmOutputMatchPair, BasicInfo, BinaryFile, Function, MatchPair, MatchResult 11 | 12 | 13 | def tranverse_path(src: str, dst: str) -> Generator[tuple[str, str], None, None]: 14 | try: 15 | f = open(src, "rb") 16 | f.close() 17 | yield (src, dst) 18 | except IsADirectoryError: 19 | for srcpath in Path(src).glob("**/*"): 20 | if not srcpath.is_file(): 21 | continue 22 | rel_srcpath = srcpath.relative_to(src) 23 | dstpath = Path(dst).joinpath(rel_srcpath) 24 | yield (str(srcpath), str(dstpath)) 25 | 26 | 27 | def tranverse_path_two(src1: str, src2: str, dst: str) -> Generator[tuple[str, str, str], None, None]: 28 | try: 29 | f = open(src1, "rb") 30 | f.close() 31 | yield (src1, src2, dst) 32 | except IsADirectoryError: 33 | for src1path in Path(src1).glob("**/*"): 34 | if not src1path.is_file(): 35 | continue 36 | rel_src1path = src1path.relative_to(src1) 37 | src2path = Path(src2).joinpath(rel_src1path) 38 | dstpath = Path(dst).joinpath(rel_src1path) 39 | yield (str(src1path), str(src2path), str(dstpath)) 40 | 41 | 42 | def unserialize(data: Any, objtype: Any, *, deepcopy: bool = True) -> Any: 43 | def _get_possible_type_of_optional(_objtype: Any) -> Any: 44 | if typing.get_origin(_objtype) is typing.Union: 45 | args = typing.get_args(_objtype) 46 | if len(args) == 2 and args[1] is type(None): # noqa: E721 47 | return args[0] 48 | return None 49 | 50 | def _get_real_type(_objtype: type) -> Any: 51 | if (r := typing.get_origin(_objtype)) is not None: 52 | return r 53 | return _objtype 54 | 55 | def _get_matched(_t: Any, _declare: Any) -> Any: 56 | if _declare is Any or _t is _declare: 57 | return _declare 58 | elif (_origin := typing.get_origin(_declare)) is not None: 59 | if _origin is typing.Union: # this implied typing.Optional 60 | for _union_arg in typing.get_args(_declare): 61 | matched = _get_matched(_t, _union_arg) 62 | if matched is not None: 63 | return matched 64 | return None 65 | elif issubclass(_t, _origin): # for typing.GenericAlias 66 | return _declare # should return origin one 67 | else: 68 | return None 69 | elif issubclass(_t, _declare): # for basic type 70 | return _declare 71 | else: 72 | return None 73 | 74 | def _get_list_type_arg(_objtype: Any) -> Any: 75 | _real = _get_real_type(_objtype) 76 | assert _real is Any or issubclass(_real, (list, tuple)), _real 77 | objtype_args = typing.get_args(_objtype) 78 | item_type = objtype_args[0] if len(objtype_args) >= 1 else Any 79 | return item_type 80 | 81 | def _get_dict_type_arg(_objtype: Any) -> tuple[Any, Any]: 82 | _real = _get_real_type(_objtype) 83 | assert _real is Any or issubclass(_real, dict), _real 84 | objtype_args = typing.get_args(_objtype) 85 | key_type = objtype_args[0] if len(objtype_args) >= 1 else Any 86 | val_type = objtype_args[1] if len(objtype_args) >= 2 else Any 87 | return key_type, val_type 88 | 89 | def _unserialize_internal(_data: Any, _objtype: Any, *, deepcopy: bool) -> Any: 90 | matched = _get_matched(type(_data), _objtype) 91 | if isinstance(_data, (list, tuple)): 92 | assert matched is not None 93 | item_type = _get_list_type_arg(matched) 94 | return type(_data)(_unserialize_internal(v, item_type, deepcopy=deepcopy) for v in _data) 95 | elif isinstance(_data, dict): 96 | if matched is not None: # normal dict 97 | key_type, val_type = _get_dict_type_arg(matched) 98 | return type(_data)( 99 | ( 100 | _unserialize_internal(k, key_type, deepcopy=deepcopy), 101 | _unserialize_internal(v, val_type, deepcopy=deepcopy), 102 | ) 103 | for k, v in _data.items() 104 | ) 105 | else: # maybe nesting dataclasses 106 | maybe_dataclass = _get_possible_type_of_optional(_objtype) 107 | if maybe_dataclass is None: 108 | maybe_dataclass = _objtype 109 | field_types = {f.name: f.type for f in dataclasses.fields(maybe_dataclass)} 110 | fields = {} 111 | for k, v in _data.items(): 112 | if k not in field_types: 113 | continue # XXX: ignore extra items 114 | assert isinstance(k, str) 115 | fields[k] = _unserialize_internal(v, field_types[k], deepcopy=deepcopy) 116 | return maybe_dataclass(**fields) 117 | else: 118 | if not matched: 119 | return _objtype(_data) # try to convert 120 | return copy.deepcopy(_data) if deepcopy else _data 121 | 122 | return _unserialize_internal(data, objtype, deepcopy=deepcopy) 123 | 124 | 125 | def dict_factory_ignore_none(x: list[tuple[str, Any]]) -> dict[Any, Any]: 126 | return {k: v for (k, v) in x if v is not None} 127 | 128 | 129 | # -------------------------------------- 130 | 131 | 132 | def load_doc(filename: str) -> BinaryFile: 133 | with open(filename, "rb") as f: 134 | doc_json = json.load(f) 135 | doc: BinaryFile = unserialize(doc_json, BinaryFile) 136 | return doc 137 | 138 | 139 | def dump_doc(doc: BinaryFile, filename: str) -> None: 140 | doc_json = dataclasses.asdict(doc, dict_factory=dict_factory_ignore_none) 141 | if os.path.dirname(filename): 142 | os.makedirs(os.path.dirname(filename), exist_ok=True) 143 | with open(filename, "w") as f: 144 | json.dump(doc_json, f) 145 | 146 | 147 | def load_matchresult(filename: str) -> MatchResult: 148 | with open(filename, "rb") as f: 149 | matchresult_json = json.load(f) 150 | matchresult: MatchResult = unserialize(matchresult_json, MatchResult) 151 | return matchresult 152 | 153 | 154 | def dump_matchresult(matchresult: MatchResult, filename: str) -> None: 155 | matchresult_json = dataclasses.asdict(matchresult, dict_factory=dict_factory_ignore_none) 156 | if os.path.dirname(filename): 157 | os.makedirs(os.path.dirname(filename), exist_ok=True) 158 | with open(filename, "w") as f: 159 | json.dump(matchresult_json, f) 160 | 161 | 162 | def build_matchresult_with_matchpairs( 163 | doc1: BinaryFile, doc2: BinaryFile, matched_pair: Iterable[AlgorithmOutputMatchPair] 164 | ) -> MatchResult: 165 | assert doc1.functions is not None 166 | assert doc2.functions is not None 167 | 168 | matched_addrs1 = set() 169 | matched_addrs2 = set() 170 | matches = [] 171 | 172 | for m in matched_pair: 173 | matched_addrs1.add(m.addr_1) 174 | matched_addrs2.add(m.addr_2) 175 | matches.append( 176 | MatchPair( 177 | function_1=Function(addr=m.addr_1), 178 | function_2=Function(addr=m.addr_2), 179 | score=m.score, 180 | ) 181 | ) 182 | 183 | unmatches_1 = [Function(addr=func1.addr) for func1 in doc1.functions if func1.addr not in matched_addrs1] 184 | unmatches_2 = [Function(addr=func2.addr) for func2 in doc2.functions if func2.addr not in matched_addrs2] 185 | 186 | match_result = MatchResult( 187 | file_1=BinaryFile(sha256=doc1.sha256, basic_info=BasicInfo(base_address=doc1.basic_info.base_address)), 188 | file_2=BinaryFile(sha256=doc2.sha256, basic_info=BasicInfo(base_address=doc2.basic_info.base_address)), 189 | matches=matches, 190 | unmatches_1=unmatches_1, 191 | unmatches_2=unmatches_2, 192 | ) 193 | 194 | return match_result 195 | 196 | 197 | # -------------------------------------- 198 | 199 | 200 | def calculate_code_linecount(code: str) -> int: 201 | return code.strip().count("\n") 202 | 203 | 204 | def is_valid_function(func: Function) -> bool: 205 | if func.linecount is not None: 206 | return func.linecount >= 7 207 | else: 208 | assert func.pseudocode is not None 209 | return func.pseudocode.strip().count("\n") >= 7 210 | 211 | 212 | def filter_out_invalid_function(doc1: BinaryFile, doc2: BinaryFile, matchresult: MatchResult) -> MatchResult: 213 | assert doc1.functions is not None 214 | assert doc2.functions is not None 215 | valid_addrs_1 = set(func1.addr for func1 in doc1.functions if is_valid_function(func1)) 216 | valid_addrs_2 = set(func2.addr for func2 in doc2.functions if is_valid_function(func2)) 217 | 218 | result = MatchResult( 219 | file_1=copy.deepcopy(matchresult.file_1), 220 | file_2=copy.deepcopy(matchresult.file_2), 221 | matches=[ 222 | copy.deepcopy(p) 223 | for p in matchresult.matches 224 | if p.function_1.addr in valid_addrs_1 and p.function_2.addr in valid_addrs_2 225 | ], 226 | unmatches_1=[copy.deepcopy(u1) for u1 in matchresult.unmatches_1 if u1.addr in valid_addrs_1], 227 | unmatches_2=[copy.deepcopy(u2) for u2 in matchresult.unmatches_2 if u2.addr in valid_addrs_2], 228 | ) 229 | 230 | return result 231 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61", "wheel", "setuptools_scm[toml]>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "binaryai-bindiffmatch" 7 | authors = [ 8 | {name = "hzqmwne", email = "huangzhengqmwne@sina.cn"}, 9 | {name = "KeenLab", email = "KeenSecurityLab@tencent.com"}, 10 | ] 11 | requires-python = ">=3.10" 12 | dependencies = [ 13 | "networkx>=3,<4", 14 | "numpy>=1,<2", 15 | "scipy>=1,<2", 16 | ] 17 | dynamic = ["version"] 18 | 19 | [project.optional-dependencies] 20 | lowmem = ["nanolsap"] 21 | dev = ["flake8", "black", "isort", "mypy"] 22 | 23 | [project.urls] 24 | homepage = "https://github.com/binaryai/bindiffmatch" 25 | documentation = "https://github.com/binaryai/bindiffmatch" 26 | repository = "https://github.com/binaryai/bindiffmatch" 27 | 28 | [tool.setuptools.packages.find] 29 | include = ["binaryai_bindiffmatch"] 30 | 31 | [tool.setuptools.package-data] 32 | "*" = ["py.typed", "MANIFEST.in"] 33 | 34 | [tool.setuptools_scm] 35 | fallback_version = "0.0.0" 36 | -------------------------------------------------------------------------------- /scripts/diaphora-3.0-b91a9e7abe03de45bf47d4619eda7f8b3f0357bb.patch: -------------------------------------------------------------------------------- 1 | diff --git a/database/schema.py b/database/schema.py 2 | index a2a0e16..7e13bf7 100644 3 | --- a/database/schema.py 4 | +++ b/database/schema.py 5 | @@ -181,5 +181,10 @@ TABLES = [ 6 | """ create table if not exists compilation_unit_functions ( 7 | id integer primary key, 8 | cu_id integer not null references compilation_units(id) on delete cascade, 9 | - func_id integer not null references functions(id) on delete cascade)""" 10 | + func_id integer not null references functions(id) on delete cascade)""", 11 | + """create table if not exists program_basic_info ( 12 | + id integer primary key, 13 | + program_id integer not null unique, 14 | + file_sha256 text not null, 15 | + base_address text not null)""", 16 | ] 17 | diff --git a/diaphora.py b/diaphora.py 18 | index 435530d..50a9f54 100755 19 | --- a/diaphora.py 20 | +++ b/diaphora.py 21 | @@ -633,6 +633,24 @@ class CBinDiff: 22 | 23 | return ret 24 | 25 | + def add_program_basic_info(self, program_id, file_sha256, base_address): 26 | + cur = self.db_cursor() 27 | + sql = "insert into main.program_basic_info (program_id, file_sha256, base_address) values (?, ?, ?)" 28 | + values = (program_id, file_sha256, str(base_address)) 29 | + cur.execute(sql, values) 30 | + cur.close() 31 | + 32 | + def get_program_basic_info(self, program_id): 33 | + cur = self.db_cursor() 34 | + sql = "select program_id, file_sha256, base_address from program_basic_info where program_id = ?" 35 | + cur.execute(sql, (program_id,)) 36 | + row = cur.fetchone() 37 | + basic_info = {} 38 | + if row is not None: 39 | + basic_info = dict(row) 40 | + cur.close() 41 | + return basic_info 42 | + 43 | def add_program_data(self, type_name, key, value): 44 | """ 45 | Add a row of program data to the database. 46 | @@ -690,6 +708,9 @@ class CBinDiff: 47 | for instruction in bb_data[key]: 48 | instruction_properties = [] 49 | for instruction_property in instruction: 50 | + if (isinstance(instruction_property, int) and 51 | + (instruction_property > 0xFFFFFFFF or instruction_property < -0xFFFFFFFF)): 52 | + instruction_property = str(instruction_property) 53 | if isinstance(instruction_property, (list, set)): 54 | instruction_properties.append( 55 | json.dumps( 56 | @@ -2484,6 +2505,19 @@ class CBinDiff: 57 | 58 | cur = results_db.cursor() 59 | try: 60 | + sql = """create table if not exists program_basic_info ( 61 | + id integer primary key, 62 | + program_id integer not null unique, 63 | + file_sha256 text not null, 64 | + base_address text not null)""" 65 | + cur.execute(sql) 66 | + 67 | + sql = "insert into main.program_basic_info (program_id, file_sha256, base_address) values (?, ?, ?)" 68 | + values = (1, self.basic_info_1["file_sha256"], str(self.basic_info_1["base_address"])) 69 | + cur.execute(sql, values) 70 | + values = (2, self.basic_info_2["file_sha256"], str(self.basic_info_2["base_address"])) 71 | + cur.execute(sql, values) 72 | + 73 | sql = "create table config (main_db text, diff_db text, version text, date text)" 74 | cur.execute(sql) 75 | 76 | @@ -3466,6 +3500,15 @@ class CBinDiff: 77 | f"WARNING: The database is from a different version (current {VERSION_VALUE}, database {row[0]})!" 78 | ) 79 | 80 | + self.basic_info_1 = self.get_program_basic_info(1) 81 | + sql = "select program_id, file_sha256, base_address from diff.program_basic_info where program_id = ?" 82 | + cur.execute(sql, (1,)) 83 | + row = cur.fetchone() 84 | + basic_info_2 = {} 85 | + if row is not None: 86 | + basic_info_2 = dict(row) 87 | + self.basic_info_2 = basic_info_2 88 | + 89 | try: 90 | t0 = time.monotonic() 91 | cur_thread = threading.current_thread() 92 | @@ -3608,6 +3651,8 @@ if __name__ == "__main__": 93 | parser.add_argument("db1") 94 | parser.add_argument("db2") 95 | parser.add_argument("-o", "--outfile", help="Write output to ") 96 | + if not IS_IDA: 97 | + parser.add_argument("--relaxed_ratio", action="store_true") 98 | args = parser.parse_args() 99 | db1 = args.db1 100 | db2 = args.db2 101 | @@ -3621,6 +3666,7 @@ if __name__ == "__main__": 102 | if do_diff: 103 | bd = CBinDiff(db1) 104 | if not IS_IDA: 105 | + bd.relaxed_ratio = args.relaxed_ratio 106 | bd.ignore_all_names = False 107 | 108 | bd.db = sqlite3_connect(db1) 109 | diff --git a/diaphora_ida.py b/diaphora_ida.py 110 | index 9728497..42a59e9 100644 111 | --- a/diaphora_ida.py 112 | +++ b/diaphora_ida.py 113 | @@ -1200,6 +1200,7 @@ class CIDABinDiff(diaphora.CBinDiff): 114 | self.db.execute("BEGIN transaction") 115 | 116 | md5sum = GetInputFileMD5() 117 | + self.add_program_basic_info(1, retrieve_input_file_sha256(), self.get_base_address()) 118 | self.save_callgraph( 119 | str(callgraph_primes), json.dumps(callgraph_all_primes), md5sum 120 | ) 121 | @@ -3867,6 +3868,7 @@ def remove_file(filename): 122 | "function_bblocks", 123 | "compilation_units", 124 | "compilation_unit_functions", 125 | + "program_basic_info", 126 | ] 127 | for func in funcs: 128 | db.execute(f"drop table if exists {func}") 129 | -------------------------------------------------------------------------------- /scripts/metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | import os 4 | from collections.abc import Iterable 5 | from concurrent.futures import Future, ProcessPoolExecutor 6 | 7 | from binaryai_bindiffmatch import metricsutils 8 | from binaryai_bindiffmatch.models import Function, MatchResult 9 | from binaryai_bindiffmatch.utils import load_matchresult 10 | 11 | global_executor = ProcessPoolExecutor(100) 12 | 13 | global_ignore_external_functions = True 14 | global_ignore_small_functions = True 15 | 16 | 17 | @dataclasses.dataclass(slots=True, kw_only=True) 18 | class MetricValue: 19 | groundtruth_match_count: int 20 | groundtruth_unmatch_1_count: int 21 | groundtruth_unmatch_2_count: int 22 | groundtruth_ignored_1_count: int 23 | groundtruth_ignored_2_count: int 24 | result_match_count: int 25 | result_unmatch_1_count: int 26 | result_unmatch_2_count: int 27 | result_match_in_groundtruth_count: int 28 | # result_unmatch_1_in_groundtruth_count: int 29 | # result_unmatch_2_in_groundtruth_count: int 30 | result_correct_match_count: int 31 | 32 | def get_precision(self) -> float: 33 | if self.result_match_in_groundtruth_count == 0: 34 | return float("nan") 35 | return self.result_correct_match_count / (self.result_match_in_groundtruth_count) 36 | 37 | def get_recall(self) -> float: 38 | if self.groundtruth_match_count == 0: 39 | return float("nan") 40 | return self.result_correct_match_count / self.groundtruth_match_count 41 | 42 | def get_f1(self) -> float: 43 | precision = self.get_precision() 44 | recall = self.get_recall() 45 | if precision + recall == 0: 46 | return float("nan") 47 | return 2 * precision * recall / (precision + recall) 48 | 49 | def display(self) -> str: 50 | groundtruth_file_1_total = ( 51 | self.groundtruth_match_count + self.groundtruth_unmatch_1_count + self.groundtruth_ignored_1_count 52 | ) 53 | groundtruth_file_2_total = ( 54 | self.groundtruth_match_count + self.groundtruth_unmatch_2_count + self.groundtruth_ignored_2_count 55 | ) 56 | result_file_1_total = self.result_match_count + self.result_unmatch_1_count 57 | result_file_2_total = self.result_match_count + self.result_unmatch_2_count 58 | return f"""\ 59 | groundtruth_file_1_total: {groundtruth_file_1_total}, groundtruth_file_2_total: {groundtruth_file_2_total} 60 | result_file_1_total: {result_file_1_total}, result_file_2_total: {result_file_2_total} 61 | groundtruth_match_count = {self.groundtruth_match_count} 62 | result_match_count = {self.result_match_count} 63 | result_match_in_groundtruth_count = {self.result_match_in_groundtruth_count} 64 | result_correct_match_count = {self.result_correct_match_count} 65 | precision = {self.get_precision()} 66 | recall = {self.get_recall()} 67 | f1 = {self.get_f1()} 68 | """ 69 | 70 | def merge(self, another: "MetricValue") -> "MetricValue": 71 | r = MetricValue( 72 | groundtruth_match_count=self.groundtruth_match_count + another.groundtruth_match_count, 73 | groundtruth_unmatch_1_count=self.groundtruth_unmatch_1_count + another.groundtruth_unmatch_1_count, 74 | groundtruth_unmatch_2_count=self.groundtruth_unmatch_2_count + another.groundtruth_unmatch_2_count, 75 | groundtruth_ignored_1_count=self.groundtruth_ignored_1_count + another.groundtruth_ignored_1_count, 76 | groundtruth_ignored_2_count=self.groundtruth_ignored_2_count + another.groundtruth_ignored_2_count, 77 | result_match_count=self.result_match_count + another.result_match_count, 78 | result_unmatch_1_count=self.result_unmatch_1_count + another.result_unmatch_1_count, 79 | result_unmatch_2_count=self.result_unmatch_2_count + another.result_unmatch_2_count, 80 | result_match_in_groundtruth_count=( 81 | self.result_match_in_groundtruth_count + another.result_match_in_groundtruth_count 82 | ), 83 | result_correct_match_count=self.result_correct_match_count + another.result_correct_match_count, 84 | ) 85 | return r 86 | 87 | @staticmethod 88 | def mergelist(values: "Iterable[MetricValue]") -> "MetricValue": 89 | sum_groundtruth_match_count = 0 90 | sum_groundtruth_unmatch_1_count = 0 91 | sum_groundtruth_unmatch_2_count = 0 92 | sum_groundtruth_ignored_1_count = 0 93 | sum_groundtruth_ignored_2_count = 0 94 | sum_result_match_count = 0 95 | sum_result_unmatch_1_count = 0 96 | sum_result_unmatch_2_count = 0 97 | sum_result_match_in_groundtruth_count = 0 98 | sum_result_correct_match_count = 0 99 | for v in values: 100 | sum_groundtruth_match_count += v.groundtruth_match_count 101 | sum_groundtruth_unmatch_1_count += v.groundtruth_unmatch_1_count 102 | sum_groundtruth_unmatch_2_count += v.groundtruth_unmatch_2_count 103 | sum_groundtruth_ignored_1_count += v.groundtruth_ignored_1_count 104 | sum_groundtruth_ignored_2_count += v.groundtruth_ignored_2_count 105 | sum_result_match_count += v.result_match_count 106 | sum_result_unmatch_1_count += v.result_unmatch_1_count 107 | sum_result_unmatch_2_count += v.result_unmatch_2_count 108 | sum_result_match_in_groundtruth_count += v.result_match_in_groundtruth_count 109 | sum_result_correct_match_count += v.result_correct_match_count 110 | r = MetricValue( 111 | groundtruth_match_count=sum_groundtruth_match_count, 112 | groundtruth_unmatch_1_count=sum_groundtruth_unmatch_1_count, 113 | groundtruth_unmatch_2_count=sum_groundtruth_unmatch_2_count, 114 | groundtruth_ignored_1_count=sum_groundtruth_ignored_1_count, 115 | groundtruth_ignored_2_count=sum_groundtruth_ignored_2_count, 116 | result_match_count=sum_result_match_count, 117 | result_unmatch_1_count=sum_result_unmatch_1_count, 118 | result_unmatch_2_count=sum_result_unmatch_2_count, 119 | result_match_in_groundtruth_count=sum_result_match_in_groundtruth_count, 120 | result_correct_match_count=sum_result_correct_match_count, 121 | ) 122 | return r 123 | 124 | 125 | def evaluation(matchresult: MatchResult, groundtruth: MatchResult) -> MetricValue: 126 | def normalize_matchresult_addr_1(addr: int) -> int: 127 | return addr - matchresult.file_1.basic_info.base_address + groundtruth.file_1.basic_info.base_address 128 | 129 | def normalize_matchresult_addr_2(addr: int) -> int: 130 | return addr - matchresult.file_2.basic_info.base_address + groundtruth.file_2.basic_info.base_address 131 | 132 | def is_external_function(func: Function) -> bool: 133 | return func.name is not None and func.name.startswith("::") # XXX 134 | 135 | def is_small_function(func: Function) -> bool: 136 | return func.linecount is not None and func.linecount <= 6 137 | 138 | def is_valid_function(func: Function) -> bool: 139 | if global_ignore_external_functions and is_external_function(func): 140 | return False 141 | if global_ignore_small_functions and is_small_function(func): 142 | return False 143 | return True 144 | 145 | groundtruth_matched_addrpairs = set() 146 | groundtruth_ignored_1_addrs = set() 147 | groundtruth_ignored_2_addrs = set() 148 | groundtruth_unmatch_1_addrs = set() 149 | groundtruth_unmatch_2_addrs = set() 150 | 151 | for p in groundtruth.matches: 152 | addr1 = p.function_1.addr 153 | addr2 = p.function_2.addr 154 | valid_1 = is_valid_function(p.function_1) 155 | valid_2 = is_valid_function(p.function_2) 156 | if valid_1 and valid_2: # only keep match with BOTH funcs are valid 157 | groundtruth_matched_addrpairs.add((addr1, addr2)) 158 | else: # ignore invalid functions 159 | groundtruth_ignored_1_addrs.add(addr1) 160 | groundtruth_ignored_2_addrs.add(addr2) 161 | for f in groundtruth.unmatches_1: 162 | addr = f.addr 163 | valid = is_valid_function(f) 164 | if valid: 165 | groundtruth_unmatch_1_addrs.add(addr) 166 | else: 167 | groundtruth_ignored_1_addrs.add(addr) 168 | for f in groundtruth.unmatches_2: 169 | addr = f.addr 170 | valid = is_valid_function(f) 171 | if valid: 172 | groundtruth_unmatch_2_addrs.add(addr) 173 | else: 174 | groundtruth_ignored_2_addrs.add(addr) 175 | 176 | groundtruth_file_1_valid_addrs = set(a for a, _ in groundtruth_matched_addrpairs) | groundtruth_unmatch_1_addrs 177 | groundtruth_file_2_valid_addrs = set(b for _, b in groundtruth_matched_addrpairs) | groundtruth_unmatch_2_addrs 178 | 179 | groundtruth_match_count = len(groundtruth_matched_addrpairs) 180 | groundtruth_unmatch_1_count = len(groundtruth_unmatch_1_addrs) 181 | groundtruth_unmatch_2_count = len(groundtruth_unmatch_2_addrs) 182 | groundtruth_ignored_1_count = len(groundtruth_ignored_1_addrs) 183 | groundtruth_ignored_2_count = len(groundtruth_ignored_2_addrs) 184 | 185 | result_match_count = len(matchresult.matches) 186 | result_unmatch_1_count = len(matchresult.unmatches_1) 187 | result_unmatch_2_count = len(matchresult.unmatches_2) 188 | result_match_in_groundtruth_count = 0 189 | result_correct_match_count = 0 190 | 191 | result_match_different_addr1_count = len(set(p.function_1.addr for p in matchresult.matches)) 192 | result_match_different_addr2_count = len(set(p.function_2.addr for p in matchresult.matches)) 193 | if not (result_match_count == result_match_different_addr1_count == result_match_different_addr2_count): 194 | # print( 195 | # "[warning]: count mismatch", 196 | # result_match_count, 197 | # result_match_different_addr1_count, 198 | # result_match_different_addr2_count, 199 | # ) 200 | pass 201 | 202 | for p in matchresult.matches: 203 | addr1 = normalize_matchresult_addr_1(p.function_1.addr) 204 | addr2 = normalize_matchresult_addr_2(p.function_2.addr) 205 | if addr1 not in groundtruth_file_1_valid_addrs or addr2 not in groundtruth_file_2_valid_addrs: 206 | # only matches with two addrs BOTH contains in groundtruth will be considered 207 | continue 208 | result_match_in_groundtruth_count += 1 209 | if (addr1, addr2) in groundtruth_matched_addrpairs: 210 | result_correct_match_count += 1 211 | 212 | r = MetricValue( 213 | groundtruth_match_count=groundtruth_match_count, 214 | groundtruth_unmatch_1_count=groundtruth_unmatch_1_count, 215 | groundtruth_unmatch_2_count=groundtruth_unmatch_2_count, 216 | groundtruth_ignored_1_count=groundtruth_ignored_1_count, 217 | groundtruth_ignored_2_count=groundtruth_ignored_2_count, 218 | result_match_count=result_match_count, 219 | result_unmatch_1_count=result_unmatch_1_count, 220 | result_unmatch_2_count=result_unmatch_2_count, 221 | result_match_in_groundtruth_count=result_match_in_groundtruth_count, 222 | result_correct_match_count=result_correct_match_count, 223 | ) 224 | return r 225 | 226 | 227 | def evaluation_on_matchresultfile(matchresult_file: str, groundtruth_file: str) -> MetricValue: 228 | matchresult = load_matchresult(matchresult_file) 229 | groundtruth = load_matchresult(groundtruth_file) 230 | result = evaluation(matchresult, groundtruth) 231 | return result 232 | 233 | 234 | def batch_evaluation_files(filepairs: Iterable[tuple[str, str]]) -> list[MetricValue]: 235 | tasks: list[Future[MetricValue]] = [] 236 | for algoresult_file, groundtruth_file in filepairs: 237 | task = global_executor.submit(evaluation_on_matchresultfile, algoresult_file, groundtruth_file) 238 | tasks.append(task) 239 | results = [task.result() for task in tasks] 240 | return results 241 | 242 | 243 | def evaluation_on_testcase(datadir: str, algorithm: str) -> None: 244 | library_filenames: dict[str, list[str]] = {} 245 | filepairs: list[tuple[str, str]] = [] 246 | for (library1, version1, optimazation1), (library2, version2, optimazation2) in metricsutils.testcase_pairs: 247 | assert library1 == library2 248 | library = library1 249 | if library in library_filenames: 250 | filenames = library_filenames[library] 251 | else: 252 | filenames = metricsutils.get_stripped_filenames(datadir, library) 253 | library_filenames[library] = filenames 254 | for filename in filenames: 255 | testcase_matchresult_filename = metricsutils.get_algorithm_matchresult_relpath( 256 | algorithm, library1, version1, optimazation1, library2, version2, optimazation2, filename 257 | ) 258 | groundtruth_matchresult_filename = metricsutils.get_algorithm_matchresult_relpath( 259 | "groundtruth", library1, version1, optimazation1, library2, version2, optimazation2, filename 260 | ) 261 | matchresult_file = os.path.join(datadir, testcase_matchresult_filename) 262 | groundtruth_file = os.path.join(datadir, groundtruth_matchresult_filename) 263 | if os.path.exists(matchresult_file) and os.path.exists(groundtruth_file): 264 | filepairs.append( 265 | ( 266 | os.path.join(datadir, testcase_matchresult_filename), 267 | os.path.join(datadir, groundtruth_matchresult_filename), 268 | ) 269 | ) 270 | else: 271 | # print(f"warning: file `{matchresult_file}` or `{groundtruth_file}` not found") 272 | pass 273 | results = batch_evaluation_files(filepairs) 274 | metricresult = MetricValue.mergelist(results) 275 | print(metricresult.display()) 276 | 277 | 278 | def evaluation_on_example(datadir: str, algorithm: str) -> None: 279 | (library1, version1, optimazation1), (library2, version2, optimazation2) = metricsutils.example_pair 280 | filename = metricsutils.example_filename 281 | example_matchresult_filename = metricsutils.get_algorithm_matchresult_relpath( 282 | algorithm, library1, version1, optimazation1, library2, version2, optimazation2, filename 283 | ) 284 | groundtruth_matchresult_filename = metricsutils.get_algorithm_matchresult_relpath( 285 | "groundtruth", library1, version1, optimazation1, library2, version2, optimazation2, filename 286 | ) 287 | metricresult = evaluation_on_matchresultfile( 288 | os.path.join(datadir, example_matchresult_filename), os.path.join(datadir, groundtruth_matchresult_filename) 289 | ) 290 | print(metricresult.display()) 291 | 292 | 293 | def parse_args() -> argparse.Namespace: 294 | parser = argparse.ArgumentParser() 295 | parser.add_argument("choose", choices=["testcases", "example"]) 296 | parser.add_argument("algorithm", choices=["diaphora", "binaryai"]) 297 | parser.add_argument("--datadir", required=False, default="./data") 298 | return parser.parse_args() 299 | 300 | 301 | def main() -> None: 302 | args = parse_args() 303 | choose = args.choose 304 | algorithm = args.algorithm 305 | datadir = args.datadir 306 | 307 | match choose: 308 | case "testcases": 309 | evaluation_on_testcase(datadir, algorithm) 310 | case "example": 311 | evaluation_on_example(datadir, algorithm) 312 | case _: 313 | print(f"[error] unknown choose `{choose}`") 314 | 315 | 316 | if __name__ == "__main__": 317 | main() 318 | --------------------------------------------------------------------------------