├── .gitignore ├── BloomFilter ├── __init__.py ├── feature_extractor.py ├── main.py └── sfbl.py ├── Dataset ├── __init__.py ├── base.py ├── normal_sample.py ├── old_new_funcs.py ├── target_project.py ├── universal-ctags │ ├── COPYING │ ├── ctags │ └── readme.txt └── utils.py ├── SyntaxFilter ├── __init__.py └── detection.py ├── TokenFilter ├── __init__.py ├── main.py └── token_extraction.py ├── Trace ├── __init__.py ├── cfg.py ├── detection.py ├── embedding.py ├── manager.py ├── norm.py ├── readme.md ├── scripts │ ├── cppparser.so │ └── taint2json.sc ├── serializer.py ├── taintflow.py └── utils.py ├── config.example.yml ├── config.py ├── dockerfile ├── environment.yml ├── main.py ├── readme.md ├── requirements.txt ├── resource ├── OldNewFuncs │ ├── readme.md │ └── sample-project │ │ └── CVE-SAMPLE │ │ ├── CVE-SAMPLE_CWE-SAMPLE_abcdef_SAMPLE.cpp__SAMPLE_FUNC_NEW.vul │ │ └── CVE-SAMPLE_CWE-SAMPLE_abcdef_SAMPLE.cpp__SAMPLE_FUNC_OLD.vul ├── codebert │ └── readme.md ├── jdk-17.0.11 │ └── readme.md ├── joern-cli │ └── readme.md ├── readme.md └── redis-7.2.3 │ └── readme.md └── server.py /.gitignore: -------------------------------------------------------------------------------- 1 | config.yml 2 | cache 3 | result 4 | venv 5 | workspace 6 | Trace/v1/oldnew -------------------------------------------------------------------------------- /BloomFilter/__init__.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor, as_completed 2 | from typing import List 3 | 4 | from loguru import logger 5 | 6 | import BloomFilter.main 7 | import config 8 | 9 | 10 | def initialization(vul_functions: List[str], rebuild=False) -> None: 11 | logger.info("Initialize BloomFilter") 12 | BloomFilter.main.initialization(vul_functions, rebuild) 13 | logger.info("BloomFilter Initialized") 14 | 15 | 16 | def detect(input_queue, output_queue, pbar_queue) -> None: 17 | with ProcessPoolExecutor(max_workers=config.bloom_filter_worker) as executor: 18 | futures = {} 19 | 20 | def process_future(future): 21 | is_vul = future.result() 22 | pbar_queue.put(("bloom", is_vul)) 23 | if is_vul: 24 | function, function_path = futures[future] 25 | output_queue.put((function, function_path, [])) 26 | 27 | while True: 28 | vul_info = input_queue.get() 29 | 30 | if vul_info[1] == "__end_of_detection__": 31 | for future in as_completed(futures.keys()): 32 | process_future(future) 33 | output_queue.put(vul_info) 34 | logger.info("Bloom Filter Finished!") 35 | return 36 | 37 | function, function_path, _ = vul_info 38 | 39 | future = executor.submit(BloomFilter.main.detect, function) 40 | futures[future] = (function, function_path) 41 | 42 | done_futures = [] 43 | 44 | for future in futures.keys(): 45 | if not future.done(): 46 | continue 47 | process_future(future) 48 | done_futures.append(future) 49 | 50 | for future in done_futures: 51 | futures.pop(future) 52 | 53 | 54 | from .feature_extractor import FeatureExtractor 55 | -------------------------------------------------------------------------------- /BloomFilter/feature_extractor.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | import threading 4 | from typing import List, Iterable 5 | 6 | import numpy as np 7 | # noinspection PyUnresolvedReferences 8 | from pygments.lexers.c_cpp import CppLexer 9 | from pygments.token import Token 10 | 11 | 12 | class OperatorStateMachine: 13 | _Operators = frozenset(['++', '--', '+', '-', '*', '/', '%', '=', '+=', '-=', '*=', '/=', '%=', '<<=', 14 | '>>=', '&=', '^=', '|=', '&&', '||', '!', '==', '!=', '>=', '<=', '>', '<', '&', 15 | '|', '<<', '>>', '~', '^', '->']) # 42 16 | _double_operator = frozenset(["&", "+", "-", "|"]) 17 | _double_operator2 = frozenset(["<", ">"]) 18 | 19 | def __init__(self): 20 | self.current_state = "" 21 | 22 | def clear(self): 23 | self.current_state = "" 24 | 25 | def process(self, current=None): 26 | if current is None and self.current_state != "": 27 | op = self.current_state 28 | self.current_state = "" 29 | return op 30 | if len(self.current_state) == 0: 31 | if current == "~": 32 | return "~" 33 | if current in self._Operators: 34 | self.current_state += current 35 | return None 36 | elif len(self.current_state) == 1: 37 | op = self.current_state + current 38 | if current == "=": 39 | self.current_state = "" 40 | return op 41 | elif current == self.current_state: 42 | if current in self._double_operator: # && ++ -- || 43 | self.current_state = "" 44 | return op 45 | if current in self._double_operator2: # << >> may go to <<= and >>= 46 | self.current_state = op 47 | return None 48 | else: 49 | op = self.current_state 50 | self.current_state = current 51 | return op 52 | elif op == "->": 53 | self.current_state = "" 54 | return "->" 55 | else: 56 | op = self.current_state 57 | self.current_state = current 58 | return op 59 | else: 60 | op = self.current_state + current 61 | if current == "=": # only <<= and >>= 62 | self.current_state = "" 63 | return op 64 | else: 65 | op = self.current_state 66 | self.current_state = current 67 | return op 68 | return None 69 | 70 | 71 | class FeatureExtractor: 72 | _APIs = ['alloc', 'free', 'mem', 'copy', 'new', 'open', 'close', 'delete', 'create', 'release', 73 | 'sizeof', 'remove', 'clear', 'dequene', 'enquene', 'detach', 'Attach', 'str', 'string', 74 | 'lock', 'mutex', 'spin', 'init', 'register', 'disable', 'enable', 'put', 'get', 'up', 75 | 'down', 'inc', 'dec', 'add', 'sub', 'set', 'map', 'stop', 'start', 'prepare', 'suspend', 76 | 'resume', 'connect'] # 42 77 | 78 | _Formatted_strings = ['d', 'i', 'o', 'u', 'x', 'X', 'f', 'F', 'e', 'E', 'g', 'G', 79 | 'a', 'A', 'c', 'C', 's', 'S', 'p', 'n'] # 21 80 | 81 | _Operators = ['bitand', 'bitor', 'xor', 'not', 'not_eq', 'or', 'or_eq', 'and', '++', '--', 82 | '+', '-', '*', '/', '%', '=', '+=', '-=', '*=', '/=', '%=', '<<=', 83 | '>>=', '&=', '^=', '|=', '&&', '||', '!', '==', '!=', '>=', '<=', '>', '<', '&', 84 | '|', '<<', '>>', '~', '^', '->'] # 42 85 | 86 | _Keywords = ['asm', 'auto', 'alignas', 'alignof', 'bool', 'break', 'case', 87 | 'catch', 'char', 'char16_t', 'char32_t', 'class', 'const', 'const_cast', 88 | 'constexpr', 'continue', 'decltype', 'default', 'do', 'double', 89 | 'dynamic_cast', 'else', 'enum', 'explicit', 'export', 'extern', 'false', 'float', 90 | 'for', 'friend', 'goto', 'if', 'inline', 'int', 'long', 'mutable', 'namespace', 91 | 'noexcept', 'nullptr', 'operator', 'private', 'protected', 'public', 92 | 'reinterpret_cast', 'return', 'short', 'signed', 'static', 93 | 'static_assert', 'static_cast', 'struct', 'switch', 'template', 'this', 94 | 'thread_local', 'throw', 'true', 'try', 'typedef', 'typeid', 'typename', 'union', 95 | 'unsigned', 'using', 'virtual', 'void', 'volatile', 'wchar_t', 'while', 'compl', 96 | 'override', 'final', 'assert'] # 77 97 | 98 | def __init__(self): 99 | self._No_Formatted_string_List = self._APIs + self._Operators + self._Keywords 100 | self._No_Formatted_string_Dict = dict([word, 0] for word in self._No_Formatted_string_List) 101 | self._Formatted_strings_Dict = dict([word, 0] for word in self._Formatted_strings) 102 | self.lexer = CppLexer() 103 | self.operator_state_machine = OperatorStateMachine() 104 | self.n = len(self._No_Formatted_string_List) + len(self._Formatted_strings) 105 | self.lock = threading.Lock() 106 | 107 | def clean(self): 108 | self._No_Formatted_string_Dict = dict([word, 0] for word in self._No_Formatted_string_List) 109 | self._Formatted_strings_Dict = dict([word, 0] for word in self._Formatted_strings) 110 | self.lexer = CppLexer() 111 | self.operator_state_machine.clear() 112 | 113 | def _extract(self, code: str): 114 | tokens = self.lexer.get_tokens(code) 115 | for token_type, value in tokens: 116 | if token_type == Token.Operator: 117 | op = self.operator_state_machine.process(value) 118 | if op is not None: 119 | if op in self._No_Formatted_string_List: 120 | self._No_Formatted_string_Dict[op] += 1 121 | else: 122 | op = self.operator_state_machine.process() 123 | if op is not None: 124 | if op in self._No_Formatted_string_List: 125 | self._No_Formatted_string_Dict[op] += 1 126 | if token_type == Token.Literal.String: 127 | if value != '"': 128 | format_symbols = re.findall(r'%([-+0 #]{0,5}\d*(?:\.\d+)?)[lhL]?([diouxXfFeEgGaAcCsSpn])', 129 | value) 130 | for symbols in format_symbols: 131 | if symbols[-1] in self._Formatted_strings: 132 | self._Formatted_strings_Dict[symbols[-1]] += 1 133 | if token_type in [Token.Keyword, Token.Keyword.Type, Token.Keyword.Reserved, 134 | Token.Name, Token.Name.Builtin]: 135 | if value in self._No_Formatted_string_Dict: 136 | self._No_Formatted_string_Dict[value] += 1 137 | return {**self._Formatted_strings_Dict, **self._No_Formatted_string_Dict} 138 | 139 | def extract_vector(self, code: str): 140 | with self.lock: 141 | self.clean() 142 | token_dict = self._extract(code) 143 | # One-Hot Vector 144 | return np.array([1 if token_dict[key] > 0 else 0 for key in token_dict], dtype=np.uint8) 145 | 146 | def extract_from_files(self, file_list: List[str]) -> Iterable: 147 | return FeatureVectorFileListIter(self, file_list) 148 | 149 | 150 | class FeatureVectorFileListIter(Iterable): 151 | def __init__(self, extractor: FeatureExtractor, file_list: List[str]): 152 | self.extractor = extractor 153 | self.len = len(file_list) 154 | self.file_list = iter(file_list) 155 | 156 | def __iter__(self): 157 | iter_list = copy.deepcopy(self.file_list) 158 | for func_path in iter_list: 159 | with open(func_path, "r") as f: 160 | yield self.extractor.extract_vector(f.read()) 161 | 162 | def __len__(self): 163 | return self.len 164 | -------------------------------------------------------------------------------- /BloomFilter/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | from loguru import logger 6 | 7 | import BloomFilter.sfbl 8 | import BloomFilter.feature_extractor 9 | 10 | DetectFilter = None 11 | 12 | 13 | def _threshold_to_maximum_tries(threshold: int) -> int: 14 | return -threshold + 1 15 | 16 | 17 | def _default_dump(obj): 18 | """Convert numpy classes to JSON serializable objects.""" 19 | if isinstance(obj, (np.integer, np.floating, np.bool_)): 20 | return obj.item() 21 | elif isinstance(obj, np.ndarray): 22 | return obj.tolist() 23 | else: 24 | return obj 25 | 26 | 27 | def initialization(vul_functions, rebuild=False): 28 | cache_json_path = "cache/bloomFilter.json" 29 | Extractor = BloomFilter.feature_extractor.FeatureExtractor() 30 | try: 31 | if rebuild: 32 | raise Exception("Rebuild flag on") 33 | with open(cache_json_path) as f: 34 | threshold = json.load(f)["threshold"] 35 | except Exception as e: 36 | rebuild = True 37 | logger.warning("Threshold cache fail or rebuild flag on, refind threshold: {}".format(e)) 38 | threshold = -100 39 | # For convenience generating SFBF, the procedure of generating threshold is removed and the value of threshold 40 | # is fixed to -100(100 tries). 41 | # 42 | # non_sample, normal_dataset is no need here anymore. You can just use vulnerability functions you collected 43 | # to generate the SFBF with the corresponding format listed in the README. 44 | # 45 | # If you want to know how we set the threshold to -100, please refer to the previous commit. 46 | # 47 | # To find out the threshold, you should prepare a vulnerable function dataset which have different 48 | # versions of vulnerable, a normal dataset which consist functions of popular projects. 49 | os.makedirs(os.path.dirname(cache_json_path), exist_ok=True) 50 | with open(cache_json_path, "w") as f: 51 | json.dump({"threshold": threshold}, f, default=_default_dump) 52 | logger.info(f"Bloom Filter Using Threshold: {threshold}") 53 | # SFBL Bloom Filter Constructing 54 | global DetectFilter 55 | DetectFilter = BloomFilter.sfbl.SFBL(n=Extractor.n, maximum_tries=_threshold_to_maximum_tries(threshold), 56 | dropout_rate=0.17, rebuild=rebuild) 57 | if DetectFilter.rebuild: 58 | construct_vectors = Extractor.extract_from_files(vul_functions) 59 | DetectFilter.construct(construct_vectors, threshold) 60 | else: 61 | logger.info("Using Cached Bloom Filter Bins") 62 | 63 | 64 | def detect(code: str) -> bool: 65 | if isinstance(DetectFilter, BloomFilter.sfbl.SFBL): 66 | Extractor = BloomFilter.feature_extractor.FeatureExtractor() 67 | vector = Extractor.extract_vector(code) 68 | return DetectFilter.detect(vector) 69 | else: 70 | logger.critical("Bloom Filter Not Initialized") 71 | raise Exception("Bloom Filter Not Initialized") 72 | -------------------------------------------------------------------------------- /BloomFilter/sfbl.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os.path 3 | import shutil 4 | import sys 5 | from collections import namedtuple 6 | from typing import Iterable 7 | 8 | import numpy as np 9 | from bloom_filter2 import BloomFilter 10 | from loguru import logger 11 | from tqdm import tqdm 12 | 13 | ThresholdInfo = namedtuple("ThresholdInfo", "recall tnr") 14 | 15 | 16 | def get_require_threshold(threshold_info, require_recall=0.95): 17 | target_threshold, recall, tnr = math.inf, 1, 0 18 | t = sorted(threshold_info.items(), key=lambda x: x[0]) 19 | for threshold, info in t: 20 | if info["recall"] >= require_recall: 21 | target_threshold = threshold 22 | else: 23 | break 24 | return target_threshold 25 | 26 | 27 | def _query_similarity(cnt): 28 | return -cnt 29 | 30 | 31 | class SFBL: 32 | def __init__(self, n, N=10000, maximum_tries=100, dropout_rate=0.1, seed=20231031, rebuild=False, use_cache=True): 33 | self.cache_dir = "cache/sfbl" 34 | self._threshold = -(maximum_tries - 1) # default threshold related to maximum_tries 35 | self._dropout_cnt = round(n * dropout_rate) 36 | self._maximum_tries = maximum_tries 37 | self.rebuild = rebuild 38 | if use_cache: 39 | if not os.path.exists(self.cache_dir) or len(os.listdir(self.cache_dir)) != self._maximum_tries: 40 | self.rebuild = True 41 | shutil.rmtree(self.cache_dir, ignore_errors=True) 42 | os.makedirs(self.cache_dir) 43 | self._filters = [BloomFilter(max_elements=N, error_rate=1e-5, 44 | filename=(os.path.join(self.cache_dir, f"{i}.sfbl.bin"), -1), 45 | start_fresh=self.rebuild) for i in range(self._maximum_tries)] 46 | else: 47 | self._filters = [BloomFilter(max_elements=N, error_rate=1e-5) for _ in range(self._maximum_tries)] 48 | self._seed = seed 49 | 50 | def _vector_encode(self, vector: np.ndarray): 51 | return vector[self._dropout_cnt:].tobytes() 52 | 53 | def insert(self, vector: np.ndarray): 54 | t = vector.copy() 55 | for i in range(self._maximum_tries): 56 | rng = np.random.RandomState(self._seed + i) 57 | rng.shuffle(t) 58 | self._filters[i].add(self._vector_encode(t)) 59 | 60 | def query(self, vector: np.ndarray): 61 | t = vector.copy() 62 | for i in range(self._maximum_tries): 63 | rng = np.random.RandomState(self._seed + i) 64 | rng.shuffle(t) 65 | if self._vector_encode(t) in self._filters[i]: 66 | return _query_similarity(i) 67 | return _query_similarity(self._maximum_tries) 68 | 69 | def construct(self, construct_vec: Iterable[np.ndarray], threshold: float) -> None: 70 | if not self.rebuild: 71 | logger.critical("Constructing a constructed SFBL") 72 | raise Exception("Constructing a constructed SFBL") 73 | self._threshold = threshold 74 | for v in tqdm(construct_vec, desc="Construct Set", unit="Funcs", smoothing=0, file=sys.stdout): 75 | self.insert(v) 76 | 77 | def detect(self, target_vec: np.ndarray) -> bool: 78 | score = self.query(target_vec) # the score is a negative. The less similar, the fewer score. 79 | return score > self._threshold 80 | -------------------------------------------------------------------------------- /Dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import Dataset.normal_sample 2 | import Dataset.old_new_funcs 3 | import Dataset.target_project 4 | import Dataset.base 5 | 6 | NormalSample = Dataset.normal_sample.NormalSampleDataset 7 | old_new_funcs_filename_split = Dataset.old_new_funcs.old_new_funcs_filename_split 8 | OldNewFuncs = Dataset.old_new_funcs.OldNewFuncsDataset 9 | Project = Dataset.target_project.ProjectDataset 10 | Base = Dataset.base.BaseDataset 11 | -------------------------------------------------------------------------------- /Dataset/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import List 3 | 4 | 5 | class BaseDataset: 6 | def __init__(self, dataset_folder_path: str, seed=20231031): 7 | """ 8 | Initialize basic property of the Dataset 9 | :param dataset_folder_path: Path to the folder of Dataset 10 | :param seed: seed for random 11 | """ 12 | self.dataset_folder_path = dataset_folder_path 13 | self.seed = seed 14 | 15 | @abc.abstractmethod 16 | def get_funcs(self, size=-1, **kwargs) -> List[str]: 17 | return [] 18 | -------------------------------------------------------------------------------- /Dataset/normal_sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import sys 5 | from typing import List 6 | 7 | from loguru import logger 8 | from tqdm import tqdm 9 | 10 | import Dataset.base 11 | import Dataset.utils 12 | 13 | 14 | class NormalSampleDataset(Dataset.base.BaseDataset): 15 | def _software_path(self, software): 16 | return os.path.join(self.dataset_folder_path, software) 17 | 18 | def _function_path(self, software, function): 19 | return os.path.join(self._software_path(software), function) 20 | 21 | def _software_list_generator(self): 22 | def _is_dir(software): 23 | return os.path.isdir(self._software_path(software)) 24 | 25 | return filter(_is_dir, os.listdir(self.dataset_folder_path)) 26 | 27 | def _function_list_generator(self, software): 28 | return os.listdir(self._software_path(software)) 29 | 30 | def _preprocess(self, size, seed): 31 | logger.info("Preprocessing NormalSample Dataset") 32 | func_path_list = [] 33 | for software in self._software_list_generator(): 34 | for function in self._function_list_generator(software): 35 | func_path_list.append((function, self._function_path(software, function))) 36 | rng = random.Random(seed) 37 | rng.shuffle(func_path_list) 38 | i = 0 39 | with tqdm(total=size, desc="Normal", unit="Funcs", file=sys.stdout) as pbar: 40 | for function, func_path in func_path_list: 41 | with open(func_path) as f: 42 | code = Dataset.utils.function_purification(f.read()) 43 | if code == "": 44 | continue 45 | target_file = os.path.join(self.cache_dir, function) 46 | with open(target_file, "w") as f: 47 | f.write(code) 48 | i += 1 49 | pbar.update() 50 | if i == size: 51 | break 52 | logger.info("Preprocessing Finished") 53 | 54 | def __init__(self, dataset_folder_path: str, seed=20231031, size=3000, rebuild=False): 55 | """ 56 | Initialize NormalSample dataset 57 | :param dataset_folder_path: Path to the folder of NormalSample Dataset 58 | :param seed: seed for random 59 | """ 60 | super().__init__(dataset_folder_path, seed) 61 | 62 | logger.info("Initializing NormalSample Dataset") 63 | self.cache_dir = os.path.join(os.getcwd(), "cache", "normal") 64 | if rebuild or not (os.path.exists(self.cache_dir) and len(os.listdir(self.cache_dir)) == size): 65 | shutil.rmtree(self.cache_dir, ignore_errors=True) 66 | os.makedirs(self.cache_dir, exist_ok=True) 67 | self._preprocess(size, self.seed) 68 | else: 69 | logger.info("Using NormalSample preprocessed Cache") 70 | logger.info(f"NormalSample Dataset Size: {len(os.listdir(self.cache_dir))}") 71 | 72 | def get_funcs(self, **kwargs) -> (List[str], List[int]): 73 | """ 74 | Get the function list of the NormalSample dataset. 75 | :return: The function list & The tag. 76 | """ 77 | return Dataset.utils.abs_listdir(self.cache_dir) 78 | -------------------------------------------------------------------------------- /Dataset/old_new_funcs.py: -------------------------------------------------------------------------------- 1 | # OldNewFuncsDataset Folder 2 | # Main Folder - Software - CVE - sample functions(OLD=vul NEW=no vul) 3 | import os.path 4 | import random 5 | import shutil 6 | import sys 7 | from collections import namedtuple 8 | from typing import List, Tuple 9 | 10 | from loguru import logger 11 | from tqdm import tqdm 12 | 13 | import Dataset.base 14 | import Dataset.utils 15 | from Dataset.utils import abs_listdir, function_purification 16 | 17 | 18 | def old_new_funcs_filename_split(name: str) -> (str, str, str, str, str, str, str): 19 | """ 20 | A util function to parse the filename of old new funcs 21 | :param name: the filename to parse 22 | :return: Corresponding CVE, CWE, commit_hash, filename, version, function name, vulnerable/patch 23 | """ 24 | part = name.split("_") 25 | cve = part[0] 26 | cwe = part[1] 27 | commit_hash = part[2] 28 | # filename 29 | i = 3 30 | # All the filename in old_new_funcs contains any .c 31 | while part[i].rfind(".c") != -1: 32 | i += 1 33 | file_name = "_".join(part[3:i]) 34 | if part[i].rfind(".") != -1: 35 | version = part[i] 36 | i += 1 37 | else: 38 | version = "" 39 | func_name = "_".join(part[i:-1]) 40 | old_new = part[-1][:-4] 41 | return (cve.strip(), cwe.strip(), commit_hash.strip(), file_name.strip(), version.strip(), func_name.strip(), 42 | old_new.strip()) 43 | 44 | 45 | class OldNewFuncsDataset(Dataset.base.BaseDataset): 46 | FunctionInfo = namedtuple("FunctionInfo", "software cve function_name vul sample") 47 | 48 | def _software_path(self, software): 49 | return os.path.join(self.dataset_folder_path, software) 50 | 51 | def _cve_path(self, software, cve): 52 | return os.path.join(self._software_path(software), cve) 53 | 54 | def _function_path(self, software, cve, function): 55 | return os.path.join(self._cve_path(software, cve), function) 56 | 57 | def _software_list_generator(self): 58 | return filter(lambda software: os.path.isdir(self._software_path(software)), 59 | os.listdir(self.dataset_folder_path)) 60 | 61 | def _cve_list_generator(self, software): 62 | return filter(lambda cve: os.path.isdir(self._cve_path(software, cve)), 63 | os.listdir(self._software_path(software))) 64 | 65 | def _function_list_generator(self, software, cve): 66 | return os.listdir(self._cve_path(software, cve)) 67 | 68 | def _preprocess(self): 69 | logger.info("Preprocessing Old_New_Funcs Dataset") 70 | with tqdm(desc="Old_New_Funcs", unit="Funcs", file=sys.stdout) as pbar: 71 | for software in self._software_list_generator(): 72 | for cve in self._cve_list_generator(software): 73 | function_set = set() 74 | for function in self._function_list_generator(software, cve): 75 | func_path = self._function_path(software, cve, function) 76 | with open(func_path) as f: 77 | # Code Purification 78 | try: 79 | raw_code = f.read() 80 | except UnicodeDecodeError: 81 | with open(func_path, encoding="cp1252") as f2: 82 | raw_code = str(f2.read()) 83 | code = function_purification(raw_code) 84 | if code == "": 85 | continue 86 | # Function Tagging 87 | _, _, _, _, _, func_name, old_new = old_new_funcs_filename_split(function) 88 | # OLD function are vulnerable function 89 | # while NEW function are patched function of the corresponding OLD function 90 | is_vul = (old_new == "OLD") 91 | if is_vul: 92 | if func_name not in function_set: 93 | target_file = os.path.join(self.sample_dir, function) 94 | function_set.add(func_name) 95 | else: 96 | target_file = os.path.join(self.non_sample_dir, function) 97 | else: 98 | target_file = os.path.join(self.no_vul_dir, function) 99 | pbar.update() 100 | with open(target_file, "w") as f: 101 | f.write(code) 102 | logger.info("Preprocessing Finished") 103 | 104 | def __init__(self, dataset_folder_path, seed=20231031, rebuild=False): 105 | """ 106 | Initializing the Old New Funcs Dataset 107 | :param dataset_folder_path: Where the dataset stores 108 | :param seed: seed for random 109 | """ 110 | super().__init__(dataset_folder_path, seed) 111 | self.dataset_folder_path = dataset_folder_path 112 | self.seed = seed 113 | self.funcs_info = {} 114 | 115 | logger.info("Initializing Old_New_Funcs Dataset") 116 | self.cache_dir = os.path.join(os.getcwd(), "cache", "old_new_funcs") 117 | self.vul_dir = os.path.join(self.cache_dir, "vul") 118 | self.no_vul_dir = os.path.join(self.cache_dir, "no_vul") 119 | self.sample_dir = os.path.join(self.vul_dir, "sample") 120 | self.non_sample_dir = os.path.join(self.vul_dir, "non_sample") 121 | for chk_dir in [self.sample_dir, self.non_sample_dir, self.no_vul_dir]: 122 | if rebuild or not (os.path.exists(chk_dir) and len(os.listdir(chk_dir)) != 0): 123 | shutil.rmtree(self.cache_dir, ignore_errors=True) 124 | os.makedirs(self.sample_dir, exist_ok=True) 125 | os.makedirs(self.non_sample_dir, exist_ok=True) 126 | os.makedirs(self.no_vul_dir, exist_ok=True) 127 | self._preprocess() 128 | break 129 | else: 130 | logger.info("Using Old_New_Funcs preprocessed Cache") 131 | sample_size = len(os.listdir(self.sample_dir)) 132 | non_sample_size = len(os.listdir(self.non_sample_dir)) 133 | no_vul_size = len(os.listdir(self.no_vul_dir)) 134 | logger.info(f"Old_New_Funcs Dataset Total Size {sample_size + non_sample_size + no_vul_size}") 135 | logger.info(f"VulFunctions: {sample_size + non_sample_size}, VulSamples: {sample_size}") 136 | logger.info(f"NoVulFunctions: {no_vul_size}") 137 | 138 | def get_funcs(self, size=-1, vul=False, no_vul=False, sample=False, non_sample=False) -> List[str]: 139 | """ 140 | Get the function list of the OLD NEW FUNCS dataset. 141 | Only one True among : vul no_vul sample non_sample 142 | vul = sample + non_sample 143 | whole_dataset = vul + no_vul 144 | :param size: Size of the return function list 145 | :param vul: All the return function list is vulnerable 146 | :param no_vul: All the return function list is not vulnerable 147 | :param sample: All the return function list is vulnerable and contains no duplicate source function.(No function 148 | has another vulnerable version in the list) 149 | :param non_sample: The function which vulnerable but not in the sample list 150 | :return: the function path list 151 | """ 152 | if vul: 153 | func_path_list = abs_listdir(self.sample_dir) + abs_listdir(self.non_sample_dir) 154 | elif no_vul: 155 | func_path_list = abs_listdir(self.no_vul_dir) 156 | elif sample: 157 | func_path_list = abs_listdir(self.sample_dir) 158 | elif non_sample: 159 | func_path_list = abs_listdir(self.non_sample_dir) 160 | else: 161 | func_path_list = (abs_listdir(self.sample_dir) + abs_listdir(self.non_sample_dir) + 162 | abs_listdir(self.no_vul_dir)) 163 | 164 | if size != -1: 165 | rng = random.Random(self.seed) 166 | func_path_list = rng.sample(func_path_list, min(size, len(func_path_list))) 167 | 168 | return func_path_list 169 | 170 | def get_func_pairs(self) -> List[Tuple[str, str]]: 171 | """ 172 | Output old and new function pairs 173 | :return: function pair list 174 | """ 175 | func_pairs = [] 176 | 177 | def _find_func_pairs(target_dir): 178 | sample_list = os.listdir(target_dir) 179 | for func_rel_path in sample_list: 180 | new_func_rel_path = func_rel_path.replace("OLD", "NEW") 181 | if os.path.exists(os.path.join(self.no_vul_dir, new_func_rel_path)): 182 | func_pairs.append((os.path.join(target_dir, func_rel_path), 183 | os.path.join(self.no_vul_dir, new_func_rel_path))) 184 | 185 | _find_func_pairs(self.sample_dir) 186 | _find_func_pairs(self.non_sample_dir) 187 | return func_pairs 188 | -------------------------------------------------------------------------------- /Dataset/target_project.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import shutil 5 | import subprocess 6 | from typing import List 7 | 8 | from loguru import logger 9 | 10 | import Dataset.base 11 | import Dataset.utils 12 | 13 | 14 | class ProjectDataset(Dataset.base.BaseDataset): 15 | """ 16 | Dataset of the target function to detect 17 | """ 18 | 19 | def _preprocess(self, project_dir): 20 | logger.info("Preprocessing Target Function Dataset") 21 | logger.info(f"Extracting function from {project_dir} to {self.cache_dir}") 22 | 23 | cmd = (f'{self.path_to_ctags} -R --kinds-C++=f -u --fields=-fP+ne --language-force=c --language-force=c++' 24 | f' --output-format=json -f - "{project_dir}"') 25 | logger.debug(f"{cmd}") 26 | all_function_list_str = subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True).decode( 27 | errors="ignore") 28 | 29 | current_file = "" 30 | current_code = [] 31 | all_function_list = all_function_list_str.split("\n") 32 | self.total_functions = len(all_function_list) # All function including those < 3 lines 33 | for line in all_function_list: 34 | if line == "": 35 | continue 36 | try: 37 | info = json.loads(line) 38 | except BaseException as e: 39 | logger.error(f"Error {e} When Parsing Ctag info: ", line) 40 | continue 41 | if info["path"] != current_file: 42 | ext = os.path.splitext(info["path"])[1].lower() 43 | if ext not in [".c", ".cc", ".cxx", ".cpp", ".c++", "cp", ".h", ".hh", "hp", ".hpp", ".hxx", ".h++"]: 44 | continue 45 | try: 46 | with open(info["path"]) as f: 47 | current_code = f.read().split("\n") 48 | current_file = info["path"] 49 | except: 50 | logger.warning(f"Fail to Parse Function in {info['path']}") 51 | continue 52 | # Get Function Range 53 | start_line = info["line"] - 1 54 | if "end" not in info: 55 | continue 56 | end_line = info["end"] 57 | 58 | # Reconstruct function declaration since sometimes they are something missing 59 | try: 60 | if "typeref" in info: 61 | func_type_parts = info["typeref"].split(":") 62 | if len(func_type_parts) > 1: 63 | if func_type_parts[0] == "typename": 64 | func_type = ":".join(func_type_parts[1:]) 65 | else: 66 | func_type = func_type_parts[0] + " " + ":".join(func_type_parts[1:]) 67 | else: 68 | func_type = func_type_parts[0] 69 | if func_type[-1] not in ["*", "&"]: 70 | func_type += " " 71 | else: 72 | func_type = "" 73 | func_decl_parts = current_code[start_line].split(info["name"], 1) 74 | if len(func_decl_parts) >= 2: 75 | current_code[start_line] = f"{func_type}{info['name']}{func_decl_parts[1]}" 76 | # Or we'll give up Reconstructing Declaration 77 | except Exception as e: 78 | logger.warning("Function Declaration Parse Error: {}".format(e)) 79 | func_body = "\n".join(current_code[start_line:end_line]) 80 | # function_body purification 81 | func_body = Dataset.utils.function_purification(func_body, self.skip_loc_threshold) 82 | if func_body == "": 83 | continue 84 | # ConstructPath 85 | relative_path = os.path.relpath(info["path"], project_dir) 86 | function_file_name = info["name"] + "@@@" + "@#@".join(relative_path.split("/")) 87 | function_file_name = function_file_name.replace("/", "%2F") 88 | function_file_name = function_file_name.replace("%", "%25") 89 | 90 | target_file = os.path.join(self.cache_dir, function_file_name) 91 | logger.debug(f"writing function to {target_file}") 92 | with open(target_file, "w") as f: 93 | f.write(func_body) 94 | logger.info("Target Function Preprocessing Finished") 95 | 96 | def __init__(self, project_dir: str, seed=20231031, rebuild=False, skip_loc_threshold=False, restore_processed=False): 97 | """ 98 | Initialize Project dataset 99 | :param project_dir: Path to the folder of Target Project Dataset 100 | :param seed: seed for random 101 | """ 102 | super().__init__(project_dir, seed) 103 | self.func_path_list = [] 104 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 105 | self.path_to_ctags = os.path.join(cur_dir, "universal-ctags/ctags") 106 | self.skip_loc_threshold = skip_loc_threshold 107 | self.restore_processed = restore_processed 108 | 109 | if not os.path.exists(self.path_to_ctags): 110 | logger.critical("Ctags Not Found In Given Path") 111 | raise Exception("Ctags Not Found In Given Path") 112 | if not os.path.exists(project_dir): 113 | logger.critical("The target Project Path is Not Exist") 114 | raise Exception("The target Project Path is Not Exist") 115 | 116 | logger.info("Initializing Project Dataset") 117 | project_name = os.path.split(project_dir.rstrip("/"))[-1] 118 | self.cache_dir = os.path.join(os.curdir, "processed", project_name) 119 | if rebuild or not (os.path.exists(self.cache_dir) and len(os.listdir(self.cache_dir)) != 0): 120 | shutil.rmtree(self.cache_dir, ignore_errors=True) 121 | os.makedirs(self.cache_dir, exist_ok=True) 122 | self._preprocess(project_dir) 123 | else: 124 | if not (os.path.exists(self.cache_dir) and len(os.listdir(self.cache_dir)) != 0): 125 | os.makedirs(self.cache_dir, exist_ok=True) 126 | self._preprocess(project_dir) 127 | else: 128 | logger.info("Using Target_Function preprocessed Cache") 129 | logger.info(f"Project Dataset Size: {len(os.listdir(self.cache_dir))}") 130 | 131 | def __del__(self): 132 | if not self.restore_processed: 133 | shutil.rmtree(self.cache_dir, ignore_errors=True) 134 | 135 | def get_funcs(self, size=-1, **kwargs) -> List[str]: 136 | """ 137 | Get the function list of the Project dataset. 138 | :param size: Size of the return function list. 139 | :return: The function list. 140 | """ 141 | if size != -1: 142 | rng = random.Random(self.seed) 143 | func_path_list = rng.sample(Dataset.utils.abs_listdir(self.cache_dir), 144 | min(size, len(self.func_path_list))) 145 | else: 146 | func_path_list = Dataset.utils.abs_listdir(self.cache_dir) 147 | return func_path_list -------------------------------------------------------------------------------- /Dataset/universal-ctags/COPYING: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc. 5 | 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Library General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 19yy 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License 307 | along with this program; if not, write to the Free Software 308 | Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 309 | 310 | 311 | Also add information on how to contact you by electronic and paper mail. 312 | 313 | If the program is interactive, make it output a short notice like this 314 | when it starts in an interactive mode: 315 | 316 | Gnomovision version 69, Copyright (C) 19yy name of author 317 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 318 | This is free software, and you are welcome to redistribute it 319 | under certain conditions; type `show c' for details. 320 | 321 | The hypothetical commands `show w' and `show c' should show the appropriate 322 | parts of the General Public License. Of course, the commands you use may 323 | be called something other than `show w' and `show c'; they could even be 324 | mouse-clicks or menu items--whatever suits your program. 325 | 326 | You should also get your employer (if you work as a programmer) or your 327 | school, if any, to sign a "copyright disclaimer" for the program, if 328 | necessary. Here is a sample; alter the names: 329 | 330 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 331 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 332 | 333 | , 1 April 1989 334 | Ty Coon, President of Vice 335 | 336 | This General Public License does not permit incorporating your program into 337 | proprietary programs. If your program is a subroutine library, you may 338 | consider it more useful to permit linking proprietary applications with the 339 | library. If this is what you want to do, use the GNU Library General 340 | Public License instead of this License. 341 | -------------------------------------------------------------------------------- /Dataset/universal-ctags/ctags: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CGCL-codes/FIRE/a9ff3f9439ef948003ab7677defb4899c100d701/Dataset/universal-ctags/ctags -------------------------------------------------------------------------------- /Dataset/universal-ctags/readme.txt: -------------------------------------------------------------------------------- 1 | The file ctag is a binary nightly build version of universal-ctags. Which is a free software under GPLv2 (COPYING File). 2 | Version: https://github.com/universal-ctags/ctags-nightly-build/releases/tag/2023.12.20%2B293f11ef2834d540eebd9cfc976369eaf3c118d7 3 | File: https://github.com/universal-ctags/ctags-nightly-build/releases/download/2023.12.20%2B293f11ef2834d540eebd9cfc976369eaf3c118d7/uctags-2023.12.20-linux-x86_64.tar.xz -------------------------------------------------------------------------------- /Dataset/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | 5 | def function_purification(code: str, skip_loc_threshold=False) -> str: 6 | # remove comments 7 | code = re.sub('\/\*[\w\W]*?\*\/', "", code) 8 | code = re.sub(r'//.*?\n', "\n", code) 9 | # remove non-ASCII 10 | code = re.sub(r"[^\x00-\x7F]+", "", code) 11 | # remove # 12 | code = re.sub(r"^#.*", "", code, flags=re.MULTILINE) 13 | # Counting ; as a way to see how many code lines, We do not consider very short functions 14 | if not skip_loc_threshold and code.count(";") <= 3: 15 | return "" 16 | # remove the empty line to compact the code 17 | purified_code_lines = list(filter(lambda c: len(c.strip()) != 0, code.split("\n"))) 18 | # Counting the line which blank or contain only 1 char, We do not consider very short functions 19 | loc = 0 20 | for i in range(len(purified_code_lines)): 21 | purified_code_lines[i] = purified_code_lines[i].strip() 22 | loc += 1 if len(purified_code_lines[i]) > 1 else 0 23 | if not skip_loc_threshold and loc <= 5: 24 | return "" 25 | return "\n".join(purified_code_lines) 26 | 27 | 28 | def abs_listdir(directory: str): 29 | return [os.path.join(directory, path) for path in os.listdir(directory)] 30 | -------------------------------------------------------------------------------- /SyntaxFilter/__init__.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor, as_completed 2 | from loguru import logger 3 | 4 | import config 5 | from .detection import detect_vulnerable_with_initialize 6 | from Trace.utils import vuln_to_patch_dict 7 | 8 | 9 | def initialization(file_pair_list): 10 | logger.info("Initialize Syntax Filter") 11 | for vuln_file, patch_file in file_pair_list: 12 | vuln_to_patch_dict[vuln_file] = patch_file 13 | return 14 | 15 | 16 | def detect( 17 | input_queue, 18 | output_queue, 19 | vulnerable_func_queue, 20 | pbar_queue, 21 | trace_all_result_queue 22 | ) -> None: 23 | with ProcessPoolExecutor(max_workers=config.syntax_worker) as executor: 24 | futures = {} 25 | 26 | def process_future(future): 27 | try: 28 | is_vul, similar_list = future.result() 29 | except Exception as e: 30 | pbar_queue.put(("syntax", False)) 31 | logger.error(f"{str(e)}") 32 | else: 33 | pbar_queue.put(("syntax", is_vul)) 34 | if not is_vul: 35 | return 36 | dst_func, dst_file = futures[future] 37 | output_queue.put((dst_func, dst_file, similar_list)) 38 | logger.debug(f"Syntax: Found potential vulnerable function in {dst_file}") 39 | 40 | while True: 41 | vul_info = input_queue.get() 42 | dst_func, dst_file, similar_list = vul_info 43 | 44 | if dst_file == "__end_of_detection__": 45 | for future in as_completed(futures.keys()): 46 | process_future(future) 47 | output_queue.put(vul_info) 48 | # pbar_queue.put(("__end_of_detection__", False)) 49 | 50 | trace_all_result_queue.put(0) 51 | logger.info("Syntax Finished!") 52 | break 53 | 54 | future = executor.submit( 55 | detect_vulnerable_with_initialize, dst_func, dst_file, similar_list, vulnerable_func_queue, 56 | trace_all_result_queue 57 | ) 58 | futures[future] = (dst_func, dst_file) 59 | 60 | done_futures = [] 61 | for future in futures.keys(): 62 | if not future.done(): 63 | continue 64 | process_future(future) 65 | done_futures.append(future) 66 | 67 | for future in done_futures: 68 | futures.pop(future) 69 | 70 | 71 | -------------------------------------------------------------------------------- /SyntaxFilter/detection.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import os 3 | import traceback 4 | from typing import List 5 | import ppdeep 6 | 7 | from loguru import logger 8 | 9 | import config 10 | from Trace.serializer import Serializer 11 | 12 | from Trace.manager import ( 13 | FunctionManager, 14 | FunctionPairManager, 15 | ) 16 | from Trace.utils import ( 17 | # diff_embedding_dict 18 | # error_func_list, 19 | # patch_line_hash_dict, 20 | vuln_to_patch_dict, 21 | ) 22 | 23 | 24 | def fuzzy_hash_similarity(s1, s2): 25 | return ppdeep.compare(ppdeep.hash(s1), ppdeep.hash(s2)) 26 | 27 | 28 | def get_fuzzy_hash(code, vuln_file, patch_file): 29 | with open(vuln_file, "r") as v, open(patch_file, "r") as p: # type: ignore 30 | vuln_sim = fuzzy_hash_similarity(code, v.read()) 31 | patch_sim = fuzzy_hash_similarity(code, p.read()) 32 | 33 | return vuln_sim, patch_sim 34 | 35 | 36 | # @profile 37 | def detect_vulnerable_with_initialize( 38 | code: str, 39 | dst_file: str, 40 | similar_list: List[str], 41 | vulnerable_func_queue, 42 | trace_all_result_queue=None, 43 | ast_sim_threshold_min=config.ast_sim_threshold_min, 44 | ast_sim_threshold_max=config.ast_sim_threshold_max, 45 | ) -> tuple[bool, list[str]]: 46 | try: 47 | cur_dir = "v1" 48 | dst_dir = f"{cur_dir}/oldnew" 49 | s = Serializer() 50 | 51 | # 1 52 | logger.debug(f"starting test file : {dst_file}") 53 | code_manager = FunctionManager( 54 | src_file=dst_file, 55 | src_func=code, 56 | # dst_dir=f"{cur_dir}/target", 57 | # clear=False, 58 | clear=True, 59 | gen_cfg=False, 60 | gen_taint=False, 61 | ) 62 | 63 | logger.debug("init file completed") 64 | 65 | output_list = [] 66 | cve_list = [] 67 | 68 | ast_sim_dict = {} 69 | # line_hash_dict_dict = {} 70 | near_sims_list = [] 71 | 72 | for vuln_file in similar_list: 73 | patch_file = vuln_to_patch_dict.get(vuln_file) 74 | if patch_file is None: 75 | logger.debug(f"no patch file for {vuln_file}") 76 | continue 77 | 78 | vuln_name = os.path.basename(vuln_file) 79 | cve_id = vuln_name.split("_")[0] 80 | 81 | 82 | logger.debug(f"init {vuln_file}.") 83 | 84 | vuln_manager = FunctionManager( 85 | src_file=vuln_file, 86 | dst_dir=dst_dir, 87 | clear=False, 88 | gen_cfg=False, 89 | gen_taint=False, 90 | ) 91 | patch_manager = FunctionManager( 92 | src_file=patch_file, 93 | dst_dir=dst_dir, 94 | clear=False, 95 | gen_cfg=False, 96 | gen_taint=False, 97 | ) 98 | 99 | func_pair_manager = FunctionPairManager(vuln_manager, patch_manager) 100 | 101 | if not s.get_patch_line(vuln_name) or not s.get_line_hash_dict(vuln_name): 102 | # diff line 103 | logger.debug(f"init {vuln_file} patch line.") 104 | 105 | diff_line = func_pair_manager.get_diff_lines_hash(filter_lines=['{', '}']) 106 | 107 | s.set_patch_line(vuln_name, diff_line) 108 | s.set_line_hash_dict( 109 | vuln_name, 110 | ( 111 | vuln_manager.hash_dict, 112 | patch_manager.hash_dict, 113 | ), 114 | ) 115 | 116 | logger.debug(f"init {vuln_file} patch line ok.") 117 | 118 | # if not s.get_fuzzy_hash(vuln_name): 119 | # s.set_fuzzy_hash( 120 | # vuln_name, (vuln_manager.fuzzy_hash, patch_manager.fuzzy_hash) 121 | # ) 122 | 123 | info = { 124 | "target_file": dst_file, 125 | "vuln_file": vuln_file, 126 | "patch_file": patch_file, 127 | } 128 | 129 | 130 | logger.debug(f"testing {vuln_file}") 131 | 132 | vuln_cond = [] 133 | 134 | 135 | def finish(): 136 | if trace_all_result_queue: 137 | trace_all_result_queue.put( 138 | {**info, **{"datail": vuln_cond, "predict": all(vuln_cond)}} 139 | ) 140 | 141 | if all(vuln_cond): 142 | output_list.append(vuln_file) 143 | cve_list.append(cve_id) 144 | 145 | 146 | dst_hash_dict = code_manager.hash_dict 147 | vuln_hash_dict, patch_hash_dict = s.get_line_hash_dict(vuln_name) 148 | vuln_hash_dict, patch_hash_dict = ( 149 | Counter(vuln_hash_dict), 150 | Counter(patch_hash_dict), 151 | ) 152 | 153 | 154 | 155 | del_lines, add_lines = s.get_patch_line(vuln_name) 156 | 157 | vuln_cond_del_lines = True 158 | for del_line in del_lines: 159 | if ( 160 | vuln_hash_dict[del_line] != patch_hash_dict[del_line] 161 | and dst_hash_dict[del_line] != vuln_hash_dict[del_line] 162 | ): 163 | vuln_cond_del_lines = False 164 | break 165 | 166 | vuln_cond.append(vuln_cond_del_lines) 167 | if del_lines == []: 168 | vuln_cond.append("no del line") 169 | if not vuln_cond_del_lines: 170 | finish() 171 | continue 172 | 173 | 174 | 175 | vuln_cond_add_lines = True 176 | for add_line in add_lines: 177 | if ( 178 | vuln_hash_dict[add_line] != patch_hash_dict[add_line] 179 | and dst_hash_dict[add_line] != vuln_hash_dict[add_line] 180 | ): 181 | vuln_cond_add_lines = False 182 | break 183 | 184 | vuln_cond.append(vuln_cond_add_lines) 185 | if add_lines == []: 186 | vuln_cond.append("no add line") 187 | if not vuln_cond_add_lines: 188 | finish() 189 | continue 190 | 191 | 192 | 193 | def jaccard_similarity(list1, list2): 194 | count1 = {} 195 | count2 = {} 196 | 197 | for item in list1: 198 | count1[item] = count1.get(item, 0) + 1 199 | 200 | for item in list2: 201 | count2[item] = count2.get(item, 0) + 1 202 | 203 | intersection = sum(min(count1.get(item, 0), count2.get(item, 0)) for item in set(list1 + list2)) 204 | union = sum(max(count1.get(item, 0), count2.get(item, 0)) for item in set(list1 + list2)) 205 | 206 | similarity = intersection / union 207 | 208 | return similarity 209 | 210 | vuln_sim, patch_sim = ( 211 | # calculate sim 212 | jaccard_similarity(code_manager.ast_nodes, vuln_manager.ast_nodes), 213 | jaccard_similarity(code_manager.ast_nodes, patch_manager.ast_nodes), 214 | ) 215 | 216 | 217 | 218 | if vuln_sim < ast_sim_threshold_min: 219 | vuln_cond.append(False) 220 | elif vuln_sim < patch_sim: 221 | if patch_sim - vuln_sim > 0.15: 222 | vuln_cond.append(True) 223 | near_sims_list.append(vuln_file) 224 | else: 225 | vuln_cond.append(False) 226 | else: 227 | vuln_cond.append(True) 228 | 229 | ast_sim_dict[vuln_file] = vuln_sim 230 | 231 | vuln_cond.extend([vuln_sim or "0", patch_sim or "0"]) 232 | finish() 233 | 234 | 235 | if len(output_list) > 1: 236 | if len(output_list) != len(set(cve_list)): 237 | 238 | cve_dict = {} 239 | for vuln_file in output_list: 240 | cve_id = os.path.basename(vuln_file).split("_")[0] 241 | if cve_id in cve_dict: 242 | cve_dict[cve_id].append(vuln_file) 243 | else: 244 | cve_dict[cve_id] = [vuln_file] 245 | 246 | output_list = [] 247 | 248 | for cve_id, vuln_files in cve_dict.items(): 249 | if len(vuln_files) == 1: 250 | vuln_file = vuln_files[0] 251 | else: 252 | 253 | vuln_file = max(vuln_files, key = lambda x: ast_sim_dict[x]) 254 | 255 | if ast_sim_dict[vuln_file] > ast_sim_threshold_max and vuln_file not in near_sims_list: 256 | 257 | vulnerable_func_queue.put((code, dst_file, [vuln_file])) 258 | else: 259 | 260 | output_list.append(vuln_file) 261 | 262 | else: 263 | new_output_list = [] 264 | for vuln_file in output_list: 265 | if ast_sim_dict[vuln_file] > ast_sim_threshold_max and vuln_file not in near_sims_list: 266 | 267 | vulnerable_func_queue.put((code, dst_file, [vuln_file])) 268 | else: 269 | 270 | new_output_list.append(vuln_file) 271 | output_list = new_output_list 272 | 273 | 274 | return output_list != [], output_list 275 | 276 | except Exception as e: 277 | traceback.print_exc() 278 | raise Exception(f"error when process file {dst_file} : {str(e)}") 279 | -------------------------------------------------------------------------------- /TokenFilter/__init__.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor, as_completed 2 | from typing import List 3 | 4 | from loguru import logger 5 | 6 | import TokenFilter.main 7 | import config 8 | 9 | 10 | def initialization(vul_functions: List[str]) -> None: 11 | 12 | logger.info("Initialize TokenFilter") 13 | TokenFilter.main.initialization(vul_functions) 14 | logger.info("TokenFilter Initialized") 15 | 16 | 17 | def detect(input_queue, output_queue, pbar_queue) -> None: 18 | with ProcessPoolExecutor(max_workers=config.token_worker) as executor: 19 | futures = {} 20 | 21 | def process_future(future): 22 | is_vul, similar_list = future.result() 23 | pbar_queue.put(("token", is_vul)) 24 | if not is_vul: 25 | return 26 | dst_func, dst_file = futures[future] 27 | output_queue.put((dst_func, dst_file, similar_list)) 28 | # logger.debug(f"TokenFilter: Found potential vulnerable function in {dst_file}") 29 | 30 | while True: 31 | vul_info = input_queue.get() 32 | dst_func, dst_file, _ = vul_info 33 | 34 | 35 | try: 36 | if dst_file == "__end_of_detection__": 37 | for future in as_completed(futures.keys()): 38 | process_future(future) 39 | output_queue.put(vul_info) 40 | logger.info("Token Filter Finished!") 41 | return 42 | 43 | future = executor.submit(TokenFilter.main.detect, dst_func) 44 | futures[future] = (dst_func, dst_file) 45 | 46 | done_futures = [] 47 | 48 | for future in futures.keys(): 49 | if not future.done(): 50 | continue 51 | process_future(future) 52 | done_futures.append(future) 53 | 54 | for future in done_futures: 55 | futures.pop(future) 56 | 57 | except Exception as e: 58 | logger.error(f"Error detect in {dst_file}: {str(e)}") 59 | 60 | -------------------------------------------------------------------------------- /TokenFilter/main.py: -------------------------------------------------------------------------------- 1 | import math 2 | from multiprocessing import Pool 3 | from functools import partial 4 | import TokenFilter.token_extraction 5 | import os 6 | import config 7 | 8 | VulTokens = [] 9 | VulTokensDict = {} 10 | 11 | def initialization(vul_functions): 12 | # Find Threshold Parse 13 | global VulTokensDict 14 | pool = Pool(5) 15 | VulTokens = pool.map(partial(TokenFilter.token_extraction.get_fea), vul_functions) 16 | for vul_tokens in filter(None, VulTokens): 17 | len_of_tokens = len(vul_tokens[1]) 18 | if len_of_tokens not in VulTokensDict or VulTokensDict[len_of_tokens] is None: 19 | VulTokensDict[len_of_tokens] = [] 20 | VulTokensDict[len_of_tokens].append(vul_tokens) 21 | # VulTokens = list(filter(None, VulTokens)) 22 | 23 | 24 | def detect(code: str) -> tuple[bool, list[str]]: 25 | if VulTokensDict: 26 | is_vul = False 27 | tokens = TokenFilter.token_extraction.get_fea_code(code) 28 | len_of_tokens = len(tokens) 29 | vuln_list = [] 30 | for token_len in range(int(math.ceil(len_of_tokens * config.jaccard_sim_threshold)), 31 | int(math.floor(len_of_tokens / config.jaccard_sim_threshold)) + 1): 32 | 33 | if token_len not in VulTokensDict or VulTokensDict[token_len] is None: 34 | continue 35 | 36 | for vulnandtokens in VulTokensDict[token_len]: 37 | vuln = TokenFilter.token_extraction.get_similarity(tokens, config.jaccard_sim_threshold, vulnandtokens) 38 | if vuln: 39 | vuln_list.append(vuln) 40 | # pool = Pool(5) 41 | # vuln_list = pool.map(partial(TokenFilter.token_extraction.get_similarity, tokens, 0.65), VulTokens) 42 | vuln_list = list(filter(None, vuln_list)) 43 | if vuln_list: 44 | is_vul = True 45 | return is_vul, vuln_list 46 | else: 47 | raise Exception("Token Filter Not Initialized") 48 | 49 | 50 | if __name__ == "__main__": 51 | 52 | forderpath = 'path/to/old/new/funcs' 53 | vuln_list = [] 54 | 55 | for path, dir, files in os.walk(forderpath): 56 | for file in files: 57 | if file.split('_')[-1] == 'OLD.vul': 58 | filePath = os.path.join(path, file) 59 | vuln_list.append(filePath) 60 | print(len(vuln_list)) 61 | 62 | initialization(vuln_list) 63 | print(VulTokens[0]) 64 | a = detect('void InstructionSelector::AddInstruction(Instruction* instr) {\n if (FLAG_turbo_instruction_scheduling &&\n InstructionScheduler::SchedulerSupported()) {\n DCHECK_NOT_NULL(scheduler_);\n scheduler_->AddInstruction(instr);\n } else {\n sequence()->AddInstruction(instr);\n }\n}\n') 65 | print(a) 66 | -------------------------------------------------------------------------------- /TokenFilter/token_extraction.py: -------------------------------------------------------------------------------- 1 | import re 2 | import Levenshtein 3 | from loguru import logger 4 | from collections import Counter 5 | 6 | def isphor(s, liter): 7 | m = re.search(liter, s) 8 | if m is not None: 9 | return True 10 | else: 11 | return False 12 | 13 | 14 | def doubisphor(forward, back): 15 | double = ( 16 | '->', '--', '-=', '+=', '++', '>=', '<=', '==', '!=', '*=', '/=', '%=', '/=', '&=', '^=', '||', '&&', '>>', '<<') 17 | string = forward + back 18 | 19 | if string in double: 20 | return True 21 | else: 22 | return False 23 | 24 | 25 | def trisphor(s, t): 26 | if (s == '>>') | (s == '<<') and (t == '='): 27 | return True 28 | else: 29 | return False 30 | 31 | 32 | def create_tokens(sentence): 33 | formal = '^[_a-zA-Z][_a-zA-Z0-9]*$' 34 | phla = '[^_a-zA-Z0-9]' 35 | space = '\s' 36 | spa = '' 37 | string = [] 38 | j = 0 39 | str = sentence 40 | i = 0 41 | 42 | while (i < len(str)): 43 | if isphor(str[i], space): 44 | if i > j: 45 | string.append(str[j:i]) 46 | j = i + 1 47 | else: 48 | j = i + 1 49 | 50 | elif isphor(str[i], phla): 51 | if (i + 1 < len(str)) and isphor(str[i + 1], phla): 52 | m = doubisphor(str[i], str[i + 1]) 53 | 54 | if m: 55 | string1 = str[i] + str[i + 1] 56 | 57 | if (i + 2 < len(str)) and (isphor(str[i + 2], phla)): 58 | if trisphor(string1, str[i + 2]): 59 | string.append(str[j:i]) 60 | string.append(str[i] + str[i + 1] + str[i + 2]) 61 | j = i + 3 62 | i = i + 2 63 | 64 | else: 65 | string.append(str[j:i]) 66 | string.append(str[i] + str[i + 1]) 67 | string.append(str[i + 2]) 68 | j = i + 3 69 | i = i + 2 70 | 71 | else: 72 | string.append(str[j:i]) 73 | string.append(str[i] + str[i + 1]) 74 | j = i + 2 75 | i = i + 1 76 | 77 | else: 78 | string.append(str[j:i]) 79 | string.append(str[i]) 80 | if str[i] != ';': 81 | string.append(str[i + 1]) 82 | j = i + 2 83 | i = i + 1 84 | else: 85 | j = i + 1 86 | 87 | else: 88 | string.append(str[j:i]) 89 | string.append(str[i]) 90 | j = i + 1 91 | 92 | i = i + 1 93 | 94 | count = 0 95 | count1 = 0 96 | sub0 = '\r' 97 | 98 | if sub0 in string: 99 | string.remove('\r') 100 | 101 | for sub1 in string: 102 | if sub1 == ' ': 103 | count1 = count1 + 1 104 | 105 | for j in range(count1): 106 | string.remove(' ') 107 | 108 | for sub in string: 109 | if sub == spa: 110 | count = count + 1 111 | 112 | for i in range(count): 113 | string.remove('') 114 | 115 | return string 116 | 117 | 118 | 119 | def get_fea(file_path): 120 | try: 121 | with open(file_path, "r", encoding='utf-8') as f: 122 | gadget = f.read() 123 | f.close() 124 | 125 | # final feature dictionary 126 | tokens_list = [] 127 | 128 | # regular expression to catch a-line comment 129 | rx_comment = re.compile('\*/\s*$') 130 | gadget = gadget.split('\n') 131 | 132 | for line in gadget: 133 | # process if not the header line and not a multi-line commented line 134 | if rx_comment.search(line) is None: 135 | 136 | # replace any non-ASCII characters with empty string 137 | ascii_line = re.sub(r'[^\x00-\x7f]', r'', line) 138 | nostrlit_line = re.sub(r'".*?"', '""', ascii_line) 139 | nocharlit_line = re.sub(r"'.*?'", "''", nostrlit_line) 140 | 141 | # tokenlization 142 | tokens = create_tokens(nocharlit_line) 143 | #tokenslist.extend(tokens) 144 | tokens_list.extend(tokens) 145 | return file_path, tokens_list 146 | except UnicodeDecodeError: 147 | pass 148 | 149 | 150 | def get_fea_code(gadget): 151 | try: 152 | # final feature dictionary 153 | tokens_list = [] 154 | 155 | # regular expression to catch a-line comment 156 | rx_comment = re.compile('\*/\s*$') 157 | gadget = gadget.split('\n') 158 | for line in gadget: 159 | # process if not the header line and not a multi-line commented line 160 | if rx_comment.search(line) is None: 161 | 162 | # replace any non-ASCII characters with empty string 163 | ascii_line = re.sub(r'[^\x00-\x7f]', r'', line) 164 | nostrlit_line = re.sub(r'".*?"', '""', ascii_line) 165 | nocharlit_line = re.sub(r"'.*?'", "''", nostrlit_line) 166 | 167 | # tokenlization 168 | tokens = create_tokens(nocharlit_line) 169 | tokens_list.extend(tokens) 170 | 171 | return tokens_list 172 | except UnicodeDecodeError: 173 | pass 174 | 175 | 176 | def jaccard_sim(list1, list2): 177 | counter1 , counter2 = Counter(list1), Counter(list2) 178 | set1,set2 = set(list1), set(list2) 179 | 180 | intersection_size = sum((min(counter1[x], counter2[x]) for x in set1.intersection(set2))) 181 | union_size = sum((max(counter1[x], counter2[x]) for x in set1.union(set2))) 182 | 183 | similarity = intersection_size / union_size if union_size != 0 else 0 184 | 185 | return similarity 186 | 187 | def Jaro_sim(group1, group2): 188 | 189 | 190 | sim = Levenshtein.jaro(group1, group2) 191 | return sim 192 | 193 | 194 | def Jaro_winkler_sim(group1, group2): 195 | 196 | sim = Levenshtein.jaro_winkler(group1, group2) 197 | return sim 198 | 199 | 200 | def Levenshtein_sim(group1, group2): 201 | 202 | distance = Levenshtein.distance(group1, group2) 203 | return distance 204 | 205 | 206 | def Levenshtein_ratio(group1, group2): 207 | 208 | sim = Levenshtein.ratio(group1, group2) 209 | return sim 210 | 211 | 212 | def get_similarity(funtokens, t, vulnandtokens): 213 | vuln, vulntokens = vulnandtokens 214 | 215 | try: 216 | sim = jaccard_sim(funtokens, vulntokens) 217 | if sim >= t: 218 | return vuln 219 | except Exception as e: 220 | logger.error(f"get_similarity failed : {str(e)}") 221 | 222 | 223 | 224 | -------------------------------------------------------------------------------- /Trace/__init__.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor, as_completed 2 | from loguru import logger 3 | 4 | import config 5 | from .detection import detect_vulnerable_with_initialize 6 | from .utils import vuln_to_patch_dict 7 | 8 | def initialization(file_pair_list): 9 | logger.info("Initialize Trace") 10 | for vuln_file, patch_file in file_pair_list: 11 | vuln_to_patch_dict[vuln_file] = patch_file 12 | return 13 | 14 | 15 | def detect( 16 | input_queue, 17 | output_queue, 18 | pbar_queue, 19 | trace_all_result_queue = None 20 | ) -> None: 21 | with ProcessPoolExecutor(max_workers=config.trace_worker) as executor: 22 | futures = {} 23 | 24 | def process_future(future): 25 | try: 26 | is_vul, similar_list = future.result() 27 | except Exception as e: 28 | pbar_queue.put(("trace", False)) 29 | logger.error(f"{str(e)}") 30 | else: 31 | pbar_queue.put(("trace", is_vul)) 32 | if not is_vul: 33 | return 34 | dst_func, dst_file = futures[future] 35 | output_queue.put((dst_func, dst_file, similar_list)) 36 | logger.debug(f"Trace: Found potential vulnerable function in {dst_file}") 37 | 38 | while True: 39 | vul_info = input_queue.get() 40 | dst_func, dst_file, similar_list = vul_info 41 | 42 | if dst_file == "__end_of_detection__": 43 | for future in as_completed(futures.keys()): 44 | process_future(future) 45 | output_queue.put(vul_info) 46 | pbar_queue.put(("__end_of_detection__", False)) 47 | 48 | if trace_all_result_queue: 49 | trace_all_result_queue.put(0) 50 | logger.info("Trace Finished!") 51 | break 52 | 53 | future = executor.submit( 54 | detect_vulnerable_with_initialize, dst_func, dst_file, similar_list, trace_all_result_queue 55 | ) 56 | futures[future] = (dst_func, dst_file) 57 | 58 | done_futures = [] 59 | for future in futures.keys(): 60 | if not future.done(): 61 | continue 62 | process_future(future) 63 | done_futures.append(future) 64 | 65 | for future in done_futures: 66 | futures.pop(future) 67 | 68 | 69 | -------------------------------------------------------------------------------- /Trace/cfg.py: -------------------------------------------------------------------------------- 1 | import re 2 | import networkx as nx 3 | from html import unescape 4 | 5 | 6 | 7 | class CFGExtractor: 8 | def __init__(self, filename, merge_node=True): 9 | self.filename = filename 10 | self.parse_cfg_file() 11 | 12 | if merge_node: 13 | self.merge_nodes() 14 | 15 | def parse_label(self, label): 16 | match = re.match(r"\((.+?),(.+)\)\(\d+)\", label) 17 | if match: 18 | method_full_name = unescape(match.group(1).strip()) 19 | code = unescape(match.group(2).strip()) 20 | 21 | if method_full_name == "RETURN": 22 | # code = code.split(",")[-1] 23 | code = code[: int(len(code) / 2)] 24 | 25 | line_number = int(match.group(3)) 26 | return method_full_name, code, line_number 27 | else: 28 | return None, None, None 29 | 30 | def parse_cfg_file(self): 31 | with open(self.filename, "r") as file: 32 | content = file.read() 33 | 34 | node_pattern = re.compile(r'"([^"]+)" \[label = <(.+)> \]') 35 | edge_pattern = re.compile(r' "([^"]+)" -> "([^"]+)"') 36 | 37 | nodes = {} 38 | edges = [] 39 | 40 | for line in content.splitlines(): 41 | match = node_pattern.match(line) 42 | if match: 43 | node_id = match.group(1) 44 | label = match.group(2) 45 | method_full_name, code, line_number = self.parse_label(label) 46 | 47 | # format code 48 | if code is not None: 49 | code = code.replace("\n", " ") 50 | code = re.sub(r'\\012\s*', " ", code) 51 | code = code.replace('" "', " ") 52 | code = re.sub(r"\s+", " ", code) 53 | 54 | nodes[node_id] = { 55 | "method_full_name": method_full_name, 56 | "code": code, 57 | "line_number": line_number, 58 | } 59 | 60 | match = edge_pattern.match(line) 61 | if match: 62 | source = match.group(1) 63 | target = match.group(2) 64 | edges.append((source, target)) 65 | 66 | self.graph = nx.DiGraph() 67 | 68 | for node_id, data in nodes.items(): 69 | self.graph.add_node(node_id, **data) 70 | 71 | for source, target in edges: 72 | if source == target: 73 | continue 74 | self.graph.add_edge(source, target) 75 | 76 | def merge_nodes(self): 77 | merged = True 78 | 79 | while merged: 80 | merged = False 81 | 82 | for node1, node2 in self.graph.edges(): 83 | 84 | data1 = self.graph.nodes[node1] 85 | data2 = self.graph.nodes[node2] 86 | 87 | 88 | if data1["line_number"] == data2["line_number"]: 89 | 90 | last_node = node2 91 | next_node = node2 92 | while next_node in self.graph.successors(last_node): 93 | last_node = next_node 94 | next_node = list(self.graph.successors(last_node))[0] 95 | 96 | 97 | if self.graph.has_edge(node1, node2): 98 | 99 | for predecessor in self.graph.predecessors(node1): 100 | if predecessor != last_node: 101 | self.graph.add_edge(predecessor, last_node) 102 | 103 | 104 | for successor in self.graph.successors(node1): 105 | if successor != last_node: 106 | self.graph.add_edge(last_node, successor) 107 | 108 | 109 | nodes_to_remove = [node1] 110 | next_node = node2 111 | while next_node in self.graph.successors(last_node): 112 | nodes_to_remove.append(next_node) 113 | next_node = list(self.graph.successors(next_node))[0] 114 | 115 | for node in nodes_to_remove: 116 | self.graph.remove_node(node) 117 | 118 | merged = True 119 | break 120 | 121 | @property 122 | def node_dict(self): 123 | """ 124 | line_number : code 125 | 126 | example: 127 | { 128 | '6': 'len = strlen(str)', 129 | '14': 'len = *ar + len | is_privileged(user)', 130 | '15': 'memcpy(buf, str, len)', 131 | '16': 'ar = malloc(sizeof(str))', 132 | '17': 'return process(ar, buf, len);', 133 | '7': '!is_privileged(user)', 134 | '8': 'clear(str)', 135 | '9': 'user++', 136 | '3': 'void' 137 | } 138 | """ 139 | code_dict = {} 140 | 141 | nodes = self.graph.nodes(data=True) 142 | for node, data in nodes: 143 | line_number = data["line_number"] 144 | code = data["code"] 145 | 146 | if line_number and code: 147 | code_dict[line_number] = code 148 | 149 | return code_dict 150 | 151 | def dump(self, filename): 152 | nx.draw(self.graph, with_labels=True) 153 | 154 | import matplotlib.pyplot as plt 155 | 156 | plt.savefig(filename) 157 | 158 | 159 | class CFPExtractor: 160 | def __init__(self, cfg_graph): 161 | self.cfg_graph = cfg_graph 162 | self.cfps = [] 163 | self.extract_cfps() 164 | self.extract_lines() 165 | 166 | self.count = 0 167 | 168 | def find_all_paths_basic(self, graph, start, end, path=[]): 169 | path = path + [start] 170 | if start == end: 171 | return [path] 172 | if not graph.has_node(start): 173 | return [] 174 | paths = [] 175 | for neighbor in graph.neighbors(start): 176 | if neighbor not in path: 177 | new_paths = self.find_all_paths_basic(graph, neighbor, end, path) 178 | for new_path in new_paths: 179 | paths.append(new_path) 180 | return paths 181 | 182 | def find_all_paths(self, graph, start, end, path=[], visited=[]): 183 | path = path + [start] 184 | visited = visited + [(start, path[-2] if len(path) > 1 else None)] 185 | 186 | if start == end: 187 | return [path] 188 | 189 | paths = [] 190 | 191 | for node in graph[start]: 192 | if (node, start) not in visited: 193 | newpaths = self.find_all_paths(graph, node, end, path, visited) 194 | for newpath in newpaths: 195 | paths.append(newpath) 196 | 197 | return paths 198 | 199 | def extract_cfps(self): 200 | 201 | entry_node = None 202 | exit_node = None 203 | for node, node_info in self.cfg_graph.nodes(data=True): 204 | if node_info["method_full_name"] == "METHOD": 205 | entry_node = node 206 | elif node_info["method_full_name"] == "METHOD_RETURN": 207 | exit_node = node 208 | if entry_node and exit_node: 209 | break 210 | 211 | 212 | self.paths = self.find_all_paths(self.cfg_graph, entry_node, exit_node) 213 | 214 | def extract_lines(self): 215 | for path in self.paths: 216 | cfp = [] 217 | for node in path: 218 | node_info = self.cfg_graph.nodes[node] 219 | if ( 220 | cfp == [] 221 | or cfp[len(cfp) - 1]["line_number"] != node_info["line_number"] 222 | ): 223 | cfp.append(node_info) 224 | else: 225 | cfp[len(cfp) - 1] = node_info 226 | 227 | cfp = cfp[1:-1] 228 | self.cfps.append(cfp) 229 | 230 | def __str__(self): 231 | s = "" 232 | for cfp in self.cfps: 233 | for n in cfp: 234 | s += str(n["line_number"]) + " " 235 | s += "\n" 236 | return s 237 | 238 | def __iter__(self): 239 | return self 240 | 241 | def __next__(self): 242 | if self.count < len(self.cfps): 243 | result = self.cfps[self.count] 244 | self.count += 1 245 | return result 246 | else: 247 | raise StopIteration 248 | -------------------------------------------------------------------------------- /Trace/detection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | from typing import List 4 | 5 | import numpy as np 6 | import ppdeep 7 | from loguru import logger 8 | 9 | from Trace.manager import ( 10 | FunctionManager, 11 | FunctionPairManager, 12 | ) 13 | from Trace.utils import ( 14 | vuln_to_patch_dict, 15 | ) 16 | from .embedding import CodeBertEmbedding 17 | from .serializer import Serializer 18 | 19 | 20 | def max_mean_col(matrix): 21 | max_each_col = np.nanmax(matrix, axis=0) 22 | mean_max = np.nanmean(max_each_col) 23 | return mean_max.tolist() 24 | 25 | 26 | def cos_similarity_matrix(m1, m2): 27 | return np.dot(m1, m2.T) / ( 28 | np.linalg.norm(m1, axis=1).reshape(-1, 1) * np.linalg.norm(m2, axis=1) 29 | ) 30 | 31 | 32 | def cos_similarity(m1, m2): 33 | sim_matrix = cos_similarity_matrix(m1, m2) 34 | sim = max_mean_col(sim_matrix) 35 | return sim 36 | 37 | 38 | def fuzzy_hash_similarity(s1, s2): 39 | return ppdeep.compare(ppdeep.hash(s1), ppdeep.hash(s2)) 40 | 41 | 42 | def get_fuzzy_hash(code, vuln_file, patch_file): 43 | with open(vuln_file, "r") as v, open(patch_file, "r") as p: # type: ignore 44 | vuln_sim = fuzzy_hash_similarity(code, v.read()) 45 | patch_sim = fuzzy_hash_similarity(code, p.read()) 46 | 47 | return vuln_sim, patch_sim 48 | 49 | 50 | # @profile 51 | def detect_vulnerable_with_initialize( 52 | code: str, 53 | dst_file: str, 54 | similar_list: List[str], 55 | trace_all_result_queue=None, 56 | ) -> tuple[bool, list[str]]: 57 | try: 58 | cur_dir = "v1" 59 | dst_dir = f"{cur_dir}/oldnew" 60 | embedder = None # type: ignore 61 | s = Serializer() 62 | 63 | 64 | logger.debug(f"starting test file : {dst_file}") 65 | code_manager = FunctionManager( 66 | src_file=dst_file, 67 | src_func=code, 68 | # embedder=embedder, 69 | # dst_dir=f"{cur_dir}/target", 70 | # clear=False, 71 | clear=True, 72 | gen_cfg=False, 73 | gen_taint=False, 74 | ) 75 | # print(code_manager.ast_seq) 76 | 77 | logger.debug("init file completed") 78 | 79 | output_list = [] 80 | cve_list = [] 81 | 82 | 83 | for vuln_file in similar_list: 84 | patch_file = vuln_to_patch_dict.get(vuln_file) 85 | if patch_file is None: 86 | logger.debug(f"no patch file for {vuln_file}") 87 | continue 88 | 89 | vuln_name = os.path.basename(vuln_file) 90 | cve_id = vuln_name.split("_")[0] 91 | 92 | 93 | if ( 94 | not s.is_error_func(vuln_name) 95 | and s.get_diff_embedding(vuln_name) is None 96 | ): 97 | 98 | embedder = CodeBertEmbedding() 99 | logger.debug(f"init {vuln_file}.") 100 | 101 | vuln_manager = FunctionManager( 102 | src_file=vuln_file, 103 | embedder=embedder, 104 | dst_dir=dst_dir, 105 | clear=False, 106 | gen_cfg=False, 107 | gen_taint=False, 108 | ) 109 | patch_manager = FunctionManager( 110 | src_file=patch_file, 111 | embedder=embedder, 112 | dst_dir=dst_dir, 113 | clear=False, 114 | gen_cfg=False, 115 | gen_taint=False, 116 | ) 117 | 118 | func_pair_manager = FunctionPairManager(vuln_manager, patch_manager) 119 | 120 | # diff embedding 121 | logger.debug(f"init {vuln_file} diff embedding.") 122 | 123 | ( 124 | vuln_diff_embedding, 125 | patch_diff_embedding, 126 | ) = func_pair_manager.get_diff_embeddings() 127 | 128 | if vuln_diff_embedding.size == 0 or patch_diff_embedding.size == 0: 129 | logger.debug( 130 | f"Embedding diff wrong: {vuln_file}:{vuln_diff_embedding.shape}/{patch_diff_embedding.shape}" 131 | ) 132 | 133 | s.set_error_func(vuln_name) 134 | else: 135 | s.set_diff_embedding( 136 | vuln_name, 137 | ( 138 | vuln_diff_embedding, 139 | patch_diff_embedding, 140 | ), 141 | ) 142 | 143 | logger.debug(f"init {vuln_file} diff embedding ok.") 144 | 145 | info = { 146 | "target_file": dst_file, 147 | "vuln_file": vuln_file, 148 | "patch_file": patch_file, 149 | } 150 | 151 | logger.debug(f"testing {vuln_file}") 152 | 153 | vuln_cond = [] 154 | 155 | def finish(): 156 | if trace_all_result_queue: 157 | trace_all_result_queue.put( 158 | {**info, **{"datail": vuln_cond, "predict": all(vuln_cond)}} 159 | ) 160 | 161 | if all(vuln_cond): 162 | output_list.append(vuln_file) 163 | cve_list.append(cve_id) 164 | 165 | if s.is_error_func(vuln_name): 166 | vuln_cond.append("no vuln_file emb") 167 | finish() 168 | continue 169 | 170 | 171 | if not code_manager.taint_line_flows: 172 | logger.debug("Empty taint line flows.") 173 | vuln_cond.append("no target_file emb") 174 | finish() 175 | continue 176 | 177 | if not embedder: 178 | embedder = CodeBertEmbedding() 179 | code_manager.set_embedder(embedder) 180 | 181 | code_embedding = code_manager.embeddings 182 | 183 | vuln_diff_embedding, patch_diff_embedding = s.get_diff_embedding(vuln_name) # type: ignore 184 | 185 | 186 | vuln_sim = cos_similarity(code_embedding, vuln_diff_embedding) 187 | patch_sim = cos_similarity(code_embedding, patch_diff_embedding) 188 | 189 | 190 | 191 | vuln_cond_sim = True 192 | if vuln_sim < patch_sim: 193 | vuln_cond_sim = False 194 | 195 | vuln_cond.append(vuln_cond_sim) 196 | vuln_cond.extend([vuln_sim, patch_sim]) 197 | 198 | if trace_all_result_queue: 199 | trace_all_result_queue.put( 200 | {**info, **{"datail": vuln_cond, "predict": all(vuln_cond)}} 201 | ) 202 | 203 | if all(vuln_cond): 204 | output_list.append(vuln_file) 205 | cve_list.append(cve_id) 206 | 207 | return output_list != [], output_list 208 | 209 | except Exception as e: 210 | traceback.print_exc() 211 | raise Exception(f"error when process file {dst_file} : {str(e)}") 212 | -------------------------------------------------------------------------------- /Trace/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import config 3 | from transformers import RobertaTokenizer, RobertaModel, logging 4 | 5 | logging.set_verbosity_error() 6 | 7 | codebert_model_path = config.codebert_model_path 8 | 9 | class CodeBertEmbedding(): 10 | def __init__(self, model_file = codebert_model_path) -> None: 11 | 12 | """ 13 | ['len = strlen(str)', '!is_privileged(user)', 'clear(str)', 'user++'] 14 | """ 15 | self.tokenizer = RobertaTokenizer.from_pretrained(model_file) 16 | self.model = RobertaModel.from_pretrained(model_file) 17 | 18 | self.max_m = 500 19 | 20 | 21 | def tokens(self, codes): 22 | tokens = [self.tokenizer.cls_token] 23 | 24 | for code in codes: 25 | tokens += self.tokenizer.tokenize(code) 26 | tokens += [self.tokenizer.sep_token] 27 | 28 | tokens = tokens[:-1] + [self.tokenizer.eos_token] 29 | 30 | return tokens 31 | 32 | 33 | def embeddings(self, codes): 34 | 35 | code_embeddings = [] 36 | for code in codes: 37 | code_embedding = self.embedding(code) 38 | 39 | # sometime it is zero 40 | if code_embedding.numel() == 0 or code_embedding.shape == torch.Size([]): 41 | continue 42 | 43 | # print(code_embedding, code, code_embedding.shape) 44 | 45 | code_embeddings.append(code_embedding) 46 | 47 | code_embeddings = torch.stack(code_embeddings) 48 | # print(f"code_emb: {code_embeddings.shape}") 49 | embeddings = torch.mean(code_embeddings.squeeze(), dim=0) 50 | return embeddings 51 | 52 | def embedding(self, code): 53 | tokens_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(code)) 54 | 55 | if len(tokens_ids) < self.max_m: 56 | # print(code, tokens_ids) 57 | 58 | code_embedding = self.model(torch.tensor(tokens_ids)[None, :])[0] # type: ignore 59 | 60 | # print(code_embedding.shape) 61 | code_embedding = torch.mean(code_embedding.squeeze(), dim=0) 62 | 63 | else: 64 | code_embedding = [] 65 | 66 | for i in range(0, len(tokens_ids), 500): 67 | batch_token_ids = tokens_ids[i:i+500] 68 | batch_code_embedding = self.model(torch.tensor(batch_token_ids)[None, :])[0] 69 | code_embedding.append(batch_code_embedding.squeeze()) 70 | 71 | code_embedding = torch.cat(code_embedding, dim=0) 72 | code_embedding = torch.mean(code_embedding, dim=0) 73 | return code_embedding 74 | 75 | # cbe = CodeBertEmbedding() 76 | -------------------------------------------------------------------------------- /Trace/manager.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, deque 2 | import os 3 | import signal 4 | import tempfile 5 | import numpy as np 6 | import subprocess 7 | from loguru import logger 8 | import torch 9 | import fcntl 10 | import ppdeep 11 | from tree_sitter import Language, Parser 12 | from anytree import AnyNode 13 | 14 | import config 15 | from .embedding import CodeBertEmbedding 16 | 17 | from .cfg import CFGExtractor 18 | from .taintflow import TaintFlowExtractor 19 | from .utils import TRACE_DIR, diff_lines, line_hash 20 | 21 | joern_path = config.joern_path 22 | 23 | 24 | class FileLockManager: 25 | def __init__(self, file_path): 26 | self.file_path = file_path 27 | 28 | def __enter__(self): 29 | self.file = open(self.file_path, "w") 30 | fcntl.flock(self.file, fcntl.LOCK_EX) 31 | return self.file 32 | 33 | def __exit__(self, exc_type, exc_value, traceback): 34 | fcntl.flock(self.file, fcntl.LOCK_UN) 35 | self.file.close() 36 | 37 | 38 | class FunctionManager: 39 | 40 | 41 | def __init__( 42 | self, 43 | embedder=None, 44 | src_file=None, 45 | src_func=None, 46 | dst_dir=None, 47 | clear=True, 48 | gen_cfg=True, 49 | gen_taint=True, 50 | ) -> None: 51 | if not src_file and not src_func: 52 | raise Exception("Specify at least one of the `src_file` and `src_func`.") 53 | 54 | self.need_clear = clear 55 | 56 | if not src_file: 57 | self.src_file = "default" 58 | else: 59 | self.src_file = src_file 60 | 61 | self.unique_name = os.path.basename(self.src_file) 62 | 63 | if not src_func: 64 | with open(self.src_file, "r") as f: 65 | src_func = f.read() 66 | 67 | self.src_func = src_func 68 | 69 | logger.debug(f"processing file {self.src_file}") 70 | 71 | if not dst_dir: 72 | self.base_dir = f"{tempfile.mkdtemp()}" 73 | else: 74 | self.base_dir = f"{TRACE_DIR}/{dst_dir}/{self.unique_name}" 75 | if not os.path.exists(self.base_dir): 76 | os.makedirs(self.base_dir, exist_ok=True) 77 | 78 | self.script_dir = f"{TRACE_DIR}/scripts/" 79 | 80 | self.code_file = f"{self.base_dir}/{self.unique_name}.c" 81 | 82 | self.cpg_file = f"{self.base_dir}/{self.unique_name}.cpg.bin" 83 | 84 | self.script_file = f"{self.script_dir}/taint2json.sc" 85 | 86 | self.ast_parser_file = f"{self.script_dir}/cppparser.so" 87 | 88 | self.taint_file = f"{self.base_dir}/{self.unique_name}.taint.json" 89 | 90 | self.cfg_dir = f"{self.base_dir}/cfg" 91 | 92 | self.cfg_file = f"{self.cfg_dir}/1-cfg.dot" 93 | 94 | self.npy_file = f"{self.base_dir}/{self.unique_name}.npy" 95 | 96 | self.npy_diff_file = f"{self.base_dir}/{self.unique_name}.diff.npy" 97 | 98 | if not os.path.exists(self.code_file): 99 | self.__generate_code_file() 100 | 101 | 102 | self.cpg_file_lock = f"{self.cpg_file}.lck" 103 | 104 | if ( 105 | gen_taint 106 | and not os.path.exists(self.taint_file) 107 | and not os.path.exists(f"{self.taint_file}.err") 108 | ): 109 | self.generate_taint_file(self.script_file, self.taint_file) 110 | 111 | if gen_cfg and not os.path.exists(self.cfg_file): 112 | self.generate_cfg_file() 113 | 114 | self._cfg_node_dict = None 115 | 116 | self.line_cb_embeddings_dict = {} 117 | 118 | self.code_cb_embeddings_dict = {} 119 | 120 | self._taint_line_flows = None 121 | 122 | self._taint_code_flows = None 123 | 124 | self._tcf_codebert_embeddings = None 125 | 126 | self._tcf_tracer_embeddings = None 127 | 128 | self._tcf_sent2vec_embeddings = None 129 | 130 | self.token_dict = None 131 | 132 | self.embedder = embedder 133 | 134 | self._ast_parser = None 135 | 136 | self._ast = None 137 | 138 | self._ast_nodes = None 139 | 140 | self._ast_hash = None 141 | 142 | self._hash_dict = None 143 | 144 | self._fuzzy_hash = None 145 | 146 | def __del__(self): 147 | if self.need_clear and hasattr(self, "base_dir"): 148 | self.clear_intermediate_file() 149 | 150 | def __generate_code_file(self): 151 | logger.debug("generating code file...") 152 | 153 | # func = self.src_func 154 | # func = re.sub(r".*?(\w+\s*\([\w\W+]*\)[\s\n]*\{)", r"void \1", func, 1) # type: ignore 155 | # func = re.sub(r"(\)\s*)\w+(\s*\{)", r"\1 \2", func, 1) 156 | # func = re.sub("(?&1 > /dev/null") 170 | 171 | def generate_taint_file( 172 | self, script_file, taint_file, extra_params={}, timeout=2 * 60 * 1000 173 | ): 174 | """ 175 | generate {self.unique_name}.taint.json 176 | """ 177 | if not os.path.exists(self.cpg_file): 178 | self.__generate_cpg_file() 179 | 180 | logger.debug("generating taint file ...") 181 | if not os.path.exists(script_file): 182 | raise Exception(f"cannot find script {script_file}, generate failed.") 183 | 184 | params = f'bin="{self.cpg_file}",file={taint_file}' 185 | for k, v in extra_params: 186 | params += f"{k}={v}" 187 | 188 | cmd = f'{os.path.join(joern_path, "joern")} --script {script_file} -p {params}' 189 | logger.debug(cmd) 190 | 191 | with FileLockManager(self.cpg_file_lock): 192 | p = subprocess.Popen( 193 | cmd, 194 | stderr=subprocess.STDOUT, 195 | stdout=subprocess.PIPE, 196 | shell=True, 197 | close_fds=True, 198 | start_new_session=True, 199 | ) 200 | try: 201 | p.communicate(timeout=timeout) 202 | # p.wait() 203 | except subprocess.TimeoutExpired: 204 | p.kill() 205 | p.terminate() 206 | os.killpg(p.pid, signal.SIGTERM) 207 | 208 | if not os.path.exists(taint_file): 209 | logger.warning( 210 | f"generate taint file failed for {self.src_file} when using {script_file}" 211 | ) 212 | f = open(f"{taint_file}.err", "w") 213 | f.close() 214 | return False 215 | 216 | logger.debug("generating taint file succeed") 217 | 218 | return True 219 | 220 | def generate_cfg_file(self): 221 | """ 222 | generate {self.unique_name}.cfg.dot 223 | """ 224 | if not os.path.exists(self.cpg_file): 225 | self.__generate_cpg_file() 226 | 227 | logger.debug("generating cfg file...") 228 | if os.path.exists(self.cfg_dir): 229 | return 230 | # os.removedirs(out_dir) 231 | cmd = f"{os.path.join(joern_path, 'joern-export')} {self.cpg_file} --repr cfg --out {self.cfg_dir}" 232 | logger.debug(cmd) 233 | with FileLockManager(self.cpg_file_lock): 234 | if not os.path.exists(self.cfg_file): 235 | os.system(cmd) 236 | 237 | if not os.path.exists(self.cfg_file): 238 | logger.warning(f"generate cfg failed for {self.src_file}\n command: {cmd}") 239 | return False 240 | 241 | logger.debug("generating cfg file succeed") 242 | return True 243 | 244 | def set_embedder(self, embedder): 245 | self.embedder = embedder 246 | 247 | @property 248 | def ast_parser(self): 249 | if not self._ast_parser: 250 | self._ast_parser = Parser() 251 | CPP_LANGUAGE = Language(self.ast_parser_file, "cpp") 252 | self._ast_parser.set_language(CPP_LANGUAGE) 253 | return self._ast_parser 254 | 255 | @property 256 | def ast(self): 257 | if not self._ast: 258 | self._ast = self.ast_parser.parse(bytes(self.src_func, "utf8")) # type: ignore 259 | return self._ast 260 | 261 | @property 262 | def ast_nodes(self): 263 | if not self._ast_nodes: 264 | root_node = self.ast.root_node 265 | nodes = [] 266 | 267 | def dfs(node): 268 | nodes.append(node.text.decode("utf-8")) 269 | for child in node.children: 270 | dfs(child) 271 | 272 | def bfs(node): 273 | if not node: 274 | return 275 | 276 | queue = deque() 277 | queue.append(node) 278 | 279 | while queue: 280 | node = queue.popleft() 281 | nodes.append(node.text.decode("utf-8")) 282 | 283 | for child in node.children: 284 | queue.append(child) 285 | 286 | dfs(root_node) 287 | self._ast_nodes = nodes 288 | 289 | return self._ast_nodes 290 | 291 | @property 292 | def ast_edges(self): 293 | edges = [] 294 | 295 | def extract_edge(node): 296 | if not node: 297 | return 298 | 299 | for child in node.children: 300 | edges.append( 301 | (node.text.decode("utf-8"), child.text.decode("utf-8")) 302 | ) 303 | extract_edge(child) 304 | 305 | root_node = self.ast.root_node 306 | extract_edge(root_node) 307 | 308 | return edges 309 | 310 | @property 311 | def taint_line_flows(self): 312 | if not self._taint_line_flows: 313 | if not os.path.exists(self.taint_file): 314 | has_taint = self.generate_taint_file(self.script_file, self.taint_file) 315 | if not has_taint: 316 | return [] 317 | taint_line_flows = TaintFlowExtractor(self.taint_file).taint_line_flows 318 | 319 | if self.cfg_node_dict and taint_line_flows: 320 | allowed_lines = list(self.cfg_node_dict.keys()) 321 | taint_line_flows = [ 322 | [ 323 | val 324 | if val in allowed_lines 325 | else max(filter(lambda x: x < val, allowed_lines), default=val) 326 | for val in taint_line_flow 327 | ] 328 | for taint_line_flow in taint_line_flows 329 | ] 330 | 331 | self._taint_line_flows = taint_line_flows 332 | return self._taint_line_flows 333 | 334 | @property 335 | def cfg_node_dict(self): 336 | if not self._cfg_node_dict: 337 | if not os.path.exists(self.cfg_file): 338 | has_cfg = self.generate_cfg_file() 339 | if not has_cfg: 340 | return {} 341 | self._cfg_node_dict = CFGExtractor(self.cfg_file).node_dict 342 | return self._cfg_node_dict 343 | 344 | @property 345 | def taint_code_flows(self): # -> list[tuple[Any | None, ...]] | None: 346 | if not self._taint_code_flows: 347 | if self.taint_line_flows and self.cfg_node_dict: 348 | logger.debug("generating taint code flows...") 349 | 350 | self._taint_code_flows = list( 351 | map( 352 | lambda x: tuple(map(lambda line: self.cfg_node_dict[line], x)), # type: ignore 353 | self.taint_line_flows, 354 | ) 355 | ) 356 | return self._taint_code_flows 357 | 358 | def embeddings_mean(self, code_embeddings): 359 | 360 | # code_embeddings = list( 361 | # filter( 362 | # lambda code_embedding: code_embedding.numel() != 0 363 | # and code_embedding.shape != torch.Size([]), 364 | # code_embeddings, 365 | # ) 366 | # ) 367 | if len(code_embeddings) == 0: 368 | return np.array([]) 369 | # print(code_embeddings) 370 | code_embeddings = torch.stack(code_embeddings) 371 | embeddings = torch.mean(code_embeddings.squeeze(), dim=0) 372 | embeddings_numpy = embeddings.detach().numpy() 373 | return embeddings_numpy 374 | 375 | def embedding_line_flows(self, line_flows): 376 | """ 377 | [[1,2,3], [2,3,5], ...] 378 | """ 379 | embs = [self.embedding_line_flow(tlf) for tlf in line_flows] 380 | # for tlf in line_flows: 381 | # print(tlf) 382 | # for line in tlf: 383 | # e = self.line_cb_embeddings_dict.get(line) 384 | # if e is None: 385 | # code = self.cfg_node_dict[line] 386 | # e = self.embedder.embedding(code) 387 | # self.embedding_line_flow(tlf) 388 | return np.array(embs) 389 | 390 | def embedding_line_flow(self, line_flow): 391 | 392 | if not self.embedder: 393 | self.set_embedder(CodeBertEmbedding()) 394 | tlf_embeddings = [] 395 | for line in line_flow: 396 | emb = self.line_cb_embeddings_dict.get(line) 397 | if emb is None: 398 | code = self.cfg_node_dict[line] 399 | emb = self.embedder.embedding(code) 400 | if emb.numel() == 0 or emb.shape == torch.Size([]): 401 | logger.error(f"strange embedding : {code}: {emb}") 402 | continue 403 | self.line_cb_embeddings_dict[line] = emb 404 | # self.code_cb_embeddings_dict[] = emb 405 | tlf_embeddings.append(emb) 406 | 407 | return self.embeddings_mean(tlf_embeddings) 408 | 409 | def embedding_code_flows(self, code_flows): 410 | return np.array([self.embedding_code_flow(cf) for cf in code_flows]) 411 | 412 | def embedding_code_flow(self, code_flow): 413 | 414 | if not self.embedder: 415 | self.set_embedder(CodeBertEmbedding()) 416 | tcf_embeddings = [] 417 | for code in code_flow: 418 | emb = self.code_cb_embeddings_dict.get(code) 419 | if emb is None: 420 | emb = self.embedder.embedding(code) 421 | if emb.numel() == 0 or emb.shape == torch.Size([]): 422 | logger.error(f"strange embedding : {code}: {emb}") 423 | continue 424 | self.code_cb_embeddings_dict[code] = emb 425 | tcf_embeddings.append(emb) 426 | return self.embeddings_mean(tcf_embeddings) 427 | 428 | @property 429 | def tcf_codebert_embeddings(self): 430 | if self._tcf_codebert_embeddings is None: 431 | # if os.path.exists(self.npy_file): 432 | # self._tcf_codebert_embeddings = np.load(self.npy_file) 433 | if self.taint_line_flows: 434 | self._tcf_codebert_embeddings = self.embedding_line_flows( 435 | self.taint_line_flows 436 | ) 437 | 438 | # np.save(self.npy_file, self._tcf_codebert_embeddings) 439 | return self._tcf_codebert_embeddings 440 | 441 | @property 442 | def embeddings(self): 443 | logger.debug("getting property embeddings") 444 | return self.tcf_codebert_embeddings 445 | 446 | def clear_intermediate_file(self): 447 | if os.path.exists(self.base_dir): 448 | import shutil 449 | 450 | shutil.rmtree(self.base_dir) 451 | 452 | @property 453 | def hash_dict(self): 454 | if self._hash_dict is None: 455 | self._hash_dict = Counter() 456 | for line in self.src_func.splitlines(): 457 | hash = line_hash(line) 458 | self._hash_dict[hash] += 1 459 | 460 | return self._hash_dict 461 | 462 | @property 463 | def fuzzy_hash(self): 464 | if self._fuzzy_hash is None: 465 | self._fuzzy_hash = ppdeep.hash(self.src_func) 466 | return self._fuzzy_hash 467 | 468 | # @property 469 | # def fuzzy_hash(self): 470 | # if self._fuzzy_hash is None: 471 | # self._fuzzy_hash = ppdeep.hash(self.ast_nodes) 472 | # return self._fuzzy_hash 473 | 474 | def get_ast_hash(self): 475 | root_node = self.ast.root_node 476 | 477 | child_dict = {} 478 | 479 | 480 | def init_child_dict(node): 481 | children = [] 482 | for child in node.children: 483 | children.append(child.id) 484 | child_dict[node.id] = children 485 | for child in node.children: 486 | init_child_dict(child) 487 | 488 | 489 | new_tree = AnyNode(id=0, text=None, data=None) 490 | nodes = [] 491 | def create_tree(root, node, parent=None): 492 | id = len(nodes) 493 | text = node.text.decode('utf-8') 494 | text_hash = hash(text) 495 | nodes.append(text) 496 | if id == 0: 497 | root.text = text 498 | root.data = node 499 | root.hash = text_hash 500 | else: 501 | newnode = AnyNode(id=id, text=text, hash=text_hash, data=node, parent=parent) 502 | for child in node.children: 503 | create_tree(root, child, parent=root if id == 0 else newnode) 504 | 505 | 506 | id2hash = {} 507 | id2number = {} 508 | def get_hash(): 509 | for i in range(len(child_dict) - 1, -1, -1): 510 | token = nodes[i] 511 | if not child_dict[i]: 512 | id2hash[i] = hash(token) 513 | id2number[i] = 1 514 | else: 515 | h = hash(token) 516 | n = 1 517 | if token == 'binary_expression': 518 | childtoken = [] 519 | for c in child_dict[i]: 520 | childtoken.append(nodes[c]) 521 | if '/' in childtoken or '-' in childtoken: 522 | j = 1 523 | for child_id in child_dict[i]: 524 | h += j * id2hash[child_id] 525 | n += id2number[child_id] 526 | j += 1 527 | else: 528 | for child_id in child_dict[i]: 529 | h += id2hash[child_id] 530 | n += id2number[child_id] 531 | else: 532 | for child_id in child_dict[i]: 533 | h += id2hash[child_id] 534 | n += id2number[child_id] 535 | id2hash[i] = h 536 | id2number[i] = n 537 | 538 | 539 | create_tree(new_tree, root_node) 540 | 541 | init_child_dict(new_tree) 542 | get_hash() 543 | 544 | hash_list_array = [[] for i in range(len(nodes) + 1)] 545 | for i in range(len(id2number)): 546 | children_num = id2number[i] 547 | hash_list_array[children_num].append(i) 548 | 549 | return hash_list_array, id2hash, child_dict 550 | 551 | 552 | @property 553 | def ast_hash(self): 554 | if not self._ast_hash: 555 | self._ast_hash = self.get_ast_hash() 556 | return self._ast_hash 557 | 558 | 559 | class FunctionManagerV2(FunctionManager): 560 | 561 | 562 | def __init__( 563 | self, 564 | embedder: CodeBertEmbedding, 565 | src_file=None, 566 | src_func=None, 567 | dst_dir=None, 568 | clear=True, 569 | ): 570 | super().__init__( 571 | embedder, 572 | src_file=src_file, 573 | src_func=src_func, 574 | dst_dir=dst_dir, 575 | clear=clear, 576 | gen_cfg=False, 577 | ) 578 | 579 | @property 580 | def cfg_node_dict(self): 581 | blacklist = ["else", "do"] 582 | if not self._cfg_node_dict: 583 | self._cfg_node_dict = {} 584 | for line, code in enumerate(self.src_func.splitlines()): 585 | if len(code) <= 1 or code in blacklist: 586 | continue 587 | self._cfg_node_dict[line] = code 588 | 589 | return self._cfg_node_dict 590 | 591 | 592 | class FunctionPairManager: 593 | def __init__( 594 | self, 595 | vuln_function_manager: FunctionManager, 596 | patch_function_manager: FunctionManager, 597 | ): 598 | self.vuln_fm = vuln_function_manager 599 | self.patch_fm = patch_function_manager 600 | 601 | def get_diff_lines(self): 602 | return diff_lines( 603 | self.vuln_fm.src_func.splitlines(), self.patch_fm.src_func.splitlines() 604 | ) 605 | 606 | def get_diff_lines_hash(self, filter_lines=[]): 607 | vuln_diff_line, patch_diff_line = self.get_diff_lines() 608 | 609 | if filter_lines != []: 610 | vuln_diff_line = [line for line in vuln_diff_line if line not in filter_lines] 611 | patch_diff_line = [line for line in patch_diff_line if line not in filter_lines] 612 | 613 | return ( 614 | list(map(line_hash, vuln_diff_line)), 615 | list(map(line_hash, patch_diff_line)), 616 | ) 617 | 618 | def get_diff_tcfs(self): 619 | """ 620 | get taint flow 621 | get diff taint flow 622 | """ 623 | if not self.vuln_fm.taint_code_flows or not self.patch_fm.taint_code_flows: 624 | return list(), list() 625 | 626 | vuln_tcfs = set(self.vuln_fm.taint_code_flows) 627 | patch_tcfs = set(self.patch_fm.taint_code_flows) 628 | vuln_unique_tcfs = vuln_tcfs.difference(patch_tcfs) 629 | patch_unique_tcfs = patch_tcfs.difference(vuln_tcfs) 630 | return list(vuln_unique_tcfs), list(patch_unique_tcfs) 631 | 632 | def get_diff_embeddings(self, embed_type="codebert"): 633 | """ """ 634 | self.vuln_fm.npy_diff_file = ( 635 | f"{self.vuln_fm.base_dir}/{embed_type}/{self.patch_fm.unique_name}.diff.npy" 636 | ) 637 | if not os.path.exists(f"{self.vuln_fm.base_dir}/{embed_type}"): 638 | os.makedirs(f"{self.vuln_fm.base_dir}/{embed_type}", exist_ok=True) 639 | 640 | self.patch_fm.npy_diff_file = ( 641 | f"{self.patch_fm.base_dir}/{embed_type}/{self.vuln_fm.unique_name}.diff.npy" 642 | ) 643 | if not os.path.exists(f"{self.patch_fm.base_dir}/{embed_type}"): 644 | os.makedirs(f"{self.patch_fm.base_dir}/{embed_type}", exist_ok=True) 645 | 646 | vuln_tcfs, patch_tcfs = [], [] 647 | if not os.path.exists(self.vuln_fm.npy_diff_file) or not os.path.exists( 648 | self.patch_fm.npy_diff_file 649 | ): 650 | vuln_tcfs, patch_tcfs = self.get_diff_tcfs() 651 | 652 | logger.debug( 653 | f"getting different taint flow embeddings in {len(vuln_tcfs)} v.s {len(patch_tcfs)}" 654 | ) 655 | 656 | # get it from npy 657 | if os.path.exists(self.vuln_fm.npy_diff_file): 658 | vuln_emb = np.load(self.vuln_fm.npy_diff_file) 659 | else: 660 | vuln_emb = self.vuln_fm.embedding_code_flows(vuln_tcfs) 661 | np.save(self.vuln_fm.npy_diff_file, vuln_emb) 662 | 663 | # get it from npy 664 | if os.path.exists(self.patch_fm.npy_diff_file): 665 | patch_emb = np.load(self.patch_fm.npy_diff_file) 666 | else: 667 | patch_emb = self.patch_fm.embedding_code_flows(patch_tcfs) 668 | np.save(self.patch_fm.npy_diff_file, patch_emb) 669 | 670 | return vuln_emb, patch_emb 671 | -------------------------------------------------------------------------------- /Trace/norm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | import sys 5 | 6 | from .utils import TRACE_DIR 7 | 8 | 9 | ctags = f"{TRACE_DIR}/../Preprocessor/universal-ctags/ctags" 10 | 11 | 12 | def abstract(function): 13 | tmp = "./temp.cpp" 14 | with open(tmp, "w", encoding="UTF-8") as f: 15 | f.write(function) 16 | 17 | abstractBody = abstract_file(tmp, function) 18 | os.remove(tmp) 19 | return abstractBody 20 | 21 | 22 | def abstract_file(file, function=None): 23 | if function is None: 24 | with open(file, "r") as f: 25 | function = f.read() 26 | 27 | abstract_func = function 28 | 29 | """ 30 | ref: `--kinds-C++` : `ctags --list-kinds-full | grep C++` 31 | f function 32 | l local 33 | v variable 34 | z parameter 35 | """ 36 | command = f'{ctags} -f - --kinds-C++=flvz -u --fields=neKSt --language-force=c --language-force=c++ "{file}"' 37 | try: 38 | func_info = subprocess.check_output( 39 | command, stderr=subprocess.STDOUT, shell=True 40 | ).decode(errors="ignore") 41 | except subprocess.CalledProcessError as e: 42 | print("Parser Error:", e) 43 | func_info = "" 44 | 45 | variables = [] 46 | parameters = [] 47 | 48 | funcs = func_info.split("\n") 49 | local_reg = re.compile(r"local") 50 | parameter_reg = re.compile(r"parameter") 51 | function_reg = re.compile(r"(function)") 52 | # param_space = re.compile(r"\(\s*([^)]+?)\s*\)") 53 | # word = re.compile(r"\w+") 54 | datatype_reg = re.compile(r"(typeref:)\w*(:)") 55 | number_reg = re.compile(r"(\d+)") 56 | # func_body = re.compile(r"{([\S\s]*)}") 57 | 58 | lines = [] 59 | 60 | param_names = [] 61 | dtype_names = [] 62 | lvar_names = [] 63 | 64 | for func in funcs: 65 | elements = re.sub(r"[\t\s ]{2,}", "", func) 66 | elements = elements.split("\t") 67 | if ( 68 | func != "" 69 | and len(elements) >= 6 70 | and (local_reg.fullmatch(elements[3]) or local_reg.fullmatch(elements[4])) 71 | ): 72 | variables.append(elements) 73 | 74 | if ( 75 | func != "" 76 | and len(elements) >= 6 77 | and ( 78 | parameter_reg.match(elements[3]) or parameter_reg.fullmatch(elements[4]) 79 | ) 80 | ): 81 | parameters.append(elements) 82 | 83 | 84 | for func in funcs: 85 | elements = re.sub(r"[\t\s ]{2,}", "", func) 86 | elements = elements.split("\t") 87 | if func != "" and len(elements) >= 8 and function_reg.fullmatch(elements[3]): 88 | lines = ( 89 | int(number_reg.search(elements[4]).group(0)), 90 | int(number_reg.search(elements[7]).group(0)), 91 | ) 92 | 93 | # print (lines) 94 | 95 | line = 0 96 | for param_name in parameters: 97 | if number_reg.search(param_name[4]): 98 | line = int(number_reg.search(param_name[4]).group(0)) 99 | elif number_reg.search(param_name[5]): 100 | line = int(number_reg.search(param_name[5]).group(0)) 101 | if len(param_name) >= 4 and lines[0] <= int(line) <= lines[1]: 102 | param_names.append(param_name[0]) 103 | if len(param_name) >= 6 and datatype_reg.search(param_name[5]): 104 | dtype_names.append( 105 | re.sub(r" \*$", "", datatype_reg.sub("", param_name[5])) 106 | ) 107 | elif len(param_name) >= 7 and datatype_reg.search(param_name[6]): 108 | dtype_names.append( 109 | re.sub(r" \*$", "", datatype_reg.sub("", param_name[6])) 110 | ) 111 | 112 | for variable in variables: 113 | if number_reg.search(variable[4]): 114 | line = int(number_reg.search(variable[4]).group(0)) 115 | elif number_reg.search(variable[5]): 116 | line = int(number_reg.search(variable[5]).group(0)) 117 | if len(variable) >= 4 and lines[0] <= int(line) <= lines[1]: 118 | lvar_names.append(variable[0]) 119 | if len(variable) >= 6 and datatype_reg.search(variable[5]): 120 | dtype_names.append( 121 | re.sub(r" \*$", "", datatype_reg.sub("", variable[5])) 122 | ) 123 | elif len(variable) >= 7 and datatype_reg.search(variable[6]): 124 | dtype_names.append( 125 | re.sub(r" \*$", "", datatype_reg.sub("", variable[6])) 126 | ) 127 | 128 | 129 | try: 130 | param_id = 0 131 | for param_name in param_names: 132 | if len(param_name) == 0: 133 | continue 134 | paramPattern = re.compile("(^|\W)" + param_name + "(\W)") 135 | abstract_func = paramPattern.sub( 136 | f"\g<1>FPARAM{param_id}\g<2>", abstract_func 137 | ) 138 | param_id += 1 139 | 140 | dtype_id = 0 141 | for dtype in dtype_names: 142 | if len(dtype) == 0: 143 | continue 144 | dtypePattern = re.compile("(^|\W)" + dtype + "(\W)") 145 | abstract_func = dtypePattern.sub( 146 | f"\g<1>DTYPE{dtype_id}\g<2>", abstract_func 147 | ) 148 | dtype_id += 1 149 | 150 | lvar_id = 0 151 | for lvar in lvar_names: 152 | if len(lvar) == 0: 153 | continue 154 | lvarPattern = re.compile("(^|\W)" + lvar + "(\W)") 155 | abstract_func = lvarPattern.sub(f"\g<1>LVAR{lvar_id}\g<2>", abstract_func) 156 | lvar_id += 1 157 | 158 | except: # noqa: E722 159 | pass 160 | 161 | return abstract_func 162 | 163 | 164 | def norm(code): 165 | code = re.sub("(? s"${file}" 8 | }catch{ 9 | case e: Exception => println("Couldn't parse that file.") 10 | } 11 | 12 | } 13 | -------------------------------------------------------------------------------- /Trace/serializer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Iterable 3 | import numpy as np 4 | import redis 5 | import config 6 | 7 | redis_host = config.redis_host 8 | redis_port = config.redis_port 9 | 10 | 11 | class Serializer: 12 | def __init__(self, host=redis_host, port=redis_port) -> None: 13 | self.host = host 14 | self.port = port 15 | 16 | self.r_patch_line_dict = redis.Redis(host=host, port=port, db=0) 17 | self.r_patch_hash_dict = redis.Redis(host=host, port=port, db=1) 18 | self.r_diff_embedding_dict = redis.Redis(host=host, port=port, db=2) 19 | self.r_error_func_list = redis.Redis(host=host, port=port, db=3) 20 | self.r_fuzzy_hash = redis.Redis(host=host, port=port, db=4) 21 | 22 | def set(self, handler, k, v): 23 | v = json.dumps(v) 24 | handler.set(k, v) 25 | 26 | def get(self, handler, k): 27 | v = handler.get(k) 28 | if v is not None: 29 | v = json.loads(v) # type: ignore 30 | return v 31 | 32 | def set_patch_line(self, k, v: Iterable): 33 | self.set(self.r_patch_line_dict, k, v) 34 | 35 | def get_patch_line(self, k): 36 | return self.get(self.r_patch_line_dict, k) 37 | 38 | def set_line_hash_dict(self, k, v: Iterable): 39 | self.set(self.r_patch_hash_dict, k, v) 40 | 41 | def get_line_hash_dict(self, k): 42 | return self.get(self.r_patch_hash_dict, k) 43 | 44 | def set_diff_embedding(self, k, v: Iterable[np.ndarray]): 45 | lv = tuple(map(lambda n: n.tolist(), v)) 46 | self.set(self.r_diff_embedding_dict, k, lv) 47 | 48 | def get_diff_embedding(self, k): 49 | v = self.get(self.r_diff_embedding_dict, k) 50 | if v is None: 51 | return None 52 | v = tuple(np.array(arr) for arr in v) 53 | return v 54 | 55 | def set_error_func(self, k): 56 | self.r_error_func_list.set(k, 1) 57 | 58 | def is_error_func(self, k): 59 | v = self.r_error_func_list.get(k) 60 | return v is not None 61 | 62 | def set_fuzzy_hash(self, k, v): 63 | self.set(self.r_fuzzy_hash, k, v) 64 | 65 | def get_fuzzy_hash(self, k): 66 | return self.get(self.r_fuzzy_hash, k) 67 | 68 | 69 | if __name__ == "__main__": 70 | s = Serializer() 71 | s.set_diff_embedding("1", (np.array([1.0, 2.0]), np.array([3.0, 4.0]))) 72 | print(s.get_diff_embedding("1")) 73 | 74 | s.set_patch_line("1", [("12", "34"), ("56", "78")]) 75 | print(s.get_patch_line("1")) 76 | 77 | s.set_error_func("1") 78 | print(s.is_error_func("1")) 79 | -------------------------------------------------------------------------------- /Trace/taintflow.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | 4 | 5 | class TaintFlowExtractor: 6 | 7 | def __init__(self, taint_file, taint_min_len=2) -> None: 8 | self.taint_file = taint_file 9 | self.taint_min_len = taint_min_len 10 | with open(self.taint_file, "r") as f: 11 | data = json.load(f) 12 | 13 | self.taint_flows = [] 14 | for item in data: 15 | taint_flow = item.get("elements", {}) 16 | self.taint_flows.append(taint_flow) 17 | 18 | self._taint_line_flows = None 19 | 20 | @property 21 | def taint_line_flows(self): 22 | if self._taint_line_flows is None: 23 | self._taint_line_flows = [] 24 | for taint_flow in self.taint_flows: 25 | line_flow = [] 26 | for node in taint_flow: 27 | if not line_flow or line_flow[len(line_flow)-1] != node["lineNumber"]: 28 | line_flow.append(node["lineNumber"]) 29 | 30 | if line_flow and len(line_flow) >= self.taint_min_len and line_flow not in self._taint_line_flows: 31 | self._taint_line_flows.append(line_flow) 32 | 33 | self._taint_line_flows = sorted(self._taint_line_flows, key=lambda x: x[0]) 34 | 35 | return self._taint_line_flows 36 | -------------------------------------------------------------------------------- /Trace/utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import difflib 3 | import hashlib 4 | import os 5 | from typing import List 6 | 7 | vuln_to_patch_dict = {} 8 | 9 | 10 | TRACE_DIR = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | def norm_line(line_list : List[str]) -> List[str]: 14 | return list(map(lambda line: line.strip(), line_list)) 15 | 16 | def get_file_pairs(file_path): 17 | with open(file_path, "r") as file: 18 | reader = csv.DictReader(file) 19 | for row in reader: 20 | yield row["old"], row["new"] 21 | 22 | 23 | def diff_lines(left_list, right_list): 24 | left_list = norm_line(left_list) 25 | right_list = norm_line(right_list) 26 | 27 | differ = difflib.Differ() 28 | diff = list(differ.compare(left_list, right_list)) 29 | 30 | left_diff = [] 31 | right_diff = [] 32 | 33 | left = 0 34 | right = 0 35 | for line in diff: 36 | if line.startswith("- "): 37 | left_diff.append(left_list[left]) 38 | left += 1 39 | elif line.startswith("+ "): 40 | right_diff.append(right_list[right]) 41 | right += 1 42 | elif not line.startswith("? ") and line.strip() != "": 43 | left += 1 44 | right += 1 45 | 46 | return left_diff, right_diff 47 | 48 | def line_hash(line): 49 | return hashlib.sha256(line.strip().encode()).hexdigest() 50 | -------------------------------------------------------------------------------- /config.example.yml: -------------------------------------------------------------------------------- 1 | basic: 2 | dataset: 3 | normal_sample_dataset_path: 'resources/NormalSample' 4 | old_new_func_dataset_path: 'resources/OldNewFuncs' 5 | trace: 6 | codebert_model_path: 'resources/codebert' 7 | joern_path: 'resources/joern-cli' 8 | workers: 9 | bloom_filter: 5 10 | token: 15 11 | syntax: 6 12 | trace: 32 13 | experiment: 14 | token_filter: 15 | jaccard_sim_threshold: 0.7 16 | trace: 17 | ast_sim_threshold_min: 0.6 18 | ast_sim_threshold_max: 1.0 19 | fuzzy_hash_sim_threshold_max: 0.9 20 | fuzzy_hash_sim_threshold_min: 0.7 21 | redis_host: 127.0.0.1 22 | redis_port: 6379 23 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import yaml 3 | 4 | 5 | def load_config(): 6 | config_file = "config.yml" 7 | if os.path.exists(config_file): 8 | with open(config_file, "r") as f: 9 | config = yaml.safe_load(f) 10 | else: 11 | config = { 12 | "basic": { 13 | "dataset": {"old_new_func_dataset_path": "", "normal_sample_dataset_path": ""}, 14 | "trace": {"codebert_model_path": "", "joern_path": ""}, 15 | "workers": {"bloom_filter": 5, "token": 15, "syntax": 6, "trace": 32} 16 | }, 17 | "experiment": { 18 | "token_filter": {"jaccard_sim_threshold": 0.7}, 19 | "trace": { 20 | "ast_sim_threshold_min": 0.7, 21 | "ast_sim_threshold_max": 0.9, 22 | "redis_host": "127.0.0.1", 23 | "redis_port": "6379", 24 | }, 25 | }, 26 | } 27 | 28 | with open(config_file, "w") as f: 29 | yaml.dump(config, f) 30 | 31 | print("config.yml not exist, quiting") 32 | exit(1) 33 | 34 | return config 35 | 36 | 37 | config = load_config() 38 | 39 | """ 40 | check basic config 41 | """ 42 | basic_dataset = config["basic"].get("dataset", {}) 43 | old_new_func_dataset_path = basic_dataset.get("old_new_func_dataset_path") 44 | assert old_new_func_dataset_path is not None 45 | 46 | normal_sample_dataset_path = basic_dataset.get("normal_sample_dataset_path") 47 | assert normal_sample_dataset_path is not None 48 | 49 | 50 | trace = config["basic"].get("trace", {}) 51 | codebert_model_path = trace.get("codebert_model_path") 52 | assert codebert_model_path is not None 53 | 54 | joern_path = trace.get("joern_path") 55 | assert joern_path is not None 56 | 57 | workers = config["basic"].get("workers", {}) 58 | bloom_filter_worker = workers.get("bloom_filter") 59 | assert bloom_filter_worker is not None 60 | token_worker = workers.get("token") 61 | assert token_worker is not None 62 | syntax_worker = workers.get("syntax") 63 | assert syntax_worker is not None 64 | trace_worker = workers.get("trace") 65 | assert trace_worker is not None 66 | 67 | """ 68 | check experiment config 69 | """ 70 | experiment_token_filter = config["experiment"].get("token_filter", {}) 71 | jaccard_sim_threshold = experiment_token_filter.get("jaccard_sim_threshold", 0.7) 72 | 73 | experiment_trace = config["experiment"].get("trace", {}) 74 | ast_sim_threshold_min = experiment_trace.get("ast_sim_threshold_min", 0.7) 75 | ast_sim_threshold_max = experiment_trace.get("ast_sim_threshold_max", 0.9) 76 | 77 | redis_host = experiment_trace.get("redis_host") 78 | redis_port = experiment_trace.get("redis_port") 79 | -------------------------------------------------------------------------------- /dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-bullseye 2 | 3 | WORKDIR /usr/src/app 4 | 5 | RUN apt-get update && apt-get install -y git libxml2 libjansson4 libyaml-0-2 vim 6 | 7 | COPY requirements.txt . 8 | RUN pip install --no-cache-dir -r requirements.txt && \ 9 | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 10 | 11 | COPY . . 12 | 13 | # install redis 14 | RUN apt install build-essential -y --no-install-recommends 15 | RUN cd resources/redis-7.2.3 && \ 16 | make && \ 17 | make install && \ 18 | cd .. && \ 19 | rm -rf /usr/src/app/resources/redis-7.2.3 20 | 21 | ENV JAVA_HOME "/usr/src/app/resources/jdk-17.0.11" 22 | ENV PATH $PATH:$JAVA_HOME/bin 23 | 24 | EXPOSE 8000 25 | 26 | CMD redis-server --daemonize yes && python3 server.py -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: FIRE 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.11 6 | - loguru=0.7.2 7 | - numpy=1.26.2 8 | - pygments=2.15.1 9 | - matplotlib=3.8.0 10 | - tqdm=4.65.0 11 | - levenshtein=0.23.0 12 | - networkx=3.2.1 13 | - redis-py=5.0.1 14 | - flask=3.0.0 15 | - pytorch::pytorch=2.1.0 16 | - pytorch::torchvision=0.16.0 17 | - pytorch::torchaudio=2.1.0 18 | - transformers=4.36.2 19 | - tree_sitter=0.20.4 20 | - anytree=2.8.0 21 | - pip 22 | - pip: 23 | - bloom-filter2==2.0.0 24 | - ppdeep==20200505 25 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import os 5 | import queue 6 | import sys 7 | import time 8 | import traceback 9 | from concurrent.futures import ProcessPoolExecutor, as_completed 10 | from multiprocessing import Manager 11 | 12 | from loguru import logger 13 | from tqdm import tqdm 14 | 15 | import BloomFilter 16 | import Dataset 17 | import SyntaxFilter 18 | import TokenFilter 19 | import Trace 20 | import config 21 | 22 | logger.remove() 23 | logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level="INFO") 24 | 25 | 26 | def progress_bar_process(total_cnt, pbar_queue, output_name="detect_info.json"): 27 | def get_time(info): 28 | if info["done"] >= info["input"]: 29 | return info["offset"] 30 | else: 31 | return info["offset"] + info["last"] - info["start"] 32 | 33 | def stop_timer(info): 34 | info["offset"] += time.perf_counter() - info["start"] 35 | info["start"] = time.perf_counter() 36 | info["last"] = info["start"] 37 | 38 | def start_timer(info): 39 | info["start"] = time.perf_counter() 40 | info["last"] = info["start"] 41 | 42 | def record_time(info): 43 | info["last"] = time.perf_counter() 44 | 45 | bloom_info = dict(input=0, done=0, offset=0.0, start=0.0, last=0.0) 46 | token_info = dict(input=0, done=0, offset=0.0, start=0.0, last=0.0) 47 | syntax_info = dict(input=0, done=0, offset=0.0, start=0.0, last=0.0) 48 | trace_info = dict(input=0, done=0, offset=0.0, start=0.0, last=0.0) 49 | vul_cnt = 0 50 | postfix_info = {} 51 | with tqdm( 52 | total=total_cnt, 53 | smoothing=0, 54 | unit="f", 55 | bar_format="{n_fmt}/{total_fmt}~{remaining}[{rate_fmt}{postfix}]", 56 | file=sys.stdout, 57 | ) as pbar: 58 | while True: 59 | try: 60 | info_from, status = pbar_queue.get(timeout=0.1) 61 | except queue.Empty: 62 | info_from, status = ("Nothing", False) 63 | if info_from == "__end_of_detection__": 64 | # dumping final infos 65 | with open(output_name, "w") as f: 66 | json.dump(postfix_info, f) 67 | break 68 | elif info_from == "dataset": 69 | if bloom_info["input"] == bloom_info["done"]: 70 | start_timer(bloom_info) 71 | bloom_info["input"] += 1 72 | elif info_from == "bloom": 73 | bloom_info["done"] += 1 74 | if bloom_info["input"] <= bloom_info["done"]: 75 | stop_timer(bloom_info) 76 | record_time(bloom_info) 77 | if status: 78 | if token_info["input"] == token_info["done"]: 79 | start_timer(token_info) 80 | token_info["input"] += 1 81 | else: 82 | pbar.update() 83 | elif info_from == "token": 84 | token_info["done"] += 1 85 | if token_info["input"] <= token_info["done"]: 86 | stop_timer(token_info) 87 | record_time(token_info) 88 | if status: 89 | if syntax_info["input"] == syntax_info["done"]: 90 | start_timer(syntax_info) 91 | syntax_info["input"] += 1 92 | else: 93 | pbar.update() 94 | elif info_from == "syntax": 95 | syntax_info["done"] += 1 96 | if syntax_info["input"] <= syntax_info["done"]: 97 | stop_timer(syntax_info) 98 | record_time(syntax_info) 99 | if status: 100 | if trace_info["input"] == trace_info["done"]: 101 | start_timer(trace_info) 102 | trace_info["input"] += 1 103 | else: 104 | pbar.update() 105 | elif info_from == "trace": 106 | trace_info["done"] += 1 107 | if trace_info["input"] <= trace_info["done"]: 108 | stop_timer(trace_info) 109 | record_time(trace_info) 110 | pbar.update() 111 | if status: 112 | vul_cnt += 1 113 | elif info_from == "Nothing": 114 | pass 115 | else: 116 | logger.error("Unknown Source Components") 117 | bloom_fail_filter_rate = token_info["input"] / max(bloom_info["done"], 1) 118 | bloom_speed = bloom_info["done"] / max(get_time(bloom_info), 1e-3) 119 | token_fail_filter_rate = syntax_info["input"] / max(token_info["done"], 1) 120 | token_speed = token_info["done"] / max(get_time(token_info), 1e-3) 121 | syntax_fail_filter_rate = trace_info["input"] / max(syntax_info["done"], 1) 122 | syntax_speed = syntax_info["done"] / max(get_time(syntax_info), 1e-3) 123 | trace_speed = trace_info["done"] / max(get_time(trace_info), 1e-3) 124 | postfix_info = { 125 | "bloom": "%d/%d(%.1f%%,%.1ff/s)" 126 | % ( 127 | bloom_info["done"], 128 | bloom_info["input"], 129 | 100 * (1 - bloom_fail_filter_rate), 130 | bloom_speed, 131 | ), 132 | "token": "%d/%d(%.1f%%,%.2f[%.2f]f/s)" 133 | % ( 134 | token_info["done"], 135 | token_info["input"], 136 | 100 * (1 - token_fail_filter_rate), 137 | token_speed, 138 | bloom_fail_filter_rate * bloom_speed, 139 | ), 140 | "syntax": "%d/%d(%.1f%%,%.2f[%.2f]f/s)" 141 | % ( 142 | syntax_info["done"], 143 | syntax_info["input"], 144 | 100 * (1 - syntax_fail_filter_rate), 145 | syntax_speed, 146 | token_fail_filter_rate * token_speed, 147 | ), 148 | "trace": "%d/%d(%d,%.2f[%.2f]f/s)" 149 | % ( 150 | trace_info["done"], 151 | trace_info["input"], 152 | vul_cnt, 153 | trace_speed, 154 | syntax_fail_filter_rate * syntax_speed, 155 | ), 156 | } 157 | pbar.set_postfix(postfix_info) 158 | 159 | 160 | def put_dataset_to_queue(dataset: Dataset.Base, output_queue, pbar_queue): 161 | # datasets = [] 162 | for func_path in dataset.get_funcs(): 163 | with open(func_path) as f: 164 | output_queue.put((f.read(), func_path, [])) 165 | pbar_queue.put(("dataset", False)) 166 | 167 | output_queue.put((None, "__end_of_detection__", [])) 168 | # return datasets 169 | 170 | 171 | def dump_trace_func(input_queue, output_name="trace.csv"): 172 | traces = [] 173 | while True: 174 | trace = input_queue.get() 175 | if trace == 0: 176 | break 177 | traces.append(trace) 178 | 179 | with open(output_name, "w") as csvfile: 180 | writer = csv.DictWriter( 181 | csvfile, 182 | fieldnames=[ 183 | "target_file", 184 | "vuln_file", 185 | "patch_file", 186 | "datail", 187 | "predict", 188 | ], 189 | ) 190 | writer.writeheader() 191 | writer.writerows(traces) 192 | 193 | 194 | def dump_vulnerable_func(input_queue, total_function_cnt, output_name="vuls.json"): 195 | vul_dict = {} 196 | 197 | vuls = [] 198 | vul_cnt = 0 199 | vul_all = 0 200 | while True: 201 | _, dst_file, similar_list = input_queue.get() 202 | if dst_file == "__end_of_detection__": 203 | break 204 | vul_cnt += 1 205 | 206 | vuls.append({"id": vul_cnt, "dst": dst_file, "sim": similar_list}) 207 | 208 | logger.success(f"[No. {vul_cnt}]Vul Detected in {dst_file}") 209 | logger.success("Similar to Vulnerability:") 210 | for exist_vul in similar_list: 211 | logger.success(exist_vul) 212 | vul_all += 1 213 | 214 | vul_dict["total_func"] = total_function_cnt 215 | vul_dict["cnt"] = vul_cnt 216 | vul_dict['all'] = vul_all 217 | vul_dict["vul"] = vuls 218 | 219 | logger.info(f"Dumping vulnerable function info to {output_name}") 220 | with open(output_name, "w") as f: 221 | json.dump(vul_dict, f, indent=4) 222 | 223 | if vul_cnt == 0: 224 | vul_dict = { 225 | "total_func": total_function_cnt, 226 | "cnt": vul_cnt, 227 | "all": vul_all, 228 | "vul": vuls 229 | } 230 | logger.info(f"Dumping vulnerable function info to {output_name}") 231 | with open(output_name, "w") as f: 232 | json.dump(vul_dict, f, indent=4) 233 | 234 | logger.info("Dump vulnerable function finished!") 235 | 236 | 237 | def main(ProjectDataset: Dataset.Project, output_name, rebuild_list): 238 | OldNewFuncsDataset = Dataset.OldNewFuncs( 239 | config.old_new_func_dataset_path, rebuild=("old-new-funcs" in rebuild_list) 240 | ) 241 | 242 | logger.info("Start Initialization") 243 | BloomFilter.initialization(OldNewFuncsDataset.get_funcs(vul=True), rebuild=("bloomFilter" in rebuild_list)) 244 | TokenFilter.initialization(OldNewFuncsDataset.get_funcs(vul=True)) 245 | SyntaxFilter.initialization(OldNewFuncsDataset.get_func_pairs()) 246 | Trace.initialization(OldNewFuncsDataset.get_func_pairs()) 247 | 248 | manager = Manager() 249 | dataset_queue = manager.Queue(maxsize=100) 250 | pbar_queue = manager.Queue(maxsize=100) 251 | bloom_filter_processed_queue = manager.Queue(maxsize=2000) 252 | token_filter_processed_queue = manager.Queue(maxsize=1000) 253 | syntax_filter_processed_queue = manager.Queue(maxsize=100) 254 | vulnerable_func_queue = manager.Queue(maxsize=100) 255 | trace_all_result_queue = manager.Queue(maxsize=100) 256 | 257 | logger.info("Start Detection") 258 | 259 | 260 | with ProcessPoolExecutor(max_workers=8) as executor: 261 | 262 | futures = [ 263 | executor.submit( 264 | progress_bar_process, 265 | len(ProjectDataset.get_funcs()), 266 | pbar_queue, 267 | os.path.splitext(output_name)[0] + ".detect_info.json", 268 | ), 269 | executor.submit( 270 | put_dataset_to_queue, ProjectDataset, dataset_queue, pbar_queue 271 | ), 272 | executor.submit( 273 | BloomFilter.detect, 274 | dataset_queue, 275 | bloom_filter_processed_queue, 276 | pbar_queue, 277 | ), 278 | executor.submit( 279 | TokenFilter.detect, 280 | bloom_filter_processed_queue, 281 | token_filter_processed_queue, 282 | pbar_queue, 283 | ), 284 | executor.submit( 285 | SyntaxFilter.detect, 286 | token_filter_processed_queue, 287 | syntax_filter_processed_queue, 288 | vulnerable_func_queue, 289 | pbar_queue, 290 | trace_all_result_queue, 291 | ), 292 | executor.submit( 293 | Trace.detect, 294 | syntax_filter_processed_queue, 295 | vulnerable_func_queue, 296 | pbar_queue, 297 | trace_all_result_queue, 298 | ), 299 | executor.submit(dump_vulnerable_func, vulnerable_func_queue, ProjectDataset.total_functions, output_name), 300 | executor.submit( 301 | dump_trace_func, 302 | trace_all_result_queue, 303 | os.path.splitext(output_name)[0] + ".trace.csv", 304 | ), 305 | ] 306 | 307 | 308 | for future in as_completed(futures): 309 | try: 310 | future.result() 311 | except Exception as e: 312 | exception_traceback = traceback.format_exc() 313 | logger.error(exception_traceback) 314 | 315 | logger.info("Detection Complete") 316 | 317 | 318 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 319 | 320 | if __name__ == "__main__": 321 | parser = argparse.ArgumentParser(description="Extract data from project dir") 322 | parser.add_argument("project", type=str, help="Path to the project dir") 323 | parser.add_argument("--rebuild", nargs="*", default=["target"], 324 | choices=["bloomFilter", "old-new-funcs", "target"], 325 | help="Rebuild any of the components/dataset cache") 326 | parser.add_argument( 327 | "--restore-processed", help="Restore processed cache" 328 | ) 329 | args = parser.parse_args() 330 | 331 | ProjectDataset = Dataset.Project(os.path.join(BASE_DIR, args.project), 332 | restore_processed=args.restore_processed, 333 | rebuild=("target" in args.rebuild)) 334 | 335 | project_name = os.path.basename(args.project) 336 | result_dir = f"result/{project_name}" 337 | os.makedirs(result_dir, exist_ok=True) 338 | 339 | main(ProjectDataset, output_name=f"{result_dir}/{project_name}.json", rebuild_list=args.rebuild) 340 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # FIRE-Public 2 | 3 | FIRE: Combining Multi-Stage Filtering with Taint Analysis for Scalable Recurring 4 | Vulnerability Detection. 5 | 6 | ## Overview 7 | 8 | The project consists four components(packages): `BloomFilter`(SFBF, Section 3.1), `TokenFilter`(Token Similarity Filter, Section 3.2), 9 | `SyntaxFilter`(AST Similarity Filter, Section 3.3), Trace(Vulnerability Identification Phase, Section 4). 10 | 11 | Besides, we provide utils classes in `Dataset` package to load dataset, including the `Old-New-Funcs` dataset, `NormalSample`dataset, and a class to load the target system (`Dataset/target_project.py`). 12 | 13 | During the detection, `cache`, `log`, `processed`, `result`, `workspace` five directories are used. 14 | 15 | We provide dockerfile and a flask server(`server.py`), so you can build the project to docker and use HTTP Request to detect vulnerability. 16 | 17 | ## Installation 18 | 19 | ### Read first before installing: 20 | 21 | ** Make sure you are installing the right version of the requirements and dependencies! ** 22 | 23 | Installing wrong version of dependency may cause exceptions and bugs, since several dependencies are under heavy developments and change fast. 24 | 25 | ** Do not extract the file in Windows and copy them to Linux. Extract them in Linux using `tar` and `unzip`. ** 26 | 27 | Extract file in Windows may lose some metadata and cause permission issue during the detection. 28 | 29 | ### Install Python Requirements 30 | 31 | #### conda 32 | 33 | ```shell 34 | conda env new -f environment.yml 35 | ``` 36 | 37 | #### pip 38 | ```shell 39 | # Install Python Requirements Except Torch 40 | pip install -r requirements.txt 41 | # Install Torch 42 | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cpu 43 | ``` 44 | 45 | ### Install CodeBert 46 | 47 | Put `codebert-base` in `resource/codebert`. 48 | 49 | We use the pretrained CodeBert model provided by neulab. You can find `codebert` here [codebert-cpp](https://huggingface.co/neulab/codebert-cpp). 50 | 51 | FIRE have extensibility to other languages, if you are interested in migrating FIRE from c/cpp to other language, change `codebert-cpp` to `codebert-` and find the right pretrained model in huggingface. 52 | 53 | See also the neulab code-bert link: (https://github.com/neulab/code-bert-score#huggingface--models)[https://github.com/neulab/code-bert-score#huggingface--models] 54 | 55 | If your interested language didn't have any pretrained models, you can use this one without pretraining: [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base). 56 | 57 | Note: These CodeBert Repos all have lfs objects. Simply using `git clone` may miss some vital objects stored in lfs. You **should** manually download those lfs object after cloning the model. 58 | 59 | ### Install Joern 60 | 61 | Joern needs Java to run. In our project we use `jdk-17.0.11`. 62 | 63 | #### Install Java 64 | 65 | Get tar.gz tarball of jdk and unzip it to `resource/jdk-17.0.11`. 66 | 67 | We have tried multiple version of java and java17 works best. **Make sure you are installing the right java version** 68 | 69 | ```bash 70 | JAVA_HOME="/path/to/FIRE-public/resource/jdk-17.0.11" 71 | PATH=$PATH:$JAVA_HOME/bin 72 | java --version 73 | ``` 74 | ``` 75 | java 17.0.11 2024-04-16 LTS 76 | Java(TM) SE Runtime Environment (build 17.0.11+7-LTS-207) 77 | Java HotSpot(TM) 64-Bit Server VM (build 17.0.11+7-LTS-207, mixed mode, sharing) 78 | ``` 79 | 80 | #### Install Joern-cli 81 | 82 | We use version `1.2.1` of Joern. You can find Joern in this GitHub repo: [joernio/joern](https://github.com/joernio/joern). 83 | 84 | You can find the v1.2.1 `joern-cli.zip` file here: [joern-cli.zip](https://github.com/joernio/joern/releases/download/v1.2.1/joern-cli.zip) 85 | 86 | Please download the zip tarball of Joern and unzip it to `resource/joern-cli` 87 | 88 | Version 1.2.1 work best for our project.** Make sure you are installing the right version **. 89 | 90 | ```bash 91 | ./resource/joern-cli/joern 92 | ``` 93 | ``` 94 | ██╗ ██████╗ ███████╗██████╗ ███╗ ██╗ 95 | ██║██╔═══██╗██╔════╝██╔══██╗████╗ ██║ 96 | ██║██║ ██║█████╗ ██████╔╝██╔██╗ ██║ 97 | ██ ██║██║ ██║██╔══╝ ██╔══██╗██║╚██╗██║ 98 | ╚█████╔╝╚██████╔╝███████╗██║ ██║██║ ╚████║ 99 | ╚════╝ ╚═════╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═══╝ 100 | Version: 1.2.1 101 | Type `help` or `browse(help)` to begin 102 | 103 | joern> 104 | ``` 105 | 106 | #### About Ctags 107 | 108 | Since Ctags is a lightweight open-source software, we put its binary version in `Database/universal-ctags` with COPYING. 109 | So you don't need to install it. However, you should make sure +x is set to ctags file before run. 110 | 111 | ```bash 112 | ./Dataset/universal-ctags/ctags --version 113 | ``` 114 | ``` 115 | Universal Ctags 6.0.0(293f11e), Copyright (C) 2015-2022 Universal Ctags Team 116 | Universal Ctags is derived from Exuberant Ctags. 117 | Exuberant Ctags 5.8, Copyright (C) 1996-2009 Darren Hiebert 118 | Compiled: Dec 20 2023, 10:38:07 119 | URL: https://ctags.io/ 120 | Output version: 0.0 121 | Optional compiled features: +wildcards, +regex, +gnulib_regex, +iconv, +option-directory, +xpath, +json, +interactive, +yaml, +packcc, +optscript 122 | ``` 123 | 124 | #### About Redis 125 | 126 | `Trace` need Redis for caching. We use Redis docker in our experiments. 127 | 128 | ** Run FIRE Outside Docker ** 129 | 130 | You can install Redis v7.2.3 using package manager or use docker. 131 | 132 | For example, you can launch Redis using the command below. 133 | 134 | ```bash 135 | docker run -p 6379:6379 redis:7.2.3 136 | ``` 137 | 138 | ** Build and Run FIRE in Docker ** 139 | 140 | please make sure you have put redis 7.2.3 in `resource/redis-7.2.3`. 141 | 142 | The external redis docker is **no need** during the detection since we will install the redis during the build of docker. 143 | 144 | If you run FIRE outside Docker, the step is **no need**. 145 | 146 | ## Datasets 147 | 148 | We use Old-New-Funcs dataset to store all the vulnerabilities and patches pairs which is used in all the components of FIRE. 149 | 150 | ### Old-New-Funcs Dataset 151 | 152 | We suggest to put the dataset to `resource/OldNewFuncs`. 153 | 154 | Unfortunately we can not open source the dataset we used in this project, but you can build one using your own data following the structure below. 155 | 156 | An Example of the Old-New-Funcs dataset folder structure: 157 | 158 | ``` 159 | |-- OldNewFuncs 160 | | |-- ffmpeg (software directory) 161 | | | |-- CVE-2009-0385 (CVE directory) 162 | | | | |-- CVE-2009-0385_CWE-189_72e715fb798f2cb79fd24a6d2eaeafb7c6eeda17_4xm.c_1.1_fourxm_read_header_OLD.vul [Vulnerable Version] 163 | | | | |-- CVE-2009-0385_CWE-189_72e715fb798f2cb79fd24a6d2eaeafb7c6eeda17_4xm.c_1.1_fourxm_read_header_NEW.vul [Patch Version] 164 | | | | |-- ...Other Old-New-Funcs files (with the filename extension `.vul`) 165 | | | |-- ...Other CVEs 166 | | |-- ...Other Software 167 | ``` 168 | 169 | We do not utilize the software and CVE directory name. However, we utilize the old-new-funcs file's filename 170 | in our project. Each Old-New-Funcs file should store a function. 171 | 172 | The Old-New-Funcs filename structure: 173 | ``` 174 | [CVE-No.]_[CWE-No.]_[Commit]_[File Extracted From]_[Version]_[Function Name]_[OLD/NEW].vul 175 | ``` 176 | `OLD` tag refers to vulnerability version, while `NEW` tag refers to patch version. 177 | 178 | We utilized the `CVE`, `Function Name` and `OLD/NEW` part of the filename in FIRE. So please set them properly. 179 | 180 | ### ~~~NormalSample Dataset~~~ (No need anymore) 181 | 182 | The NormalSample Dataset Structure: 183 | 184 | We suggest to put the dataset at `resource/NormalSample` 185 | 186 | ``` 187 | |-- NormalSample Dataset 188 | | |-- ffmpeg (software directory) 189 | | | |-- ...functions 190 | | |-- ...Other Software 191 | ``` 192 | 193 | There is no extra constraints for the filenames of the normal functions store in the software directory. 194 | 195 | ## How To Run 196 | 197 | ### Run Locally 198 | 199 | Make sure you have properly installed all the requirements and prepared the datasets before run. 200 | 201 | You can execute `python3 main --help` to read the help message of this project. 202 | 203 | Currently, FIRE only runs on Linux. 204 | 205 | #### Basic Usage 206 | ```bash 207 | python3 main.py /path/to/target/system 208 | ``` 209 | 210 | #### Help Message 211 | ```bash 212 | python3 main.py --help 213 | ``` 214 | ``` 215 | usage: main.py [-h] [--rebuild [{bloomFilter,old-new-funcs,normal-sample,target} ...]] project 216 | 217 | Extract data from project dir 218 | 219 | positional arguments: 220 | project Path to the project dir 221 | 222 | options: 223 | -h, --help show this help message and exit 224 | --rebuild [{bloomFilter,old-new-funcs,target} ...] 225 | Rebuild any of the components/dataset cache 226 | ``` 227 | 228 | Note: It would be better putting the project arguments before options to avoid parsing error. An example using `--rebuild` option: 229 | 230 | ```bash 231 | python3 main.py /path/to/target --rebuild bloomFilter old-new-funcs target 232 | ``` 233 | 234 | #### Rebuild Option 235 | 236 | We provide rebuild option to rebuild the cache when there are any updates to the dataset. **We suggest to apply all the rebuild options first time before running the project.** 237 | 238 | If you update Old-New-Funcs Dataset, please rebuild `bloomFilter` and `old-new-funcs`. 239 | 240 | If you do not specify any rebuild options, `target` option is set default to extract function of the target system each time before the vulnerbility detection. 241 | 242 | Use space to separate the option if you want to apply multiple rebuild option. 243 | 244 | #### Results 245 | 246 | Detection results not only display in the console, but also in the `result` folder as well. You can find the detection result in `result/[target-system]`. 247 | 248 | ### Run Remote or In Docker 249 | 250 | Run `server.py` if you want to run FIRE remote. If you use docker, `server.py` runs automatically. 251 | This will open a flask server on port 8000 on the machine/docker. You can change the port in the `server.py`. 252 | 253 | ```bash 254 | python3 server.py 255 | ``` 256 | 257 | You can publish a vulnerability detecting job using the following HTTP requests. 258 | 259 | #### Request 260 | 261 | - Method: GET 262 | - URL: /process?git-url={git-url}&branch={branch} 263 | - `git-url`: git url to the target system. 264 | - `branch`: tag or branch of the target system. 265 | 266 | #### Response 267 | 268 | - Body(Json) 269 | - `time`: Project Runtime. 270 | - `vul`: Vulnerabilities Detected. 271 | - `vul_cnt`: Count of the detected vulnerabilities. 272 | 273 | #### Docker build 274 | 275 | You should fully generate the cache (old-new-funcs and bloomFilter) before building the docker. 276 | 277 | ```bash 278 | docker build . 279 | ``` 280 | 281 | ### Notes 282 | 283 | We use lazy caching technique (generate the cache vector when the vulnerability and patch function are needed) instead of generate vectors of all vulnerability and patch functions in advance in `Trace` component to accelerate the experiments, making the first run of FIRE might slower than expected. However, in production environment, all the vectors of vulnerability and patch function should generate in advance. So please **run again** to get the actual run speed. 284 | 285 | The experiments are conducted on a machine with a 3.40 GHz Intel i7-13700k processor and 48 GB of RAM, running on ArchLinux with Linux Zen Kernel (Appendix C). **Please adjust the max process in each component to avoid crashes according to your experiments environments**. 286 | 287 | # Publication 288 | Siyue Feng, Yueming Wu, Wenjie Xue, Sikui Pan, Deqing Zou, Yang Liu and Hai Jin. 2024. FIRE: Combining Multi-Stage Filtering with Taint Analysis for Scalable Recurring Vulnerability Detection. In Proceedings of the 33rd USENIX Security Symposium (USENIX Security ’24), August 14–16, 2024, Philadelphia Marriott Downtown in Philadelphia, PA, USA, 18 pages. 289 | 290 | 291 | If you use our dataset or source code, please kindly cite our paper: 292 | ``` 293 | @INPROCEEDINGS{fire2024, 294 | author={Feng, Siyue and Wu, Yueming and Xue, Wenjie and Pan, Sikui and Zou, Deqing and Liu, Yang and Jin, Hai}, 295 | booktitle={33rd USENIX Security Symposium (USENIX Security ’24)}, 296 | title={FIRE: Combining Multi-Stage Filtering with Taint Analysis for Scalable Recurring Vulnerability Detection}, 297 | year={2024}} 298 | ``` 299 | 300 | # Support or Contact 301 | FIRE is developed in the National Engineering Research Center for Big Data Technology and System, Services Computing Technology and System Lab, Hubei Key Laboratory of Distributed System Security, Hubei Engineering Research Center on Big Data Security, Cluster and Grid Computing Lab, Huazhong University of Science and Technology, Wuhan, China by Siyue Feng (fengsiyue@hust.edu.cn), Yueming Wu (wuyueming21@gmail.com), Wenjie Xue (xuewenjie2021@hust.edu.cn), Sikui Pan (skpan@hust.edu.cn), Deqing Zou (deqingzou@hust.edu.cn), Yang Liu (yangliu@ntu.edu.sg), and Hai Jin (hjin@hust.edu.cn). 302 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | loguru==0.7.2 2 | numpy==1.26.2 3 | pygments==2.15.1 4 | bloom-filter2==2.0.0 5 | matplotlib==3.8.0 6 | tqdm==4.65.0 7 | levenshtein==0.23.0 8 | networkx==3.2.1 9 | transformers==4.36.2 10 | redis==5.0.1 11 | flask==3.0.0 12 | ppdeep==20200505 13 | tree-sitter==0.20.4 14 | anytree==2.8.0 -------------------------------------------------------------------------------- /resource/OldNewFuncs/readme.md: -------------------------------------------------------------------------------- 1 | We suggest to put old-new-funcs dataset here. 2 | 3 | in `sample-project` dir, there is a simple example there. -------------------------------------------------------------------------------- /resource/OldNewFuncs/sample-project/CVE-SAMPLE/CVE-SAMPLE_CWE-SAMPLE_abcdef_SAMPLE.cpp__SAMPLE_FUNC_NEW.vul: -------------------------------------------------------------------------------- 1 | // Put Patch function here -------------------------------------------------------------------------------- /resource/OldNewFuncs/sample-project/CVE-SAMPLE/CVE-SAMPLE_CWE-SAMPLE_abcdef_SAMPLE.cpp__SAMPLE_FUNC_OLD.vul: -------------------------------------------------------------------------------- 1 | // Put Vul function here -------------------------------------------------------------------------------- /resource/codebert/readme.md: -------------------------------------------------------------------------------- 1 | Put `codebert-base` in this directory and remove this file. -------------------------------------------------------------------------------- /resource/jdk-17.0.11/readme.md: -------------------------------------------------------------------------------- 1 | Put oracle JDK 17.0.11 here and remove this file. -------------------------------------------------------------------------------- /resource/joern-cli/readme.md: -------------------------------------------------------------------------------- 1 | put Joern 1.2.1 here and remove this file. -------------------------------------------------------------------------------- /resource/readme.md: -------------------------------------------------------------------------------- 1 | This is the resource FIRE need. 2 | 3 | You should download corresponding version of resource and extract them to the corresponding directory before FIRE run. -------------------------------------------------------------------------------- /resource/redis-7.2.3/readme.md: -------------------------------------------------------------------------------- 1 | put redis 7.2.3 here and remove this file -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import subprocess 5 | import time 6 | 7 | import yaml 8 | from flask import Flask, request 9 | 10 | app = Flask(__name__) 11 | process_running = False 12 | 13 | 14 | @app.route('/') 15 | def hello(): 16 | return 'Hello, World!' 17 | 18 | 19 | def change_yml(jaccard_sim_threshold=0.7, ast_hash_sim_threshold_min=0.7): 20 | if os.path.exists("config.yml"): 21 | with open("config.yml", "r") as f: 22 | config = yaml.safe_load(f) 23 | config["experiment"]["token_filter"]["jaccard_sim_threshold"] = jaccard_sim_threshold 24 | config["experiment"]["trace"]["ast_sim_threshold_min"] = ast_hash_sim_threshold_min 25 | with open("config.yml", "w") as f: 26 | yaml.dump(config, f) 27 | 28 | 29 | @app.route('/process') 30 | def process(): 31 | global process_running 32 | 33 | if not process_running: 34 | process_running = True 35 | try: 36 | shutil.copy("config.default.yml", "config.yml") 37 | git_url = request.args.get('git_url', '') 38 | branch = request.args.get('branch', '') 39 | jaccard_sim_threshold_str = request.args.get('jaccard_sim_threshold', '0.7') 40 | ast_sim_threshold_min_str = request.args.get('ast_sim_threshold_min', '0.7') 41 | ast_sim_threshold_min = float(ast_sim_threshold_min_str) 42 | jaccard_sim_threshold = float(jaccard_sim_threshold_str) 43 | change_yml(jaccard_sim_threshold, ast_sim_threshold_min) 44 | 45 | if git_url and branch: 46 | git_name = os.path.basename(git_url.rstrip('/')) 47 | subprocess.call('git clone --branch %s --depth=1 %s' % (branch, git_url), shell=True) 48 | start_time = time.time() 49 | code = subprocess.call('python3 main.py %s' % git_name, shell=True) 50 | end_time = time.time() 51 | shutil.rmtree(git_name) 52 | if code != 0: 53 | process_running = False 54 | return json.dumps({'Error': 'Detect Failed'}), 500 55 | 56 | result_dir = f'result/{git_name}' 57 | trace_file = os.path.join(result_dir, f'{git_name}.trace.csv') 58 | log_file = os.path.join(result_dir, f'{git_name}.json') 59 | info_file = os.path.join(result_dir, f'{git_name}.detect_info.json') 60 | if os.path.exists(log_file): 61 | 62 | with open(log_file, 'r') as f: 63 | orig_vul_json = json.load(f) 64 | 65 | vul_cnt = orig_vul_json["all"] 66 | vul_json = {} 67 | for vul in orig_vul_json["vul"]: 68 | if vul["dst"] not in vul_json: 69 | vul_json[vul["dst"]] = vul["sim"] 70 | else: 71 | vul_json[vul["dst"]].extend(vul["sim"]) 72 | 73 | if os.path.exists(trace_file): 74 | with open(trace_file, "r") as f: 75 | csv_info = f.read() 76 | else: 77 | csv_info = "" 78 | 79 | if os.path.exists(info_file): 80 | with open(info_file, "r") as f: 81 | detect_info = json.load(f) 82 | else: 83 | detect_info = {} 84 | 85 | response = json.dumps( 86 | {"time": end_time - start_time, "vul": vul_json, "vul_cnt": vul_cnt, "csv_info": csv_info, 87 | "detect_info": detect_info}) 88 | else: 89 | return json.dumps({'Error': 'Log file not found.'}), 500 90 | 91 | shutil.rmtree(result_dir, ignore_errors=True) 92 | process_running = False 93 | return response 94 | else: 95 | process_running = False 96 | return json.dumps({'Error': 'Missing git_url or branch parameter.'}), 400 97 | 98 | except Exception as e: 99 | process_running = False 100 | return json.dumps({'Error': str(e)}), 500 101 | 102 | else: 103 | return json.dumps({'Error': 'Another process is already running.'}), 429 104 | 105 | 106 | if __name__ == '__main__': 107 | app.run("0.0.0.0", port=8000) 108 | --------------------------------------------------------------------------------