├── varbert ├── __init__.py ├── fine-tune │ ├── __init__.py │ └── preprocess.py ├── tokenizer │ ├── preprocess.py │ └── train_bpe_tokenizer.py ├── mlm │ └── preprocess.py ├── generate_vocab.py ├── resize_model.py ├── cmlm │ ├── preprocess.py │ └── eval.py └── README.md ├── varcorpus ├── __init__.py ├── dataset-gen │ ├── __init__.py │ ├── log.py │ ├── pathmanager.py │ ├── decompiler │ │ ├── ida_unrecogn_func.py │ │ ├── ghidra_dec.py │ │ ├── ida_dec.py │ │ ├── run_decompilers.py │ │ └── ida_analysis.py │ ├── preprocess_vars.py │ ├── binary.py │ ├── joern_parser.py │ ├── generate.py │ ├── utils.py │ ├── strip_types.py │ ├── dwarf_info.py │ ├── parse_decompiled_code.py │ ├── runner.py │ ├── create_dataset_splits.py │ └── variable_matching.py └── README.md ├── requirements.txt ├── Dockerfile ├── .gitignore └── README.md /varbert/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /varcorpus/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /varbert/fine-tune/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cpgqls_client 2 | regex 3 | pyelftools==0.29 4 | cle 5 | torch 6 | transformers 7 | jsonlines 8 | pandas 9 | gdown 10 | jupyterlab 11 | tensorboardX 12 | datasets 13 | scikit-learn 14 | accelerate>=0.20.1 -------------------------------------------------------------------------------- /varcorpus/dataset-gen/log.py: -------------------------------------------------------------------------------- 1 | import logging.config 2 | import os 3 | 4 | def setup_logging(LOG_DIRECTORY, DEBUG=False): 5 | if not os.path.exists(LOG_DIRECTORY): 6 | os.makedirs(LOG_DIRECTORY) 7 | 8 | log_level = 'DEBUG' if DEBUG else 'INFO' 9 | LOGGING_CONFIG = { 10 | 'version': 1, 11 | 'disable_existing_loggers': False, 12 | 13 | 'formatters': { 14 | 'default_formatter': { 15 | 'format': '%(asctime)s - %(levelname)s - %(name)s - %(filename)s - %(lineno)d : %(message)s' 16 | } 17 | }, 18 | 19 | 'handlers': { 20 | 'main_handler': { 21 | 'class': 'logging.FileHandler', 22 | 'formatter': 'default_formatter', 23 | 'filename': os.path.join(LOG_DIRECTORY, 'varcorpus.log') 24 | } 25 | }, 26 | 27 | 'loggers': { 28 | 'main': { 29 | 'handlers': ['main_handler'], 30 | 'level': log_level, 31 | 'propagate': False 32 | } 33 | } 34 | } 35 | 36 | logging.config.dictConfig(LOGGING_CONFIG) -------------------------------------------------------------------------------- /varcorpus/dataset-gen/pathmanager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | class PathManager: 5 | def __init__(self, args): 6 | # Base paths set by user input 7 | self.binaries_dir = args.binaries_dir 8 | self.data_dir = os.path.abspath(args.data_dir) 9 | self.joern_dir = os.path.abspath(args.joern_dir) 10 | if not args.tmpdir: 11 | self.tmpdir = tempfile.mkdtemp(prefix='varbert_tmpdir_', dir='/tmp') 12 | else: 13 | self.tmpdir = args.tmpdir 14 | print("self.tmpdir", self.tmpdir) 15 | os.makedirs(self.tmpdir, exist_ok=True) 16 | 17 | self.ida_path = args.ida_path 18 | self.ghidra_path = args.ghidra_path 19 | self.corpus_language = args.corpus_language 20 | 21 | # Derived paths 22 | self.strip_bin_dir = os.path.join(self.tmpdir, 'binary/strip') 23 | self.type_strip_bin_dir = os.path.join(self.tmpdir, 'binary/type_strip') 24 | self.joern_data_path = os.path.join(self.tmpdir, 'joern') 25 | self.dc_path = os.path.join(self.tmpdir, 'dc') 26 | self.failed_path_ida = os.path.join(self.tmpdir, 'failed') 27 | self.dwarf_info_path = os.path.join(self.tmpdir, 'dwarf') 28 | self.match_path = os.path.join(self.tmpdir, 'map') 29 | 30 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | RUN apt-get update -y 4 | # for tdata geo locn 5 | RUN ln -snf /usr/share/zoneinfo/$CONTAINER_TIMEZONE /etc/localtime && echo $CONTAINER_TIMEZONE > /etc/timezone 6 | 7 | # Install dependencies 8 | RUN apt-get -y install python3.11 python3-pip git binutils-multiarch wget 9 | # Ghidra 10 | RUN apt-get -y install openjdk-17-jdk openjdk-11-jdk 11 | 12 | # Joern 13 | RUN apt-get install -y openjdk-11-jre-headless psmisc && \ 14 | apt-get clean; 15 | RUN pip3 install cpgqls_client regex 16 | 17 | WORKDIR /varbert_workdir 18 | 19 | # Dwarfwrite 20 | RUN git clone https://github.com/rhelmot/dwarfwrite.git 21 | WORKDIR /varbert_workdir/dwarfwrite 22 | RUN pip install . 23 | 24 | WORKDIR /varbert_workdir/ 25 | RUN git clone https://github.com/sefcom/VarBERT.git 26 | WORKDIR /varbert_workdir/VarBERT/ 27 | RUN pip install -r requirements.txt 28 | 29 | RUN apt-get install unzip 30 | RUN wget -cO /varbert_workdir/joern.tar.gz "https://www.dropbox.com/scl/fi/toh6087y5t5xyln47i5ih/modified_joern.tar.gz?rlkey=lfvjn1u7zvtp9a4cu8z8vgsof" && tar -xzf /varbert_workdir/joern.tar.gz -C /varbert_workdir && rm /varbert_workdir/joern.tar.gz 31 | RUN wget -cO /varbert_workdir/ghidra_10.4_PUBLIC_20230928.zip "https://github.com/NationalSecurityAgency/ghidra/releases/download/Ghidra_10.4_build/ghidra_10.4_PUBLIC_20230928.zip" && unzip /varbert_workdir/ghidra_10.4_PUBLIC_20230928.zip -d /varbert_workdir && rm /varbert_workdir/ghidra_10.4_PUBLIC_20230928.zip 32 | 33 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/decompiler/ida_unrecogn_func.py: -------------------------------------------------------------------------------- 1 | import idaapi 2 | import idc 3 | import ida_funcs 4 | import ida_hexrays 5 | import ida_kernwin 6 | import ida_loader 7 | 8 | def setup(): 9 | global analysis 10 | path = ida_loader.get_path(ida_loader.PATH_TYPE_CMD) 11 | 12 | def read_list(filename): 13 | with open(filename, 'r') as r: 14 | ty_addrs = r.read().split('\n') 15 | return ty_addrs 16 | 17 | def add_unrecognized_func(ty_addrs): 18 | for addr in ty_addrs: 19 | # check if func is recognized by IDA 20 | name = ida_funcs.get_func_name(int(addr)) 21 | if name: 22 | print(f"func present at: {addr} {name}") 23 | else: 24 | if ida_funcs.add_func(int(addr)): 25 | print(f"func recognized at: {addr} ") 26 | else: 27 | print(f"bad address {addr}") 28 | 29 | def go(): 30 | setup() 31 | ea = 0 32 | filename = idc.ARGV[1] 33 | ty_addrs = read_list(filename) 34 | add_unrecognized_func(ty_addrs) 35 | 36 | while True: 37 | func = ida_funcs.get_next_func(ea) 38 | if func is None: 39 | break 40 | ea = func.start_ea 41 | seg = idc.get_segm_name(ea) 42 | if seg != ".text": 43 | continue 44 | print('analyzing', ida_funcs.get_func_name(ea), hex(ea), ea) 45 | analyze_func(func) 46 | 47 | def analyze_func(func): 48 | cfunc = ida_hexrays.decompile_func(func, None, 0) 49 | if cfunc is None: 50 | return 51 | 52 | idaapi.auto_wait() 53 | go() 54 | -------------------------------------------------------------------------------- /varbert/tokenizer/preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import numpy as np 4 | import time 5 | import os 6 | import math 7 | import json 8 | import jsonlines 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | import re 12 | import pandas as pd 13 | 14 | def update_data(data, pattern): 15 | new_data = [] 16 | for body in tqdm(data): 17 | idx = 0 18 | new_body = "" 19 | for each_var in list(re.finditer(pattern, body)): 20 | s = each_var.start() 21 | e = each_var.end() 22 | prefix = body[idx:s] 23 | var = body[s:e] 24 | orig_var = var.split("@@")[-2] 25 | new_body += prefix + orig_var 26 | idx = e 27 | new_body += body[idx:] 28 | new_data.append(new_body) 29 | return new_data 30 | 31 | def main(args): 32 | data = [] 33 | with jsonlines.open(args.input_file) as f: 34 | for each in tqdm(f): 35 | data.append(each['norm_func']) 36 | 37 | new_data = update_data(data, "@@\w+@@\w+@@") 38 | 39 | with jsonlines.open(args.output_file, mode='w') as writer: 40 | for item in new_data: 41 | writer.write({'func': item}) 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser(description='Process JSONL files.') 45 | parser.add_argument('--input_file', type=str, help='Path to the input HSC jsonl file') 46 | parser.add_argument('--output_file', type=str, help='Path to the save HSC jsonl file for tokenization') 47 | args = parser.parse_args() 48 | main(args) 49 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/decompiler/ghidra_dec.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import shutil 4 | import logging 5 | from ghidra.app.decompiler import DecompInterface, DecompileOptions 6 | 7 | def decompile_(binary_name): 8 | 9 | binary = getCurrentProgram() 10 | decomp_interface = DecompInterface() 11 | options = DecompileOptions() 12 | options.setNoCastPrint(True) 13 | decomp_interface.setOptions(options) 14 | decomp_interface.openProgram(binary) 15 | 16 | func_mg = binary.getFunctionManager() 17 | funcs = func_mg.getFunctions(True) 18 | 19 | # args 20 | args = getScriptArgs() 21 | out_path, failed_path, log_path = str(args[0]), str(args[1]), str(args[2]) 22 | 23 | regex_com = r"(/\*[\s\S]*?\*\/)|(//.*)" 24 | tot_lines = 0 25 | 26 | try: 27 | with open(out_path, 'w') as w: 28 | for func in funcs: 29 | results = decomp_interface.decompileFunction(func, 0, None ) 30 | addr = str(func.getEntryPoint()) 31 | func_res = results.getDecompiledFunction() 32 | if 'EXTERNAL' in func.getName(True): 33 | continue 34 | if func_res: 35 | func_c = str(func_res.getC()) 36 | # rm comments 37 | new_func = str(re.sub(regex_com, '', func_c, 0, re.MULTILINE)).strip() 38 | 39 | # for joern parsing 40 | before = '//----- (' 41 | after = ') ----------------------------------------------------\n' 42 | addr_line = before + addr + after 43 | tot_lines += 2 44 | w.write(addr_line) 45 | w.write(new_func) 46 | w.write('\n') 47 | tot_lines += new_func.count('\n') 48 | 49 | except Exception as e: 50 | # move log file to failed path 51 | shutil.move(os.path.join(log_path, (binary_name + '.log')), os.path.join(failed_path, ('ghidra_' + binary_name + '.log'))) 52 | 53 | if __name__ == '__main__': 54 | 55 | bin_name = str(locals()['currentProgram']).split(' - .')[0].strip() 56 | decompile_(bin_name) 57 | 58 | -------------------------------------------------------------------------------- /varbert/mlm/preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import numpy as np 4 | import time 5 | import os 6 | import math 7 | import json 8 | import jsonlines 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | import re 12 | import pandas as pd 13 | 14 | def update_function(data, pattern): 15 | new_data = [] 16 | for body in tqdm(data): 17 | idx = 0 18 | new_body = "" 19 | for each_var in list(re.finditer(pattern, body)): 20 | s = each_var.start() 21 | e = each_var.end() 22 | prefix = body[idx:s] 23 | var = body[s:e] 24 | orig_var = var.split("@@")[-2] 25 | new_body += prefix + orig_var 26 | idx = e 27 | new_body += body[idx:] 28 | new_data.append(new_body) 29 | return new_data 30 | 31 | def process_file(input_file, output_file, pattern): 32 | data = [] 33 | with jsonlines.open(input_file) as f: 34 | for each in tqdm(f): 35 | data.append(each['norm_func']) 36 | 37 | new_data = update_function(data, pattern) 38 | 39 | mlm_data = [] 40 | for idx, each in tqdm(enumerate(new_data)): 41 | if len(each) > 0 and not each.isspace(): 42 | mlm_data.append({'text': each, 'source': 'human', '_id': idx}) 43 | 44 | with jsonlines.open(output_file, mode='w') as f: 45 | for each in tqdm(mlm_data): 46 | f.write(each) 47 | 48 | def main(args): 49 | process_file(args.train_file, args.output_train_file, "@@\w+@@\w+@@") 50 | process_file(args.test_file, args.output_test_file, "@@\w+@@\w+@@") 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser(description='Process and save jsonl files for MLM') 54 | parser.add_argument('--train_file', type=str, required=True, help='Path to the input train JSONL file') 55 | parser.add_argument('--test_file', type=str, required=True, help='Path to the input test JSONL file') 56 | parser.add_argument('--output_train_file', type=str, required=True, help='Path to the output train JSONL file for MLM') 57 | parser.add_argument('--output_test_file', type=str, required=True, help='Path to the output test JSONL file for MLM') 58 | args = parser.parse_args() 59 | main(args) 60 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/decompiler/ida_dec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shutil 4 | import logging 5 | import subprocess 6 | from pathlib import Path 7 | 8 | l = logging.getLogger('main') 9 | 10 | def decompile_(decompiler_workdir, decompiler_path, binary_name, binary_path, binary_type, 11 | decompiled_binary_code_path, failed_path , type_strip_addrs, type_strip_mangled_names): 12 | try: 13 | shutil.copy(binary_path, decompiler_workdir) 14 | current_script_dir = Path(__file__).resolve().parent 15 | addr_file = os.path.join(decompiler_workdir, f'{binary_name}_addr') 16 | if binary_type == 'strip': 17 | run_str = f'{decompiler_path} -S"{current_script_dir}/ida_unrecogn_func.py {addr_file}" -L{decompiler_workdir}/log_{binary_name}.txt -Ohexrays:{decompiler_workdir}/outfile:ALL -A {decompiler_workdir}/{binary_name}' 18 | elif binary_type == 'type_strip': 19 | run_str = f'{decompiler_path} -S"{current_script_dir}/ida_analysis.py {addr_file}" -L{decompiler_workdir}/log_{binary_name}.txt -Ohexrays:{decompiler_workdir}/outfile:ALL -A {decompiler_workdir}/{binary_name}' 20 | 21 | for _ in range(5): 22 | subprocess.run([run_str], shell=True) 23 | if f'outfile.c' in os.listdir(decompiler_workdir): 24 | l.debug(f"outfile.c generated! for {binary_name}") 25 | break 26 | time.sleep(5) 27 | 28 | except Exception as e: 29 | shutil.move(os.path.join(decompiler_workdir, f'log_{binary_name}.txt'), os.path.join(failed_path, ('ida_' + binary_name + '.log'))) 30 | return None, str(os.path.join(failed_path, ('ida_' + binary_name + '.log'))) 31 | 32 | finally: 33 | # cleanup! 34 | if os.path.exists(os.path.join(decompiler_workdir, f'{binary_name}.i64')): 35 | os.remove(os.path.join(decompiler_workdir, f'{binary_name}.i64')) 36 | if binary_type == 'strip': 37 | subprocess.run(['mv', addr_file, type_strip_addrs ]) 38 | subprocess.run(['mv', f'{addr_file}_names', type_strip_mangled_names]) 39 | if os.path.exists(f'{decompiler_workdir}/outfile.c'): 40 | if subprocess.run(['mv', f'{decompiler_workdir}/outfile.c', f'{decompiled_binary_code_path}']).returncode == 0: 41 | return True, f'{decompiled_binary_code_path}' 42 | else: 43 | return False, str(os.path.join(failed_path, ('ida_' + binary_name + '.log'))) -------------------------------------------------------------------------------- /varcorpus/dataset-gen/decompiler/run_decompilers.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import time 3 | import os 4 | import shutil 5 | import re 6 | from collections import defaultdict 7 | import json 8 | import sys 9 | import tempfile 10 | from decompiler.ida_dec import decompile_ 11 | # from logger import get_logger 12 | from pathlib import Path 13 | import logging 14 | l = logging.getLogger('main') 15 | 16 | class Decompiler: 17 | def __init__(self, decompiler, decompiler_path, decompiler_workdir, 18 | binary_name, binary_path, binary_type, 19 | decompiled_binary_code_path, failed_path, type_strip_addrs, type_strip_mangled_names) -> None: 20 | self.decompiler = decompiler 21 | self.decompiler_path = decompiler_path 22 | self.decompiler_workdir = decompiler_workdir 23 | self.binary_name = binary_name 24 | self.binary_path = binary_path 25 | self.binary_type = binary_type 26 | self.decompiled_binary_code_path = decompiled_binary_code_path 27 | self.failed_path = failed_path 28 | self.type_strip_addrs= type_strip_addrs 29 | self.type_strip_mangled_names = type_strip_mangled_names 30 | self.decompile() 31 | 32 | def decompile(self): 33 | 34 | if self.decompiler == "ida": 35 | success, path = decompile_(self.decompiler_workdir, Path(self.decompiler_path), self.binary_name, 36 | Path(self.binary_path), self.binary_type, Path(self.decompiled_binary_code_path), 37 | Path(self.failed_path), Path(self.type_strip_addrs), Path(self.type_strip_mangled_names)) 38 | if not success: 39 | l.error(f"Decompilation failed for {self.binary_name} :: {self.binary_type} :: check logs at {path}") 40 | return None, None 41 | return success, path 42 | 43 | elif self.decompiler == "ghidra": 44 | try: 45 | current_script_dir = Path(__file__).resolve().parent 46 | subprocess.call(['{} {} tmp_project -scriptPath {} -postScript ghidra_dec.py {} {} {} -import {} -readOnly -log {}.log'.format(self.decompiler_path, self.decompiler_workdir, current_script_dir, 47 | self.decompiled_binary_code_path, self.failed_path, 48 | self.decompiler_workdir, self.binary_path, self.decompiler_workdir, 49 | self.binary_name)], shell=True, 50 | stdout=subprocess.DEVNULL, 51 | stderr=subprocess.DEVNULL) 52 | 53 | except Exception as e: 54 | l.error(f"Decompilation failed for {self.binary_name} :: {self.binary_type} {e}") 55 | return None 56 | return True, self.decompiled_binary_code_path -------------------------------------------------------------------------------- /varcorpus/dataset-gen/preprocess_vars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import logging 5 | import concurrent.futures 6 | from itertools import islice 7 | from functools import partial 8 | from multiprocessing import Manager 9 | 10 | l = logging.getLogger('main') 11 | 12 | manager = Manager() 13 | vars_ = manager.dict() 14 | clean_vars = manager.dict() 15 | 16 | def read_(filename): 17 | with open(filename, 'r') as r: 18 | data = json.loads(r.read()) 19 | return data 20 | 21 | def read_jsonl(filename): 22 | with open(filename, 'r') as r: 23 | data = r.readlines() 24 | return data 25 | 26 | def count_(data): 27 | tot = 0 28 | for name, c in data.items(): 29 | tot += int(c) 30 | return tot 31 | 32 | def change_case(str_): 33 | 34 | # aP 35 | if len(str_) == 2: 36 | var = str_.lower() 37 | 38 | # _aP # __A 39 | elif len(str_) == 3 and (str_[0] == '_' or str_[:2] == '__'): 40 | var = str_.lower() 41 | 42 | else: 43 | s1 = re.sub('([A-Z]+)([A-Z][a-z]+)', r'\1_\2', str_) 44 | #s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', str_) 45 | var = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 46 | return var 47 | 48 | 49 | def clean(data, lookup, start, batch_size: int): 50 | 51 | global vars_ 52 | global clean_vars 53 | end = start + batch_size 54 | 55 | # change case and lower 56 | for var in islice(data, start, end): 57 | if var in lookup: 58 | vars_[var] = lookup[var] 59 | up_var = lookup[var] 60 | else: 61 | up_var = change_case(var) 62 | # rm underscore 63 | up_var = up_var.strip('_') 64 | vars_[var] = up_var 65 | if up_var in clean_vars: 66 | clean_vars[up_var] += data[var] 67 | else: 68 | clean_vars[up_var] = data[var] 69 | 70 | def normalize_variables(filename, ty, outdir, lookup_dir, WORKERS): 71 | existing_lookup = False 72 | if os.path.exists(os.path.join(lookup_dir, 'universal_lookup.json')): 73 | with open(os.path.join(lookup_dir, 'universal_lookup.json'), 'r') as r: 74 | lookup = json.loads(r.read()) 75 | existing_lookup = True 76 | else: 77 | lookup = {} 78 | l.debug(f"len of lookup: {len(lookup)}") 79 | data = read_(filename) 80 | l.debug(f"cleaning variables for: {filename} | count: {len(data)}") 81 | 82 | batch = len(data) // (WORKERS - 1) 83 | with concurrent.futures.ProcessPoolExecutor(max_workers=WORKERS) as executor: 84 | cal_partial = partial(clean, data, lookup, batch_size=batch,) 85 | executor.map(cal_partial, [batch * i for i in range(WORKERS)]) 86 | l.debug(f"norm vars for file {filename} before: {len(vars_)} after: {len(set(vars_.values()))}") 87 | sorted_clean_vars = dict(sorted(clean_vars.items(), key=lambda item: item[1], reverse = True)) 88 | 89 | with open(os.path.join(outdir, f"{ty}_clean.json"), 'w') as w: 90 | w.write(json.dumps(dict(sorted_clean_vars))) 91 | 92 | if existing_lookup: 93 | for v in vars_: 94 | if v not in lookup: 95 | lookup[v] = vars_[v] 96 | 97 | with open(os.path.join(lookup_dir, 'universal_lookup.json'), 'w') as w: 98 | w.write(json.dumps(lookup)) 99 | else: 100 | with open(os.path.join(lookup_dir, 'universal_lookup.json'), 'w') as w: 101 | w.write(json.dumps(dict(vars_))) 102 | 103 | 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/binary.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | import hashlib 4 | from elftools.elf.elffile import ELFFile 5 | from elftools.dwarf.locationlists import ( 6 | LocationEntry, LocationExpr, LocationParser) 7 | from elftools.dwarf.descriptions import ( 8 | describe_DWARF_expr, _import_extra, describe_attr_value,set_global_machine_arch) 9 | from collections import defaultdict 10 | from utils import write_json 11 | import os 12 | import sys 13 | import shutil 14 | import logging 15 | from utils import subprocess_ 16 | from dwarf_info import Dwarf 17 | from strip_types import type_strip_target_binary 18 | 19 | l = logging.getLogger('main') 20 | 21 | class Binary: 22 | def __init__(self, target_binary, b_name, path_manager, language, decompiler, load_data=True): 23 | self.target_binary = target_binary # to binary_path 24 | self.binary_name = b_name 25 | self.hash = None 26 | self.dwarf_data = None 27 | self.strip_binary = None 28 | self.type_strip_binary = None 29 | self.path_manager = path_manager 30 | self.language = language 31 | self.decompiler = decompiler 32 | self.dwarf_dict = None 33 | 34 | if load_data: 35 | self.dwarf_dict = self.load() 36 | self.strip_binary, self.type_strip_binary = self.modify_dwarf() 37 | 38 | # read things from binary! 39 | def load(self): 40 | try: 41 | l.debug(f'Reading DWARF info from :: {self.target_binary}') 42 | if not self.is_elf_has_dwarf(): 43 | return None 44 | 45 | self.hash = self.md5_hash() 46 | self.dwarf_data = Dwarf(self.target_binary, self.binary_name, self.decompiler) 47 | dwarf_dict = self.dump_data() 48 | l.debug(f'Finished reading DWARF info from :: {self.target_binary}') 49 | return dwarf_dict 50 | except Exception as e: 51 | l.error(f"Error in reading DWARF info from :: {self.target_binary} :: {e}") 52 | 53 | def md5_hash(self): 54 | return subprocess.check_output(['md5sum', self.target_binary]).decode('utf-8').strip().split(' ')[0] 55 | 56 | def is_elf_has_dwarf(self): 57 | try: 58 | with open(self.target_binary, 'rb') as f: 59 | elffile = ELFFile(f) 60 | if not elffile.has_dwarf_info(): 61 | return None 62 | set_global_machine_arch(elffile.get_machine_arch()) 63 | dwarf_info = elffile.get_dwarf_info() 64 | return dwarf_info 65 | except Exception as e: 66 | l.error(f"Error in is_elf_has_dwarf :: {self.target_binary} :: {e}") 67 | 68 | def dump_data(self): 69 | dump = defaultdict(dict) 70 | dump['hash'] = self.hash 71 | dump['vars_per_func'] = self.dwarf_data.vars_in_each_func 72 | dump['linkage_name_to_func_name'] = self.dwarf_data.linkage_name_to_func_name 73 | dump['language'] = self.language 74 | write_json(os.path.join(self.path_manager.tmpdir, 'dwarf', self.binary_name), dict(dump)) 75 | return dump 76 | 77 | def _strip_binary(self, target_binary): 78 | 79 | res = subprocess_(['strip', '--strip-all', target_binary]) 80 | if isinstance(res, Exception): 81 | l.error(f"error in stripping binary! {res} :: {target_binary}") 82 | return None 83 | return target_binary 84 | 85 | def _type_strip_binary(self, in_binary, out_binary): 86 | 87 | type_strip_target_binary(in_binary, out_binary, self.decompiler) 88 | if not os.path.exists(out_binary): 89 | l.error(f"error in stripping binary! :: {in_binary}") 90 | return out_binary 91 | 92 | def modify_dwarf(self): 93 | 94 | shutil.copy(self.target_binary, self.path_manager.strip_bin_dir) 95 | type_strip = self._type_strip_binary(os.path.join(self.path_manager.strip_bin_dir, self.binary_name), os.path.join(self.path_manager.type_strip_bin_dir, self.binary_name)) 96 | strip = self._strip_binary(os.path.join(self.path_manager.strip_bin_dir, self.binary_name)) 97 | return strip, type_strip 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /varbert/tokenizer/train_bpe_tokenizer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tokenizers import ByteLevelBPETokenizer 3 | from transformers import RobertaTokenizerFast 4 | 5 | import argparse 6 | import os 7 | 8 | def main(agrs): 9 | paths = [str(x) for x in Path(args.input_path).glob("**/*.txt")] 10 | 11 | # Initialize a tokenizer 12 | tokenizer = ByteLevelBPETokenizer(add_prefix_space=True) 13 | 14 | # Customize training 15 | tokenizer.train(files=paths, 16 | vocab_size=args.vocab_size, 17 | min_frequency=args.min_frequency, 18 | special_tokens=[ "", 19 | "", 20 | "", 21 | "", 22 | "", 23 | ]) 24 | os.makedirs(args.output_path,exist_ok=True) 25 | tokenizer.save_model(args.output_path) 26 | 27 | 28 | # Test the tokenizer 29 | tokenizer = RobertaTokenizerFast.from_pretrained(args.output_path, max_len=1024) 30 | ghidra_func = "undefined * FUN_001048d0(void)\n{\n undefined8 uVar1;\n long lVar2;\n undefined *puVar3;\n \n uVar1 = PyState_FindModule(&readlinemodule);\n lVar2 = PyModule_GetState(uVar1);\n if (*(lVar2 + 0x18) == 0) {\n FUN_00103605(&_Py_NoneStruct);\n puVar3 = &_Py_NoneStruct;\n }\n else {\n uVar1 = PyState_FindModule(&readlinemodule);\n lVar2 = PyModule_GetState(uVar1);\n FUN_00103605(*(lVar2 + 0x18));\n uVar1 = PyState_FindModule(&readlinemodule);\n lVar2 = PyModule_GetState(uVar1);\n puVar3 = *(lVar2 + 0x18);\n }\n return puVar3;\n}" 31 | 32 | tokens = tokenizer.tokenize(ghidra_func) 33 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 34 | # print("Normal Tokens:",str(tokens)) 35 | # print("Normal Token Ids:",str(token_ids)) 36 | 37 | human_func = 'int fast_s_mp_sqr(const mp_int *a, mp_int *b)\n{\n int olduse, res, pa, ix, iz;\n mp_digit W[MP_WARRAY], *tmpx;\n mp_word W1;\n /* grow the destination as required */\n pa = a->used + a->used;\n if (b->alloc < pa) {\n if ((res = mp_grow(b, pa)) != 0 /* no error, all is well */) {\n return res;\n }\n }\n /* number of output digits to produce */\n W1 = 0;\n for (ix = 0; ix < pa; ix++) {\n int tx, ty, iy;\n mp_word _W;\n mp_digit *tmpy;\n /* clear counter */\n _W = 0;\n /* get offsets into the two bignums */\n ty = MIN(a->used-1, ix);\n tx = ix - ty;\n /* setup temp aliases */\n tmpx = a->dp + tx;\n tmpy = a->dp + ty;\n /* this is the number of times the loop will iterrate, essentially\n while (tx++ < a->used && ty-- >= 0) { ... }\n */\n iy = MIN(a->used-tx, ty+1);\n /* now for squaring tx can never equal ty\n * we halve the distance since they approach at a rate of 2x\n * and we have to round because odd cases need to be executed\n */\n iy = MIN(iy, ((ty-tx)+1)>>1);\n /* execute loop */\n for (iz = 0; iz < iy; iz++) {\n _W += (mp_word)*tmpx++ * (mp_word)*tmpy--;\n }\n /* double the inner product and add carry */\n _W = _W + _W + W1;\n /* even columns have the square term in them */\n if (((unsigned)ix & 1u) == 0u) {\n _W += (mp_word)a->dp[ix>>1] * (mp_word)a->dp[ix>>1];\n }\n /* store it */\n W[ix] = _W & ((((mp_digit)1)<<((mp_digit)MP_DIGIT_BIT))-((mp_digit)1));\n /* make next carry */\n W1 = _W >> (mp_word)(CHAR_BIT*sizeof(mp_digit));\n }\n /* setup dest */\n olduse = b->used;\n b->used = a->used+a->used;\n {\n mp_digit *tmpb;\n tmpb = b->dp;\n for (ix = 0; ix < pa; ix++) {\n *tmpb++ = W[ix] & ((((mp_digit)1)<<((mp_digit)MP_DIGIT_BIT))-((mp_digit)1));\n }\n /* clear unused digits [that existed in the old copy of c] */\n for (; ix < olduse; ix++) {\n *tmpb++ = 0;\n }\n }\n mp_clamp(b);\n return 0 /* no error, all is well */;\n}' 38 | 39 | tokens = tokenizer.tokenize(human_func) 40 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 41 | # print("Normal Tokens:",str(tokens)) 42 | # print("Normal Token Ids:",str(token_ids)) 43 | 44 | if __name__ == "__main__": 45 | 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--input_path', type=str, help='path to the input text files') 48 | parser.add_argument('--vocab_size', type=int, default=50265, help='size of tokenizer vocabulary') 49 | parser.add_argument('--min_frequency', type=int, default=2, help='minimum frequency') 50 | parser.add_argument('--output_path', type=str, help='path to the output text files') 51 | 52 | 53 | args = parser.parse_args() 54 | 55 | -------------------------------------------------------------------------------- /varbert/generate_vocab.py: -------------------------------------------------------------------------------- 1 | # Generate vocab from preprocessed sets with fid 2 | 3 | import argparse 4 | import os 5 | import json 6 | import jsonlines as jsonl 7 | import re 8 | from collections import defaultdict 9 | from tqdm import tqdm 10 | 11 | def load_jsonl_files(file_path): 12 | data = [] 13 | with jsonl.open(file_path) as ofd: 14 | for each in tqdm(ofd, desc=f"Loading data from {file_path}"): 15 | data.append(each) 16 | return data 17 | 18 | def read_json(file_path): 19 | with open(file_path, 'r') as f: 20 | data = json.load(f) 21 | return data 22 | 23 | def calculate_distribution(data, dataset_type): 24 | var_distrib = defaultdict(int) 25 | for each in tqdm(data): 26 | func = each['norm_func'] 27 | pattern = "@@\w+@@\w+@@" 28 | if dataset_type == 'varcorpus': 29 | dwarf_norm_type = each['type_stripped_norm_vars'] 30 | 31 | for each_var in list(re.finditer(pattern,func)): 32 | s = each_var.start() 33 | e = each_var.end() 34 | var = func[s:e] 35 | orig_var = var.split("@@")[-2] 36 | 37 | # Collect variables only dwarf 38 | if dataset_type == 'varcorpus': 39 | if orig_var in dwarf_norm_type: 40 | var_distrib[orig_var]+=1 41 | elif dataset_type == 'hsc': 42 | var_distrib[orig_var]+=1 43 | 44 | sorted_var_distrib = sorted(var_distrib.items(), key = lambda x : x[1], reverse=True) 45 | return sorted_var_distrib 46 | 47 | 48 | def build_vocab(data, vocab_size, existing_vocab=None): 49 | if existing_vocab: 50 | vocab_list = list(existing_vocab) 51 | else: 52 | vocab_list = [] 53 | for idx, each in tqdm(enumerate(data)): 54 | if len(vocab_list) == args.vocab_size: 55 | print("limit reached:", args.vocab_size, "Missed:",len(data)-idx-1) 56 | break 57 | if each[0] in vocab_list: 58 | continue 59 | else: 60 | vocab_list.append(each[0]) 61 | 62 | idx2word, word2idx = {}, {} 63 | for i,each in enumerate(vocab_list): 64 | idx2word[i] = each 65 | word2idx[each] = i 66 | 67 | return idx2word, word2idx 68 | 69 | def save_json(data, output_path, filename): 70 | with open(os.path.join(output_path, filename), 'w') as w: 71 | w.write(json.dumps(data)) 72 | 73 | 74 | def main(args): 75 | # Load existing human vocabulary if provided 76 | if args.existing_vocab: 77 | with open(args.existing_vocab, 'r') as f: 78 | human_vocab = json.load(f) 79 | else: 80 | human_vocab = None 81 | 82 | # Load train and test data 83 | train_data = load_jsonl_files(args.train_file) 84 | test_data = load_jsonl_files(args.test_file) 85 | 86 | # TODO add check to 87 | var_distrib_train = calculate_distribution(train_data, args.dataset_type) 88 | var_distrib_test = calculate_distribution(test_data, args.dataset_type) 89 | 90 | # save only if needed 91 | # save_json(var_distrib_train, args.output_dir, 'var_distrib_train.json') 92 | # save_json(var_distrib_test, args.output_dir, 'var_distrib_test.json') 93 | 94 | print("Train data distribution", len(var_distrib_train)) 95 | print("Test data distribution", len(var_distrib_test)) 96 | 97 | existing_vocab_data = {} 98 | if args.existing_vocab: 99 | print("Human vocab size", len(human_vocab)) 100 | existing_vocab_data = read_json(args.existing_vocab) 101 | 102 | # Build and save the vocabulary 103 | idx2word, word2idx = build_vocab(var_distrib_train, args.vocab_size, existing_vocab=existing_vocab_data) 104 | print("Vocabulary size", len(idx2word)) 105 | save_json(idx2word, args.output_dir, 'idx_to_word.json') 106 | save_json(word2idx, args.output_dir, 'word_to_idx.json') 107 | 108 | if __name__ == "__main__": 109 | 110 | parser = argparse.ArgumentParser(description="Dataset Vocabulary Generator") 111 | parser.add_argument("--dataset_type", type=str, choices=['hsc', 'varcorpus'], required=True, help="Create vocab for HSC (source code) or VarCorpus (decompiled code)") 112 | parser.add_argument("--train_file", type=str, required=True, help="Path to the training data file") 113 | parser.add_argument("--test_file", type=str, required=True, help="Path to the test data file") 114 | parser.add_argument("--existing_vocab", type=str, help="Path to the existing human vocabulary file") 115 | parser.add_argument("--vocab_size", type=int, default=50000, help="Limit for the vocabulary size") 116 | parser.add_argument("--output_dir", type=str, required=True, help="Path where the output vocabulary will be saved") 117 | 118 | args = parser.parse_args() 119 | main(args) -------------------------------------------------------------------------------- /varcorpus/dataset-gen/joern_parser.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import json 4 | import time 5 | import logging 6 | import subprocess 7 | from collections import defaultdict, OrderedDict 8 | from cpgqls_client import CPGQLSClient, import_code_query 9 | 10 | l = logging.getLogger('main') 11 | 12 | class JoernParser: 13 | def __init__(self, binary_name, binary_type, decompiler, 14 | decompiled_code, workdir, port, outpath ): 15 | self.binary_name = binary_name 16 | self.binary_type = binary_type 17 | self.decompiler = decompiler 18 | self.dc_inpath = decompiled_code 19 | self.port = port 20 | self.client = CPGQLSClient(f"localhost:{self.port}") 21 | self.joern_workdir = os.path.join(workdir, 'tmp_joern') 22 | self.joern_data = defaultdict(dict) 23 | self.joern_outpath = outpath 24 | self.parse_joern() 25 | 26 | 27 | def edit_dc(self): 28 | try: 29 | regex_ul = r'(\d)(uLL)' 30 | regex_l = r'(\d)(LL)' 31 | regex_u = r'(\d)(u)' 32 | 33 | with open(f'{self.dc_inpath}', 'r') as r: 34 | data = r.read() 35 | 36 | tmp_1 = re.sub(regex_ul, r'\g<1>', data) 37 | tmp_2 = re.sub(regex_l, r'\g<1>', tmp_1) 38 | final = re.sub(regex_u, r'\g<1>', tmp_2) 39 | 40 | with open(os.path.join(self.joern_workdir, f'{self.binary_name}.c'), 'w') as w: 41 | w.write(final) 42 | 43 | except Exception as e: 44 | l.error(f"Error in joern :: {self.decompiler} {self.binary_type} {self.binary_name} :: {e} ") 45 | 46 | 47 | def clean_up(self): 48 | self.client.execute(f'close("{self.binary_name}")') 49 | self.client.execute(f'delete("{self.binary_name}")') 50 | 51 | 52 | def split(self): 53 | try: 54 | l.debug(f"joern parsing! {self.decompiler} {self.binary_type} {self.binary_name}") 55 | regex = r"(\((\"([\s\S]*?)\"))((, )(\"([\s\S]*?)\")((, )(\d*)(, )(\d*)))((, )(\"([\s\S]*?)\")((, )(\d*)(, )(\d*)))" 56 | r = re.compile(regex) 57 | 58 | self.client.execute(import_code_query(self.joern_workdir, f'{self.binary_name}')) 59 | fetch_q = f'show(cpg.identifier.l.map(x => (x.location.filename, x.method.name, x.method.lineNumber.get, x.method.lineNumberEnd.get, x.name, x.lineNumber.get, x.columnNumber.get)).sortBy(_._7).sortBy(_._6).sortBy(_._1))' 60 | 61 | result = self.client.execute(fetch_q) 62 | res_stdout = result['stdout'] 63 | 64 | if '***temporary file:' in res_stdout: 65 | tmp_file = res_stdout.split(':')[-1].strip()[:-3] 66 | 67 | with open(tmp_file, 'r') as tf_read: 68 | res_stdout = tf_read.read() 69 | subprocess.check_output(['rm', '{}'.format(tmp_file)]) 70 | 71 | matches = r.finditer(res_stdout,re.MULTILINE) 72 | raw_data = defaultdict(dict) 73 | track_func = set() 74 | 75 | random = defaultdict(list) 76 | random_tmp = [] 77 | temp = '' 78 | test = set() 79 | main_dict = defaultdict(dict) 80 | tmp_file_name = '' 81 | for m in matches: 82 | 83 | file_path = m.group(3) 84 | func_name = m.group(7) 85 | func_start = m.group(10) 86 | func_end = m.group(12) 87 | var_name = m.group(16) 88 | var_line = m.group(19) 89 | var_col = m.group(21) 90 | 91 | if tmp_file_name != self.binary_name: 92 | tmp_file_name = self.binary_name 93 | random = defaultdict(list) 94 | raw_data = defaultdict(dict) 95 | random_tmp = [] 96 | 97 | pkg_func = func_name + '_' + func_start 98 | if pkg_func != temp: 99 | track_func = set() 100 | random = defaultdict(list) 101 | random_tmp = [] 102 | 103 | if var_name not in random_tmp: 104 | random_tmp.append(var_name) 105 | 106 | temp = pkg_func 107 | track_func.add(pkg_func) 108 | test.add(pkg_func) 109 | random[var_name].append(var_line) 110 | 111 | raw_data[pkg_func].update({"func_start":func_start, "func_end" : func_end}) 112 | raw_data[pkg_func].update({"variable": dict(random)}) 113 | raw_data[pkg_func].update({'tmp':random_tmp}) 114 | 115 | with open(self.joern_outpath, 'w') as w: 116 | w.write(json.dumps(raw_data)) 117 | self.joern_data = raw_data 118 | 119 | except Exception as e: 120 | l.error(f"Error in joern :: {self.decompiler} {self.binary_type} {self.binary_name} :: {e} ") 121 | 122 | def parse_joern(self): 123 | time.sleep(2) 124 | self.edit_dc() 125 | raw_data = self.split() 126 | self.clean_up() 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | import logging.config 5 | import concurrent.futures 6 | from functools import partial 7 | from multiprocessing import Manager 8 | import json 9 | 10 | from pathmanager import PathManager 11 | from runner import Runner 12 | 13 | 14 | def main(args): 15 | if args.decompiler == 'ida' and not args.ida_path: 16 | parser.error("IDA path is required when decompiling with IDA.") 17 | elif args.decompiler == 'ghidra' and not args.ghidra_path: 18 | parser.error("Ghidra path is required when decompiling with Ghidra.") 19 | 20 | if not args.corpus_language and not args.language_map: 21 | parser.error("Either --corpus_language or --language_map is required.") 22 | if args.corpus_language and args.language_map: 23 | parser.error("Provide only one of --corpus_language or --language_map, not both.") 24 | if args.language_map and not os.path.isfile(args.language_map): 25 | parser.error(f"Language map file not found: {args.language_map}") 26 | if args.corpus_language and args.corpus_language not in {"c", "cpp"}: 27 | parser.error("Invalid --corpus_language. Allowed values are: c, cpp") 28 | 29 | language_map = None 30 | if args.language_map: 31 | try: 32 | with open(args.language_map, 'r') as f: 33 | language_map = json.load(f) 34 | except Exception as e: 35 | parser.error(f"Failed to read language map JSON: {e}") 36 | 37 | path_manager = PathManager(args) 38 | target_binaries = [ 39 | os.path.abspath(path) 40 | for path in glob.glob(os.path.join(path_manager.binaries_dir, '**', '*'), recursive=True) 41 | if os.path.isfile(path) 42 | ] 43 | # Validate language_map if provided 44 | if language_map is not None: 45 | # keys are expected to be binary basenames; values 'c' or 'cpp' 46 | valid_langs = {'c', 'cpp'} 47 | invalid = {k: v for k, v in language_map.items() if v not in valid_langs} 48 | if invalid: 49 | parser.error(f"Invalid languages in language_map (allowed: c, cpp): {invalid}") 50 | missing = [os.path.basename(p) for p in target_binaries if os.path.basename(p) not in language_map] 51 | if missing: 52 | parser.error(f"language_map missing entries for binaries: {missing}") 53 | decompiler = args.decompiler 54 | if not args.corpus_language: 55 | pass 56 | #TODO: detect 57 | 58 | splits = True if args.splits else False 59 | l.info(f"Decompiling {len(target_binaries)} binaries with {decompiler}") 60 | 61 | default_language = None if language_map is not None else args.corpus_language 62 | effective_language_map = language_map if language_map else {} 63 | 64 | runner = Runner( 65 | decompiler=args.decompiler, 66 | target_binaries=target_binaries, 67 | WORKERS=args.WORKERS, 68 | path_manager=path_manager, 69 | PORT=8090, 70 | language=default_language, 71 | DEBUG=args.DEBUG, 72 | language_map=effective_language_map 73 | ) 74 | 75 | runner.run(PARSE=True, splits=splits) 76 | 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser(description="Data set generation for VarBERT") 80 | 81 | parser.add_argument( 82 | "-b", "--binaries_dir", 83 | type=str, 84 | help="Path to binaries dir", 85 | required=True, 86 | ) 87 | 88 | parser.add_argument( 89 | "-d", "--data_dir", 90 | type=str, 91 | help="Path to data dir", 92 | required=True, 93 | ) 94 | 95 | parser.add_argument( 96 | "--tmpdir", 97 | type=str, 98 | help="Path to save intermediate files. Default is /tmp", 99 | required=False, 100 | default="/tmp/varbert_tmpdir" 101 | ) 102 | 103 | parser.add_argument( 104 | "--decompiler", 105 | choices=['ida', 'ghidra'], 106 | type=str, 107 | help="choose decompiler IDA or Ghidra", 108 | required=True 109 | ) 110 | 111 | parser.add_argument( 112 | "-ida", "--ida_path", 113 | type=str, 114 | help="Path to IDA", 115 | required=False, 116 | ) 117 | 118 | parser.add_argument( 119 | "-ghidra", "--ghidra_path", 120 | type=str, 121 | help="Path to Ghidra", 122 | required=False, 123 | ) 124 | 125 | parser.add_argument( 126 | "-joern", "--joern_dir", 127 | type=str, 128 | help="Path to Joern", 129 | required=False, 130 | ) 131 | 132 | # TODO: if language not given, maybe detect it 133 | parser.add_argument( 134 | "-lang", "--corpus_language", 135 | type=str, 136 | help="Corpus language", 137 | required=False, 138 | ) 139 | 140 | parser.add_argument( 141 | "--language_map", 142 | type=str, 143 | help="Path to JSON mapping of binary name to language, e.g. {\"bin1\": \"c\", \"bin2\": \"cpp\"}", 144 | required=False, 145 | ) 146 | 147 | 148 | parser.add_argument( 149 | "-w", "--WORKERS", 150 | type=int, 151 | help="Number of workers", 152 | default=2, 153 | required=False 154 | ) 155 | 156 | parser.add_argument( 157 | "--DEBUG", 158 | help="Turn on debug logging mode", 159 | action='store_true', 160 | required=False 161 | ) 162 | parser.add_argument( 163 | "--splits", 164 | help="Create test and train split", 165 | action='store_true', 166 | required=False 167 | ) 168 | args = parser.parse_args() 169 | from log import setup_logging 170 | setup_logging(args.tmpdir, args.DEBUG) 171 | l = logging.getLogger('main') 172 | main(args) 173 | 174 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VarBERT 2 | VarBERT is a BERT-based model which predicts meaningful variable names and variable origins in decompiled code. Leveraging the power of transfer learning, VarBERT can help you in software reverse engineering tasks. VarBERT is pre-trained on 5M human-written source code functions, and then it is fine-tuned on decompiled code from IDA and Ghidra, spanning four compiler optimizations (*O0*, *O1*, *O2*, *O3*). 3 | We built two data sets: (a) Human Source Code data set (HSC) and (b) VarCorpus (for IDA and Ghidra). 4 | This work is developed for IEEE S&P 2024 paper ["Len or index or count, anything but v1": Predicting Variable Names in Decompilation Output with Transfer Learning](https://www.atipriya.com/files/papers/varbert_oakland24.pdf) 5 | 6 | Key Features 7 | 8 | - Pre-trained on 5.2M human-written source code functions. 9 | - Fine-tuned on decompiled code from IDA and Ghidra. 10 | - Supports four compiler optimizations: O0, O1, O2, O3. 11 | - Achieves an accuracy of 54.43% for IDA and 54.49% for Ghidra on O2 optimized binaries. 12 | - A total of 16 models are available, covering two decompilers, four optimizations, and two splitting strategies. 13 | 14 | ### Table of Contents 15 | - [Overview](#overview) 16 | - [VarBERT Model](#varbert-model) 17 | - [Using VarBERT](#use-varbert) 18 | - [Training and Inference](#training-and-inference) 19 | - [Data sets](#data-sets) 20 | - [Installation Instructions](#installation) 21 | - [Cite](#citing) 22 | 23 | ### Overview 24 | This repository contains details on generating a new dataset, and training and running inference on existing VarBERT models from the paper. To use VarBERT models in your day-to-day reverse engineering tasks, please refer to [Use VarBERT](#use-varbert). 25 | 26 | 27 | ### VarBERT Model 28 | We take inspiration for VARBERT from the concepts of transfer learning generally and specifically Bidirectional Encoder Representations from Transformers (BERT). 29 | 30 | - **Pre-training**: VarBERT is pre-trained on HSC functions using Masked Language Modeling (MLM) and Constrained Masked Language Modeling (CMLM). 31 | - **Fine-tuning**: VarBERT is then further fine-tuned on top of the previously pre-trained model using VarCorpus (decompilation output of IDA and Ghidra). It can be further extended to any other decompiler capable of generating C-Style decompilation output. 32 | 33 | ### Use VarBERT 34 | - The VarBERT API is a Python library to access and use the latest models. It can be used in three ways: 35 | 1. From the CLI, directly on decompiled text (without an attached decompiler). 36 | 2. As a scripting library. 37 | 3. As a decompiler plugin with [DAILA](https://github.com/mahaloz/DAILA) for enhanced decompiling experience. 38 | 39 | For a step-by-step guide and a demo on how to get started with the VarBERT API, please visit [VarBERT API](https://github.com/binsync/varbert_api/tree/main). 40 | 41 | ### Training and Inference 42 | For training a new model or running inference on existing models, see our detailed guide at [Training VarBERT](./varbert/README.md) 43 | 44 | Models available for download: 45 | - [Pre-trained models](https://www.dropbox.com/scl/fo/anibfmk6j8xkzi4nqk55f/h?rlkey=fw6ops1q3pqvsbdy5tl00brpw&dl=0) 46 | - [Fine-tuned models](https://www.dropbox.com/scl/fo/socl7rd5lsv926whylqpn/h?rlkey=i0x74bdipj41hys5rorflxawo&dl=0) 47 | 48 | (A [README](https://www.dropbox.com/scl/fi/13s9z5z08u245jqdgfsdc/readme.md?rlkey=yjo33al04j1d5jrwc5pz2hhpz&dl=0) containing all the necessary links for the model is also available.) 49 | 50 | ### Data sets 51 | - **HSC**: Collected from C source files from the Debian APT repository, totaling 5.2M functions. 52 | 53 | - **VarCorpus**: Decompiled functions from C and C++ binaries, built from Gentoo package repository for four compiler optimizations: O0, O1, O2, and O3. 54 | 55 | Additionally, we have two splits: (a) Function Split (b) Binary Split. 56 | - Function Split: Functions are randomly distributed between the test and train sets. 57 | - Binary Split: All functions from a single binary are exclusively present in either the test set or the train set. 58 | To create a new data, follow detailed instuctions at [Building VarCorpus](./varcorpus/README.md) 59 | 60 | Data sets available at: 61 | - [HSC](https://www.dropbox.com/scl/fo/4cu2fmuh10c4wp7xt53tu/h?rlkey=mlsnkyed35m4rl512ipuocwtt&dl=0) 62 | - [VarCorpus](https://www.dropbox.com/scl/fo/3thmg8xoq2ugtjwjcgjsm/h?rlkey=azgjeq513g4semc1qdi5xyroj&dl=0) 63 | 64 | 65 | The fine-tuned models and their corresponding datasets are named `IDA-O0-Function` and `IDA-O0`, respectively. This naming convention indicates that the models and data set are based on functions decompiled from O0 binaries using the IDA decompiler. 66 | 67 | > [!NOTE] 68 | > Our existing data sets have been generated using IDA Pro 7.6 and Ghidra 10.4. 69 | 70 | ### Gentoo Binaries ### 71 | You can access the Gentoo binaries used to create VarCorpus here: https://www.dropbox.com/scl/fo/awtitjnc48k224373vcrx/h?rlkey=muj6t1watc6vn2ds6du7egoha&e=1&st=eicpyqln&dl=0 72 | 73 | ### Installation 74 | Prerequisites for training model or generating data set 75 | 76 | Linux with Python 3.8 or higher 77 | torch ≥ 1.9.0 78 | transformers ≥ 4.10.0 79 | 80 | #### Docker 81 | 82 | ``` 83 | docker build -t . varbert 84 | ``` 85 | 86 | #### Without Docker 87 | ```bash 88 | pip install -r requirements.txt 89 | 90 | # joern requires Java 11 91 | sudo apt-get install openjdk-11-jdk 92 | 93 | # Ghidra 10.4 requires Java 17+ 94 | sudo apt-get install openjdk-17-jdk 95 | 96 | git clone git@github.com:rhelmot/dwarfwrite.git 97 | cd dwarfwrite 98 | pip install . 99 | ``` 100 | Note: Ensure you install the correct Java version required by your specific Ghidra version. 101 | 102 | 103 | 104 | ### Citing 105 | 106 | Please cite our paper if you use this in your research: 107 | 108 | ``` 109 | @inproceedings{pal2024len, 110 | title={" Len or index or count, anything but v1": Predicting Variable Names in Decompilation Output with Transfer Learning}, 111 | author={Pal, Kuntal Kumar and Bajaj, Ati Priya and Banerjee, Pratyay and Dutcher, Audrey and Nakamura, Mutsumi and Basque, Zion Leonahenahe and Gupta, Himanshu and Sawant, Saurabh Arjun and Anantheswaran, Ujjwala and Shoshitaishvili, Yan and others}, 112 | booktitle={2024 IEEE Symposium on Security and Privacy (SP)}, 113 | pages={4069--4087}, 114 | year={2024}, 115 | organization={IEEE} 116 | } 117 | ``` 118 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import psutil 5 | import logging 6 | import subprocess 7 | from typing import List 8 | from pathlib import Path 9 | from concurrent.futures import process 10 | from elftools.elf.elffile import ELFFile 11 | 12 | # log = get_logger(sys._getframe().f_code.co_filename) 13 | 14 | def read_text(file_name): 15 | with open(file_name, 'r') as r: 16 | data = r.read().strip() 17 | return data 18 | 19 | def read_lines(file_name): 20 | with open(file_name, 'r') as r: 21 | data = r.read().split() 22 | return data 23 | 24 | def read_json(file_name): 25 | with open(file_name, 'r') as r: 26 | data = json.loads(r.read()) 27 | return data 28 | 29 | def write_json(file_name, data): 30 | with open(f'{file_name}.json', 'w') as w: 31 | w.write(json.dumps(data)) 32 | 33 | def subprocess_(command, timeout=None, shell=False): 34 | try: 35 | subprocess.check_call(command, timeout=timeout, shell=shell) 36 | except Exception as e: 37 | l.error(f"error in binary! :: {command} :: {e}") 38 | return e 39 | 40 | def is_elf(binary_path: str): 41 | ''' 42 | if elf return elffile obj 43 | ''' 44 | with open(binary_path, 'rb') as rb: 45 | bytes = rb.read(4) 46 | if bytes == b"\x7fELF": 47 | rb.seek(0) 48 | elffile = ELFFile(rb) 49 | if elffile and elffile.has_dwarf_info(): 50 | dwarf_info = elffile.get_dwarf_info() 51 | 52 | return elffile, dwarf_info 53 | 54 | def create_dirs(dir_names, tmpdir, ty): 55 | """ 56 | Create directories within the specified data directory. 57 | 58 | :param dir_names: List of directory names to create. 59 | :param tmpdir: Base data directory where directories will be created. 60 | :param ty: Subdirectory type, such as 'strip' or 'type-strip'. If None, only base directories are created. 61 | """ 62 | # Path(os.path.join(tmpdir, 'dwarf')).mkdir(parents=True, exist_ok=True) 63 | for name in dir_names: 64 | if ty: 65 | target_path = os.path.join(tmpdir, name, ty) 66 | else: 67 | target_path = os.path.join(tmpdir, name) 68 | Path(target_path).mkdir(parents=True, exist_ok=True) 69 | 70 | 71 | def set_up_data_dir(tmpdir, workdir, decompiler): 72 | """ 73 | Set up the data directory with required subdirectories for the given decompiler. 74 | 75 | :param tmpdir: The base data directory to set up. 76 | :param workdir: The working directory where some data directories will be created. 77 | :param decompiler: Name of the decompiler to create specific subdirectories. 78 | """ 79 | 80 | base_dirs = ['binary', 'failed'] 81 | dc_joern_dirs = ['dc', 'joern'] 82 | map_dirs = [f"map/{decompiler}", f"dc/{decompiler}/type_strip-addrs", f"dc/{decompiler}/type_strip-names"] 83 | workdir_dirs = [f'{decompiler}_data', 'tmp_joern'] 84 | 85 | # Create directories for 'strip' and 'type-strip' 86 | for ty in ['strip', 'type_strip']: 87 | create_dirs([f'{d}/{decompiler}' for d in dc_joern_dirs] + base_dirs, tmpdir, ty) 88 | 89 | # Create additional directories 90 | create_dirs(map_dirs, tmpdir, None) 91 | 92 | # Create workdir directories 93 | create_dirs(workdir_dirs, workdir, None) 94 | 95 | # copy binary 96 | create_dirs(['dwarf'], tmpdir, None) 97 | 98 | 99 | ### JOERN 100 | 101 | 102 | l = logging.getLogger('main') 103 | 104 | class JoernServer: 105 | def __init__(self, joern_path, port): 106 | self.joern_path = joern_path 107 | self.port = port 108 | self.process = None 109 | self.java_process = None 110 | 111 | def start(self): 112 | if self.is_server_running(): 113 | l.error(f"Joern server already running on port {self.port}") 114 | self.stop() 115 | return 116 | 117 | # hack: joern can't find the shell script fuzzyc2cpg.sh (CPG generator via the shell in this version) 118 | current_dir = os.getcwd() 119 | try: 120 | os.chdir(self.joern_path) 121 | joern_cmd = [os.path.join(self.joern_path, 'joern'), '--server', '--server-port', str(self.port)] 122 | self.process = subprocess.Popen(joern_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 123 | 124 | for _ in range(20): 125 | self.java_process = self.is_server_running() 126 | if self.java_process: 127 | l.debug(f"Joern server started on port {self.port}") 128 | return 129 | time.sleep(4) 130 | l.debug(f"Retrying to start Joern server on port {self.port}") 131 | except Exception as e: 132 | l.error(f"Failed to start Joern server on port {self.port} :: {e}") 133 | return 134 | 135 | finally: 136 | os.chdir(current_dir) 137 | 138 | def stop(self): 139 | if self.java_process is not None: 140 | self.java_process.kill() 141 | time.sleep(3) 142 | if self.is_server_running(): 143 | l.warning("Joern server did not terminate gracefully, forcing termination") 144 | self.java_process.kill() 145 | 146 | def is_server_running(self): 147 | try: 148 | for proc in psutil.process_iter(['pid', 'name', 'cmdline']): 149 | try: 150 | if proc.info['name'] == 'java' and 'io.shiftleft.joern.console.AmmoniteBridge' in proc.info['cmdline'] and '--server-port' in proc.info['cmdline'] and str(self.port) in proc.info['cmdline']: 151 | return psutil.Process(proc.info['pid']) 152 | except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): 153 | continue 154 | except Exception as e: 155 | l.error(f"Error while checking if server is running: {e}") 156 | 157 | return False 158 | 159 | def __enter__(self): 160 | self.start() 161 | return self 162 | 163 | def __exit__(self, exc_type, exc_value, traceback): 164 | self.stop() 165 | 166 | def exit(self): 167 | self.stop() 168 | 169 | def restart(self): 170 | self.server.stop() 171 | time.sleep(10) 172 | self.server.start() -------------------------------------------------------------------------------- /varbert/resize_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | from transformers import ( 3 | WEIGHTS_NAME, 4 | AdamW, 5 | BertConfig, 6 | BertForMaskedLM, 7 | BertTokenizer, 8 | CamembertConfig, 9 | CamembertForMaskedLM, 10 | CamembertTokenizer, 11 | DistilBertConfig, 12 | DistilBertForMaskedLM, 13 | DistilBertTokenizer, 14 | GPT2Config, 15 | GPT2LMHeadModel, 16 | GPT2Tokenizer, 17 | OpenAIGPTConfig, 18 | OpenAIGPTLMHeadModel, 19 | OpenAIGPTTokenizer, 20 | PreTrainedModel, 21 | PreTrainedTokenizer, 22 | RobertaConfig, 23 | RobertaForMaskedLM, 24 | RobertaTokenizer, 25 | RobertaTokenizerFast, 26 | get_linear_schedule_with_warmup, 27 | ) 28 | import argparse 29 | import os 30 | import shutil 31 | from typing import Dict, List, Tuple 32 | 33 | import numpy as np 34 | import torch 35 | import torch.nn as nn 36 | from torch.nn import CrossEntropyLoss, MSELoss 37 | from torch.nn.utils.rnn import pad_sequence 38 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 39 | from torch.utils.data.distributed import DistributedSampler 40 | from transformers.activations import ACT2FN, gelu 41 | 42 | vocab_size = 50001 43 | class RobertaLMHead2(nn.Module): 44 | 45 | def __init__(self,config): 46 | super().__init__() 47 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 48 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 49 | self.decoder = nn.Linear(config.hidden_size, vocab_size, bias=False) 50 | self.bias = nn.Parameter(torch.zeros(vocab_size)) 51 | 52 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 53 | self.decoder.bias = self.bias 54 | 55 | def forward(self, features, **kwargs): 56 | x = self.dense(features) 57 | x = gelu(x) 58 | x = self.layer_norm(x) 59 | # project back to size of vocabulary with bias 60 | x = self.decoder(x) 61 | return x 62 | 63 | class RobertaForMaskedLMv2(RobertaForMaskedLM): 64 | 65 | def __init__(self, config): 66 | super().__init__(config) 67 | self.lm_head2 = RobertaLMHead2(config) 68 | self.init_weights() 69 | 70 | def forward( 71 | self, 72 | input_ids=None, 73 | attention_mask=None, 74 | token_type_ids=None, 75 | position_ids=None, 76 | head_mask=None, 77 | inputs_embeds=None, 78 | encoder_hidden_states=None, 79 | encoder_attention_mask=None, 80 | labels=None, 81 | output_attentions=None, 82 | output_hidden_states=None, 83 | return_dict=None, 84 | type_label=None 85 | ): 86 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 87 | 88 | 89 | outputs = self.roberta( 90 | input_ids, 91 | attention_mask=attention_mask, 92 | token_type_ids=token_type_ids, 93 | position_ids=position_ids, 94 | head_mask=head_mask, 95 | inputs_embeds=inputs_embeds, 96 | encoder_hidden_states=encoder_hidden_states, 97 | encoder_attention_mask=encoder_attention_mask, 98 | output_attentions=output_attentions, 99 | output_hidden_states=output_hidden_states, 100 | return_dict=return_dict, 101 | ) 102 | sequence_output = outputs[0] 103 | prediction_scores = self.lm_head2(sequence_output) 104 | 105 | masked_lm_loss = None 106 | if labels is not None: 107 | loss_fct = CrossEntropyLoss() 108 | masked_lm_loss = loss_fct(prediction_scores.view(-1, vocab_size), labels.view(-1)) 109 | 110 | masked_loss = masked_lm_loss 111 | 112 | output = (prediction_scores,) + outputs[2:] 113 | return ((masked_loss,) + output) if masked_loss is not None else output 114 | 115 | 116 | def get_resized_model(oldmodel, newvocab_len): 117 | 118 | ## Change the bias dimensions 119 | def _get_resized_bias(old_bias, new_size): 120 | old_num_tokens = old_bias.data.size()[0] 121 | if old_num_tokens == new_size: 122 | return old_bias 123 | 124 | # Create new biases 125 | new_bias = nn.Parameter(torch.zeros(new_size)) 126 | new_bias.to(old_bias.device) 127 | 128 | # Copy from the previous weights 129 | num_tokens_to_copy = min(old_num_tokens, new_size) 130 | new_bias.data[:num_tokens_to_copy] = old_bias.data[:num_tokens_to_copy] 131 | return new_bias 132 | 133 | ## Change the decoder dimensions 134 | 135 | cls_layer_oldmodel = oldmodel.lm_head2.decoder 136 | oldvocab_len, old_embedding_dim = cls_layer_oldmodel.weight.size() 137 | print(f"Old Vocab Size: {oldvocab_len} \t Old Embedding Dim: {old_embedding_dim}") 138 | if oldvocab_len == newvocab_len: 139 | return oldmodel 140 | 141 | # Create new weights 142 | cls_layer_newmodel = nn.Linear(in_features=old_embedding_dim, out_features=newvocab_len) 143 | cls_layer_newmodel.to(cls_layer_oldmodel.weight.device) 144 | 145 | # initialize all weights (in particular added tokens) 146 | oldmodel._init_weights(cls_layer_newmodel) 147 | 148 | # Copy from the previous weights 149 | num_tokens_to_copy = min(oldvocab_len, newvocab_len) 150 | cls_layer_newmodel.weight.data[:num_tokens_to_copy, :] = cls_layer_oldmodel.weight.data[:num_tokens_to_copy, :] 151 | oldmodel.lm_head2.decoder = cls_layer_newmodel 152 | # Change the bias 153 | old_bias = oldmodel.lm_head2.bias 154 | oldmodel.lm_head2.bias = _get_resized_bias(old_bias, newvocab_len) 155 | 156 | return oldmodel 157 | 158 | 159 | 160 | def main(args): 161 | 162 | model = RobertaForMaskedLMv2.from_pretrained(args.old_model) 163 | vocab = json.load(open(args.vocab_path)) 164 | print(f"New Vocab Size : {len(vocab)}") 165 | newmodel = get_resized_model(model, len(vocab)+1) 166 | print(f"Final Cls Layer of New model : {newmodel.lm_head2.decoder}") 167 | newmodel.save_pretrained(args.out_model_path) 168 | 169 | # Move the other files 170 | files = ['tokenizer_config.json','vocab.json','training_args.bin','special_tokens_map.json','merges.txt'] 171 | for file in files: 172 | shutil.copyfile(os.path.join(args.old_model,file), os.path.join(args.out_model_path,file)) 173 | 174 | 175 | if __name__ == "__main__": 176 | 177 | parser = argparse.ArgumentParser() 178 | parser.add_argument('--old_model', type=str, help='name of the train file') 179 | parser.add_argument('--vocab_path', type=str, help='path to the out vocab, the size of output layer you need') 180 | parser.add_argument('--out_model_path', type=str, help='path to the new modified model') 181 | args = parser.parse_args() 182 | 183 | main(args) -------------------------------------------------------------------------------- /varcorpus/README.md: -------------------------------------------------------------------------------- 1 | ### Building VarCorpus 2 | 3 | To build VarCorpus, we collected C and C++ packages from Gentoo and built them across four compiler optimizations (O0, O1, O2 and O3) with debugging information (-g enabled), using [Bintoo](https://github.com/sefcom/bintoo). 4 | 5 | The script reads binaries from a directory, generates type-stripped and stripped binaries, decompiles and parses them using Joern to match variables, and saves deduplicated training and testing sets for both function and binary split in a data directory. This pipeline supports two decompilers, IDA and Ghidra; it can be extended to any decompiler that generates C-style decompilation output. 6 | 7 | To generate the dataset, choose one of the following language selection modes. 8 | 9 | #### Single-language mode (Either C or CPP) 10 | 11 | ```bash 12 | python3 generate.py \ 13 | -b \ 14 | -d \ 15 | --decompiler \ 16 | -lang \ 17 | -ida or -ghidra \ 18 | -w \ 19 | -joern \ 20 | --splits 21 | ``` 22 | 23 | #### Mixed-language mode (Both C and CPP) 24 | 25 | Provide a JSON mapping of binary basenames to languages and omit `-lang`: 26 | 27 | ```json 28 | { 29 | "bin_one": "c", 30 | "bin_two": "cpp" 31 | } 32 | ``` 33 | 34 | ```bash 35 | python3 generate.py \ 36 | -b \ 37 | -d \ 38 | --decompiler \ 39 | -ida or -ghidra \ 40 | --language_map \ 41 | -w \ 42 | -joern \ 43 | --splits 44 | ``` 45 | 46 | Notes: 47 | - Exactly one of `-lang/--corpus_language` or `--language_map` must be provided (mutually exclusive). 48 | - Allowed languages are `c` or `cpp` (lowercase). 49 | 50 | The current implementation relies on a specific modified version of [Joern](https://github.com/joernio/joern). We made this modification to expedite the data set creation process. Please download the compatible Joern version from [Joern](https://www.dropbox.com/scl/fi/toh6087y5t5xyln47i5ih/modified_joern.tar.gz?rlkey=lfvjn1u7zvtp9a4cu8z8vgsof&dl=0) and save it. 51 | 52 | ``` 53 | wget -O joern.tar.gz https://www.dropbox.com/scl/fi/toh6087y5t5xyln47i5ih/modified_joern.tar.gz?rlkey=lfvjn1u7zvtp9a4cu8z8vgsof&dl=0 54 | tar xf joern.tar.gz 55 | ``` 56 | 57 | For `-joern`, provide the path to directory with joern executable from the downloaded version. 58 | ``` 59 | -joern /joern/ 60 | ``` 61 | 62 | ### Use Docker 63 | 64 | You can skip the setup and use our Dockerfile directly to build data set. 65 | 66 | Have your binaries directory and a output directory ready on your host machine. These will be mounted into the container so you can easily provide binaries and retrieve train and test sets. 67 | 68 | #### For Ghidra: 69 | 70 | ``` 71 | docker build -t varbert -f ../../Dockerfile .. 72 | 73 | docker run -it \ 74 | -v $PWD/:/varbert_workdir/data/binaries \ 75 | -v $PWD/:/varbert_workdir/data/sets \ 76 | varbert \ 77 | python3 /varbert_workdir/VarBERT/varcorpus/dataset-gen/generate.py \ 78 | -b /varbert_workdir/data/binaries \ 79 | -d /varbert_workdir/data/sets \ 80 | --decompiler ghidra \ 81 | -lang \ 82 | -ghidra /varbert_workdir/ghidra_10.4_PUBLIC/support/analyzeHeadless \ 83 | -w \ 84 | -joern /varbert_workdir/joern \ 85 | --splits 86 | ``` 87 | 88 | To enable debug mode add `--DEBUG` arg and mount a tmp directory from host to see intermediate files: 89 | 90 | ``` 91 | docker run -it \ 92 | -v $PWD/:/varbert_workdir/data/binaries \ 93 | -v $PWD/:/varbert_workdir/data/sets \ 94 | -v $PWD/:/tmp/varbert_tmpdir \ 95 | varbert \ 96 | python3 /varbert_workdir/VarBERT/varcorpus/dataset-gen/generate.py \ 97 | -b /varbert_workdir/data/binaries \ 98 | -d /varbert_workdir/data/sets \ 99 | --decompiler ghidra \ 100 | -lang \ 101 | -ghidra /varbert_workdir/ghidra_10.4_PUBLIC/support/analyzeHeadless \ 102 | -w \ 103 | --DEBUG \ 104 | -joern /varbert_workdir/joern \ 105 | --splits 106 | ``` 107 | 108 | What this does: 109 | 110 | - `-v $PWD/:/varbert_workdir/data/binaries`: Mounts your local binaries directory into the container. 111 | - `-v $PWD/:/varbert_workdir/data/sets`: Mounts the directory where you want to save the generated train/test sets. 112 | 113 | Inside the container, your binaries are accessible at `/varbert_workdir/data/binaries`, resulting data sets will be saved to `/varbert_workdir/data/sets` and intermediate files are available at `/tmp/varbert_tmpdir`. 114 | 115 | 116 | 117 | #### For IDA: 118 | 119 | Please update Dockerfile to include your IDA and run. 120 | 121 | ``` 122 | docker run -it \ 123 | -v $PWD/:/varbert_workdir/data/binaries \ 124 | -v $PWD/:/varbert_workdir/data/sets \ 125 | varbert \ 126 | python3 /varbert_workdir/VarBERT/varcorpus/dataset-gen/generate.py \ 127 | -b /varbert_workdir/data/binaries \ 128 | -d /varbert_workdir/data/sets \ 129 | --decompiler ida \ 130 | -lang \ 131 | -ida \ 132 | -w \ 133 | -joern /varbert_workdir/joern \ 134 | --splits 135 | ``` 136 | 137 | 138 | #### Notes: 139 | 140 | - The train and test sets are split in an 80:20 ratio. If there aren't enough functions (or binaries) to meet this ratio, you may end up with no train or test sets after the run. 141 | - We built the dataset using **Ghidra 10.4**. If you wish to use a different version of Ghidra, please update the Ghidra download link in the Dockerfile accordingly. 142 | - In some cases there is a license popup which should be accepted before you can successfully run IDA in docker. 143 | - Disable type casts for more efficient variable matching. (we disabled it while building VarCorpus). 144 | 145 | ### Debug and temporary directories 146 | 147 | - `--DEBUG`: preserves the temporary working directory so you can inspect intermediates and logs. 148 | - `--tmpdir `: sets a custom working directory (default: `/tmp/varbert_tmpdir`). Important subdirectories during a run: 149 | - `binary/strip`, `binary/type_strip`: modified binaries 150 | - `dc//`: decompiled code 151 | - `dwarf`: DWARF info per binary 152 | - `joern//`: Joern JSON 153 | - `map//`: variables mapped decompiled functions 154 | - `splits`: train and test sets 155 | 156 | 157 | Sample Function: 158 | ```json 159 | { 160 | "id": 5, 161 | "language": "C", 162 | "func": "__int64 __fastcall sub_13DF(_QWORD *@@var_1@@a1@@, _QWORD *@@var_0@@Ancestors@@)\n{\n while ( @@var_0@@Ancestors@@ )\n {\n if ( @@var_0@@Ancestors@@[1] == @@var_1@@a1@@[1] && @@var_0@@Ancestors@@[2] == *@@var_1@@a1@@ )\n return 1LL;\n @@var_0@@Ancestors@@ = *@@var_0@@Ancestors@@;\n }\n return 0LL;\n}", 163 | "type_stripped_vars": {"Ancestors": "dwarf", "a1": "ida"}, 164 | "stripped_vars": ["a2", "a1"], 165 | "mapped_vars": {"a2": "Ancestors", "a1": "a1"}, 166 | "func_name_dwarf": "is_ancestor", 167 | "hash": "2998f4a10a8f052257122c23897d10b7", 168 | "func_name": "8140277b36ef8461df62b160fed946cb_(00000000000013DF)", 169 | "norm_func": "__int64 __fastcall sub_13DF(_QWORD *@@var_1@@a1@@, _QWORD *@@var_0@@ancestors@@)\n{\n while ( @@var_0@@ancestors@@ )\n {\n if ( @@var_0@@ancestors@@[1] == @@var_1@@a1@@[1] && @@var_0@@ancestors@@[2] == *@@var_1@@a1@@ )\n return 1LL;\n @@var_0@@ancestors@@ = *@@var_0@@ancestors@@;\n }\n return 0LL;\n}", 170 | "vars_map": [["Ancestors", "ancestors"]], 171 | "fid": "5-8140277b36ef8461df62b160fed946cb-(00000000000013DF)" 172 | } 173 | ``` 174 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/strip_types.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | import string 4 | import logging 5 | from elftools.elf.elffile import ELFFile 6 | from elftools.dwarf import constants 7 | import dwarfwrite 8 | from dwarfwrite.restructure import ReStructurer, VOID 9 | from cle.backends.elf.elf import ELF 10 | 11 | l = logging.getLogger('main') 12 | DWARF_VERSION = 4 13 | VOID_STAR_ID = 127 14 | 15 | class TypeStubRewriter(ReStructurer): 16 | def __init__(self, fp, cu_data=None): 17 | super().__init__(fp) 18 | fp.seek(0) 19 | self.rng = random.Random(fp.read()) 20 | self.cu_data = cu_data 21 | 22 | def unit_get_functions(self, unit): 23 | result = list(super().unit_get_functions(unit)) 24 | 25 | for die in unit.iter_DIEs(): 26 | if die.tag not in ('DW_TAG_class_type', 'DW_TAG_struct_type'): 27 | continue 28 | for member in die.iter_children(): 29 | if member.tag != 'DW_TAG_subprogram': 30 | continue 31 | result.append(member) 32 | 33 | return result 34 | 35 | def type_basic_encoding(self, handler): 36 | return constants.DW_ATE_unsigned 37 | 38 | def type_basic_name(self, handler): 39 | return { 40 | 1: "unsigned char", 41 | 2: "unsigned short", 42 | 4: "unsigned int", 43 | 8: "unsigned long long", 44 | 16: "unsigned __int128", 45 | 32: "unsigned __int256", 46 | }[handler] 47 | 48 | def function_get_linkage_name(self, handler): 49 | return None 50 | 51 | def parameter_get_name(self, handler): 52 | return super().parameter_get_name(handler) 53 | 54 | 55 | def get_type_size(cu, offset, wordsize): 56 | # returns type_size, base_type_offset, is_exemplar[true=yes, false=maybe, none=no] 57 | type_die = cu.get_DIE_from_refaddr(offset + cu.cu_offset) 58 | base_type_die = type_die 59 | base_type_offset = offset 60 | 61 | while True: 62 | if base_type_die.tag == 'DW_TAG_pointer_type': 63 | return wordsize, base_type_offset, None 64 | elif base_type_die.tag == 'DW_TAG_structure_type': 65 | try: 66 | first_member = next(base_type_die.iter_children()) 67 | assert first_member.tag == 'DW_TAG_member' 68 | assert first_member.attributes['DW_AT_data_member_location'].value == 0 69 | except (KeyError, StopIteration, AssertionError): 70 | return wordsize, base_type_offset, None 71 | base_type_offset = first_member.attributes['DW_AT_type'].value 72 | base_type_die = cu.get_DIE_from_refaddr(base_type_offset + cu.cu_offset) 73 | continue 74 | # TODO unions 75 | elif base_type_die.tag == 'DW_TAG_base_type': 76 | size = base_type_die.attributes['DW_AT_byte_size'].value 77 | if any(ty in base_type_die.attributes['DW_AT_name'].value for ty in (b'char', b'int', b'long')): 78 | return size, base_type_offset, True 79 | else: 80 | return size, base_type_offset, None 81 | elif 'DW_AT_type' in base_type_die.attributes: 82 | base_type_offset = base_type_die.attributes['DW_AT_type'].value 83 | base_type_die = cu.get_DIE_from_refaddr(base_type_offset + cu.cu_offset) 84 | continue 85 | else: 86 | # afaik this case is only reached for void types 87 | return wordsize, base_type_offset, None 88 | 89 | def build_type_to_size(ipath): 90 | with open(ipath, 'rb') as ifp: 91 | elf = ELFFile(ifp) 92 | arch = ELF.extract_arch(elf) 93 | 94 | ifp.seek(0) 95 | dwarf = elf.get_dwarf_info() 96 | 97 | cu_data = {} 98 | for cu in dwarf.iter_CUs(): 99 | type_to_size = {} 100 | # import ipdb; ipdb.set_trace() 101 | cu_data[cu.cu_offset] = type_to_size 102 | for die in cu.iter_DIEs(): 103 | attr = die.attributes.get("DW_AT_type", None) 104 | if attr is None: 105 | continue 106 | 107 | type_size, _, _ = get_type_size(cu, attr.value, arch.bytes) 108 | type_to_size[attr.value] = type_size 109 | 110 | return cu_data 111 | 112 | class IdaStubRewriter(TypeStubRewriter): 113 | 114 | def get_attribute(self, die, name): 115 | r = super().get_attribute(die, name) 116 | if name != 'DW_AT_type' or r is None: 117 | return r 118 | size = self.cu_data[r.cu.cu_offset][r.offset - r.cu.cu_offset] 119 | return size 120 | 121 | def type_basic_size(self, handler): 122 | return handler 123 | 124 | def function_get_name(self, handler): 125 | r = super().function_get_name(handler) 126 | return r 127 | 128 | def parameter_get_location(self, handler): 129 | return None 130 | 131 | class GhidraStubRewriter(TypeStubRewriter): 132 | 133 | def __init__(self, fp, cu_data=None, low_pc_to_funcname=None): 134 | super().__init__(fp, cu_data) 135 | self.low_pc_to_funcname = low_pc_to_funcname 136 | 137 | def get_attribute(self, die, name): 138 | r = super().get_attribute(die, name) 139 | if name != 'DW_AT_type' or r is None: 140 | return r 141 | if die.tag == "DW_TAG_formal_parameter": 142 | varname = self.get_attribute(die, "DW_AT_name") 143 | if varname == b"this": 144 | return VOID_STAR_ID 145 | size = self.cu_data[r.cu.cu_offset][r.offset - r.cu.cu_offset] 146 | return size 147 | 148 | def type_basic_size(self, handler): 149 | if handler == VOID_STAR_ID: 150 | return 8 # assuming the binary is 64-bit 151 | return handler 152 | 153 | def function_get_name(self, handler): 154 | # import ipdb; ipdb.set_trace() 155 | low_pc = self.get_attribute(handler, "DW_AT_low_pc") 156 | if low_pc is not None and str(low_pc) in self.low_pc_to_funcname: 157 | return self.low_pc_to_funcname[str(low_pc)] 158 | return super().function_get_name(handler) 159 | 160 | def type_ptr_of(self, handler): 161 | if handler == VOID_STAR_ID: 162 | return VOID 163 | return None 164 | 165 | def parameter_get_artificial(self, handler): 166 | return None 167 | 168 | def parameter_get_location(self, handler): 169 | return super().parameter_get_location(handler) 170 | 171 | 172 | def get_spec_offsets_and_names(ipath): 173 | 174 | elf = ELFFile(open(ipath, "rb")) 175 | all_spec_offsets = [ ] 176 | low_pc_to_funcname = {} 177 | spec_offset_to_low_pc = {} 178 | 179 | dwarf = elf.get_dwarf_info() 180 | for cu in dwarf.iter_CUs(): 181 | cu_offset = cu.cu_offset 182 | for die in cu.iter_DIEs(): 183 | for subdie in cu.iter_DIE_children(die): 184 | if subdie.tag == "DW_TAG_subprogram": 185 | if "DW_AT_low_pc" in subdie.attributes: 186 | low_pc = str(subdie.attributes.get("DW_AT_low_pc").value) 187 | if "DW_AT_specification" in subdie.attributes: 188 | spec_offset = subdie.attributes["DW_AT_specification"].value 189 | global_spec_offset = spec_offset + cu_offset 190 | all_spec_offsets.append(global_spec_offset) 191 | spec_offset_to_low_pc[str(global_spec_offset)] = low_pc 192 | 193 | for cu in dwarf.iter_CUs(): 194 | for die in cu.iter_DIEs(): 195 | for subdie in cu.iter_DIE_children(die): 196 | if subdie.offset in all_spec_offsets: 197 | if "DW_AT_name" in subdie.attributes: 198 | low_pc_to_funcname[spec_offset_to_low_pc[str(subdie.offset)]] = subdie.attributes.get("DW_AT_name").value 199 | 200 | return low_pc_to_funcname 201 | 202 | def type_strip_target_binary(ipath, opath, decompiler): 203 | try: 204 | cu_data = build_type_to_size(ipath) 205 | if decompiler == "ghidra": 206 | low_pc_to_funcname = get_spec_offsets_and_names(ipath) 207 | GhidraStubRewriter.rewrite_dwarf(in_path=ipath, out_path=opath, cu_data=cu_data, low_pc_to_funcname=low_pc_to_funcname) 208 | elif decompiler == "ida": 209 | IdaStubRewriter.rewrite_dwarf(in_path=ipath, out_path=opath, cu_data=cu_data) 210 | else: 211 | l.error("Unsupported Decompiler. Please choose from IDA or Ghidra") 212 | except Exception as e: 213 | l.error(f"Error occured while creating a type-strip binary: {e}") 214 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/decompiler/ida_analysis.py: -------------------------------------------------------------------------------- 1 | import idaapi 2 | import idc 3 | import ida_funcs 4 | import ida_hexrays 5 | import ida_kernwin 6 | import ida_loader 7 | 8 | from collections import namedtuple, defaultdict 9 | from sortedcontainers import SortedDict 10 | 11 | from elftools.dwarf.descriptions import describe_reg_name 12 | from elftools.elf.elffile import ELFFile 13 | from elftools.dwarf.dwarf_expr import DWARFExprParser 14 | from elftools.dwarf import locationlists 15 | 16 | import json 17 | 18 | LocationEntry = namedtuple("LocationEntry", ("begin_offset", "end_offset", "location")) 19 | NameResult = namedtuple("NameResult", ("name", "size")) 20 | 21 | class RegVarAnalysis: 22 | def __init__(self, fname): 23 | a = ELFFile(open(fname, 'rb')) 24 | b = a.get_dwarf_info() 25 | 26 | self.result = defaultdict(SortedDict) 27 | 28 | self.loc_parser = b.location_lists() 29 | self.expr_parser = DWARFExprParser(b.structs) 30 | self.range_lists = b.range_lists() 31 | 32 | for c in b.iter_CUs(): 33 | for x in c.iter_DIEs(): 34 | if x.tag != 'DW_TAG_variable' or 'DW_AT_name' not in x.attributes or x.get_parent() is x.cu.get_top_DIE() or 'DW_AT_location' not in x.attributes: 35 | continue 36 | # ??? 37 | if x.attributes['DW_AT_location'].form == 'DW_FORM_exprloc': 38 | loclist = self.get_single_loc(x) 39 | elif x.attributes['DW_AT_location'].form != 'DW_FORM_sec_offset': 40 | assert False 41 | else: 42 | loclist = self.get_loclist(x) 43 | for loc in loclist: 44 | if len(loc.location) != 1 or not loc.location[0].op_name.startswith('DW_OP_reg'): 45 | # discard complicated variables 46 | continue 47 | expr = loc.location[0] 48 | if expr.op_name == 'DW_OP_regx': 49 | reg_name = describe_reg_name(expr.args[0], a.get_machine_arch()) 50 | else: 51 | reg_name = describe_reg_name(int(expr.op_name[9:]), a.get_machine_arch()) 52 | self.result[reg_name][loc.begin_offset] = NameResult(x.attributes['DW_AT_name'].value.decode(), loc.end_offset - loc.begin_offset) 53 | 54 | def get_single_loc(self, f): 55 | base_addr = 0 56 | low_pc = f.cu.get_top_DIE().attributes.get("DW_AT_low_pc", None) 57 | if low_pc is not None: 58 | base_addr = low_pc.value 59 | parent = f.get_parent() 60 | ranges = [] 61 | while parent is not f.cu.get_top_DIE(): 62 | if 'DW_AT_low_pc' in parent.attributes and 'DW_AT_high_pc' in parent.attributes: 63 | ranges.append(( 64 | parent.attributes['DW_AT_low_pc'].value + base_addr, 65 | parent.attributes['DW_AT_high_pc'].value + base_addr, 66 | )) 67 | break 68 | if 'DW_AT_ranges' in parent.attributes: 69 | rlist = self.range_lists.get_range_list_at_offset(parent.attributes['DW_AT_ranges'].value) 70 | ranges = [ 71 | (rentry.begin_offset + base_addr, rentry.end_offset + base_addr) for rentry in rlist 72 | ] 73 | break 74 | parent = parent.get_parent() 75 | else: 76 | return [] 77 | 78 | return [LocationEntry( 79 | location=self.expr_parser.parse_expr(f.attributes['DW_AT_location'].value), 80 | begin_offset=begin, 81 | end_offset=end 82 | ) for begin, end in ranges] 83 | 84 | 85 | def get_loclist(self, f): 86 | base_addr = 0 87 | low_pc = f.cu.get_top_DIE().attributes.get("DW_AT_low_pc", None) 88 | if low_pc is not None: 89 | base_addr = low_pc.value 90 | 91 | loc_list = self.loc_parser.get_location_list_at_offset(f.attributes['DW_AT_location'].value) 92 | result = [] 93 | for item in loc_list: 94 | if type(item) is locationlists.LocationEntry: 95 | try: 96 | result.append(LocationEntry( 97 | base_addr + item.begin_offset, 98 | base_addr + item.end_offset, 99 | self.expr_parser.parse_expr(item.loc_expr))) 100 | except KeyError as e: 101 | if e.args[0] == 249: # gnu extension dwarf expr ops 102 | continue 103 | else: 104 | raise 105 | elif type(item) is locationlists.BaseAddressEntry: 106 | base_addr = item.base_address 107 | else: 108 | raise TypeError("What kind of loclist entry is this?") 109 | return result 110 | 111 | def lookup(self, reg, addr): 112 | try: 113 | key = next(self.result[reg].irange(maximum = addr, reverse=True)) 114 | except StopIteration: 115 | return None 116 | else: 117 | val = self.result[reg][key] 118 | if key + val.size <= addr: 119 | return None 120 | return val.name 121 | 122 | analysis: RegVarAnalysis = None 123 | 124 | def setup(): 125 | global analysis 126 | if analysis is not None: 127 | return 128 | path = ida_loader.get_path(ida_loader.PATH_TYPE_CMD) 129 | analysis = RegVarAnalysis(path) 130 | 131 | def dump_list(list_, filename): 132 | with open(filename, 'w') as w: 133 | w.write("\n".join(list_)) 134 | 135 | def write_json(data, filename): 136 | with open(filename, 'w') as w: 137 | w.write(json.dumps(data)) 138 | 139 | def go(): 140 | setup() 141 | 142 | ea = 0 143 | collect_addrs, mangled_names_to_demangled_names = [], {} 144 | filename = idc.ARGV[1] 145 | while True: 146 | func = ida_funcs.get_next_func(ea) 147 | if func is None: 148 | break 149 | ea = func.start_ea 150 | seg = idc.get_segm_name(ea) 151 | if seg != ".text": 152 | continue 153 | collect_addrs.append(str(ea)) 154 | typ = idc.get_type(ea) 155 | # void sometimes introduce extra variables, updating return type helps in variable matching for type-strip binary 156 | if 'void' in str(typ): 157 | newtype = str(typ).replace("void", f"__int64 {str(ida_funcs.get_func_name(ea))}") + ";" 158 | res = idc.SetType(ea, newtype) 159 | 160 | print("analyzing" , ida_funcs.get_func_name(ea)) 161 | analyze_func(func) 162 | # # Demangle the name 163 | mangled_name = ida_funcs.get_func_name(ea) 164 | demangled_name = idc.demangle_name(mangled_name, idc.get_inf_attr(idc.INF_SHORT_DN)) 165 | if demangled_name: 166 | mangled_names_to_demangled_names[mangled_name] = demangled_name 167 | else: 168 | mangled_names_to_demangled_names[mangled_name] = mangled_name 169 | 170 | dump_list(collect_addrs, filename) 171 | write_json(mangled_names_to_demangled_names, f'{filename}_names') 172 | 173 | def analyze_func(func): 174 | cfunc = ida_hexrays.decompile_func(func, None, 0) 175 | if cfunc is None: 176 | return 177 | v = Visitor(func.start_ea, cfunc) 178 | v.apply_to(cfunc.body, None) 179 | return v 180 | 181 | class Visitor(idaapi.ctree_visitor_t): 182 | def __init__(self, ea, cfunc): 183 | super().__init__(idaapi.CV_FAST) 184 | self.ea = ea 185 | self.cfunc = cfunc 186 | self.vars = [] 187 | self.already_used = {lvar.name for lvar in cfunc.lvars if lvar.has_user_name} 188 | self.already_fixed = set(self.already_used) 189 | 190 | def visit_expr(self, expr): 191 | if expr.op == ida_hexrays.cot_var: 192 | lvar = expr.get_v().getv() 193 | old_name = lvar.name 194 | if expr.ea == idc.BADADDR: 195 | pass 196 | else: 197 | if old_name not in self.already_fixed: 198 | if lvar.location.is_reg1(): 199 | reg_name = ida_hexrays.print_vdloc(lvar.location, 8) 200 | if reg_name in analysis.result: 201 | var_name = analysis.lookup(reg_name, expr.ea) 202 | if var_name: 203 | nonce_int = 0 204 | nonce = '' 205 | while var_name + nonce in self.already_used: 206 | nonce = '_' + str(nonce_int) 207 | nonce_int += 1 208 | name = var_name + nonce 209 | ida_hexrays.rename_lvar(self.ea, old_name, name) 210 | self.already_used.add(name) 211 | self.already_fixed.add(old_name) 212 | 213 | return 0 214 | 215 | idaapi.auto_wait() 216 | go() 217 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/dwarf_info.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from elftools.elf.elffile import ELFFile 3 | from elftools.dwarf.locationlists import ( 4 | LocationEntry, LocationExpr, LocationParser) 5 | from elftools.dwarf.descriptions import ( 6 | describe_DWARF_expr, _import_extra, describe_attr_value,set_global_machine_arch) 7 | from collections import defaultdict 8 | 9 | l = logging.getLogger('main') 10 | 11 | def is_elf(binary_path: str): 12 | 13 | with open(binary_path, 'rb') as rb: 14 | if rb.read(4) != b"\x7fELF": 15 | l.error(f"File is not an ELF format: {binary_path}") 16 | return None 17 | rb.seek(0) 18 | elffile = ELFFile(rb) 19 | if not elffile.has_dwarf_info(): 20 | l.error(f"No DWARF info found in: {binary_path}") 21 | return None 22 | return elffile.get_dwarf_info() 23 | 24 | class Dwarf: 25 | def __init__(self, binary_path, binary_name, decompiler) -> None: 26 | self.binary_path = binary_path 27 | self.binary_name = binary_name 28 | self.decompiler = decompiler 29 | self.dwarf_info = is_elf(self.binary_path) 30 | self.vars_in_each_func = defaultdict(list) 31 | self.spec_offset_var_names = defaultdict(set) if decompiler == 'ghidra' else None 32 | self.linkage_name_to_func_name = {} 33 | self._load() 34 | 35 | def _load(self): 36 | if not self.dwarf_info: 37 | l.error(f"Failed to load DWARF information for: {self.binary_path}") 38 | return 39 | if self.decompiler == 'ida': 40 | self.spec_offset_var_names = self.collect_spec_and_names() 41 | self.vars_in_each_func, self.linkage_name_to_func_name = self.read_dwarf() 42 | 43 | # for Ghidra only 44 | def collect_spec_and_names(self): 45 | # 1. get CU's offset 2. get spec offset, if any function resolves use that function name with vars 46 | spec_offset_var_names = defaultdict(set) 47 | for CU in self.dwarf_info.iter_CUs(): 48 | try: 49 | current_cu_offset = CU.cu_offset 50 | top_DIE = CU.get_top_DIE() 51 | self.die_info_rec(top_DIE) 52 | current_spec_offset = 0 53 | 54 | for i in range(0, len(CU._dielist)): 55 | if CU._dielist[i].tag == 'DW_TAG_subprogram': 56 | 57 | spec_offset, current_spec_offset = 0, 0 58 | spec = CU._dielist[i].attributes.get("DW_AT_specification", None) 59 | 60 | # specification value + CU offset = defination of function 61 | # 15407 + 0xbd1 = 0x8800 62 | # <8800>: Abbrev Number: 74 (DW_TAG_subprogram) 63 | # <8801> DW_AT_external 64 | 65 | if spec: 66 | spec_offset = spec.value 67 | current_spec_offset = spec_offset + current_cu_offset 68 | 69 | if CU._dielist[i].tag == 'DW_TAG_formal_parameter' or CU._dielist[i].tag == 'DW_TAG_variable': 70 | if self.is_artifical(CU._dielist[i]): 71 | continue 72 | name = self.get_name(CU._dielist[i]) 73 | if not name: 74 | continue 75 | if current_spec_offset: 76 | spec_offset_var_names[str(current_spec_offset)].add(name) 77 | 78 | except Exception as e: 79 | l.error(f"Error in collect_spec_and_names :: {self.binary_path} :: {e}") 80 | return spec_offset_var_names 81 | 82 | # for global vars - consider only variables those have locations. 83 | # otherwise variables like stderr, optind, make it to global variables list and 84 | # sometimes these can be marked as valid variables for a particular binary 85 | 86 | def read_dwarf(self): 87 | # for each binary 88 | if not self.dwarf_info: 89 | return defaultdict(list), {} 90 | vars_in_each_func = defaultdict(list) 91 | previous_func_offset = 0 92 | l.debug(f"Reading DWARF info from :: {self.binary_path}") 93 | for CU in self.dwarf_info.iter_CUs(): 94 | try: 95 | current_func_name, ch, func_name = 'global_vars', '', '' 96 | func_offset = 0 97 | tmp_list, global_vars_list, local_vars_and_args_list = [], [], [] 98 | 99 | # caches die info 100 | top_DIE = CU.get_top_DIE() 101 | self.die_info_rec(top_DIE) 102 | 103 | # 1. global vars 2. vars and args in subprogram 104 | for i in range(0, len(CU._dielist)): 105 | try: 106 | if CU._dielist[i].tag == 'DW_TAG_variable': 107 | if not CU._dielist[i].attributes.get('DW_AT_location'): 108 | continue 109 | if self.is_artifical(CU._dielist[i]): 110 | continue 111 | var_name = self.get_name(CU._dielist[i]) 112 | if not var_name: 113 | continue 114 | tmp_list.append(var_name) 115 | 116 | # shouldn't check location for parameters because we don't add their location in DWARF 117 | if CU._dielist[i].tag == 'DW_TAG_formal_parameter': 118 | param_name = self.get_name(CU._dielist[i]) 119 | if not param_name: 120 | continue 121 | tmp_list.append(param_name) 122 | 123 | # for last func and it's var 124 | local_vars_and_args_list = tmp_list 125 | if CU._dielist[i].tag == 'DW_TAG_subprogram': 126 | if self.is_artifical(CU._dielist[i]): 127 | continue 128 | # IDA 129 | if self.decompiler == 'ghidra': 130 | # func addr is func name for now 131 | low_pc = CU._dielist[i].attributes.get('DW_AT_low_pc', None) 132 | addr = None 133 | if low_pc is not None: 134 | addr = str(low_pc.value) 135 | 136 | ranges = CU._dielist[i].attributes.get('DW_AT_ranges', None) 137 | if ranges is not None: 138 | addr = self.dwarf_info.range_lists().get_range_list_at_offset(ranges.value)[0].begin_offset 139 | 140 | if not addr: 141 | continue 142 | func_name = addr 143 | # Ghidra 144 | if self.decompiler == 'ida': 145 | func_name = self.get_name(CU._dielist[i]) 146 | if not func_name: 147 | continue 148 | 149 | func_linkage_name = self.get_linkage_name(CU._dielist[i]) 150 | func_offset = CU._dielist[i].offset 151 | # because func name from dwarf is without class name but IDA gives us funcname with classname 152 | # so we match them using linkage_name 153 | if func_linkage_name: 154 | self.linkage_name_to_func_name[func_linkage_name] = func_name 155 | func_name = func_linkage_name 156 | else: 157 | # need it later for matching 158 | self.linkage_name_to_func_name[func_name] = func_name 159 | func_name = func_name 160 | 161 | # because DIE's are serialized and subprogram comes before vars and params 162 | vars_from_specification_subprogram = [] 163 | if previous_func_offset in self.spec_offset_var_names: 164 | vars_from_specification_subprogram = self.spec_offset_var_names[str(previous_func_offset)] 165 | previous_func_offset = str(func_offset) 166 | 167 | if current_func_name != func_name: 168 | if current_func_name == 'global_vars': 169 | global_vars_list.extend(tmp_list) 170 | vars_in_each_func[current_func_name].extend(global_vars_list) 171 | else: 172 | if self.decompiler == 'ida': 173 | if vars_from_specification_subprogram: 174 | tmp_list.extend(vars_from_specification_subprogram) 175 | vars_in_each_func[current_func_name].extend(tmp_list) 176 | ch = current_func_name 177 | current_func_name = func_name 178 | tmp_list = [] 179 | except Exception as e: 180 | l.error(f"Error in reading DWARF {e}") 181 | 182 | if current_func_name != ch and func_name: 183 | vars_in_each_func[func_name].extend(local_vars_and_args_list) 184 | 185 | except Exception as e: 186 | l.error(f"Error in read_dwarf :: {self.binary_name} :: {e}") 187 | l.debug(f"Number of functions in {self.binary_name}: {str(len(vars_in_each_func))}") 188 | return vars_in_each_func, self.linkage_name_to_func_name 189 | 190 | def get_name(self, die): 191 | name = die.attributes.get('DW_AT_name', None) 192 | if name: 193 | return name.value.decode('ascii') 194 | 195 | def get_linkage_name(self, die): 196 | name = die.attributes.get('DW_AT_linkage_name', None) 197 | if name: 198 | return name.value.decode('ascii') 199 | 200 | # compiler generated variables or functions (destructors, ...) 201 | def is_artifical(self, die): 202 | return die.attributes.get('DW_AT_artificial', None) 203 | 204 | def die_info_rec(self, die, indent_level=' '): 205 | """ A recursive function for showing information about a DIE and its 206 | children. 207 | """ 208 | child_indent = indent_level + ' ' 209 | for child in die.iter_children(): 210 | self.die_info_rec(child, child_indent) 211 | 212 | @classmethod 213 | def get_vars_in_each_func(cls, binary_path): 214 | dwarf = cls(binary_path) 215 | return dwarf.vars_in_each_func 216 | 217 | @classmethod 218 | def get_vars_for_func(cls, binary_path, func_name): 219 | dwarf = cls(binary_path) 220 | return dwarf.vars_in_each_func[func_name] + dwarf.vars_in_each_func['global_vars'] -------------------------------------------------------------------------------- /varbert/cmlm/preprocess.py: -------------------------------------------------------------------------------- 1 | # from pymongo import MongoClient 2 | from tqdm import tqdm 3 | from transformers import RobertaTokenizerFast 4 | import jsonlines 5 | import sys 6 | import json 7 | import logging 8 | import traceback 9 | import random 10 | import os 11 | import re 12 | import argparse 13 | from multiprocessing import Process, Manager, cpu_count, Pool 14 | from itertools import repeat 15 | from collections import defaultdict 16 | 17 | l = logging.getLogger('model_main') 18 | 19 | def read_input_files(filename): 20 | 21 | samples, sample_ids = [], set() 22 | with jsonlines.open(filename,'r') as f: 23 | for each in tqdm(f): 24 | samples.append(each) 25 | sample_ids.add(str(each['mongo_id'])) 26 | return sample_ids, samples 27 | 28 | 29 | def prep_input_files(input_file, num_processes, tokenizer, word_to_idx, max_sample_chunk, input_file_ids, output_file_ids, preprocessed_outfile): 30 | 31 | #------------------------ INPUT FILES ------------------------ 32 | output_data = Manager().list() 33 | pool = Pool(processes=num_processes) # Instantiate the pool here 34 | each_alloc = len(input_file) // (num_processes-1) 35 | input_data = [input_file[i*each_alloc:(i+1)*each_alloc] for i in range(0,num_processes)] 36 | x = [len(each) for each in input_data] 37 | print(f"Allocation samples for each worker: {len(input_data)}, {x}") 38 | 39 | pool.starmap(generate_id_files,zip(input_data, 40 | repeat(output_data), 41 | repeat(tokenizer), 42 | repeat(word_to_idx), 43 | repeat(max_sample_chunk) 44 | )) 45 | pool.close() 46 | pool.join() 47 | 48 | # Write to Output file 49 | with jsonlines.open(preprocessed_outfile,'w') as f: 50 | for each in tqdm(output_data): 51 | f.write(each) 52 | 53 | # check : #source ids == target_ids after parallel processing 54 | print(f"src_tgt_intersection:", len(input_file_ids - output_file_ids), len(output_file_ids-input_file_ids)) 55 | 56 | for each in tqdm(output_data): 57 | output_file_ids.add(str(each['_id']).split("_")[0]) 58 | print("src_tgt_intersection:",len(input_file_ids - output_file_ids), len(output_file_ids-input_file_ids)) 59 | print(len(output_data)) 60 | 61 | # validate : vocab_check 62 | vocab_check = defaultdict(int) 63 | total = 0 64 | for each in tqdm(output_data): 65 | variables = each['orig_vars'] 66 | for var in variables: 67 | total += 1 68 | normvar = normalize(var) 69 | _, vocab_stat = get_var_token(normvar,word_to_idx) 70 | if "in_vocab" in vocab_stat: 71 | vocab_check['in_vocab']+=1 72 | if "not_in_vocab" in vocab_stat: 73 | vocab_check['not_in_vocab']+=1 74 | if "part_in_vocab" in vocab_stat: 75 | vocab_check['part_in_vocab']+=1 76 | 77 | print(vocab_check, round(vocab_check['in_vocab']*100/total,2), round(vocab_check['not_in_vocab']*100/total,2), round(vocab_check['part_in_vocab']*100/total,2)) 78 | 79 | 80 | def generate_id_files(data, output_data, tokenizer, word_to_idx, n): 81 | 82 | for d in tqdm(data): 83 | try: 84 | ppw = preprocess_word_mask(d,tokenizer,word_to_idx) 85 | outrow = {"words":ppw[4],"mod_words":ppw[6],"inputids":ppw[0],"labels":ppw[1],"gold_texts":ppw[2],"gold_texts_id":ppw[3],"meta":[],"orig_vars":ppw[5], "_id":ppw[7]} 86 | # if input length is more than max possible 1024 then split and make more sample found by tracing _id 87 | if len(outrow['inputids']) > n: 88 | for i in range(0, len(outrow['inputids']), n): 89 | sample = {"words": outrow['words'][i:i+n], 90 | "mod_words":outrow['mod_words'][i:i+n], 91 | "inputids":outrow['inputids'][i:i+n], 92 | "labels":outrow["labels"][i:i+n], 93 | "gold_texts":outrow["gold_texts"], 94 | "gold_texts_id":outrow["gold_texts_id"], 95 | "orig_vars":outrow["orig_vars"], 96 | "meta":outrow["meta"], 97 | "_id":str(outrow['_id'])+"_"+str((i)//n),} 98 | output_data.append(sample) 99 | else: 100 | output_data.append(outrow) 101 | except: 102 | print("Unexpected error:", sys.exc_info()[0]) 103 | traceback.print_exception(*sys.exc_info()) 104 | 105 | def change_case(strt): 106 | return ''.join(['_'+i.lower() if i.isupper() else i for i in strt]).lstrip('_') 107 | def is_camel_case(s): 108 | return s != s.lower() and s != s.upper() and "_" not in s 109 | 110 | def normalize(k): 111 | if is_camel_case(k): k=change_case(k) 112 | else: k=k.lower() 113 | return k 114 | 115 | def get_var_token(norm_variable_word,word_to_idx): 116 | vocab_check = defaultdict(int) 117 | token = word_to_idx.get(norm_variable_word,args.vocab_size) 118 | if token == args.vocab_size: 119 | vocab_check['not_in_vocab']+=1 120 | if "_" in norm_variable_word: 121 | word_splits=norm_variable_word.split("_") 122 | word_splits = [ee for ee in word_splits if ee ] 123 | for x in word_splits: 124 | ptoken=word_to_idx.get(x,args.vocab_size) 125 | if ptoken!=args.vocab_size: 126 | token=ptoken 127 | vocab_check['part_in_vocab']+=1 128 | break 129 | else: 130 | vocab_check['in_vocab']+=1 131 | return [token], vocab_check 132 | 133 | 134 | def canonicalize_code(code): 135 | code = re.sub('//.*?\\n|/\\*.*?\\*/', '\\n', code, flags=re.S) 136 | lines = [l.rstrip() for l in code.split('\\n')] 137 | code = '\\n'.join(lines) 138 | # code = re.sub('@@\\w+@@(\\w+)@@\\w+', '\\g<1>', code) 139 | return code 140 | 141 | def preprocess_word_mask(text,tokenizer, word_to_idx): 142 | # vars_map = text['vars_map'] #needed for vartype detection 143 | # vars_map = dict([[ee[1],ee[0]] for ee in vars_map]) 144 | _id = text['mongo_id'] 145 | ftext = canonicalize_code(text['norm_func']) 146 | words = ftext.replace("\n"," ").split(" ") 147 | pwords =[] 148 | tpwords =[] 149 | owords =[] 150 | towords =[] 151 | pos=0 152 | masked_pos=[] 153 | var_words =[] 154 | var_toks = [] 155 | mod_words = [] 156 | orig_vars = [] 157 | 158 | vocab=tokenizer.get_vocab() 159 | 160 | for word in words: 161 | if re.search(args.var_loc_pattern, word): 162 | idx = 0 163 | for each_var in list(re.finditer(args.var_loc_pattern,word)): 164 | s = each_var.start() 165 | e = each_var.end() 166 | prefix = word[idx:s] 167 | var = word[s:e] 168 | orig_var = var.split("@@")[-2] 169 | 170 | # Somethings attached before the variables 171 | if prefix: 172 | toks = tokenizer.tokenize(prefix) 173 | for t in toks: 174 | mod_words.append(t) 175 | tpwords.append(vocab[t]) 176 | towords.append(vocab[t]) 177 | 178 | # Original variable handling 179 | # ---- IF dwarf : inputs: labels:tokenized,normalized var 180 | # ---- NOT dwarf: inputs: tokenized,orig_vars labels:tokenized,orig_vars 181 | 182 | norm_variable_word = normalize(orig_var) 183 | var_tokens, _ = get_var_token(norm_variable_word,word_to_idx) 184 | var_toks.append(var_tokens) 185 | var_words.append(norm_variable_word) #Gold_texts (gold labels) 186 | mod_words.append(orig_var) 187 | orig_vars.append(orig_var) 188 | 189 | tpwords.append(vocab[""]) 190 | towords.append(var_tokens[0]) 191 | 192 | idx = e 193 | 194 | # Postfix if any 195 | postfix = word[idx:] 196 | if postfix: 197 | toks = tokenizer.tokenize(postfix) 198 | for t in toks: 199 | mod_words.append(t) 200 | tpwords.append(vocab[t]) 201 | towords.append(vocab[t]) 202 | 203 | else: 204 | # When there is no variables 205 | toks = tokenizer.tokenize(word) 206 | for t in toks: 207 | if t == "": 208 | continue #skip adding the token if code has keyword 209 | mod_words.append(t) 210 | tpwords.append(vocab[t]) 211 | towords.append(vocab[t]) 212 | 213 | assert len(tpwords) == len(towords) 214 | assert len(var_toks) == len(var_words) 215 | 216 | # 0 1 2 3 4 5 6 7 217 | return tpwords,towords,var_words, var_toks, words, orig_vars, mod_words, _id 218 | 219 | 220 | def main(args): 221 | 222 | tokenizer = RobertaTokenizerFast.from_pretrained(args.tokenizer) 223 | word_to_idx = json.load(open(args.vocab_word_to_idx)) 224 | idx_to_word = json.load(open(args.vocab_idx_to_word)) 225 | max_sample_chunk = args.max_chunk_size-2 226 | args.vocab_size = len(word_to_idx) #OOV as UNK[assinged id=vocab_size] 227 | print("Vocab_size:",args.vocab_size) 228 | 229 | train, test = [], [] 230 | src_train_ids, srctest_ids = set(), set() 231 | tgt_train_ids, tgttest_ids = set(), set() 232 | 233 | src_train_ids, train = read_input_files(args.train_file) 234 | srctest_ids, test = read_input_files(args.test_file) 235 | print(f"Data size Train: {len(train)} \t Test: {len(test)}") 236 | 237 | num_processes = min(args.workers, cpu_count()) 238 | print("Running with #workers : ",num_processes) 239 | 240 | prep_input_files(train, num_processes, tokenizer, word_to_idx, max_sample_chunk, src_train_ids, tgt_train_ids, args.out_train_file ) 241 | prep_input_files(test, num_processes, tokenizer, word_to_idx, max_sample_chunk, srctest_ids, tgttest_ids, args.out_test_file) 242 | 243 | 244 | if __name__ == "__main__": 245 | 246 | parser = argparse.ArgumentParser() 247 | parser.add_argument('--train_file', type=str, help='name of the train file') 248 | parser.add_argument('--test_file', type=str, help='name of name of the test file') 249 | 250 | parser.add_argument('--tokenizer', type=str, help='path to the tokenizer') 251 | 252 | parser.add_argument('--vocab_word_to_idx', type=str, help='Output Vocab Word to index file') 253 | parser.add_argument('--vocab_idx_to_word', type=str, help='Output Vocab Index to Word file') 254 | 255 | parser.add_argument('--vocab_size', type=int, default=50001, help='size of output vocabulary') 256 | parser.add_argument('--var_loc_pattern', type=str, default="@@\w+@@\w+@@", help='pattern representing variable location') 257 | parser.add_argument('--max_chunk_size', type=int, default=1024, help='size of maximum chunk of input for the model') 258 | parser.add_argument('--workers', type=int, default=20, help='number of parallel workers you need') 259 | 260 | parser.add_argument('--out_train_file', type=str, help='name of the output train file') 261 | parser.add_argument('--out_test_file', type=str, help='name of name of the output test file') 262 | args = parser.parse_args() 263 | 264 | main(args) -------------------------------------------------------------------------------- /varcorpus/dataset-gen/parse_decompiled_code.py: -------------------------------------------------------------------------------- 1 | import re 2 | import logging 3 | 4 | 5 | l = logging.getLogger('main') 6 | # 7 | # Raw IDA decompiled code pre-processing 8 | # 9 | class IDAParser: 10 | def __init__(self, decompiled_code, binary_name, linkage_name_to_func_name) -> None: 11 | self.decompiled_code = decompiled_code 12 | self.functions = {} 13 | self.func_addr_to_name = {} 14 | self.func_name_to_addr = {} 15 | self.linkage_name_to_func_name = linkage_name_to_func_name 16 | self.func_name_to_linkage_name = self.create_dict_fn_to_ln() 17 | self.binary_name = binary_name 18 | self.func_name_wo_line = {} 19 | self.preprocess_ida_raw_code() 20 | 21 | def create_dict_fn_to_ln(self): 22 | 23 | # linkage name to func name (it is the signature abc(void), so split at '(') 24 | func_name_to_linkage_name = {} 25 | if self.linkage_name_to_func_name: 26 | for ln, fn in self.linkage_name_to_func_name.items(): 27 | func_name_to_linkage_name[fn.split('(')[0]] = ln 28 | return func_name_to_linkage_name 29 | 30 | def preprocess_ida_raw_code(self): 31 | 32 | data = self.decompiled_code 33 | func_addr_to_line_count = {} 34 | functions, self.func_name_to_addr, self.func_addr_to_name, self.func_name_wo_line = self.split_ida_c_file_into_funcs(data) 35 | for addr, func in functions.items(): 36 | try: 37 | func_sign = func.split('{')[0].strip() 38 | func_body = '{'.join(func.split('{')[1:]) 39 | 40 | if not addr in self.func_addr_to_name: 41 | continue 42 | func_name = self.func_addr_to_name[addr] 43 | 44 | # find local variables 45 | varlines_bodylines = func_body.strip("\n").split('\n\n') 46 | if len(varlines_bodylines) >= 2: 47 | var_dec_lines = varlines_bodylines[0] 48 | local_vars = self.find_local_vars(var_dec_lines) 49 | else: 50 | local_vars = [] 51 | self.functions[addr] = {'func_name': func_name, 52 | 'func_name_no_line': '_'.join(func_name.split('_')[:-1]), # refer to comments in split_ida_c_file_into_funcs 53 | 'func': func, 54 | 'func_prototype': func_sign, 55 | 'func_body': func_body, 56 | 'local_vars': local_vars, 57 | 'addr': addr 58 | } 59 | 60 | except Exception as e: 61 | l.error(f'Error in {self.binary_name}:{func_name}:{addr} = {e}') 62 | l.info(f'Functions after IDA parsing {self.binary_name} :: {len(self.functions)}') 63 | 64 | def split_ida_c_file_into_funcs(self, data: str): 65 | 66 | line_count = 1 67 | chunks = data.split('//----- ') 68 | data_declarations = chunks[0].split('//-------------------------------------------------------------------------')[2] 69 | func_dict, func_name_to_addr, func_addr_to_name, func_name_with_linenum_linkage_name = {}, {}, {}, {} 70 | func_name_wo_line = {} 71 | for chunk in chunks: 72 | lines = chunk.split('\n') 73 | line_count = line_count 74 | if not lines: 75 | continue 76 | if '-----------------------------------' in lines[0]: 77 | name = '' 78 | func_addr = lines[0].strip('-').strip() 79 | func = '\n'.join(lines[1:]) 80 | func_dict[func_addr] = func 81 | 82 | # func name and line number 83 | all_func_lines = func.splitlines() 84 | first_line = all_func_lines[0] 85 | sec_line = all_func_lines[1] 86 | if '(' in first_line: 87 | name = first_line.split('(')[0].split(' ')[-1].strip('*') + '_' + str(line_count + 1) 88 | elif '//' in first_line and '(' in sec_line: 89 | name = sec_line.split('(')[0].split(' ')[-1].strip('*') + '_' + str(line_count + 2) 90 | 91 | if name: 92 | # for cpp we use linkage name/mangled names instead of original function names 93 | # 1. dwarf info gives func name w/o class name but IDA has class_name::func_name. 2. to help with function overloading or same func name in different classes 94 | # if there is no mangled name, original function is copied to dict so we have all the func names. 95 | # so in case of C, we can use same dict. it is essentially func name 96 | 97 | tmp_wo_line = name.split('_') 98 | if len(tmp_wo_line) <=1 : 99 | continue 100 | name_wo_line, line_num = '_'.join(tmp_wo_line[:-1]), tmp_wo_line[-1] 101 | 102 | # replace demangled name with mangled name (cpp) or simply a func name (c) 103 | if name_wo_line in self.func_name_to_linkage_name: # it is not func from source 104 | # type-strip 105 | name = self.func_name_to_linkage_name[name_wo_line] 106 | name = f'{name}_{line_num}' 107 | elif name_wo_line.strip().split('::')[-1] in self.func_name_to_linkage_name: 108 | name = f'{name}_{line_num}' 109 | 110 | # strip 111 | func_name_to_addr[name] = func_addr 112 | func_addr_to_name[func_addr] = name 113 | func_name_wo_line[name_wo_line] = func_addr 114 | 115 | line_count += len(lines) - 1 116 | return func_dict, func_name_to_addr, func_addr_to_name, func_name_wo_line 117 | 118 | def find_local_vars(self, lines): 119 | # use regex 120 | local_vars = [] 121 | regex = r"(\w+(\[\d+\]|\d{0,6}));" 122 | matches = re.finditer(regex, lines) 123 | if matches: 124 | for m in matches: 125 | tmpvar = m.group(1) 126 | if not tmpvar: 127 | continue 128 | lv = tmpvar.split('[')[0] 129 | local_vars.append(lv) 130 | return local_vars 131 | 132 | 133 | # 134 | # Raw Ghidra decompiled code pre-processing 135 | # 136 | 137 | class GhidraParser: 138 | def __init__(self, decompiled_code_path, binary_name) -> None: 139 | self.decompiled_code_path = decompiled_code_path 140 | self.functions = {} 141 | self.func_addr_to_name = {} 142 | self.func_name_to_addr = {} 143 | self.linkage_name_to_func_name = None 144 | self.func_name_to_linkage_name = None 145 | self.binary_name = binary_name 146 | self.func_name_wo_line = {} 147 | self.preprocess_ghidra_raw_code() 148 | 149 | def preprocess_ghidra_raw_code(self): 150 | data = self.decompiled_code_path 151 | functions, self.func_name_to_addr, self.func_addr_to_name, self.func_name_wo_line = self.split_ghidra_c_file_into_funcs(data) 152 | if not functions: 153 | return 154 | for addr, func in functions.items(): 155 | try: 156 | func_sign = func.split('{')[0].strip() 157 | func_body = '{'.join(func.split('{')[1:]) 158 | if not addr in self.func_addr_to_name: 159 | continue 160 | func_name = self.func_addr_to_name[addr] 161 | 162 | varlines_bodylines = func_body.strip("\n").split('\n\n') 163 | if len(varlines_bodylines) >= 2: 164 | var_dec_lines = varlines_bodylines[0] 165 | local_vars = self.find_local_vars(varlines_bodylines) 166 | else: 167 | local_vars = [] 168 | 169 | self.functions[addr] = {'func_name': func_name, 170 | 'func_name_no_line': '_'.join(func_name.split('_')[:-1]), 171 | 'func': func, 172 | 'func_prototype': func_sign, 173 | 'func_body': func_body, 174 | 'local_vars': local_vars, 175 | # 'arguments': func_args, 176 | 'addr': addr 177 | } 178 | except Exception as e: 179 | l.error(f'Error in {self.binary_name}:{func_name}:{addr} = {e}') 180 | 181 | def split_ghidra_c_file_into_funcs(self, data: str): 182 | 183 | chunks = data.split('//----- ') 184 | 185 | func_dict, func_name_to_addr, func_addr_to_name = {}, {}, {} 186 | func_name_wo_line = {} 187 | 188 | line_count = 1 189 | for chunk in chunks[1:]: 190 | line_count = line_count 191 | lines = chunk.split('\n') 192 | if not lines: 193 | continue 194 | if '-----------------------------------' in lines[0]: 195 | func_addr = lines[0].strip('-').strip() 196 | func = '\n'.join(lines[1:]) 197 | func_dict[func_addr] = func 198 | 199 | # get func name and line number TODO: Ghidra had a func split in two lines - maybe use regex for it 200 | all_func_lines = func.split('\n\n') 201 | first_line = all_func_lines[0] 202 | if '(' in first_line: 203 | t_name = first_line.split('(')[0] 204 | if "\n" in t_name: 205 | name = t_name.split('\n')[-1].split(' ')[-1].strip('*') + '_' + str(line_count + 1) 206 | else: 207 | name = t_name.split(' ')[-1].strip('*') + '_' + str(line_count + 1) 208 | 209 | if name: 210 | name_wo_line = '_'.join(name[:-1]) 211 | func_name_to_addr[name] = func_addr 212 | func_addr_to_name[func_addr] = name 213 | func_name_wo_line[name_wo_line] = func_addr 214 | 215 | line_count += len(lines) - 1 216 | return func_dict, func_name_to_addr, func_addr_to_name, func_name_wo_line 217 | 218 | def find_local_vars(self, varlines_bodylines): 219 | 220 | all_vars = [] 221 | try: 222 | first_curly = '' 223 | sep = '' 224 | dec_end = 0 225 | for elem in range(len(varlines_bodylines)): 226 | if varlines_bodylines[elem] == '{' and first_curly == '': 227 | first_curly = '{' 228 | dec_start = elem + 1 229 | if first_curly == '{' and sep == '' and varlines_bodylines[elem] == ' ': 230 | dec_end = elem - 1 231 | 232 | for index in range(dec_start, dec_end + 1): 233 | if index < len(varlines_bodylines): 234 | f_sp = varlines_bodylines[index].split(' ') 235 | if f_sp: 236 | tmp_var = f_sp[-1][:-1].strip('*') 237 | 238 | if tmp_var.strip().startswith('['): 239 | up_var = f_sp[-2:-1][0].strip('*') 240 | all_vars.append(up_var) 241 | else: 242 | all_vars.append(tmp_var) 243 | 244 | except Exception as e: 245 | l.error(f'Error in finding local vars {self.binary_name} :: {e}') 246 | 247 | return all_vars 248 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import glob 4 | import shutil 5 | import logging 6 | from tqdm import tqdm 7 | from pathlib import Path 8 | import concurrent.futures 9 | from multiprocessing import Manager 10 | 11 | from binary import Binary 12 | from joern_parser import JoernParser 13 | from variable_matching import DataLoader 14 | from decompiler.run_decompilers import Decompiler 15 | from utils import set_up_data_dir, JoernServer, read_json 16 | from create_dataset_splits import create_train_and_test_sets 17 | 18 | l = logging.getLogger('main') 19 | 20 | class Runner: 21 | def __init__(self, decompiler, target_binaries, WORKERS, path_manager, PORT, language, DEBUG, language_map=None) -> None: 22 | self.decompiler = decompiler 23 | self.target_binaries = target_binaries 24 | self.WORKERS = WORKERS 25 | self.path_manager = path_manager 26 | self.binary = None 27 | self.PORT = PORT 28 | self.language = language 29 | self.DEBUG = DEBUG 30 | self.collect_workdir = Manager().dict() 31 | self.language_map = language_map or {} 32 | 33 | def _setup_environment_for_ghidra(self): 34 | java_home = "/usr/lib/jvm/java-17-openjdk-amd64" 35 | os.environ["JAVA_HOME"] = java_home 36 | if os.environ.get('JAVA_HOME') != java_home: 37 | l.error(f"INCORRECT JAVA VERSION FOR GHIDRA: {os.environ.get('JAVA_HOME')}") 38 | 39 | def decompile_runner(self, binary_path: str): 40 | """ 41 | Runner method to handle the decompilation of a binary. 42 | """ 43 | try: 44 | decompile_results = {} 45 | binary_name = Path(binary_path).name 46 | l.info(f"Processing target binary: {binary_name}") 47 | workdir = f'{self.path_manager.tmpdir}/workdir/{binary_name}' 48 | l.debug(f"{binary_name} in workdir {workdir}") 49 | set_up_data_dir(self.path_manager.tmpdir, workdir, self.decompiler) 50 | 51 | # collect dwarf info and create strip and type-strip binaries 52 | self.binary = Binary(binary_path, binary_name, self.path_manager, self.language, self.decompiler, True) 53 | if not (self.binary.strip_binary and self.binary.type_strip_binary): 54 | l.error(f"Strip/Type-strip binary not found for {self.binary.binary_name}") 55 | return 56 | l.debug(f"Strip/Type-strip binaries created for {self.binary.binary_name}") 57 | decomp_workdir = os.path.join(workdir, f"{self.decompiler}_data") 58 | decompiler_path = getattr(self.path_manager, f"{self.decompiler}_path") 59 | if self.decompiler == "ghidra": 60 | self._setup_environment_for_ghidra() 61 | for binary_type in ["type_strip", "strip"]: 62 | binary_path = getattr(self.binary, f"{binary_type}_binary") 63 | l.debug(f"Decompiling {self.binary.binary_name} :: {binary_type} with {self.decompiler} :: {binary_path}") 64 | 65 | dec = Decompiler( 66 | decompiler=self.decompiler, 67 | decompiler_path=Path(decompiler_path), 68 | decompiler_workdir=decomp_workdir, 69 | binary_name=self.binary.binary_name, 70 | binary_path=binary_path, 71 | binary_type=binary_type, 72 | decompiled_binary_code_path=os.path.join(self.path_manager.dc_path, str(self.decompiler), binary_type, f'{binary_name}.c'), 73 | failed_path=os.path.join(self.path_manager.failed_path_ida, binary_type), 74 | type_strip_addrs=os.path.join(self.path_manager.dc_path, self.decompiler, f"type_strip-addrs", binary_name), 75 | type_strip_mangled_names = os.path.join(self.path_manager.dc_path, self.decompiler, f"type_strip-names", binary_name) 76 | ) 77 | dec_path = dec 78 | 79 | if not dec_path: 80 | l.error(f"Decompilation failed for {self.binary.binary_name} :: {binary_type} ") 81 | l.info(f"Decompilation succesful for {self.binary.binary_name} :: {binary_type}!") 82 | decompile_results[binary_type] = dec_path 83 | 84 | except Exception as e: 85 | l.error(f"Error in decompiling {self.binary.binary_name} with {self.decompiler}: {e}") 86 | return 87 | 88 | self.collect_workdir[self.binary.binary_name] = workdir 89 | return self.binary.dwarf_dict, decompile_results 90 | 91 | def joern_runner(self, dwarf_data, binary_name, decompilation_results, PARSE): 92 | 93 | workdir = self.collect_workdir[binary_name] 94 | joern_data_strip, joern_data_type_strip = '', '' 95 | try: 96 | decompiled_code_strip_path = decompilation_results['strip'].decompiled_binary_code_path 97 | decompiled_code_type_strip_path = decompilation_results['type_strip'].decompiled_binary_code_path 98 | decompiled_code_type_strip_names = decompilation_results['type_strip'].type_strip_mangled_names 99 | dwarf_info_path = dwarf_data 100 | data_map_dump = os.path.join(self.path_manager.match_path, self.decompiler, binary_name) 101 | 102 | if not (os.path.exists(decompiled_code_strip_path) and os.path.exists(decompiled_code_type_strip_path)): 103 | l.error(f"Decompiled code not found for {binary_name} :: {self.decompiler}") 104 | return 105 | 106 | # paths for joern 107 | joern_strip_path = os.path.join(self.path_manager.joern_data_path, self.decompiler, 'strip', (binary_name + '.json')) 108 | joern_type_strip_path = os.path.join(self.path_manager.joern_data_path, self.decompiler, 'type_strip', (binary_name + '.json')) 109 | 110 | if PARSE: 111 | # JOERN 112 | l.info(f'Joern parsing for {binary_name} :: strip :: {self.decompiler} :: in {workdir}') 113 | joern_data_strip = JoernParser(binary_name=binary_name, binary_type='strip', decompiler=self.decompiler, 114 | decompiled_code=decompiled_code_strip_path, workdir=workdir, port=self.PORT, 115 | outpath=joern_strip_path).joern_data 116 | 117 | l.info(f'Joern parsing for {binary_name} :: type-strip :: {self.decompiler} :: in {workdir}') 118 | joern_data_type_strip = JoernParser(binary_name=binary_name, binary_type='type-strip', decompiler=self.decompiler, 119 | decompiled_code=decompiled_code_type_strip_path, workdir=workdir, port=self.PORT, 120 | outpath=joern_type_strip_path).joern_data 121 | 122 | 123 | else: 124 | joern_data_strip = read_json(joern_strip_path) 125 | joern_data_type_strip = read_json(joern_type_strip_path) 126 | 127 | l.info(f'Joern parsing completed for {binary_name} :: strip :: {self.decompiler} :: {len(joern_data_strip)}') 128 | l.info(f'Joern parsing completed for {binary_name} :: type-strip :: {self.decompiler} :: {len(joern_data_type_strip)}') 129 | 130 | 131 | if not (joern_data_strip and joern_data_type_strip): 132 | l.info(f"oops! strip/ty-strip joern not found! {binary_name}") 133 | return 134 | 135 | l.info(f'Start mapping for {binary_name} :: {self.decompiler}') 136 | 137 | # Select language strictly from map if provided, else fall back to default 138 | language_to_use = self.language_map.get(binary_name, self.language) if self.language_map else self.language 139 | matchvariable = DataLoader(binary_name, dwarf_info_path, self.decompiler, decompiled_code_strip_path, 140 | decompiled_code_type_strip_path, joern_data_strip, joern_data_type_strip, data_map_dump, language_to_use, decompiled_code_type_strip_names) 141 | 142 | 143 | except Exception as e: 144 | l.info(f"Error! :: {binary_name} :: {self.decompiler} :: {e} ") 145 | 146 | 147 | def run(self, PARSE, splits): 148 | 149 | if PARSE: 150 | try: 151 | joern_progress = tqdm(total=0, desc="Joern Processing") 152 | with concurrent.futures.ProcessPoolExecutor(max_workers=self.WORKERS) as executor: 153 | future_to_binary = {executor.submit(self.decompile_runner, binary): binary for binary in self.target_binaries} 154 | # As each future completes, process it with joern_runner 155 | # for future in concurrent.futures.as_completed(future_to_binary): 156 | for future in tqdm(concurrent.futures.as_completed(future_to_binary), total=len(future_to_binary), desc="Decompiling Binaries"): 157 | binary = future_to_binary[future] 158 | binary_name = Path(binary).name 159 | binary_info, decompilation_results = future.result() 160 | if not decompilation_results: 161 | l.error(f"No Decompilation results for {binary_name}") 162 | return 163 | l.info(f"Start Joern for :: {binary_name}") 164 | try: 165 | joern_progress.total += 1 166 | # Setup JAVA_HOME for Joern 167 | os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64" 168 | java_env = os.environ.get('JAVA_HOME', None) 169 | if java_env != '/usr/lib/jvm/java-11-openjdk-amd64': 170 | l.error(f"INCORRECT JAVA VERSION JOERN: {java_env}") 171 | with JoernServer(self.path_manager.joern_dir, self.PORT) as joern_server: 172 | self.joern_runner(binary_info, binary_name, decompilation_results, PARSE) 173 | joern_progress.update(1) 174 | except Exception as e: 175 | l.error(f"Error in starting Joern server for {binary_name} :: {e}") 176 | finally: 177 | if joern_server.is_server_running(): 178 | joern_server.stop() 179 | l.debug(f"Stop Joern for :: {binary_name}") 180 | time.sleep(10) 181 | 182 | # dedup functions -> create train and test sets 183 | joern_progress.close() 184 | if splits: 185 | create_train_and_test_sets( f'{self.path_manager.tmpdir}', self.decompiler) 186 | 187 | except Exception as e: 188 | l.error(f"Error during parallel decompilation: {e}") 189 | 190 | else: 191 | create_train_and_test_sets( f'{self.path_manager.tmpdir}', self.decompiler) 192 | 193 | # if not debug, delete tmpdir and copy splits to data dir 194 | source_dir = f'{self.path_manager.tmpdir}/splits/' 195 | # Get all file paths 196 | file_paths = glob.glob(os.path.join(source_dir, '**', 'final*.jsonl'), recursive=True) 197 | 198 | for file_path in file_paths: 199 | # Check if the filename contains 'final' 200 | if 'final' in os.path.basename(file_path): 201 | relative_path = os.path.relpath(file_path, start=source_dir) 202 | new_destination_path = os.path.join(self.path_manager.data_dir, relative_path) 203 | # Create directories if they don't exist 204 | os.makedirs(os.path.dirname(new_destination_path), exist_ok=True) 205 | # Copy the file 206 | shutil.copy(file_path, new_destination_path) 207 | l.info(f"Copied {file_path} to {new_destination_path}") 208 | if self.path_manager.tmpdir and not self.DEBUG: 209 | shutil.rmtree(self.path_manager.tmpdir) 210 | l.info(f"Deleted {self.path_manager.tmpdir}") -------------------------------------------------------------------------------- /varbert/fine-tune/preprocess.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from transformers import RobertaTokenizerFast 3 | import jsonlines 4 | import sys 5 | import json 6 | import traceback 7 | import random 8 | import os 9 | import re 10 | import argparse 11 | from collections import defaultdict 12 | from multiprocessing import Process, Manager, cpu_count, Pool 13 | from itertools import repeat, islice 14 | 15 | def read_input_files(filename): 16 | 17 | samples, sample_ids = [], set() 18 | with jsonlines.open(filename,'r') as f: 19 | for each in tqdm(f): 20 | if each.get('fid') is None: 21 | new_fid = str(each['id']) + "-" + each['func_name'].replace("_", "-") 22 | each['fid'] = new_fid 23 | if each.get('type_stripped_norm_vars') is None: 24 | vars_map = each.get('vars_map') 25 | norm_var_type = {} 26 | if vars_map: 27 | for pair in vars_map: 28 | norm_var = pair[1] 29 | var = pair[0] 30 | if norm_var in norm_var_type and each["type_stripped_vars"][var] != 'dwarf': 31 | norm_var_type[norm_var] = 'dwarf' 32 | else: 33 | norm_var_type[norm_var] = each["type_stripped_vars"][var] 34 | each['type_stripped_norm_vars'] = norm_var_type 35 | 36 | sample_ids.add(each['fid']) 37 | samples.append(each) 38 | return sample_ids, samples 39 | 40 | def prep_input_files(input_file, num_processes, tokenizer, word_to_idx, max_sample_chunk, input_file_ids, output_file_ids, preprocessed_outfile): 41 | 42 | #------------------------ INPUT FILES ------------------------ 43 | output_data = Manager().list() 44 | pool = Pool(processes=num_processes) # Instantiate the pool here 45 | each_alloc = len(input_file) // (num_processes-1) 46 | input_data = [input_file[i*each_alloc:(i+1)*each_alloc] for i in range(0,num_processes)] 47 | x = [len(each) for each in input_data] 48 | print(f"Allocation samples for each worker: {len(input_data)}, {x}") 49 | 50 | pool.starmap(generate_id_files,zip(input_data, 51 | repeat(output_data), 52 | repeat(tokenizer), 53 | repeat(word_to_idx), 54 | repeat(max_sample_chunk) 55 | )) 56 | pool.close() 57 | pool.join() 58 | 59 | # Write to Output file 60 | with jsonlines.open(preprocessed_outfile,'w') as f: 61 | for each in tqdm(output_data): 62 | output_file_ids.add(str(each['fid']).split("_")[0]) 63 | f.write(each) 64 | 65 | # validate : # source ids == target_ids all ids are present after parallel processing 66 | print(len(input_file_ids), len(output_file_ids)) 67 | print(f"src_tgt_intersection:", len(input_file_ids - output_file_ids), len(output_file_ids-input_file_ids)) 68 | 69 | # validate : vocab_check 70 | vocab_check = defaultdict(int) 71 | total = 0 72 | for each in tqdm(output_data): 73 | variables = list(each['type_dict'].keys()) 74 | for var in variables: 75 | total += 1 76 | _, vocab_stat = get_var_token(var, word_to_idx) 77 | if "in_vocab" in vocab_stat: 78 | vocab_check['in_vocab']+=1 79 | if "not_in_vocab" in vocab_stat: 80 | vocab_check['not_in_vocab']+=1 81 | if "part_in_vocab" in vocab_stat: 82 | vocab_check['part_in_vocab']+=1 83 | 84 | print(vocab_check, round(vocab_check['in_vocab']*100/total,2), round(vocab_check['not_in_vocab']*100/total,2)) 85 | 86 | def get_var_token(norm_variable_word,word_to_idx): 87 | vocab_check = defaultdict(int) 88 | token = word_to_idx.get(norm_variable_word,args.vocab_size) 89 | if token == args.vocab_size: 90 | vocab_check['not_in_vocab']+=1 91 | else: 92 | vocab_check['in_vocab']+=1 93 | return [token], vocab_check 94 | 95 | 96 | def preprocess_word_mask(text,tokenizer, word_to_idx): 97 | type_dict = text['type_stripped_norm_vars'] 98 | _id = text['_id'] if "_id" in text.keys() else text['fid'] 99 | ftext = text['norm_func'] 100 | words = ftext.replace("\n"," ").split(" ") 101 | pwords =[] 102 | tpwords =[] 103 | owords =[] 104 | towords =[] 105 | pos=0 106 | masked_pos=[] 107 | var_words =[] 108 | var_toks = [] 109 | mod_words = [] 110 | labels_typ = [] 111 | orig_vars = [] 112 | varmap_position = defaultdict(list) 113 | mask_one_ida = False 114 | 115 | vocab=tokenizer.get_vocab() 116 | 117 | for word in words: 118 | 119 | 120 | if re.search(args.var_loc_pattern, word): 121 | idx = 0 122 | for each_var in list(re.finditer(args.var_loc_pattern,word)): 123 | s = each_var.start() 124 | e = each_var.end() 125 | prefix = word[idx:s] 126 | var = word[s:e] 127 | orig_var = var.split("@@")[-2] 128 | 129 | # Somethings attached before the variables 130 | if prefix: 131 | toks = tokenizer.tokenize(prefix) 132 | for t in toks: 133 | mod_words.append(t) 134 | tpwords.append(vocab[t]) 135 | towords.append(vocab[t]) 136 | labels_typ.append(-100) 137 | 138 | var_tokens, _ = get_var_token(orig_var,word_to_idx) 139 | var_toks.append(var_tokens) 140 | var_words.append(orig_var) # Gold_texts (gold labels) 141 | mod_words.append(orig_var) 142 | orig_vars.append(orig_var) 143 | 144 | # Second label generation variable Type (decomp vs dwarf) 145 | if orig_var not in type_dict or type_dict[orig_var] == args.decompiler: # if it is not present consider ida 146 | labels_typ.append(1) 147 | tpwords.append(vocab[""]) 148 | towords.append(-100) 149 | varmap_position[-100].append(pos) 150 | elif type_dict[orig_var] == "dwarf": 151 | labels_typ.append(0) 152 | tpwords.append(vocab[""]) 153 | towords.append(var_tokens[0]) 154 | varmap_position[orig_var].append(pos) 155 | else: 156 | print("ERROR: CHECK LABEL TYPE IN STRIPPED BIN DICTIONARY") 157 | exit(0) 158 | pos += 1 159 | 160 | idx = e 161 | 162 | # Postfix if any 163 | postfix = word[idx:] 164 | if postfix: 165 | toks = tokenizer.tokenize(postfix) 166 | for t in toks: 167 | mod_words.append(t) 168 | tpwords.append(vocab[t]) 169 | towords.append(vocab[t]) 170 | labels_typ.append(-100) 171 | 172 | else: 173 | # When there are no variables 174 | toks = tokenizer.tokenize(word) 175 | for t in toks: 176 | if t == "": 177 | continue # skip adding the token if code has keyword 178 | mod_words.append(t) 179 | tpwords.append(vocab[t]) 180 | towords.append(vocab[t]) 181 | labels_typ.append(-100) 182 | 183 | assert len(tpwords) == len(towords) 184 | assert len(labels_typ) == len(towords) 185 | assert len(var_toks) == len(var_words) 186 | 187 | # 0 1 2 3 4 5 6 7 8 9 188 | return tpwords,towords,var_words, var_toks, labels_typ, words, orig_vars, mod_words, type_dict, _id, varmap_position 189 | 190 | 191 | def generate_id_files(data, output_data, tokenizer, word_to_idx, n): 192 | 193 | for d in tqdm(data): 194 | try: 195 | ppw = preprocess_word_mask(d,tokenizer,word_to_idx) 196 | outrow = {"words":ppw[5],"mod_words":ppw[7],"inputids":ppw[0],"labels":ppw[1],"gold_texts":ppw[2],"gold_texts_id":ppw[3],"labels_type":ppw[4],"meta":[],"orig_vars":ppw[6], "type_dict":ppw[8], "fid":ppw[9],'varmap_position':ppw[10]} 197 | # if input length is more than max possible 1024 then split and make more sample found by tracing _id 198 | if len(outrow['inputids']) > n: 199 | for i in range(0, len(outrow['inputids']), n): 200 | sample = {"words": outrow['words'][i:i+n], 201 | "mod_words":outrow['mod_words'][i:i+n], 202 | "inputids":outrow['inputids'][i:i+n], 203 | "labels":outrow["labels"][i:i+n], 204 | "labels_type":outrow["labels_type"][i:i+n], 205 | "gold_texts":outrow["gold_texts"], 206 | "gold_texts_id":outrow["gold_texts_id"], 207 | "orig_vars":outrow["orig_vars"], 208 | "type_dict":outrow["type_dict"], 209 | "meta":outrow["meta"], 210 | "fid":outrow['fid']+"_"+str((i)//n), 211 | "varmap_position":outrow["varmap_position"], 212 | } 213 | output_data.append(sample) 214 | else: 215 | output_data.append(outrow) 216 | except: 217 | print("Unexpected error:", sys.exc_info()[0]) 218 | traceback.print_exception(*sys.exc_info()) 219 | 220 | 221 | 222 | def main(args): 223 | 224 | tokenizer = RobertaTokenizerFast.from_pretrained(args.tokenizer) 225 | word_to_idx = json.load(open(args.vocab_word_to_idx)) 226 | idx_to_word = json.load(open(args.vocab_idx_to_word)) 227 | max_sample_chunk = args.max_chunk_size-2 228 | args.vocab_size = len(word_to_idx) # OOV as UNK[assinged id=vocab_size] 229 | print(f"Vocab_size: {args.vocab_size}") 230 | 231 | train, test = [], [] 232 | src_train_ids, srctest_ids = set(), set() 233 | tgt_train_ids, tgttest_ids = set(), set() 234 | 235 | src_train_ids, train = read_input_files(args.train_file) 236 | srctest_ids, test = read_input_files(args.test_file) 237 | print(f"Data size Train: {len(train)} \t Test: {len(test)}") 238 | 239 | num_processes = min(args.workers, cpu_count()) 240 | print(f"Running with #workers : {num_processes}") 241 | 242 | prep_input_files(train, num_processes, tokenizer, word_to_idx, max_sample_chunk, src_train_ids, tgt_train_ids, args.out_train_file ) 243 | prep_input_files(test, num_processes, tokenizer, word_to_idx, max_sample_chunk, srctest_ids, tgttest_ids, args.out_test_file) 244 | 245 | if __name__ == "__main__": 246 | 247 | parser = argparse.ArgumentParser() 248 | parser.add_argument('--train_file', type=str, help='name of the train file') 249 | parser.add_argument('--test_file', type=str, help='name of name of the test file') 250 | parser.add_argument('--test_notintrain_file', type=str, help='name of the test not in train file') 251 | 252 | parser.add_argument('--tokenizer', type=str, help='path to the tokenizer') 253 | 254 | parser.add_argument('--vocab_word_to_idx', type=str, help='Output Vocab Word to index file') 255 | parser.add_argument('--vocab_idx_to_word', type=str, help='Output Vocab Index to Word file') 256 | 257 | parser.add_argument('--vocab_size', type=int, default=150001, help='size of output vocabulary') 258 | parser.add_argument('--var_loc_pattern', type=str, default="@@\w+@@\w+@@", help='pattern representing variable location') 259 | parser.add_argument('--decompiler', type=str, default="ida", help='decompiler for type prediction; ida or ghidra') 260 | parser.add_argument('--max_chunk_size', type=int, default=1024, help='size of maximum chunk of input for the model') 261 | parser.add_argument('--workers', type=int, default=30, help='number of parallel workers you need') 262 | 263 | parser.add_argument('--out_train_file', type=str, help='name of the output train file') 264 | parser.add_argument('--out_test_file', type=str, help='name of name of the output test file') 265 | # parser.add_argument('--out_test_notintrain_file', type=str, help='name of the output test not in train file') 266 | 267 | 268 | args = parser.parse_args() 269 | 270 | main(args) 271 | -------------------------------------------------------------------------------- /varbert/README.md: -------------------------------------------------------------------------------- 1 | ### Table of Contents 2 | 3 | 1. [Training VarBERT](#training-varbert) 4 | 2. [Fine-tuning VarBERT](#fine-tune) 5 | 3. [Vocab Files](#vocab-files) 6 | 4. [Tokenizer](#tokenizer) 7 | 5. [Masked Language Modeling (MLM)](#masked-language-modeling-mlm) 8 | 6. [Constrained Masked Language Modeling (CMLM) in VarBERT](#constrained-masked-language-modeling-cmlm-in-varbert) 9 | 7. [Resize Model](#resize-model) 10 | 11 | ### Training VarBERT 12 | 13 | In our paper, we follow a two-step training process: 14 | - **Pre-training**: VarBERT is pre-trained on source code functions (HSC data set) using Masked Language Modeling (MLM) followed by Constrained Masked Language Modeling (CMLM). 15 | - **Fine-tuning**: Subsequently, VarBERT is fine-tuned on top of the pre-trained model using VarCorpus (decompilation output of IDA and Ghidra) to predict variable names and variable origins (i.e., whether a variable originates from source code or is decompiler-generated). 16 | 17 | Training Process Overview (from the paper): 18 | 19 | BERT-Base → MLM → CMLM → Fine-tune 20 | 21 | 22 | This approach can be adapted for use with any other decompiler capable of generating C-Style decompilation output. Use the pre-trained model from step one and fine-tune with a new (or existing) decompiler. jump to [] 23 | 24 | **Essential Components for Training:** 25 | 1. Base Model: Choose from BERT-Base, MLM Model or CMLM Model. Remember to resize the Base Model if the vocab size on subsequent model is different ([How to Resize Model](#resize-model)). 26 | 2. [Tokenizer](#tokenizer) 27 | 3. Train and Test sets 28 | 4. [Vocab Files](#vocab-files): Contains the most frequent variable names from your training dataset. 29 | 30 | 31 | ### Fine-tune 32 | Fine-tuning is generally performed on top of a pre-trained model (MLM + CMLM Model in our case). However, fine-tuning can also be directly applied to an MLM model or a BERT-Base model. 33 | During this phase, the model learns to predict variable names and their origins. 34 | VarBERT predicts the `Top-N` variables (where N can be 1, 3, 5, 10) for variable names. For their origin, the output dwarf indicates a source code origin, while `ida` or `ghidra` suggest that the variable is decompiler-generated. 35 | 36 | Access our fine-tuned model from our paper: [Models](https://www.dropbox.com/scl/fo/socl7rd5lsv926whylqpn/h?rlkey=i0x74bdipj41hys5rorflxawo&dl=0) 37 | 38 | To train a new model follow these steps: 39 | 40 | - **Base Model**: [CMLM Model](https://www.dropbox.com/scl/fi/72ku0tf3o93kn67k60d7d/CMLM_MODEL.tar.gz?rlkey=8kwlfwc87uwcsab86np4bhub0&dl=0) 41 | - **Tokenizer**: Refer to [Tokenizer](#tokenizer) 42 | - **Train and Test sets**: Each VarCorpus data set tarball has both pre-processed and non-processed train and test sets. If you prefer to use pre-processed sets please jump to step 2 (i.e. training a model), otherwise, start with step 1. [VarCorpus](https://www.dropbox.com/scl/fo/3thmg8xoq2ugtjwjcgjsm/h?rlkey=azgjeq513g4semc1qdi5xyroj&dl=0). Alternatively, if you created a new data set using `generate.py` in [Building VarCorpus](./varcorpus/README.md) use files saved in `data dir` 43 | - **Vocab Files**: You can find our vocab files in tarball of each trained model available at [link](https://www.dropbox.com/scl/fo/socl7rd5lsv926whylqpn/h?rlkey=i0x74bdipj41hys5rorflxawo&dl=0) or refer to [Vocab Files](#vocab-files) to create new vocab files. 44 | 45 | 46 | 1. Preprocess data set for training 47 | 48 | ```python 49 | python3 preprocess.py \ 50 | --train_file \ 51 | --test_file \ 52 | --tokenizer \ 53 | --vocab_word_to_idx \ 54 | --vocab_idx_to_word \ 55 | --decompiler \ 56 | --max_chunk_size 800 \ 57 | --workers 2 \ 58 | --out_train_file \ 59 | --out_test_file 60 | ``` 61 | 62 | 2. Fine-tune a model 63 | 64 | 65 | ```python 66 | python3 -m torch.distributed.launch --nproc_per_node=2 training.py \ 67 | --overwrite_output_dir \ 68 | --train_data_file \ 69 | --output_dir \ 70 | --block_size 800 \ 71 | --tokenizer_name \ 72 | --model_type roberta \ 73 | --model_name_or_path \ 74 | --vocab_path \ 75 | --do_train \ 76 | --num_train_epochs 4 \ 77 | --save_steps 10000 \ 78 | --logging_steps 1000 \ 79 | --per_gpu_train_batch_size 4 \ 80 | --mlm; 81 | ``` 82 | This base model path can be any model. Either previously trained CMLM model or MLM model or BERT-base model. 83 | 84 | Note: To run training without distributed, use `python3 training.py`, with the same arguments. 85 | 86 | 3. Run Inference 87 | 88 | ```python 89 | python3 eval.py \ 90 | --model_name \ 91 | --tokenizer_name \ 92 | --block_size 800 \ 93 | --data_file \ 94 | --prefix ft_bintoo_test \ 95 | --batch_size 16 \ 96 | --pred_path resultdir \ 97 | --out_vocab_map 98 | ``` 99 | 100 | 101 | ### Vocab Files 102 | Vocab files consist of the top N most frequently occurring variable names from the chosen training dataset. Specifically, we use the top 50K variable names from the Human Source Code (HSC) dataset and the top 150K variable names from the VarCorpus dataset. These are variables our model learns upon. 103 | 104 | (We use `50001` and `150001` as vocab size for CMLM model and Fine-tuning respectively. top 50K variable names + 1 for UNK) 105 | 106 | Use existing vocab files from our paper: 107 | 108 | Please note that we have different vocab files for each model. You can find our vocab files in tarball of each trained model available at [link](https://www.dropbox.com/scl/fo/socl7rd5lsv926whylqpn/h?rlkey=i0x74bdipj41hys5rorflxawo&dl=0). 109 | 110 | To create new vocab files: 111 | 112 | ```python 113 | python3 generate_vocab.py \ 114 | --dataset_type \ 115 | --train_file \ 116 | --test_file \ 117 | --vocab_size \ 118 | --output_dir \ 119 | ``` 120 | 121 | 122 | ### Tokenizer 123 | To adapt model to the unique nature of the new data (source code), we use a Byte-Pair Encoding (BPE) tokenizer to learn a new source vocabulary. We train a tokenizer with a vocabulary size of 50K, similar to RoBERTa's 50,265. 124 | 125 | Use tokenizer trained on HSC data set (train set): [Tokenizer](https://www.dropbox.com/scl/fi/i8seayujpqdc0egavks18/tokenizer.tar.gz?rlkey=fnhorh3uo2diqv0v1qaymzo2r&dl=0) 126 | ```bash 127 | wget -O tokenizer.tar.gz https://www.dropbox.com/scl/fi/i8seayujpqdc0egavks18/tokenizer.tar.gz?rlkey=fnhorh3uo2diqv0v1qaymzo2r&dl=0 128 | ``` 129 | 130 | Training a new Tokenizer: 131 | 1. Prepare train set: The input to tokenizer should be in text format. 132 | (If using [HSC data set](https://www.dropbox.com/scl/fi/1eekwcsg7wr7cux6y34xb/hsc_data.tar.gz?rlkey=s3kjroqt7a27hoeoc56mfyljw&dl=0)) 133 | 134 | ```python 135 | python3 preprocess.py \ 136 | --input_file \ 137 | --output_file 138 | ``` 139 | 140 | 2. Training Tokenizer 141 | 142 | ```python 143 | python3 train_bpe_tokenizer.py \ 144 | --input_path \ 145 | --vocab_size 50265 \ 146 | --min_frequency 2 \ 147 | --output_path 148 | ``` 149 | 150 | 151 | **To pre-train model from scratch** 152 | 153 | ### Masked Language Modeling (MLM) 154 | 155 | Learn the representation of code tokens using BERT from scratch through a Masked Language Modeling approach similar to the one used in RoBERTa. In this process, some tokens are randomly masked, and the model learns to predict these masked tokens, thereby gaining a deeper understanding of code-token representations. 156 | 157 | #### Using a Pre-trained MLM Model 158 | 159 | Access our pre-trained MLM model, trained on 5.2M functions, from our paper: [MLM Model](https://www.dropbox.com/scl/fi/72ku0tf3o93kn67k60d7d/CMLM_MODEL.tar.gz?rlkey=8kwlfwc87uwcsab86np4bhub0&dl=0) 160 | 161 | #### To train a new model, follow these steps: 162 | 163 | 1. Preprocess data set for training 164 | [HSC data set](https://www.dropbox.com/scl/fi/1eekwcsg7wr7cux6y34xb/hsc_data.tar.gz?rlkey=s3kjroqt7a27hoeoc56mfyljw&dl=0) 165 | 166 | ```python 167 | python3 preprocess.py \ 168 | --train_file \ 169 | --test_file \ 170 | --output_train_file \ 171 | --output_test_file 172 | ``` 173 | 174 | 2. Train MLM Model 175 | 176 | - The MLM model is trained on top of the BERT-Base model using pre-processed HSC train and test files from step 1. 177 | - For the BERT-Base model, refer to [BERT-Base Model](https://www.dropbox.com/scl/fi/18p37f5drph8pekv8kcj2/BERT_Base.tar.gz?rlkey=3x4mpr4hmkyndunhg9fpu0p3b&dl=0) 178 | 179 | 180 | ```python 181 | python3 training.py \ 182 | --model_name_or_path \ 183 | --model_type roberta \ 184 | --tokenizer_name \ 185 | --train_file \ 186 | --validation_file \ 187 | --max_seq_length 800 \ 188 | --mlm_probability 0.15 \ 189 | --num_train_epochs 40 \ 190 | --do_train \ 191 | --do_eval \ 192 | --per_device_train_batch_size 44 \ 193 | --per_device_eval_batch_size 44 \ 194 | --output_dir \ 195 | --save_steps 10000 \ 196 | --logging_steps 5000 \ 197 | --overwrite_output_dir 198 | ``` 199 | 200 | ### Constrained Masked Language Modeling (CMLM) 201 | Constrained MLM is a variation of MLM. In this approach, tokens are not randomly masked; instead, we specifically mask certain tokens, which in our case are variable names in source code functions. 202 | 203 | #### Using a Pre-trained CMLM Model: 204 | Access our pre-trained CMLM model from the paper: [CMLM Model](https://www.dropbox.com/scl/fi/72ku0tf3o93kn67k60d7d/CMLM_MODEL.tar.gz?rlkey=8kwlfwc87uwcsab86np4bhub0&dl=0) 205 | 206 | To train a new model follow these steps: 207 | - **Base Model**: [MLM Model](https://www.dropbox.com/scl/fi/a0i61xeij0bogkusr4yf7/MLM_MODEL.tar.gz?rlkey=pqenu7f851sgdn6ofcfp6dxoa&dl=0) 208 | - **Tokenizer**: Refer to [Tokenizer](#tokenizer) 209 | - **Train and Test sets**: [CMLM Data set](https://www.dropbox.com/scl/fi/q0itko6fitpxx3dx71qhv/cmlm_dataset.tar.gz?rlkey=51j9iagvg8u3rak79euqjocml&dl=0) 210 | - **Vocab Files**: To generate new vocab refer to [Vocab Files](#vocab-files) or use exisiting at [link](https://www.dropbox.com/scl/fi/yot7urpeem53dttditg7p/cmlm_dataset.tar.gz?rlkey=cned7sgijladr1pu5ery82z8a&dl=0) 211 | 212 | 213 | 1. Preprocess data set for training 214 | [HSC CMLM data set](https://www.dropbox.com/scl/fi/q0itko6fitpxx3dx71qhv/cmlm_dataset.tar.gz?rlkey=51j9iagvg8u3rak79euqjocml&dl=0) 215 | 216 | ```python 217 | python3 preprocess.py \ 218 | --train_file \ 219 | --test_file \ 220 | --tokenizer \ 221 | --vocab_word_to_idx \ 222 | --vocab_idx_to_word \ 223 | --max_chunk_size 800 \ 224 | --out_train_file \ 225 | --out_test_file \ 226 | --workers 4 \ 227 | --vocab_size 50001 228 | ``` 229 | 230 | 2. Train CMLM Model 231 | The CMLM model is trained on top of the MLM model using pre-processed HSC train and test files from step 1. 232 | [MLM Model](https://www.dropbox.com/scl/fi/a0i61xeij0bogkusr4yf7/MLM_MODEL.tar.gz?rlkey=pqenu7f851sgdn6ofcfp6dxoa&dl=0) 233 | 234 | ```python 235 | python3 training.py \ 236 | --overwrite_output_dir \ 237 | --train_data_file \ 238 | --output_dir \ 239 | --block_size 800 \ 240 | --model_type roberta \ 241 | --model_name_or_path \ 242 | --tokenizer_name \ 243 | --do_train \ 244 | --num_train_epochs 30 \ 245 | --save_steps 50000 \ 246 | --logging_steps 5000 \ 247 | --per_gpu_train_batch_size 32 \ 248 | --mlm; 249 | ``` 250 | 251 | 3. Run Inference 252 | 253 | ```python 254 | python run_cmlm_scoring.py \ 255 | --model_name \ 256 | --data_file \ 257 | --prefix cmlm_hsc_5M_50K \ 258 | --pred_path \ 259 | --batch_size 40; 260 | ``` 261 | 262 | ### Resize Model 263 | 264 | It's necessary to resize the model when the vocab size changes. For example, if you're fine-tuning over a CMLM model initially trained with a 50K vocab size using a 150K vocab size for the fine-tuned, the CMLM model needs to be resized. 265 | 266 | ```python 267 | python3 resize_model.py \ 268 | --old_model \ 269 | --vocab_path \ 270 | --out_model_path 271 | ``` -------------------------------------------------------------------------------- /varbert/cmlm/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import pprint 5 | import logging 6 | import argparse 7 | import jsonlines 8 | import numpy as np 9 | import pandas as pd 10 | from typing import Dict, List, Tuple 11 | from collections import defaultdict 12 | from tqdm import tqdm, trange 13 | 14 | import torch 15 | from torch.nn.utils.rnn import pad_sequence 16 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 17 | from torch.utils.data.distributed import DistributedSampler 18 | 19 | from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score 20 | 21 | from transformers import ( 22 | WEIGHTS_NAME, 23 | AdamW, 24 | BertConfig, 25 | BertForMaskedLM, 26 | BertTokenizer, 27 | CamembertConfig, 28 | CamembertForMaskedLM, 29 | CamembertTokenizer, 30 | DistilBertConfig, 31 | DistilBertForMaskedLM, 32 | DistilBertTokenizer, 33 | GPT2Config, 34 | GPT2LMHeadModel, 35 | GPT2Tokenizer, 36 | OpenAIGPTConfig, 37 | OpenAIGPTLMHeadModel, 38 | OpenAIGPTTokenizer, 39 | PreTrainedModel, 40 | PreTrainedTokenizer, 41 | RobertaConfig, 42 | RobertaForMaskedLM, 43 | RobertaTokenizer, 44 | get_linear_schedule_with_warmup, 45 | ) 46 | 47 | import torch.nn as nn 48 | from torch.nn import CrossEntropyLoss, MSELoss 49 | 50 | from transformers.activations import ACT2FN, gelu 51 | 52 | 53 | from transformers import RobertaConfig 54 | from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel, BertPreTrainedModel 55 | from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaLMHead 56 | 57 | ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = { 58 | "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin", 59 | "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin", 60 | "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin", 61 | "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-pytorch_model.bin", 62 | "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-pytorch_model.bin", 63 | "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-pytorch_model.bin", 64 | } 65 | 66 | l = logging.getLogger('model_main') 67 | vocab_size = 50001 68 | 69 | class RobertaLMHead2(nn.Module): 70 | 71 | def __init__(self,config): 72 | super().__init__() 73 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 74 | # self.dense = nn.Linear(config.hidden_size, config.hidden_size*8) 75 | 76 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 77 | # self.layer_norm = nn.LayerNorm(8*config.hidden_size, eps=config.layer_norm_eps) 78 | 79 | self.decoder = nn.Linear(config.hidden_size, vocab_size, bias=False) 80 | # self.decoder = nn.Linear(8*config.hidden_size, vocab_size, bias=False) 81 | 82 | self.bias = nn.Parameter(torch.zeros(vocab_size)) 83 | 84 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 85 | self.decoder.bias = self.bias 86 | 87 | def forward(self, features, **kwargs): 88 | x = self.dense(features) 89 | x = gelu(x) 90 | x = self.layer_norm(x) 91 | 92 | # project back to size of vocabulary with bias 93 | x = self.decoder(x) 94 | 95 | return x 96 | 97 | 98 | class RobertaForMaskedLMv2(RobertaForMaskedLM): 99 | 100 | def __init__(self, config): 101 | super().__init__(config) 102 | self.lm_head2 = RobertaLMHead2(config) 103 | self.init_weights() 104 | 105 | def forward( 106 | self, 107 | input_ids=None, 108 | attention_mask=None, 109 | token_type_ids=None, 110 | position_ids=None, 111 | head_mask=None, 112 | inputs_embeds=None, 113 | encoder_hidden_states=None, 114 | encoder_attention_mask=None, 115 | masked_lm_labels=None, 116 | output_attentions=None, 117 | output_hidden_states=None, 118 | return_dict=None, 119 | ): 120 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 121 | 122 | # print("inputs:",input_ids,"labels:",labels) 123 | outputs = self.roberta( 124 | input_ids, 125 | attention_mask=attention_mask, 126 | token_type_ids=token_type_ids, 127 | position_ids=position_ids, 128 | head_mask=head_mask, 129 | inputs_embeds=inputs_embeds, 130 | encoder_hidden_states=encoder_hidden_states, 131 | encoder_attention_mask=encoder_attention_mask, 132 | output_attentions=output_attentions, 133 | output_hidden_states=output_hidden_states, 134 | return_dict=return_dict, 135 | ) 136 | sequence_output = outputs[0] 137 | prediction_scores = self.lm_head2(sequence_output) 138 | output_pred_scores = torch.topk(prediction_scores,k=20,dim=-1) 139 | outputs = (output_pred_scores,) # Add hidden states and attention if they are here 140 | 141 | masked_lm_loss = None 142 | if labels is not None: 143 | loss_fct = CrossEntropyLoss() 144 | masked_lm_loss = loss_fct(prediction_scores.view(-1, vocab_size), masked_lm_labels.view(-1)) 145 | outputs = (masked_lm_loss,) + outputs 146 | 147 | return outputs 148 | # output = (prediction_scores,) + outputs[2:] 149 | # return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 150 | 151 | 152 | class RobertaForCMLM(BertPreTrainedModel): 153 | config_class = RobertaConfig 154 | pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 155 | base_model_prefix = "roberta" 156 | 157 | def __init__(self, config): 158 | super().__init__(config) 159 | 160 | self.roberta = RobertaModel(config) 161 | self.lm_head = RobertaLMHead(config) 162 | 163 | self.init_weights() 164 | 165 | def get_output_embeddings(self): 166 | return self.lm_head.decoder 167 | 168 | def forward(self,input_ids=None,attention_mask=None,token_type_ids=None, 169 | position_ids=None,head_mask=None,inputs_embeds=None,masked_lm_labels=None): 170 | outputs = self.roberta( 171 | input_ids, 172 | attention_mask=attention_mask, 173 | token_type_ids=token_type_ids, 174 | position_ids=position_ids, 175 | head_mask=head_mask, 176 | inputs_embeds=inputs_embeds, 177 | ) 178 | sequence_output = outputs[0] 179 | prediction_scores = self.lm_head(sequence_output) 180 | output_pred_scores = torch.topk(prediction_scores,k=20,dim=-1) 181 | outputs = (output_pred_scores,) # Add hidden states and attention if they are here 182 | 183 | if masked_lm_labels is not None: 184 | loss_fct = CrossEntropyLoss() 185 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 186 | outputs = (masked_lm_loss,) + outputs 187 | 188 | return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) 189 | 190 | 191 | try: 192 | from torch.utils.tensorboard import SummaryWriter 193 | except ImportError: 194 | from tensorboardX import SummaryWriter 195 | 196 | 197 | logger = logging.getLogger(__name__) 198 | 199 | 200 | MODEL_CLASSES = { 201 | "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), 202 | "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 203 | "bert": (BertConfig, BertForMaskedLM, BertTokenizer), 204 | "roberta": (RobertaConfig, RobertaForMaskedLMv2, RobertaTokenizer), 205 | "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer), 206 | "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer), 207 | } 208 | 209 | 210 | class CMLDataset(Dataset): 211 | def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size=512,limit=None): 212 | assert os.path.isfile(file_path) 213 | logger.info("Creating features from dataset file at %s", file_path) 214 | 215 | self.examples=[] 216 | with jsonlines.open(file_path, 'r') as f: 217 | for ix,line in tqdm(enumerate(f),desc="Reading Jsonlines",ascii=True): 218 | if limit is not None and ix>limit: 219 | continue 220 | if (None in line["inputids"]) or (None in line["labels"]): 221 | print("LineNum:",ix,line) 222 | continue 223 | else: 224 | self.examples.append(line) 225 | 226 | if limit is not None: 227 | self.examples = self.examples[0:limit] 228 | self.block_size = int(block_size) 229 | self.tokenizer = tokenizer 230 | self.truncated={} 231 | 232 | def __len__(self): 233 | return len(self.examples) 234 | 235 | def __getitem__(self, i): 236 | block_size = self.block_size 237 | tokenizer = self.tokenizer 238 | 239 | item = self.examples[i] 240 | 241 | input_ids = item["inputids"] 242 | labels = item["labels"] 243 | assert len(input_ids) == len(labels) 244 | 245 | if len(input_ids) > block_size-2: 246 | self.truncated[i]=1 247 | 248 | if len(input_ids) >= block_size-2: 249 | input_ids = input_ids[0:block_size-2] 250 | labels = labels[0:block_size-2] 251 | elif len(input_ids) < block_size-2: 252 | input_ids = input_ids+[tokenizer.pad_token_id]*(self.block_size-2-len(input_ids)) 253 | labels = labels + [tokenizer.pad_token_id]*(self.block_size-2-len(labels)) 254 | 255 | input_ids = tokenizer.build_inputs_with_special_tokens(input_ids) 256 | labels = tokenizer.build_inputs_with_special_tokens(labels) 257 | 258 | assert len(input_ids) == len(labels) 259 | assert len(input_ids) == block_size 260 | try: 261 | input_ids = torch.tensor(input_ids, dtype=torch.long) 262 | labels = torch.tensor(labels, dtype=torch.long) 263 | mask_idxs = (input_ids==tokenizer.mask_token_id).bool() 264 | labels[~mask_idxs]=-100 265 | except: 266 | l.error(f"Unexpected error at index {i}: {sys.exc_info()[0]}") 267 | raise 268 | 269 | return input_ids , labels 270 | 271 | parser = argparse.ArgumentParser() 272 | parser.add_argument( 273 | "--model_name", 274 | type=str, 275 | help="The model checkpoint for weights initialization.", 276 | ) 277 | parser.add_argument( 278 | "--data_file", 279 | type=str, 280 | help="Input Data File to Score", 281 | ) 282 | parser.add_argument( 283 | "--meta_file", 284 | type=str, 285 | help="Input Meta File to Score", 286 | ) 287 | 288 | parser.add_argument( 289 | "--prefix", 290 | default="test", 291 | type=str, 292 | help="prefix to separate the output files", 293 | ) 294 | 295 | parser.add_argument( 296 | "--pred_path", 297 | default="outputs", 298 | type=str, 299 | help="path where the predictions will be stored", 300 | ) 301 | 302 | parser.add_argument( 303 | "--batch_size", 304 | default=20, 305 | type=int, 306 | help="Eval Batch Size", 307 | ) 308 | args = parser.parse_args() 309 | 310 | device = torch.device("cuda") 311 | n_gpu = torch.cuda.device_count() 312 | config_class, model_class, tokenizer_class = MODEL_CLASSES["roberta"] 313 | 314 | config = config_class.from_pretrained(args.model_name) 315 | tokenizer = tokenizer_class.from_pretrained(args.model_name) 316 | model = model_class.from_pretrained( 317 | args.model_name, 318 | from_tf=bool(".ckpt" in args.model_name), 319 | config=config, 320 | ) 321 | 322 | model.to(device) 323 | tiny_dataset = CMLDataset(tokenizer,file_path=args.data_file,block_size=1024) 324 | eval_sampler = SequentialSampler(tiny_dataset) 325 | eval_dataloader = DataLoader(tiny_dataset, sampler=eval_sampler, batch_size=args.batch_size) 326 | 327 | model.eval() 328 | eval_loss = 0.0 329 | nb_eval_steps = 0 330 | 331 | matched={1:0,3:0,5:0,10:0} 332 | totalmasked={1:0,3:0,5:0,10:0} 333 | 334 | pred_list={ 335 | 1 : [], 336 | 3 : [], 337 | 5 : [], 338 | 10: [] 339 | } 340 | gold_list=[] 341 | 342 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 343 | inputs, labels = batch[0], batch[1] 344 | only_masked = inputs==tokenizer.mask_token_id 345 | masked_gold = labels[only_masked] 346 | 347 | inputs = inputs.to(device) 348 | labels = labels.to(device) 349 | 350 | gold_list.append(masked_gold.tolist()) 351 | 352 | with torch.no_grad(): 353 | outputs = model(inputs, masked_lm_labels=labels) 354 | lm_loss = outputs[0] 355 | inference = outputs[1].indices.cpu() 356 | eval_loss += lm_loss.mean().item() 357 | 358 | # TopK Calculation 359 | masked_predict = inference[only_masked] 360 | for k in [1,3,5,10]: 361 | topked = masked_predict[:,0:k] 362 | pred_list[k].append(topked.tolist()) 363 | for i,cur_gold_tok in enumerate(masked_gold): 364 | totalmasked[k]+=1 365 | cur_predict_scores = topked[i] 366 | if cur_gold_tok in cur_predict_scores: 367 | matched[k]+=1 368 | 369 | nb_eval_steps += 1 370 | 371 | eval_loss = eval_loss / nb_eval_steps 372 | perplexity = torch.exp(torch.tensor(eval_loss)) 373 | print(f"Perplexity: ", perplexity) 374 | for i in [1,3,5,10]: 375 | print("TopK:", i, matched[i]/totalmasked[i]) 376 | 377 | print("Truncated:", len(tiny_dataset.truncated)) 378 | 379 | os.makedirs(args.pred_path, exist_ok=True) 380 | json.dump(pred_list,open(os.path.join(args.pred_path, args.prefix+"_pred_list.json"),"w")) 381 | json.dump(gold_list,open(os.path.join(args.pred_path, args.prefix+"_gold_list.json"),"w")) -------------------------------------------------------------------------------- /varcorpus/dataset-gen/create_dataset_splits.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import json 4 | import sys 5 | from tqdm import tqdm 6 | import concurrent.futures 7 | from multiprocessing import Manager 8 | from itertools import islice 9 | from functools import partial 10 | import time 11 | import re 12 | import hashlib 13 | import numpy as np 14 | import itertools 15 | import random 16 | import logging 17 | from collections import OrderedDict, Counter 18 | from preprocess_vars import normalize_variables 19 | 20 | l = logging.getLogger('main') 21 | 22 | manager = Manager() 23 | md5_to_funcline = manager.dict() 24 | funcline_to_data = manager.dict() 25 | cfuncs_md5 = manager.dict() 26 | cppfuncs_md5 = manager.dict() 27 | dwarf_vars = manager.list() 28 | decompiler_vars = manager.list() 29 | new_funcs = manager.list() 30 | 31 | def read_json(filename): 32 | with open(filename, 'r') as r: 33 | data = json.loads(r.read()) 34 | if data: 35 | return data 36 | 37 | def write_json(filename, data): 38 | with open(filename, 'w') as w: 39 | w.write(json.dumps(data)) 40 | 41 | def write_list(filename, data, distribution=False, save_list=True): 42 | if distribution: 43 | c = Counter(data) 44 | sorted_dict = dict(sorted(c.items(), key=lambda item: item[1], reverse = True)) 45 | write_json(filename + '_distribution.json', sorted_dict) 46 | 47 | if save_list: 48 | with open(filename, 'w') as w: 49 | w.write('\n'.join(data)) 50 | 51 | 52 | def write_jsonlines(path, data): 53 | with open(path, 'a+') as w: 54 | w.write(json.dumps(data) + "\n") 55 | return True 56 | 57 | # collect hash(func_body) 58 | def hash_funcs(workdir, decomp, file_): 59 | global md5_to_funcline 60 | global funcline_to_data 61 | global cfuncs_md5 62 | global cppfuncs_md5 63 | 64 | var_regex = r"@@(var_\d+)@@(\w+)@@" 65 | json_data = read_json(os.path.join(workdir, "map", decomp, file_)) 66 | if not json_data: 67 | return 68 | 69 | id_ = 0 70 | for name, data in islice(json_data.items(), 0, len(json_data)): 71 | func = data['func'] 72 | up_func = re.sub(var_regex, "\\2", func) 73 | # dedup func body 74 | func_body = '{'.join(up_func.split('{')[1:]) 75 | # for Ghidra 76 | # if 'anon_var' in func_body or 'anon_func' in func_body: 77 | # continue 78 | func_body = func_body.replace('\t', '').replace('\n', '').replace(' ','') 79 | md5 = hashlib.md5(func_body.encode('utf-8')).hexdigest() 80 | id_ += 1 81 | data['id'] = id_ 82 | data['func_name'] = name 83 | data['md5'] = md5 84 | md5_to_funcline[md5] = name 85 | funcline_to_data[name] = data 86 | if data['language'] == 'c': 87 | cfuncs_md5[md5] = name 88 | else: 89 | cppfuncs_md5[md5] = name 90 | 91 | # de-dup functions before you split! 92 | def func_dedup( workdir, decomp, start: int, batch_size: int, input): 93 | end = start + batch_size 94 | for hash_, funcline in islice(md5_to_funcline.items(), start, end): 95 | try: 96 | binary_name = funcline.split('_')[0] 97 | ret = write_jsonlines(os.path.join(workdir, "dedup_func", decomp, f'{binary_name}.jsonl'), funcline_to_data[funcline]) 98 | if not ret: 99 | l.error(f"error in dumping {hash_} | {funcline}") 100 | except Exception as e: 101 | l.error(f"Error in deduplicating functions{e}") 102 | 103 | def find_c_and_cpp_files(dedup_dir, decomp): 104 | 105 | files = os.listdir(os.path.join(dedup_dir, decomp)) 106 | random.seed(5) 107 | random.shuffle(files) 108 | 109 | c_name_and_lines = OrderedDict() 110 | cpp_name_and_lines = OrderedDict() 111 | # find c and cpp files 112 | c_funcs_per_binary, cpp_funcs_per_binary = [], [] 113 | for f in files: 114 | with open(os.path.join(dedup_dir, decomp, f), 'r') as r: 115 | lines = r.readlines() 116 | if lines: 117 | if json.loads(lines[0])['language'] == 'c': 118 | c_funcs_per_binary.append(len(lines)) 119 | c_name_and_lines[f] = len(lines) 120 | else: 121 | cpp_funcs_per_binary.append(len(lines)) 122 | cpp_name_and_lines[f] = len(lines) 123 | 124 | l.info(f"c files: {len(c_funcs_per_binary)} \tcpp files: {len(cpp_funcs_per_binary)}") 125 | return c_funcs_per_binary, cpp_funcs_per_binary, c_name_and_lines, cpp_name_and_lines 126 | 127 | # decide splits 128 | def decide_per_binary_split(dedup_dir, c_funcs_per_binary, cpp_funcs_per_binary, c_name_and_lines, cpp_name_and_lines, workdir, decomp): 129 | l.debug("Deciding binary splits...") 130 | def split_files(list_, name_and_lines): 131 | list_ = list(name_and_lines.items()) 132 | random.shuffle(list_) 133 | train_, test_ = 0, 0 134 | # sum_ = sum(list_) 135 | sum_ = sum([lines for _, lines in list_]) 136 | # if number of total functions are fewer, update the limits 137 | t1 = round(sum_ * 0.709, 2) 138 | t2 = round(sum_ * 0.805, 2) 139 | 140 | l.debug(f"total functions!: {sum_} \tlower limit: {t1} \tupper limit: {t2}") 141 | got = [] 142 | found = False 143 | for elem in range(0, len(list_)): 144 | # train_ += list_[elem] 145 | _, lines = list_[elem] 146 | train_ += lines 147 | got.append(list_[elem]) 148 | if t1 <= train_ <= t2: 149 | l.debug(f"functions in train split!: {train_}") 150 | l.debug(f"found the split! pick files until: {elem}") 151 | break 152 | l.debug(f"functions in test split: {sum_ - train_}") 153 | # train_files = dict(islice(name_and_lines.items(), 0, elem+1)) 154 | # test_files = dict(islice(name_and_lines.items(), elem+1, len(list_))) 155 | train_files = dict(list_[:elem+1]) 156 | test_files = dict(list_[elem+1:]) 157 | return train_files, test_files 158 | 159 | if len(c_funcs_per_binary): 160 | c_train_files, c_test_files = split_files(c_funcs_per_binary, c_name_and_lines) 161 | while True: 162 | if len(c_test_files) == 0: 163 | l.debug("retrying for c...") 164 | c_train_files, c_test_files = split_files(c_funcs_per_binary, c_name_and_lines) 165 | else: 166 | break 167 | if len(cpp_funcs_per_binary): 168 | cpp_train_files, cpp_test_files = split_files(cpp_funcs_per_binary, cpp_name_and_lines) 169 | while True: 170 | if len(cpp_test_files) == 0: 171 | l.debug("retrying for cpp...") 172 | cpp_train_files, cpp_test_files = split_files(cpp_funcs_per_binary, cpp_name_and_lines) 173 | else: 174 | break 175 | 176 | if len(cpp_funcs_per_binary) and len(c_funcs_per_binary): 177 | train_files = {**c_train_files, **cpp_train_files} 178 | test_files = {**c_test_files, **cpp_test_files} 179 | elif len(cpp_funcs_per_binary): 180 | train_files = cpp_train_files 181 | test_files = cpp_test_files 182 | elif len(c_funcs_per_binary): 183 | train_files = c_train_files 184 | test_files = c_test_files 185 | 186 | with open(os.path.join(workdir, "splits", decomp, "train_files.json"), 'w') as w: 187 | w.write(json.dumps(train_files)) 188 | 189 | with open(os.path.join(workdir, "splits", decomp, "test_files.json"), 'w') as w: 190 | w.write(json.dumps(test_files)) 191 | return list(train_files.keys()), list(test_files.keys()) 192 | 193 | # decide splits 194 | def decide_per_func_split(workdir, decomp): 195 | l.debug("Deciding function splits...") 196 | def split_funcs(all_funcs): 197 | 198 | shuffled_funcnames = list(all_funcs.values()) 199 | random.seed(5) 200 | random.shuffle(shuffled_funcnames) 201 | 202 | total = len(all_funcs) 203 | train_ = int(total * 0.80) 204 | train_funcs = shuffled_funcnames[:train_] 205 | test_funcs = shuffled_funcnames[train_:] 206 | return train_funcs, test_funcs 207 | 208 | l.info(f"Total C funcs {len(cfuncs_md5)}") 209 | c_train_funcs, c_test_funcs = split_funcs(cfuncs_md5) 210 | l.info(f"Total CPP funcs {len(cppfuncs_md5)}") 211 | cpp_train_funcs, cpp_test_funcs = split_funcs(cppfuncs_md5) 212 | 213 | # total funcs now 214 | c_train_funcs.extend(cpp_train_funcs) 215 | c_test_funcs.extend(cpp_test_funcs) 216 | 217 | with open(os.path.join(workdir, "splits", decomp, "train_funcs.txt"), 'w') as w: 218 | w.write("\n".join(c_train_funcs)) 219 | 220 | with open(os.path.join(workdir, "splits", decomp, "test_funcs.txt"), 'w') as w: 221 | w.write("\n".join(c_test_funcs)) 222 | 223 | return c_train_funcs, c_test_funcs 224 | 225 | # create binary splits 226 | def create_per_binary_split(files, filename, workdir, decomp): 227 | 228 | with open(os.path.join(workdir, "splits", decomp,'per-binary', f'{filename}.jsonl'), 'w') as w: 229 | for file_ in files: 230 | with open(os.path.join(workdir, "dedup_func", decomp, file_), 'r') as r: 231 | lines = r.readlines() 232 | for line in lines: 233 | w.write(json.dumps(json.loads(line)) + '\n') 234 | 235 | # create func splits 236 | def create_per_func_split(funcs, filename, workdir, decomp): 237 | 238 | with open(os.path.join(workdir, "splits", decomp,'per-func', f'{filename}.jsonl'), 'w') as w: 239 | for func in funcs: 240 | if func not in funcline_to_data: 241 | print(func) 242 | w.write( json.dumps(funcline_to_data[func])+ "\n") 243 | 244 | def get_variables(decomp, func_): 245 | 246 | global dwarf_vars 247 | global decompiler_vars 248 | 249 | var_regex = r"@@(var_\d+)@@(\w+)@@" 250 | data = json.loads(func_) 251 | vars_ = data['type_stripped_vars'] 252 | for var, ty in vars_.items(): 253 | if ty == "dwarf": 254 | dwarf_vars.append(var) 255 | elif ty == decomp: 256 | decompiler_vars.append(var) 257 | else: 258 | l.error(f"Error in getting variables: {var} {ty}") 259 | 260 | def substitute_norm(lookup, func): 261 | 262 | global new_funcs 263 | try: 264 | data = json.loads(func) 265 | func = data['func'] 266 | tyvars_= data['type_stripped_vars'] 267 | # if the variable is removed in cleaning invalid vars 268 | if not tyvars_: 269 | return 270 | vars_ = [] 271 | for v in tyvars_: 272 | if tyvars_[v] == 'dwarf': 273 | vars_.append(v) 274 | # vars_map can be empty if there are no dwarf variable names in the func 275 | vars_map = [] 276 | norm_func = func 277 | for var in vars_: 278 | og_var = rf"(@@var_\d+@@)({var})(@@)" 279 | if var in lookup: 280 | if lookup[var] == '': 281 | continue 282 | else: 283 | norm_func = re.sub(og_var, rf'\1{lookup[var]}\3', norm_func) 284 | vars_map.append([var, lookup[var]]) 285 | else: 286 | l.error(f"something fishy! {var} {data['id']}") 287 | 288 | up_data = data 289 | up_data['norm_func'] = norm_func 290 | up_data['vars_map'] = vars_map 291 | up_data['fid'] = str(data['id']) + "-" + data['func_name'].replace("_","-") 292 | 293 | # for preprocessing for vocab (vars_map and type_stripped_norm_vars same thing) 294 | norm_var_type = {} 295 | for pair in vars_map: 296 | norm_var = pair[1] 297 | var = pair[0] 298 | if norm_var in norm_var_type and up_data["type_stripped_vars"][var] != 'dwarf': 299 | norm_var_type[norm_var] = 'dwarf' 300 | else: 301 | norm_var_type[norm_var] = up_data["type_stripped_vars"][var] 302 | 303 | up_data['type_stripped_norm_vars'] = norm_var_type 304 | new_funcs.append(up_data) 305 | except Exception as e: 306 | l.error(f'Error in updating function in substitute_norm: {e}') 307 | 308 | def write_norm(filename, funcs): 309 | with open(filename, 'w') as w: 310 | for func in funcs: 311 | w.write(json.dumps(func) + "\n") 312 | 313 | def create_train_and_test_sets(workdir, decomp, WORKERS=4): 314 | 315 | total_steps = 4 316 | progress = tqdm(total=total_steps, desc="Initializing") 317 | 318 | # create dirs 319 | dirs = [f"dedup_func/{decomp}", "lookup", f"splits/{decomp}/per-func", f"splits/{decomp}/per-binary", f"vars/{decomp}/per-func", f"vars/{decomp}/per-binary" ] 320 | for d in dirs: 321 | os.makedirs(os.path.join(workdir, d), exist_ok=True) 322 | files = os.listdir(os.path.join(workdir, 'map', decomp)) 323 | 324 | progress.set_description("Deduplicating functions") 325 | # create hash of funcs for de-dup 326 | with concurrent.futures.ProcessPoolExecutor(max_workers=WORKERS) as executor: 327 | # executor.map(hash_funcs, files) 328 | partial_func = partial(hash_funcs, workdir, decomp) 329 | executor.map(partial_func, files) 330 | 331 | l.info(f"functions before de-dup: {len(funcline_to_data)} \tfunctions after de-dup: {len(md5_to_funcline)}") 332 | 333 | batch = len(md5_to_funcline) // (WORKERS - 1) 334 | if os.listdir(os.path.join(workdir, "dedup_func", decomp)): 335 | print("files in dedup dir!, remove them") 336 | l.error("files in dedup dir!, remove them") 337 | exit(1) 338 | 339 | l.debug("now saving dedup code!") 340 | with concurrent.futures.ProcessPoolExecutor(max_workers=WORKERS) as executor: 341 | cal_partial = partial(func_dedup, workdir, decomp, batch_size=batch, input=input,) 342 | executor.map(cal_partial, [batch * i for i in range(WORKERS)]) 343 | 344 | l.info(f'Binaries in dedup func: {len(os.listdir(os.path.join(workdir, "dedup_func", decomp)))}') 345 | progress.update(1) 346 | # find c and cpp files 347 | c_funcs_per_binary, cpp_funcs_per_binary, c_name_and_lines, cpp_name_and_lines = find_c_and_cpp_files(os.path.join(workdir, "dedup_func"), decomp) 348 | 349 | # # per-binary splits 350 | progress.set_description("Creating Binary Split") 351 | train_b, test_b = decide_per_binary_split(os.path.join(workdir, "dedup_func"), c_funcs_per_binary, cpp_funcs_per_binary, c_name_and_lines, cpp_name_and_lines, workdir, decomp) 352 | l.debug(f"per binary numbers: Train: {len(train_b)} \tTest: {len(test_b)}") 353 | create_per_binary_split(train_b, "train", workdir, decomp) 354 | create_per_binary_split(test_b, "test", workdir, decomp) 355 | train_b.extend(test_b) 356 | all_b = train_b 357 | l.debug(f"total samples in binary split: {len(all_b)}") 358 | progress.update(1) 359 | 360 | # per-func splits 361 | progress.set_description("Creating Function Split") 362 | train_fn, test_fn = decide_per_func_split(workdir, decomp) 363 | l.debug(f"per func numbers: Train: {len(train_fn)} \tTest: {len(test_fn)}") 364 | create_per_func_split(train_fn, "train", workdir, decomp) 365 | create_per_func_split(test_fn, "test", workdir, decomp) 366 | train_fn.extend(test_fn) 367 | all_fn = train_fn 368 | l.debug(f"total samples in func split: {len(all_fn)}") 369 | progress.update(1) 370 | 371 | progress.set_description("Updating variables and saving splits") 372 | clean_up_variables(WORKERS, workdir, decomp) 373 | progress.update(1) 374 | progress.set_description("Train and Test splits created!") 375 | progress.close() 376 | 377 | # TODO: Fix later 378 | def clean_up_variables(WORKERS, workdir, decomp): 379 | # replace variables 380 | import glob 381 | files = glob.glob(f'{workdir}/splits/{decomp}/**/*', recursive=True) 382 | for f in files: 383 | new_funcs[:] = [] 384 | if os.path.isfile(f) and f.endswith('jsonl') and 'all' not in f and 'final' not in f: 385 | tmp = f.split('/') 386 | ty = tmp[-2] 387 | file_ = tmp[-1] 388 | with open(f, 'r') as r: 389 | funclines = r.readlines() 390 | with concurrent.futures.ProcessPoolExecutor(max_workers=WORKERS) as executor: 391 | partial_func = partial(get_variables, decomp) 392 | executor.map(partial_func, funclines) 393 | 394 | # create variable lists 395 | var_out_file = os.path.join(workdir, "vars", decomp, f"{ty}", f"{file_[:-6]}_dwarf_vars") 396 | write_list(var_out_file, list(dwarf_vars), True, False) 397 | 398 | write_list(os.path.join(workdir, "vars", decomp, f"{ty}", f"{file_[:-6]}_decomp_vars"), list(decompiler_vars), True, False) 399 | 400 | clean_var_out = os.path.join(workdir, "vars", decomp, f"{ty}") 401 | # subprocess.call(['python3.8', 'preprocess_vars.py', f'{var_out_file}_distribution.json', file_[:-6], clean_var_out, os.path.join(workdir, 'lookup')]) 402 | normalize_variables(f'{var_out_file}_distribution.json', file_[:-6], clean_var_out, os.path.join(workdir, 'lookup'), WORKERS) 403 | import time; time.sleep(10) 404 | with open(os.path.join(workdir, 'lookup', 'universal_lookup.json'), 'r') as r: 405 | lookup = json.loads(r.read()) 406 | 407 | with concurrent.futures.ProcessPoolExecutor(max_workers=WORKERS) as executor: 408 | partial_func = partial(substitute_norm, lookup) 409 | executor.map(partial_func, funclines) 410 | # how so many funcs?? # FIXME 411 | write_norm(os.path.join(workdir, "splits", decomp, ty, f"final_{file_[:-6]}.jsonl"), new_funcs) 412 | l.info(f"Functions in file {f}: {str(len(new_funcs))}") 413 | # create_train_and_test_sets('/tmp/varbert_tmpdir', 'ghidra') 414 | -------------------------------------------------------------------------------- /varcorpus/dataset-gen/variable_matching.py: -------------------------------------------------------------------------------- 1 | from binary import Binary 2 | from utils import read_text, read_json, write_json 3 | import os 4 | import string 5 | import re 6 | import json 7 | from collections import defaultdict, Counter, OrderedDict 8 | import sys 9 | import hashlib 10 | import logging 11 | from parse_decompiled_code import IDAParser, GhidraParser 12 | from typing import Any, List, Dict 13 | 14 | l = logging.getLogger('main') 15 | 16 | class DataLoader: 17 | def __init__(self, binary_name, dwarf_info_path, decompiler, decompiled_code_strip_path, 18 | decompiled_code_type_strip_path, joern_strip_data, joern_type_strip_data, data_map_dump, language, decompiled_code_type_strip_ln_fn_names) -> None: 19 | self.binary_name = binary_name 20 | self.decompiler = decompiler 21 | self.decompiled_code_strip_path = decompiled_code_strip_path 22 | self.decompiled_code_type_strip_path = decompiled_code_type_strip_path 23 | self.joern_strip_data = joern_strip_data 24 | self.joern_type_strip_data = joern_type_strip_data 25 | self.dwarf_vars, self.dwarf_funcs, self.linkage_name_to_func_name, self.decompiled_code_type_strip_ln_fn_names = self._update_values((dwarf_info_path), decompiled_code_type_strip_ln_fn_names) 26 | self.data_map_dump = data_map_dump 27 | self.language = language 28 | self.sample = {} 29 | self.run() 30 | 31 | def _update_values(self, path, decompiled_code_type_strip_ln_fn_names): 32 | if isinstance(path, str): 33 | tmp = read_json(path) 34 | tmp = path 35 | linkage_name_to_func_name = {} 36 | if 'vars_per_func' in tmp.keys(): 37 | dwarf_vars = tmp['vars_per_func'] 38 | dwarf_funcs = dwarf_vars.keys() 39 | if 'linkage_name_to_func_name' in tmp.keys(): 40 | if tmp['linkage_name_to_func_name']: 41 | linkage_name_to_func_name = tmp['linkage_name_to_func_name'] 42 | else: 43 | for funcname, _ in tmp['vars_per_func'].items(): 44 | linkage_name_to_func_name[funcname] = funcname 45 | # TODO: if setting corpus language to c and cpp then enable this 46 | # else: 47 | # # C does not have linkage names, also at the time of building data set, we did not have cpp implementation in, so hack it 48 | # for funcname, _ in tmp['vars_per_func'].items(): 49 | # linkage_name_to_func_name[funcname] = funcname 50 | # print("linkage", linkage_name_to_func_name) 51 | 52 | # cpp 53 | if os.path.exists(decompiled_code_type_strip_ln_fn_names): 54 | decompiled_code_type_strip_ln_fn_names = read_json(decompiled_code_type_strip_ln_fn_names) 55 | else: 56 | # c 57 | decompiled_code_type_strip_ln_fn_names = linkage_name_to_func_name 58 | return dwarf_vars, dwarf_funcs, linkage_name_to_func_name, decompiled_code_type_strip_ln_fn_names 59 | 60 | def run(self): 61 | if self.decompiler.lower() == 'ida': 62 | parser_strip = IDAParser(read_text(self.decompiled_code_strip_path), self.binary_name, None) 63 | parser_type_strip = IDAParser(read_text(self.decompiled_code_type_strip_path), self.binary_name, self.decompiled_code_type_strip_ln_fn_names) 64 | elif self.decompiler.lower() == 'ghidra': 65 | parser_strip = GhidraParser(read_text(self.decompiled_code_strip_path), self.binary_name) 66 | parser_type_strip = GhidraParser(read_text(self.decompiled_code_type_strip_path), self.binary_name) 67 | 68 | # load joern data 69 | joern_data_strip = JoernDataLoader(self.joern_strip_data, parser_strip.func_name_to_addr, None) 70 | joern_data_type_strip = JoernDataLoader(self.joern_type_strip_data, parser_type_strip.func_name_to_addr, parser_type_strip.func_name_to_linkage_name) 71 | # filter out functions that are not common 72 | common_func_addr_decompiled_code = set(parser_strip.func_addr_to_name.keys()).intersection(parser_type_strip.func_addr_to_name.keys()) 73 | common_func_addr_from_joern = set(joern_data_strip.joern_addr_to_name.keys()).intersection(joern_data_type_strip.joern_addr_to_name.keys()) 74 | 75 | # joern names to addr 76 | # remove comments at the last because Joern's numbers are with comments 77 | for addr in common_func_addr_from_joern: 78 | try: 79 | if addr not in common_func_addr_decompiled_code: 80 | continue 81 | # remove decompiler/compiler generated functions - we get dwarf names from type-strip 82 | if addr not in parser_type_strip.func_addr_to_name: 83 | continue 84 | l.debug(f'Trying for func addr {self.binary_name} :: {self.decompiler} {addr}') 85 | 86 | func_name_type_strip = parser_type_strip.func_addr_to_name[addr] 87 | func_name_strip = parser_strip.func_addr_to_name[addr] 88 | 89 | # if function is compiler-generated (artifical) we do not have collect DWARF info fot that function (we need only source functions) 90 | # For Ghidra: we use addr as funcname in dwarf info that is collected from binary 91 | dwarf_addr_compatible_addr = None 92 | if self.decompiler.lower() == 'ghidra': 93 | dwarf_addr_compatible_addr = str(int(addr[1:-1], 16) - 0x00100000) 94 | if dwarf_addr_compatible_addr not in self.dwarf_funcs: 95 | l.debug(f'Func not in DWARF / Source! {self.binary_name} :: {self.decompiler} {dwarf_addr_compatible_addr}') 96 | continue 97 | 98 | # For IDA func name collected from binary is w/o line number 99 | elif self.decompiler.lower() == 'ida': 100 | # funcnames from cpp have '::' in type-strip dc and not in parsed dwarf data - get linkage names so we are good 101 | funcname_from_type_strip = parser_type_strip.functions[addr]['func_name_no_line'] 102 | 103 | if funcname_from_type_strip not in self.dwarf_funcs and funcname_from_type_strip.strip().split('::')[-1] not in self.dwarf_funcs: 104 | l.debug(f'Func not in DWARF / Source! {self.binary_name} :: {self.decompiler} {funcname_from_type_strip}') 105 | continue 106 | 107 | dc_functions_strip = parser_strip.functions[addr] 108 | dc_functions_type_strip = parser_type_strip.functions[addr] 109 | 110 | dwarf_func_name = '_'.join(func_name_type_strip.split('_')[:-1]) 111 | cfunc_strip = CFunc(self, dc_functions_strip, joern_data_strip, "strip", func_name_strip, dwarf_func_name, addr, None) 112 | cfunc_type_strip = CFunc(self, dc_functions_type_strip, joern_data_type_strip, "type-strip", func_name_type_strip, dwarf_func_name, addr, self.linkage_name_to_func_name) 113 | 114 | # unequal number of joern vars! do not proceed 115 | if len(cfunc_strip.joern_vars) != len(cfunc_type_strip.joern_vars): 116 | l.debug(f'Return! unequal number of joern vars! {self.binary_name} :: {self.decompiler} {addr}') 117 | continue 118 | 119 | # unequal number of local vars! 120 | if len(cfunc_strip.local_vars_dc) == len(cfunc_type_strip.local_vars_dc): 121 | # identify missing var and update lines to current lines (I n don't need for type_strip? but doing it for uniformity) 122 | l.debug(f'Identify missing vars! {self.binary_name} :: {self.decompiler} :: {addr}') 123 | cfunc_strip.joern_vars = self.identify_missing_vars(cfunc_strip) 124 | cfunc_type_strip.joern_vars = self.identify_missing_vars(cfunc_type_strip) 125 | 126 | l.debug(f'Removing invalid variable names! {self.binary_name} :: {self.decompiler} :: {addr}') 127 | 128 | # remove invalid variable names 129 | valid_strip_2_type_strip_varnames, cfunc_strip.joern_vars, cfunc_type_strip.joern_vars = self.remove_invalid_variable_names(cfunc_strip, cfunc_type_strip, 130 | parser_strip, parser_type_strip) 131 | 132 | l.debug(f'Start matching variable names! {self.binary_name} :: {self.decompiler} :: {addr}') 133 | matchvar = MatchVariables(cfunc_strip, cfunc_type_strip, valid_strip_2_type_strip_varnames, self.dwarf_vars, self.decompiler, cfunc_type_strip.dwarf_func_name, self.binary_name, self.language, dwarf_addr_compatible_addr) 134 | 135 | sample_name, sample_data = matchvar.dump_sample() 136 | if sample_data: 137 | l.debug(f'Dump sample! {self.binary_name} :: {self.decompiler} :: {addr}') 138 | self.sample[sample_name] = sample_data 139 | 140 | except Exception as e: 141 | l.error(f"ERRR {self.binary_name} :: {e} {addr}") 142 | l.info(f'Variable Matching complete for {self.binary_name} :: {self.decompiler} :: samples :: {len(self.sample)}') 143 | write_json(self.data_map_dump, dict(self.sample)) 144 | 145 | def identify_missing_vars(self, cfunc): 146 | # 1. update file lines to func lines 147 | # 2. joern may miss some occurences of variables without declaration (i.e. variables from .data, .bss section. find them and update corresponding lines) 148 | joern_vars = cfunc.joern_vars 149 | tmp_dict = {} 150 | for var, var_lines in joern_vars.items(): 151 | updated_lines = [] 152 | updated_lines[:] = [int(number) - int(cfunc.joern_func_start) +1 for number in var_lines] 153 | var = var.strip('~') 154 | find_all = rf'([^\d\w_-])({var})([^\d\w_-])' 155 | for i, line in enumerate(cfunc.func_lines, 1): 156 | matches = re.search(find_all, line) 157 | if matches: 158 | if i not in updated_lines and not line.startswith('//'): 159 | updated_lines.append(i) 160 | 161 | updated_lines = list(set(updated_lines)) 162 | 163 | tmp_dict[var] = updated_lines 164 | return tmp_dict 165 | 166 | 167 | def remove_invalid_variable_names(self, cfunc_strip, cfunc_type_strip, parser_strip, parser_type_strip): 168 | 169 | # https://www.hex-rays.com/products/ida/support/idadoc/1361.shtml 170 | if self.decompiler.lower() == 'ida': 171 | data_types = ['_BOOL1', '_BOOL2', '_BOOL4', '__int8', '__int16', '__int32', '__int64', '__int128', '_BYTE', '_WORD', '_DWORD', '_QWORD', '_OWORD', '_TBYTE', '_UNKNOWN', '__pure', '__noreturn', '__usercall', '__userpurge', '__spoils', '__hidden', '__return_ptr', '__struct_ptr', '__array_ptr', '__unused', '__cppobj', '__ptr32', '__ptr64', '__shifted', '__high'] 172 | 173 | # anymore undefined? 174 | if self.decompiler.lower() == 'ghidra': 175 | data_types = ['ulong', 'uint', 'ushort', 'ulonglong', 'bool', 'char', 'int', 'long', 'undefined', 'undefined1', 'undefined2', 'undefined4', 'undefined8', 'byte', 'FILE', 'size_t'] 176 | 177 | invalid_vars = ['null', 'true', 'false', 'True', 'False', 'NULL', 'char', 'int'] 178 | # also func names which may have been identified as var names 179 | 180 | zip_joern_vars = dict(zip(cfunc_strip.joern_vars, cfunc_type_strip.joern_vars)) 181 | # strip: type-strip 182 | valid_strip_2_type_strip_varnames = {} 183 | valid_strip_vars = {} 184 | valid_type_strip_vars = {} 185 | 186 | for strip_var, type_strip_var in zip_joern_vars.items(): 187 | try: 188 | strip_var_woc, type_strip_var_woc = strip_var, type_strip_var 189 | strip_var = strip_var.strip('~') 190 | type_strip_var = type_strip_var.strip('~') 191 | 192 | # remove func name 193 | if strip_var in parser_strip.func_name_wo_line.keys() or type_strip_var in parser_type_strip.func_name_wo_line.keys(): 194 | continue 195 | # remove func name with :: 196 | if strip_var.strip().split('::')[-1] in parser_strip.func_name_wo_line.keys() or type_strip_var.strip().split('::')[-1] in parser_type_strip.func_name_wo_line.keys(): 197 | continue 198 | # (rn: sub_2540_475) 199 | if strip_var in parser_strip.func_name_to_addr.keys() or type_strip_var in parser_type_strip.func_name_to_addr.keys(): 200 | continue 201 | # remove types 202 | if strip_var in data_types or type_strip_var in data_types: 203 | continue 204 | # remove invalid name 205 | if strip_var in invalid_vars or type_strip_var in invalid_vars: 206 | continue 207 | 208 | if '::' in strip_var: 209 | strip_var_woc = strip_var.strip().split('::')[-1] 210 | if '::' in type_strip_var: 211 | type_strip_var_woc = type_strip_var.strip().split('::')[-1] 212 | 213 | # create valid mapping 214 | valid_strip_2_type_strip_varnames[strip_var_woc] = type_strip_var_woc 215 | 216 | # update joern vars 217 | if strip_var in cfunc_strip.joern_vars: 218 | valid_strip_vars[strip_var_woc] = cfunc_strip.joern_vars[strip_var] 219 | 220 | elif f'~{strip_var}' in cfunc_strip.joern_vars: 221 | # if strip_var_woc: 222 | valid_strip_vars[strip_var_woc] = cfunc_strip.joern_vars[f'~{strip_var}'] 223 | 224 | if type_strip_var in cfunc_type_strip.joern_vars: 225 | # if type_strip_var_woc: 226 | valid_type_strip_vars[type_strip_var_woc] = cfunc_type_strip.joern_vars[type_strip_var] 227 | # else: 228 | # valid_type_strip_vars[type_strip_var] = cfunc_type_strip.joern_vars[type_strip_var] 229 | 230 | if f'~{type_strip_var}' in cfunc_type_strip.joern_vars: 231 | # if type_strip_var_woc: 232 | valid_type_strip_vars[type_strip_var_woc] = cfunc_type_strip.joern_vars[f'~{type_strip_var}'] 233 | # else: 234 | # valid_type_strip_vars[type_strip_var] = cfunc_type_strip.joern_vars[f'~{type_strip_var}'] 235 | except Exception as e: 236 | l.error(f'Error in removing invalid variable names! {self.binary_name} :: {self.decompiler} :: {e}') 237 | return valid_strip_2_type_strip_varnames, valid_strip_vars, valid_type_strip_vars 238 | 239 | 240 | class JoernDataLoader: 241 | def __init__(self, joern_data, decompiled_code_name_to_addr, func_name_to_linkage_name) -> None: 242 | self.functions = {} 243 | self.joern_data = joern_data 244 | self.decompiled_code_name_to_addr = decompiled_code_name_to_addr 245 | self.joern_name_to_addr = {} 246 | self.joern_addr_to_name = {} 247 | self.joern_start_line_to_funcname = {} 248 | self._load(func_name_to_linkage_name) 249 | 250 | def _load(self, func_name_to_linkage_name): 251 | # key: func name, val: {var: lines} 252 | func_line_num_to_addr = {} 253 | for k, v in self.decompiled_code_name_to_addr.items(): 254 | line_num = str(k.split('_')[-1]) 255 | func_line_num_to_addr[line_num] = v 256 | try: 257 | counter = 0 258 | for func_name, v in self.joern_data.items(): 259 | try: 260 | # check if func name in decompiled code funcs and get addr mapping 261 | # type-strip - replace demangled name with mangled name (cpp) or simply a func name (c) 262 | # some func names from joern are different from parsed decompiled code (IDA: YaSkkServ::`anonymous namespace'::signal_dictionary_update_handler | Joern: signal_dictionary_update_handler ) 263 | if func_name_to_linkage_name: 264 | tmp_wo_line = func_name.split('_') 265 | if len(tmp_wo_line) <=1 : 266 | continue 267 | name_wo_line, line_num = '_'.join(tmp_wo_line[:-1]), tmp_wo_line[-1] 268 | 269 | if name_wo_line in func_name_to_linkage_name: 270 | func_name = func_name_to_linkage_name[name_wo_line] 271 | func_name = f'{func_name}_{line_num}' 272 | 273 | if func_name in self.decompiled_code_name_to_addr: 274 | addr = self.decompiled_code_name_to_addr[func_name] 275 | elif line_num in func_line_num_to_addr: 276 | addr = func_line_num_to_addr[line_num] 277 | else: 278 | l.warning("joern and IDA func name did not match") 279 | 280 | self.joern_addr_to_name[addr] = func_name 281 | self.joern_name_to_addr[func_name] = addr 282 | self.joern_start_line_to_funcname[v['func_start']] = func_name 283 | 284 | joern_vars = v['variable'] 285 | start = v['func_start'] 286 | end = v['func_end'] 287 | if start and end: 288 | self.functions[func_name] = {'func_name': func_name, 289 | 'variables': joern_vars, 290 | # 'var_lines': joern_var_lines, 291 | 'start': start, 292 | 'end': end} 293 | counter += 1 294 | except Exception as e: 295 | l.error(f'Error in loading joern data! {e}') 296 | except Exception as e: 297 | l.error(f' error in getting joern vars! {e}') 298 | 299 | 300 | class CFunc: 301 | 302 | def __init__(self, dcl, dc_func, joern_data, binary_type, func_name, dwarf_func_name, func_addr, linkage_name_to_func_name ) -> None: 303 | 304 | self.local_vars_dc = dc_func['local_vars'] 305 | 306 | self.func = dc_func['func'] 307 | self.func_addr = func_addr 308 | self.func_prototype = None 309 | self.func_body = None 310 | self.func_lines = None 311 | 312 | self.func_name = func_name 313 | self.func_name_no_line = '_'.join(func_name.split('_')[:-1]) 314 | self.line = func_name.split('_')[-1] 315 | self.binary_name = dcl.binary_name 316 | self.binary_type = binary_type 317 | 318 | self.decompiler = dcl.decompiler 319 | self.dwarf_func_name = dwarf_func_name 320 | self.dwarf_mangled_func_name = linkage_name_to_func_name 321 | 322 | # start 323 | if self.func_name in joern_data.functions: 324 | self.joern_func_start = joern_data.functions[self.func_name]['start'] 325 | elif self.line in joern_data.joern_start_line_to_funcname: 326 | tmp_func_name = joern_data.joern_start_line_to_funcname[self.line] 327 | self.joern_func_start = joern_data.functions[tmp_func_name]['start'] 328 | 329 | # variables 330 | if self.func_name in joern_data.functions: 331 | self.all_vars_joern = joern_data.functions[self.func_name]['variables'] 332 | elif self.line in joern_data.joern_start_line_to_funcname: 333 | tmp_func_name = joern_data.joern_start_line_to_funcname[self.line] 334 | self.all_vars_joern = joern_data.functions[tmp_func_name]['variables'] 335 | 336 | self.set_func_details() 337 | 338 | def __repr__(self) -> str: 339 | pass 340 | 341 | @property 342 | def joern_vars(self): 343 | return self.all_vars_joern 344 | 345 | @joern_vars.setter 346 | def joern_vars(self, new_value): 347 | self.all_vars_joern = new_value 348 | 349 | def set_func_details(self) -> None: 350 | if self.func: 351 | self.func_lines = self.func.split('\n') 352 | self.func_prototype = self.func.split('{')[0] 353 | self.func_body = '{'.join(self.func.split('{')[1:]) 354 | 355 | 356 | class MatchVariables: 357 | 358 | def __init__(self, strip: CFunc, type_strip: CFunc, valid_strip_2_type_strip_varnames, dwarf_vars, decompiler, dwarf_func_name, binary_name, language, dwarf_addr_compatible_addr=None) -> None: 359 | self.cfunc_strip = strip 360 | self.cfunc_type_strip = type_strip 361 | self.mapped_vars = valid_strip_2_type_strip_varnames 362 | self.dwarf_vars = dwarf_vars 363 | self.labelled_vars = {} 364 | self.modified_func = None 365 | self.decompiler = decompiler 366 | self.dwarf_func_name = dwarf_func_name 367 | self.md5_hash = None 368 | self.binary_name = binary_name 369 | self.language = language 370 | self.dwarf_addr_compatible_addr = dwarf_addr_compatible_addr 371 | self.match() 372 | 373 | def match(self): 374 | if len(self.cfunc_strip.joern_vars) != len(self.cfunc_type_strip.joern_vars): 375 | return 376 | if len(self.mapped_vars) == 0: 377 | return 378 | 379 | self.modified_func = self.update_func() 380 | # label DWARF and decompiler-gen 381 | self.labelled_vars = self.label_vars() 382 | self.md5_hash = self.func_hash() 383 | 384 | def update_func(self): 385 | def rm_comments(func): 386 | cm_regex = r'// .*' 387 | cm_func = re.sub(cm_regex, ' ', func).strip() 388 | return cm_func 389 | 390 | # pre-process variables and replace them with "@@dwarf_var_name@@var_id@@" 391 | varname2token = {} 392 | for i, varname in enumerate(self.mapped_vars, 0): 393 | varname2token[varname] = f"@@var_{i}@@{self.mapped_vars[varname]}@@" 394 | new_func = self.cfunc_strip.func 395 | 396 | # if no line numbers available 397 | allowed_prefixes = [" ", "&", "(", "*", "++", "--", "!"] 398 | allowed_suffixes = [" ", ")", ",", ";", "[", "++", "--"] 399 | for varname, newname in varname2token.items(): 400 | for p in allowed_prefixes: 401 | for s in allowed_suffixes: 402 | new_func = new_func.replace(f"{p}{varname}{s}", f"{p}{newname}{s}") 403 | 404 | # no var is labelled as stderr. it is nether in global vars nor in vars for this function, so we do not add it 405 | if '@@' not in new_func: 406 | return None 407 | return rm_comments(new_func) 408 | 409 | 410 | def label_vars(self): 411 | labelled_vars = {} 412 | for strip_var, type_strip_var in self.mapped_vars.items(): 413 | if self.decompiler == 'ida': 414 | check_value = self.cfunc_type_strip.dwarf_func_name 415 | elif self.decompiler == 'ghidra': 416 | check_value = self.dwarf_addr_compatible_addr 417 | if type_strip_var in self.dwarf_vars[check_value] or type_strip_var in self.dwarf_vars['global_vars']: 418 | labelled_vars[type_strip_var] = 'dwarf' 419 | else: 420 | labelled_vars[type_strip_var] = self.decompiler 421 | return labelled_vars 422 | 423 | 424 | def func_hash(self) -> None: 425 | var_regex = r"@@(var_\d+)@@(\w+)@@" 426 | up_func = re.sub(var_regex, "\\2", self.modified_func) 427 | func_body = '{'.join(up_func.split('{')[1:]) 428 | md5 = hashlib.md5(func_body.encode('utf-8')).hexdigest() 429 | return md5 430 | 431 | 432 | def dump_sample(self): 433 | if self.decompiler == 'ida': 434 | func_name_dwarf = str(self.cfunc_type_strip.dwarf_mangled_func_name[str(self.dwarf_func_name)]) 435 | elif self.decompiler == 'ghidra': 436 | func_name_dwarf = str(self.cfunc_type_strip.dwarf_func_name) 437 | if self.modified_func: 438 | name = f'{self.binary_name}_{self.cfunc_strip.func_addr}' 439 | data = { 'func': self.modified_func, 440 | 'type_stripped_vars': dict(self.labelled_vars), 441 | 'stripped_vars': list(self.mapped_vars.keys()), 442 | 'mapped_vars': dict(self.mapped_vars), 443 | 'func_name_dwarf':func_name_dwarf, 444 | 'dwarf_mangled_func_name': str(self.dwarf_func_name), 445 | 'hash': self.md5_hash, 446 | 'language': self.language 447 | } 448 | 449 | return name, data 450 | else: 451 | return None, None 452 | --------------------------------------------------------------------------------