├── varbert
├── __init__.py
├── fine-tune
│ ├── __init__.py
│ └── preprocess.py
├── tokenizer
│ ├── preprocess.py
│ └── train_bpe_tokenizer.py
├── mlm
│ └── preprocess.py
├── generate_vocab.py
├── resize_model.py
├── cmlm
│ ├── preprocess.py
│ └── eval.py
└── README.md
├── varcorpus
├── __init__.py
├── dataset-gen
│ ├── __init__.py
│ ├── log.py
│ ├── pathmanager.py
│ ├── decompiler
│ │ ├── ida_unrecogn_func.py
│ │ ├── ghidra_dec.py
│ │ ├── ida_dec.py
│ │ ├── run_decompilers.py
│ │ └── ida_analysis.py
│ ├── preprocess_vars.py
│ ├── binary.py
│ ├── joern_parser.py
│ ├── generate.py
│ ├── utils.py
│ ├── strip_types.py
│ ├── dwarf_info.py
│ ├── parse_decompiled_code.py
│ ├── runner.py
│ ├── create_dataset_splits.py
│ └── variable_matching.py
└── README.md
├── requirements.txt
├── Dockerfile
├── .gitignore
└── README.md
/varbert/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/varcorpus/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/varbert/fine-tune/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cpgqls_client
2 | regex
3 | pyelftools==0.29
4 | cle
5 | torch
6 | transformers
7 | jsonlines
8 | pandas
9 | gdown
10 | jupyterlab
11 | tensorboardX
12 | datasets
13 | scikit-learn
14 | accelerate>=0.20.1
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/log.py:
--------------------------------------------------------------------------------
1 | import logging.config
2 | import os
3 |
4 | def setup_logging(LOG_DIRECTORY, DEBUG=False):
5 | if not os.path.exists(LOG_DIRECTORY):
6 | os.makedirs(LOG_DIRECTORY)
7 |
8 | log_level = 'DEBUG' if DEBUG else 'INFO'
9 | LOGGING_CONFIG = {
10 | 'version': 1,
11 | 'disable_existing_loggers': False,
12 |
13 | 'formatters': {
14 | 'default_formatter': {
15 | 'format': '%(asctime)s - %(levelname)s - %(name)s - %(filename)s - %(lineno)d : %(message)s'
16 | }
17 | },
18 |
19 | 'handlers': {
20 | 'main_handler': {
21 | 'class': 'logging.FileHandler',
22 | 'formatter': 'default_formatter',
23 | 'filename': os.path.join(LOG_DIRECTORY, 'varcorpus.log')
24 | }
25 | },
26 |
27 | 'loggers': {
28 | 'main': {
29 | 'handlers': ['main_handler'],
30 | 'level': log_level,
31 | 'propagate': False
32 | }
33 | }
34 | }
35 |
36 | logging.config.dictConfig(LOGGING_CONFIG)
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/pathmanager.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | class PathManager:
5 | def __init__(self, args):
6 | # Base paths set by user input
7 | self.binaries_dir = args.binaries_dir
8 | self.data_dir = os.path.abspath(args.data_dir)
9 | self.joern_dir = os.path.abspath(args.joern_dir)
10 | if not args.tmpdir:
11 | self.tmpdir = tempfile.mkdtemp(prefix='varbert_tmpdir_', dir='/tmp')
12 | else:
13 | self.tmpdir = args.tmpdir
14 | print("self.tmpdir", self.tmpdir)
15 | os.makedirs(self.tmpdir, exist_ok=True)
16 |
17 | self.ida_path = args.ida_path
18 | self.ghidra_path = args.ghidra_path
19 | self.corpus_language = args.corpus_language
20 |
21 | # Derived paths
22 | self.strip_bin_dir = os.path.join(self.tmpdir, 'binary/strip')
23 | self.type_strip_bin_dir = os.path.join(self.tmpdir, 'binary/type_strip')
24 | self.joern_data_path = os.path.join(self.tmpdir, 'joern')
25 | self.dc_path = os.path.join(self.tmpdir, 'dc')
26 | self.failed_path_ida = os.path.join(self.tmpdir, 'failed')
27 | self.dwarf_info_path = os.path.join(self.tmpdir, 'dwarf')
28 | self.match_path = os.path.join(self.tmpdir, 'map')
29 |
30 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:22.04
2 |
3 | RUN apt-get update -y
4 | # for tdata geo locn
5 | RUN ln -snf /usr/share/zoneinfo/$CONTAINER_TIMEZONE /etc/localtime && echo $CONTAINER_TIMEZONE > /etc/timezone
6 |
7 | # Install dependencies
8 | RUN apt-get -y install python3.11 python3-pip git binutils-multiarch wget
9 | # Ghidra
10 | RUN apt-get -y install openjdk-17-jdk openjdk-11-jdk
11 |
12 | # Joern
13 | RUN apt-get install -y openjdk-11-jre-headless psmisc && \
14 | apt-get clean;
15 | RUN pip3 install cpgqls_client regex
16 |
17 | WORKDIR /varbert_workdir
18 |
19 | # Dwarfwrite
20 | RUN git clone https://github.com/rhelmot/dwarfwrite.git
21 | WORKDIR /varbert_workdir/dwarfwrite
22 | RUN pip install .
23 |
24 | WORKDIR /varbert_workdir/
25 | RUN git clone https://github.com/sefcom/VarBERT.git
26 | WORKDIR /varbert_workdir/VarBERT/
27 | RUN pip install -r requirements.txt
28 |
29 | RUN apt-get install unzip
30 | RUN wget -cO /varbert_workdir/joern.tar.gz "https://www.dropbox.com/scl/fi/toh6087y5t5xyln47i5ih/modified_joern.tar.gz?rlkey=lfvjn1u7zvtp9a4cu8z8vgsof" && tar -xzf /varbert_workdir/joern.tar.gz -C /varbert_workdir && rm /varbert_workdir/joern.tar.gz
31 | RUN wget -cO /varbert_workdir/ghidra_10.4_PUBLIC_20230928.zip "https://github.com/NationalSecurityAgency/ghidra/releases/download/Ghidra_10.4_build/ghidra_10.4_PUBLIC_20230928.zip" && unzip /varbert_workdir/ghidra_10.4_PUBLIC_20230928.zip -d /varbert_workdir && rm /varbert_workdir/ghidra_10.4_PUBLIC_20230928.zip
32 |
33 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/decompiler/ida_unrecogn_func.py:
--------------------------------------------------------------------------------
1 | import idaapi
2 | import idc
3 | import ida_funcs
4 | import ida_hexrays
5 | import ida_kernwin
6 | import ida_loader
7 |
8 | def setup():
9 | global analysis
10 | path = ida_loader.get_path(ida_loader.PATH_TYPE_CMD)
11 |
12 | def read_list(filename):
13 | with open(filename, 'r') as r:
14 | ty_addrs = r.read().split('\n')
15 | return ty_addrs
16 |
17 | def add_unrecognized_func(ty_addrs):
18 | for addr in ty_addrs:
19 | # check if func is recognized by IDA
20 | name = ida_funcs.get_func_name(int(addr))
21 | if name:
22 | print(f"func present at: {addr} {name}")
23 | else:
24 | if ida_funcs.add_func(int(addr)):
25 | print(f"func recognized at: {addr} ")
26 | else:
27 | print(f"bad address {addr}")
28 |
29 | def go():
30 | setup()
31 | ea = 0
32 | filename = idc.ARGV[1]
33 | ty_addrs = read_list(filename)
34 | add_unrecognized_func(ty_addrs)
35 |
36 | while True:
37 | func = ida_funcs.get_next_func(ea)
38 | if func is None:
39 | break
40 | ea = func.start_ea
41 | seg = idc.get_segm_name(ea)
42 | if seg != ".text":
43 | continue
44 | print('analyzing', ida_funcs.get_func_name(ea), hex(ea), ea)
45 | analyze_func(func)
46 |
47 | def analyze_func(func):
48 | cfunc = ida_hexrays.decompile_func(func, None, 0)
49 | if cfunc is None:
50 | return
51 |
52 | idaapi.auto_wait()
53 | go()
54 |
--------------------------------------------------------------------------------
/varbert/tokenizer/preprocess.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import numpy as np
4 | import time
5 | import os
6 | import math
7 | import json
8 | import jsonlines
9 | from tqdm import tqdm
10 | from collections import defaultdict
11 | import re
12 | import pandas as pd
13 |
14 | def update_data(data, pattern):
15 | new_data = []
16 | for body in tqdm(data):
17 | idx = 0
18 | new_body = ""
19 | for each_var in list(re.finditer(pattern, body)):
20 | s = each_var.start()
21 | e = each_var.end()
22 | prefix = body[idx:s]
23 | var = body[s:e]
24 | orig_var = var.split("@@")[-2]
25 | new_body += prefix + orig_var
26 | idx = e
27 | new_body += body[idx:]
28 | new_data.append(new_body)
29 | return new_data
30 |
31 | def main(args):
32 | data = []
33 | with jsonlines.open(args.input_file) as f:
34 | for each in tqdm(f):
35 | data.append(each['norm_func'])
36 |
37 | new_data = update_data(data, "@@\w+@@\w+@@")
38 |
39 | with jsonlines.open(args.output_file, mode='w') as writer:
40 | for item in new_data:
41 | writer.write({'func': item})
42 |
43 | if __name__ == "__main__":
44 | parser = argparse.ArgumentParser(description='Process JSONL files.')
45 | parser.add_argument('--input_file', type=str, help='Path to the input HSC jsonl file')
46 | parser.add_argument('--output_file', type=str, help='Path to the save HSC jsonl file for tokenization')
47 | args = parser.parse_args()
48 | main(args)
49 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/decompiler/ghidra_dec.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import shutil
4 | import logging
5 | from ghidra.app.decompiler import DecompInterface, DecompileOptions
6 |
7 | def decompile_(binary_name):
8 |
9 | binary = getCurrentProgram()
10 | decomp_interface = DecompInterface()
11 | options = DecompileOptions()
12 | options.setNoCastPrint(True)
13 | decomp_interface.setOptions(options)
14 | decomp_interface.openProgram(binary)
15 |
16 | func_mg = binary.getFunctionManager()
17 | funcs = func_mg.getFunctions(True)
18 |
19 | # args
20 | args = getScriptArgs()
21 | out_path, failed_path, log_path = str(args[0]), str(args[1]), str(args[2])
22 |
23 | regex_com = r"(/\*[\s\S]*?\*\/)|(//.*)"
24 | tot_lines = 0
25 |
26 | try:
27 | with open(out_path, 'w') as w:
28 | for func in funcs:
29 | results = decomp_interface.decompileFunction(func, 0, None )
30 | addr = str(func.getEntryPoint())
31 | func_res = results.getDecompiledFunction()
32 | if 'EXTERNAL' in func.getName(True):
33 | continue
34 | if func_res:
35 | func_c = str(func_res.getC())
36 | # rm comments
37 | new_func = str(re.sub(regex_com, '', func_c, 0, re.MULTILINE)).strip()
38 |
39 | # for joern parsing
40 | before = '//----- ('
41 | after = ') ----------------------------------------------------\n'
42 | addr_line = before + addr + after
43 | tot_lines += 2
44 | w.write(addr_line)
45 | w.write(new_func)
46 | w.write('\n')
47 | tot_lines += new_func.count('\n')
48 |
49 | except Exception as e:
50 | # move log file to failed path
51 | shutil.move(os.path.join(log_path, (binary_name + '.log')), os.path.join(failed_path, ('ghidra_' + binary_name + '.log')))
52 |
53 | if __name__ == '__main__':
54 |
55 | bin_name = str(locals()['currentProgram']).split(' - .')[0].strip()
56 | decompile_(bin_name)
57 |
58 |
--------------------------------------------------------------------------------
/varbert/mlm/preprocess.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import numpy as np
4 | import time
5 | import os
6 | import math
7 | import json
8 | import jsonlines
9 | from tqdm import tqdm
10 | from collections import defaultdict
11 | import re
12 | import pandas as pd
13 |
14 | def update_function(data, pattern):
15 | new_data = []
16 | for body in tqdm(data):
17 | idx = 0
18 | new_body = ""
19 | for each_var in list(re.finditer(pattern, body)):
20 | s = each_var.start()
21 | e = each_var.end()
22 | prefix = body[idx:s]
23 | var = body[s:e]
24 | orig_var = var.split("@@")[-2]
25 | new_body += prefix + orig_var
26 | idx = e
27 | new_body += body[idx:]
28 | new_data.append(new_body)
29 | return new_data
30 |
31 | def process_file(input_file, output_file, pattern):
32 | data = []
33 | with jsonlines.open(input_file) as f:
34 | for each in tqdm(f):
35 | data.append(each['norm_func'])
36 |
37 | new_data = update_function(data, pattern)
38 |
39 | mlm_data = []
40 | for idx, each in tqdm(enumerate(new_data)):
41 | if len(each) > 0 and not each.isspace():
42 | mlm_data.append({'text': each, 'source': 'human', '_id': idx})
43 |
44 | with jsonlines.open(output_file, mode='w') as f:
45 | for each in tqdm(mlm_data):
46 | f.write(each)
47 |
48 | def main(args):
49 | process_file(args.train_file, args.output_train_file, "@@\w+@@\w+@@")
50 | process_file(args.test_file, args.output_test_file, "@@\w+@@\w+@@")
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser(description='Process and save jsonl files for MLM')
54 | parser.add_argument('--train_file', type=str, required=True, help='Path to the input train JSONL file')
55 | parser.add_argument('--test_file', type=str, required=True, help='Path to the input test JSONL file')
56 | parser.add_argument('--output_train_file', type=str, required=True, help='Path to the output train JSONL file for MLM')
57 | parser.add_argument('--output_test_file', type=str, required=True, help='Path to the output test JSONL file for MLM')
58 | args = parser.parse_args()
59 | main(args)
60 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/decompiler/ida_dec.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import shutil
4 | import logging
5 | import subprocess
6 | from pathlib import Path
7 |
8 | l = logging.getLogger('main')
9 |
10 | def decompile_(decompiler_workdir, decompiler_path, binary_name, binary_path, binary_type,
11 | decompiled_binary_code_path, failed_path , type_strip_addrs, type_strip_mangled_names):
12 | try:
13 | shutil.copy(binary_path, decompiler_workdir)
14 | current_script_dir = Path(__file__).resolve().parent
15 | addr_file = os.path.join(decompiler_workdir, f'{binary_name}_addr')
16 | if binary_type == 'strip':
17 | run_str = f'{decompiler_path} -S"{current_script_dir}/ida_unrecogn_func.py {addr_file}" -L{decompiler_workdir}/log_{binary_name}.txt -Ohexrays:{decompiler_workdir}/outfile:ALL -A {decompiler_workdir}/{binary_name}'
18 | elif binary_type == 'type_strip':
19 | run_str = f'{decompiler_path} -S"{current_script_dir}/ida_analysis.py {addr_file}" -L{decompiler_workdir}/log_{binary_name}.txt -Ohexrays:{decompiler_workdir}/outfile:ALL -A {decompiler_workdir}/{binary_name}'
20 |
21 | for _ in range(5):
22 | subprocess.run([run_str], shell=True)
23 | if f'outfile.c' in os.listdir(decompiler_workdir):
24 | l.debug(f"outfile.c generated! for {binary_name}")
25 | break
26 | time.sleep(5)
27 |
28 | except Exception as e:
29 | shutil.move(os.path.join(decompiler_workdir, f'log_{binary_name}.txt'), os.path.join(failed_path, ('ida_' + binary_name + '.log')))
30 | return None, str(os.path.join(failed_path, ('ida_' + binary_name + '.log')))
31 |
32 | finally:
33 | # cleanup!
34 | if os.path.exists(os.path.join(decompiler_workdir, f'{binary_name}.i64')):
35 | os.remove(os.path.join(decompiler_workdir, f'{binary_name}.i64'))
36 | if binary_type == 'strip':
37 | subprocess.run(['mv', addr_file, type_strip_addrs ])
38 | subprocess.run(['mv', f'{addr_file}_names', type_strip_mangled_names])
39 | if os.path.exists(f'{decompiler_workdir}/outfile.c'):
40 | if subprocess.run(['mv', f'{decompiler_workdir}/outfile.c', f'{decompiled_binary_code_path}']).returncode == 0:
41 | return True, f'{decompiled_binary_code_path}'
42 | else:
43 | return False, str(os.path.join(failed_path, ('ida_' + binary_name + '.log')))
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/decompiler/run_decompilers.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import time
3 | import os
4 | import shutil
5 | import re
6 | from collections import defaultdict
7 | import json
8 | import sys
9 | import tempfile
10 | from decompiler.ida_dec import decompile_
11 | # from logger import get_logger
12 | from pathlib import Path
13 | import logging
14 | l = logging.getLogger('main')
15 |
16 | class Decompiler:
17 | def __init__(self, decompiler, decompiler_path, decompiler_workdir,
18 | binary_name, binary_path, binary_type,
19 | decompiled_binary_code_path, failed_path, type_strip_addrs, type_strip_mangled_names) -> None:
20 | self.decompiler = decompiler
21 | self.decompiler_path = decompiler_path
22 | self.decompiler_workdir = decompiler_workdir
23 | self.binary_name = binary_name
24 | self.binary_path = binary_path
25 | self.binary_type = binary_type
26 | self.decompiled_binary_code_path = decompiled_binary_code_path
27 | self.failed_path = failed_path
28 | self.type_strip_addrs= type_strip_addrs
29 | self.type_strip_mangled_names = type_strip_mangled_names
30 | self.decompile()
31 |
32 | def decompile(self):
33 |
34 | if self.decompiler == "ida":
35 | success, path = decompile_(self.decompiler_workdir, Path(self.decompiler_path), self.binary_name,
36 | Path(self.binary_path), self.binary_type, Path(self.decompiled_binary_code_path),
37 | Path(self.failed_path), Path(self.type_strip_addrs), Path(self.type_strip_mangled_names))
38 | if not success:
39 | l.error(f"Decompilation failed for {self.binary_name} :: {self.binary_type} :: check logs at {path}")
40 | return None, None
41 | return success, path
42 |
43 | elif self.decompiler == "ghidra":
44 | try:
45 | current_script_dir = Path(__file__).resolve().parent
46 | subprocess.call(['{} {} tmp_project -scriptPath {} -postScript ghidra_dec.py {} {} {} -import {} -readOnly -log {}.log'.format(self.decompiler_path, self.decompiler_workdir, current_script_dir,
47 | self.decompiled_binary_code_path, self.failed_path,
48 | self.decompiler_workdir, self.binary_path, self.decompiler_workdir,
49 | self.binary_name)], shell=True,
50 | stdout=subprocess.DEVNULL,
51 | stderr=subprocess.DEVNULL)
52 |
53 | except Exception as e:
54 | l.error(f"Decompilation failed for {self.binary_name} :: {self.binary_type} {e}")
55 | return None
56 | return True, self.decompiled_binary_code_path
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/preprocess_vars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import json
4 | import logging
5 | import concurrent.futures
6 | from itertools import islice
7 | from functools import partial
8 | from multiprocessing import Manager
9 |
10 | l = logging.getLogger('main')
11 |
12 | manager = Manager()
13 | vars_ = manager.dict()
14 | clean_vars = manager.dict()
15 |
16 | def read_(filename):
17 | with open(filename, 'r') as r:
18 | data = json.loads(r.read())
19 | return data
20 |
21 | def read_jsonl(filename):
22 | with open(filename, 'r') as r:
23 | data = r.readlines()
24 | return data
25 |
26 | def count_(data):
27 | tot = 0
28 | for name, c in data.items():
29 | tot += int(c)
30 | return tot
31 |
32 | def change_case(str_):
33 |
34 | # aP
35 | if len(str_) == 2:
36 | var = str_.lower()
37 |
38 | # _aP # __A
39 | elif len(str_) == 3 and (str_[0] == '_' or str_[:2] == '__'):
40 | var = str_.lower()
41 |
42 | else:
43 | s1 = re.sub('([A-Z]+)([A-Z][a-z]+)', r'\1_\2', str_)
44 | #s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', str_)
45 | var = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
46 | return var
47 |
48 |
49 | def clean(data, lookup, start, batch_size: int):
50 |
51 | global vars_
52 | global clean_vars
53 | end = start + batch_size
54 |
55 | # change case and lower
56 | for var in islice(data, start, end):
57 | if var in lookup:
58 | vars_[var] = lookup[var]
59 | up_var = lookup[var]
60 | else:
61 | up_var = change_case(var)
62 | # rm underscore
63 | up_var = up_var.strip('_')
64 | vars_[var] = up_var
65 | if up_var in clean_vars:
66 | clean_vars[up_var] += data[var]
67 | else:
68 | clean_vars[up_var] = data[var]
69 |
70 | def normalize_variables(filename, ty, outdir, lookup_dir, WORKERS):
71 | existing_lookup = False
72 | if os.path.exists(os.path.join(lookup_dir, 'universal_lookup.json')):
73 | with open(os.path.join(lookup_dir, 'universal_lookup.json'), 'r') as r:
74 | lookup = json.loads(r.read())
75 | existing_lookup = True
76 | else:
77 | lookup = {}
78 | l.debug(f"len of lookup: {len(lookup)}")
79 | data = read_(filename)
80 | l.debug(f"cleaning variables for: {filename} | count: {len(data)}")
81 |
82 | batch = len(data) // (WORKERS - 1)
83 | with concurrent.futures.ProcessPoolExecutor(max_workers=WORKERS) as executor:
84 | cal_partial = partial(clean, data, lookup, batch_size=batch,)
85 | executor.map(cal_partial, [batch * i for i in range(WORKERS)])
86 | l.debug(f"norm vars for file {filename} before: {len(vars_)} after: {len(set(vars_.values()))}")
87 | sorted_clean_vars = dict(sorted(clean_vars.items(), key=lambda item: item[1], reverse = True))
88 |
89 | with open(os.path.join(outdir, f"{ty}_clean.json"), 'w') as w:
90 | w.write(json.dumps(dict(sorted_clean_vars)))
91 |
92 | if existing_lookup:
93 | for v in vars_:
94 | if v not in lookup:
95 | lookup[v] = vars_[v]
96 |
97 | with open(os.path.join(lookup_dir, 'universal_lookup.json'), 'w') as w:
98 | w.write(json.dumps(lookup))
99 | else:
100 | with open(os.path.join(lookup_dir, 'universal_lookup.json'), 'w') as w:
101 | w.write(json.dumps(dict(vars_)))
102 |
103 |
104 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/binary.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from pathlib import Path
3 | import hashlib
4 | from elftools.elf.elffile import ELFFile
5 | from elftools.dwarf.locationlists import (
6 | LocationEntry, LocationExpr, LocationParser)
7 | from elftools.dwarf.descriptions import (
8 | describe_DWARF_expr, _import_extra, describe_attr_value,set_global_machine_arch)
9 | from collections import defaultdict
10 | from utils import write_json
11 | import os
12 | import sys
13 | import shutil
14 | import logging
15 | from utils import subprocess_
16 | from dwarf_info import Dwarf
17 | from strip_types import type_strip_target_binary
18 |
19 | l = logging.getLogger('main')
20 |
21 | class Binary:
22 | def __init__(self, target_binary, b_name, path_manager, language, decompiler, load_data=True):
23 | self.target_binary = target_binary # to binary_path
24 | self.binary_name = b_name
25 | self.hash = None
26 | self.dwarf_data = None
27 | self.strip_binary = None
28 | self.type_strip_binary = None
29 | self.path_manager = path_manager
30 | self.language = language
31 | self.decompiler = decompiler
32 | self.dwarf_dict = None
33 |
34 | if load_data:
35 | self.dwarf_dict = self.load()
36 | self.strip_binary, self.type_strip_binary = self.modify_dwarf()
37 |
38 | # read things from binary!
39 | def load(self):
40 | try:
41 | l.debug(f'Reading DWARF info from :: {self.target_binary}')
42 | if not self.is_elf_has_dwarf():
43 | return None
44 |
45 | self.hash = self.md5_hash()
46 | self.dwarf_data = Dwarf(self.target_binary, self.binary_name, self.decompiler)
47 | dwarf_dict = self.dump_data()
48 | l.debug(f'Finished reading DWARF info from :: {self.target_binary}')
49 | return dwarf_dict
50 | except Exception as e:
51 | l.error(f"Error in reading DWARF info from :: {self.target_binary} :: {e}")
52 |
53 | def md5_hash(self):
54 | return subprocess.check_output(['md5sum', self.target_binary]).decode('utf-8').strip().split(' ')[0]
55 |
56 | def is_elf_has_dwarf(self):
57 | try:
58 | with open(self.target_binary, 'rb') as f:
59 | elffile = ELFFile(f)
60 | if not elffile.has_dwarf_info():
61 | return None
62 | set_global_machine_arch(elffile.get_machine_arch())
63 | dwarf_info = elffile.get_dwarf_info()
64 | return dwarf_info
65 | except Exception as e:
66 | l.error(f"Error in is_elf_has_dwarf :: {self.target_binary} :: {e}")
67 |
68 | def dump_data(self):
69 | dump = defaultdict(dict)
70 | dump['hash'] = self.hash
71 | dump['vars_per_func'] = self.dwarf_data.vars_in_each_func
72 | dump['linkage_name_to_func_name'] = self.dwarf_data.linkage_name_to_func_name
73 | dump['language'] = self.language
74 | write_json(os.path.join(self.path_manager.tmpdir, 'dwarf', self.binary_name), dict(dump))
75 | return dump
76 |
77 | def _strip_binary(self, target_binary):
78 |
79 | res = subprocess_(['strip', '--strip-all', target_binary])
80 | if isinstance(res, Exception):
81 | l.error(f"error in stripping binary! {res} :: {target_binary}")
82 | return None
83 | return target_binary
84 |
85 | def _type_strip_binary(self, in_binary, out_binary):
86 |
87 | type_strip_target_binary(in_binary, out_binary, self.decompiler)
88 | if not os.path.exists(out_binary):
89 | l.error(f"error in stripping binary! :: {in_binary}")
90 | return out_binary
91 |
92 | def modify_dwarf(self):
93 |
94 | shutil.copy(self.target_binary, self.path_manager.strip_bin_dir)
95 | type_strip = self._type_strip_binary(os.path.join(self.path_manager.strip_bin_dir, self.binary_name), os.path.join(self.path_manager.type_strip_bin_dir, self.binary_name))
96 | strip = self._strip_binary(os.path.join(self.path_manager.strip_bin_dir, self.binary_name))
97 | return strip, type_strip
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/varbert/tokenizer/train_bpe_tokenizer.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from tokenizers import ByteLevelBPETokenizer
3 | from transformers import RobertaTokenizerFast
4 |
5 | import argparse
6 | import os
7 |
8 | def main(agrs):
9 | paths = [str(x) for x in Path(args.input_path).glob("**/*.txt")]
10 |
11 | # Initialize a tokenizer
12 | tokenizer = ByteLevelBPETokenizer(add_prefix_space=True)
13 |
14 | # Customize training
15 | tokenizer.train(files=paths,
16 | vocab_size=args.vocab_size,
17 | min_frequency=args.min_frequency,
18 | special_tokens=[ "",
19 | "",
20 | "",
21 | "",
22 | "",
23 | ])
24 | os.makedirs(args.output_path,exist_ok=True)
25 | tokenizer.save_model(args.output_path)
26 |
27 |
28 | # Test the tokenizer
29 | tokenizer = RobertaTokenizerFast.from_pretrained(args.output_path, max_len=1024)
30 | ghidra_func = "undefined * FUN_001048d0(void)\n{\n undefined8 uVar1;\n long lVar2;\n undefined *puVar3;\n \n uVar1 = PyState_FindModule(&readlinemodule);\n lVar2 = PyModule_GetState(uVar1);\n if (*(lVar2 + 0x18) == 0) {\n FUN_00103605(&_Py_NoneStruct);\n puVar3 = &_Py_NoneStruct;\n }\n else {\n uVar1 = PyState_FindModule(&readlinemodule);\n lVar2 = PyModule_GetState(uVar1);\n FUN_00103605(*(lVar2 + 0x18));\n uVar1 = PyState_FindModule(&readlinemodule);\n lVar2 = PyModule_GetState(uVar1);\n puVar3 = *(lVar2 + 0x18);\n }\n return puVar3;\n}"
31 |
32 | tokens = tokenizer.tokenize(ghidra_func)
33 | token_ids = tokenizer.convert_tokens_to_ids(tokens)
34 | # print("Normal Tokens:",str(tokens))
35 | # print("Normal Token Ids:",str(token_ids))
36 |
37 | human_func = 'int fast_s_mp_sqr(const mp_int *a, mp_int *b)\n{\n int olduse, res, pa, ix, iz;\n mp_digit W[MP_WARRAY], *tmpx;\n mp_word W1;\n /* grow the destination as required */\n pa = a->used + a->used;\n if (b->alloc < pa) {\n if ((res = mp_grow(b, pa)) != 0 /* no error, all is well */) {\n return res;\n }\n }\n /* number of output digits to produce */\n W1 = 0;\n for (ix = 0; ix < pa; ix++) {\n int tx, ty, iy;\n mp_word _W;\n mp_digit *tmpy;\n /* clear counter */\n _W = 0;\n /* get offsets into the two bignums */\n ty = MIN(a->used-1, ix);\n tx = ix - ty;\n /* setup temp aliases */\n tmpx = a->dp + tx;\n tmpy = a->dp + ty;\n /* this is the number of times the loop will iterrate, essentially\n while (tx++ < a->used && ty-- >= 0) { ... }\n */\n iy = MIN(a->used-tx, ty+1);\n /* now for squaring tx can never equal ty\n * we halve the distance since they approach at a rate of 2x\n * and we have to round because odd cases need to be executed\n */\n iy = MIN(iy, ((ty-tx)+1)>>1);\n /* execute loop */\n for (iz = 0; iz < iy; iz++) {\n _W += (mp_word)*tmpx++ * (mp_word)*tmpy--;\n }\n /* double the inner product and add carry */\n _W = _W + _W + W1;\n /* even columns have the square term in them */\n if (((unsigned)ix & 1u) == 0u) {\n _W += (mp_word)a->dp[ix>>1] * (mp_word)a->dp[ix>>1];\n }\n /* store it */\n W[ix] = _W & ((((mp_digit)1)<<((mp_digit)MP_DIGIT_BIT))-((mp_digit)1));\n /* make next carry */\n W1 = _W >> (mp_word)(CHAR_BIT*sizeof(mp_digit));\n }\n /* setup dest */\n olduse = b->used;\n b->used = a->used+a->used;\n {\n mp_digit *tmpb;\n tmpb = b->dp;\n for (ix = 0; ix < pa; ix++) {\n *tmpb++ = W[ix] & ((((mp_digit)1)<<((mp_digit)MP_DIGIT_BIT))-((mp_digit)1));\n }\n /* clear unused digits [that existed in the old copy of c] */\n for (; ix < olduse; ix++) {\n *tmpb++ = 0;\n }\n }\n mp_clamp(b);\n return 0 /* no error, all is well */;\n}'
38 |
39 | tokens = tokenizer.tokenize(human_func)
40 | token_ids = tokenizer.convert_tokens_to_ids(tokens)
41 | # print("Normal Tokens:",str(tokens))
42 | # print("Normal Token Ids:",str(token_ids))
43 |
44 | if __name__ == "__main__":
45 |
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument('--input_path', type=str, help='path to the input text files')
48 | parser.add_argument('--vocab_size', type=int, default=50265, help='size of tokenizer vocabulary')
49 | parser.add_argument('--min_frequency', type=int, default=2, help='minimum frequency')
50 | parser.add_argument('--output_path', type=str, help='path to the output text files')
51 |
52 |
53 | args = parser.parse_args()
54 |
55 |
--------------------------------------------------------------------------------
/varbert/generate_vocab.py:
--------------------------------------------------------------------------------
1 | # Generate vocab from preprocessed sets with fid
2 |
3 | import argparse
4 | import os
5 | import json
6 | import jsonlines as jsonl
7 | import re
8 | from collections import defaultdict
9 | from tqdm import tqdm
10 |
11 | def load_jsonl_files(file_path):
12 | data = []
13 | with jsonl.open(file_path) as ofd:
14 | for each in tqdm(ofd, desc=f"Loading data from {file_path}"):
15 | data.append(each)
16 | return data
17 |
18 | def read_json(file_path):
19 | with open(file_path, 'r') as f:
20 | data = json.load(f)
21 | return data
22 |
23 | def calculate_distribution(data, dataset_type):
24 | var_distrib = defaultdict(int)
25 | for each in tqdm(data):
26 | func = each['norm_func']
27 | pattern = "@@\w+@@\w+@@"
28 | if dataset_type == 'varcorpus':
29 | dwarf_norm_type = each['type_stripped_norm_vars']
30 |
31 | for each_var in list(re.finditer(pattern,func)):
32 | s = each_var.start()
33 | e = each_var.end()
34 | var = func[s:e]
35 | orig_var = var.split("@@")[-2]
36 |
37 | # Collect variables only dwarf
38 | if dataset_type == 'varcorpus':
39 | if orig_var in dwarf_norm_type:
40 | var_distrib[orig_var]+=1
41 | elif dataset_type == 'hsc':
42 | var_distrib[orig_var]+=1
43 |
44 | sorted_var_distrib = sorted(var_distrib.items(), key = lambda x : x[1], reverse=True)
45 | return sorted_var_distrib
46 |
47 |
48 | def build_vocab(data, vocab_size, existing_vocab=None):
49 | if existing_vocab:
50 | vocab_list = list(existing_vocab)
51 | else:
52 | vocab_list = []
53 | for idx, each in tqdm(enumerate(data)):
54 | if len(vocab_list) == args.vocab_size:
55 | print("limit reached:", args.vocab_size, "Missed:",len(data)-idx-1)
56 | break
57 | if each[0] in vocab_list:
58 | continue
59 | else:
60 | vocab_list.append(each[0])
61 |
62 | idx2word, word2idx = {}, {}
63 | for i,each in enumerate(vocab_list):
64 | idx2word[i] = each
65 | word2idx[each] = i
66 |
67 | return idx2word, word2idx
68 |
69 | def save_json(data, output_path, filename):
70 | with open(os.path.join(output_path, filename), 'w') as w:
71 | w.write(json.dumps(data))
72 |
73 |
74 | def main(args):
75 | # Load existing human vocabulary if provided
76 | if args.existing_vocab:
77 | with open(args.existing_vocab, 'r') as f:
78 | human_vocab = json.load(f)
79 | else:
80 | human_vocab = None
81 |
82 | # Load train and test data
83 | train_data = load_jsonl_files(args.train_file)
84 | test_data = load_jsonl_files(args.test_file)
85 |
86 | # TODO add check to
87 | var_distrib_train = calculate_distribution(train_data, args.dataset_type)
88 | var_distrib_test = calculate_distribution(test_data, args.dataset_type)
89 |
90 | # save only if needed
91 | # save_json(var_distrib_train, args.output_dir, 'var_distrib_train.json')
92 | # save_json(var_distrib_test, args.output_dir, 'var_distrib_test.json')
93 |
94 | print("Train data distribution", len(var_distrib_train))
95 | print("Test data distribution", len(var_distrib_test))
96 |
97 | existing_vocab_data = {}
98 | if args.existing_vocab:
99 | print("Human vocab size", len(human_vocab))
100 | existing_vocab_data = read_json(args.existing_vocab)
101 |
102 | # Build and save the vocabulary
103 | idx2word, word2idx = build_vocab(var_distrib_train, args.vocab_size, existing_vocab=existing_vocab_data)
104 | print("Vocabulary size", len(idx2word))
105 | save_json(idx2word, args.output_dir, 'idx_to_word.json')
106 | save_json(word2idx, args.output_dir, 'word_to_idx.json')
107 |
108 | if __name__ == "__main__":
109 |
110 | parser = argparse.ArgumentParser(description="Dataset Vocabulary Generator")
111 | parser.add_argument("--dataset_type", type=str, choices=['hsc', 'varcorpus'], required=True, help="Create vocab for HSC (source code) or VarCorpus (decompiled code)")
112 | parser.add_argument("--train_file", type=str, required=True, help="Path to the training data file")
113 | parser.add_argument("--test_file", type=str, required=True, help="Path to the test data file")
114 | parser.add_argument("--existing_vocab", type=str, help="Path to the existing human vocabulary file")
115 | parser.add_argument("--vocab_size", type=int, default=50000, help="Limit for the vocabulary size")
116 | parser.add_argument("--output_dir", type=str, required=True, help="Path where the output vocabulary will be saved")
117 |
118 | args = parser.parse_args()
119 | main(args)
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/joern_parser.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import json
4 | import time
5 | import logging
6 | import subprocess
7 | from collections import defaultdict, OrderedDict
8 | from cpgqls_client import CPGQLSClient, import_code_query
9 |
10 | l = logging.getLogger('main')
11 |
12 | class JoernParser:
13 | def __init__(self, binary_name, binary_type, decompiler,
14 | decompiled_code, workdir, port, outpath ):
15 | self.binary_name = binary_name
16 | self.binary_type = binary_type
17 | self.decompiler = decompiler
18 | self.dc_inpath = decompiled_code
19 | self.port = port
20 | self.client = CPGQLSClient(f"localhost:{self.port}")
21 | self.joern_workdir = os.path.join(workdir, 'tmp_joern')
22 | self.joern_data = defaultdict(dict)
23 | self.joern_outpath = outpath
24 | self.parse_joern()
25 |
26 |
27 | def edit_dc(self):
28 | try:
29 | regex_ul = r'(\d)(uLL)'
30 | regex_l = r'(\d)(LL)'
31 | regex_u = r'(\d)(u)'
32 |
33 | with open(f'{self.dc_inpath}', 'r') as r:
34 | data = r.read()
35 |
36 | tmp_1 = re.sub(regex_ul, r'\g<1>', data)
37 | tmp_2 = re.sub(regex_l, r'\g<1>', tmp_1)
38 | final = re.sub(regex_u, r'\g<1>', tmp_2)
39 |
40 | with open(os.path.join(self.joern_workdir, f'{self.binary_name}.c'), 'w') as w:
41 | w.write(final)
42 |
43 | except Exception as e:
44 | l.error(f"Error in joern :: {self.decompiler} {self.binary_type} {self.binary_name} :: {e} ")
45 |
46 |
47 | def clean_up(self):
48 | self.client.execute(f'close("{self.binary_name}")')
49 | self.client.execute(f'delete("{self.binary_name}")')
50 |
51 |
52 | def split(self):
53 | try:
54 | l.debug(f"joern parsing! {self.decompiler} {self.binary_type} {self.binary_name}")
55 | regex = r"(\((\"([\s\S]*?)\"))((, )(\"([\s\S]*?)\")((, )(\d*)(, )(\d*)))((, )(\"([\s\S]*?)\")((, )(\d*)(, )(\d*)))"
56 | r = re.compile(regex)
57 |
58 | self.client.execute(import_code_query(self.joern_workdir, f'{self.binary_name}'))
59 | fetch_q = f'show(cpg.identifier.l.map(x => (x.location.filename, x.method.name, x.method.lineNumber.get, x.method.lineNumberEnd.get, x.name, x.lineNumber.get, x.columnNumber.get)).sortBy(_._7).sortBy(_._6).sortBy(_._1))'
60 |
61 | result = self.client.execute(fetch_q)
62 | res_stdout = result['stdout']
63 |
64 | if '***temporary file:' in res_stdout:
65 | tmp_file = res_stdout.split(':')[-1].strip()[:-3]
66 |
67 | with open(tmp_file, 'r') as tf_read:
68 | res_stdout = tf_read.read()
69 | subprocess.check_output(['rm', '{}'.format(tmp_file)])
70 |
71 | matches = r.finditer(res_stdout,re.MULTILINE)
72 | raw_data = defaultdict(dict)
73 | track_func = set()
74 |
75 | random = defaultdict(list)
76 | random_tmp = []
77 | temp = ''
78 | test = set()
79 | main_dict = defaultdict(dict)
80 | tmp_file_name = ''
81 | for m in matches:
82 |
83 | file_path = m.group(3)
84 | func_name = m.group(7)
85 | func_start = m.group(10)
86 | func_end = m.group(12)
87 | var_name = m.group(16)
88 | var_line = m.group(19)
89 | var_col = m.group(21)
90 |
91 | if tmp_file_name != self.binary_name:
92 | tmp_file_name = self.binary_name
93 | random = defaultdict(list)
94 | raw_data = defaultdict(dict)
95 | random_tmp = []
96 |
97 | pkg_func = func_name + '_' + func_start
98 | if pkg_func != temp:
99 | track_func = set()
100 | random = defaultdict(list)
101 | random_tmp = []
102 |
103 | if var_name not in random_tmp:
104 | random_tmp.append(var_name)
105 |
106 | temp = pkg_func
107 | track_func.add(pkg_func)
108 | test.add(pkg_func)
109 | random[var_name].append(var_line)
110 |
111 | raw_data[pkg_func].update({"func_start":func_start, "func_end" : func_end})
112 | raw_data[pkg_func].update({"variable": dict(random)})
113 | raw_data[pkg_func].update({'tmp':random_tmp})
114 |
115 | with open(self.joern_outpath, 'w') as w:
116 | w.write(json.dumps(raw_data))
117 | self.joern_data = raw_data
118 |
119 | except Exception as e:
120 | l.error(f"Error in joern :: {self.decompiler} {self.binary_type} {self.binary_name} :: {e} ")
121 |
122 | def parse_joern(self):
123 | time.sleep(2)
124 | self.edit_dc()
125 | raw_data = self.split()
126 | self.clean_up()
127 |
128 |
129 |
130 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import argparse
4 | import logging.config
5 | import concurrent.futures
6 | from functools import partial
7 | from multiprocessing import Manager
8 | import json
9 |
10 | from pathmanager import PathManager
11 | from runner import Runner
12 |
13 |
14 | def main(args):
15 | if args.decompiler == 'ida' and not args.ida_path:
16 | parser.error("IDA path is required when decompiling with IDA.")
17 | elif args.decompiler == 'ghidra' and not args.ghidra_path:
18 | parser.error("Ghidra path is required when decompiling with Ghidra.")
19 |
20 | if not args.corpus_language and not args.language_map:
21 | parser.error("Either --corpus_language or --language_map is required.")
22 | if args.corpus_language and args.language_map:
23 | parser.error("Provide only one of --corpus_language or --language_map, not both.")
24 | if args.language_map and not os.path.isfile(args.language_map):
25 | parser.error(f"Language map file not found: {args.language_map}")
26 | if args.corpus_language and args.corpus_language not in {"c", "cpp"}:
27 | parser.error("Invalid --corpus_language. Allowed values are: c, cpp")
28 |
29 | language_map = None
30 | if args.language_map:
31 | try:
32 | with open(args.language_map, 'r') as f:
33 | language_map = json.load(f)
34 | except Exception as e:
35 | parser.error(f"Failed to read language map JSON: {e}")
36 |
37 | path_manager = PathManager(args)
38 | target_binaries = [
39 | os.path.abspath(path)
40 | for path in glob.glob(os.path.join(path_manager.binaries_dir, '**', '*'), recursive=True)
41 | if os.path.isfile(path)
42 | ]
43 | # Validate language_map if provided
44 | if language_map is not None:
45 | # keys are expected to be binary basenames; values 'c' or 'cpp'
46 | valid_langs = {'c', 'cpp'}
47 | invalid = {k: v for k, v in language_map.items() if v not in valid_langs}
48 | if invalid:
49 | parser.error(f"Invalid languages in language_map (allowed: c, cpp): {invalid}")
50 | missing = [os.path.basename(p) for p in target_binaries if os.path.basename(p) not in language_map]
51 | if missing:
52 | parser.error(f"language_map missing entries for binaries: {missing}")
53 | decompiler = args.decompiler
54 | if not args.corpus_language:
55 | pass
56 | #TODO: detect
57 |
58 | splits = True if args.splits else False
59 | l.info(f"Decompiling {len(target_binaries)} binaries with {decompiler}")
60 |
61 | default_language = None if language_map is not None else args.corpus_language
62 | effective_language_map = language_map if language_map else {}
63 |
64 | runner = Runner(
65 | decompiler=args.decompiler,
66 | target_binaries=target_binaries,
67 | WORKERS=args.WORKERS,
68 | path_manager=path_manager,
69 | PORT=8090,
70 | language=default_language,
71 | DEBUG=args.DEBUG,
72 | language_map=effective_language_map
73 | )
74 |
75 | runner.run(PARSE=True, splits=splits)
76 |
77 |
78 | if __name__ == "__main__":
79 | parser = argparse.ArgumentParser(description="Data set generation for VarBERT")
80 |
81 | parser.add_argument(
82 | "-b", "--binaries_dir",
83 | type=str,
84 | help="Path to binaries dir",
85 | required=True,
86 | )
87 |
88 | parser.add_argument(
89 | "-d", "--data_dir",
90 | type=str,
91 | help="Path to data dir",
92 | required=True,
93 | )
94 |
95 | parser.add_argument(
96 | "--tmpdir",
97 | type=str,
98 | help="Path to save intermediate files. Default is /tmp",
99 | required=False,
100 | default="/tmp/varbert_tmpdir"
101 | )
102 |
103 | parser.add_argument(
104 | "--decompiler",
105 | choices=['ida', 'ghidra'],
106 | type=str,
107 | help="choose decompiler IDA or Ghidra",
108 | required=True
109 | )
110 |
111 | parser.add_argument(
112 | "-ida", "--ida_path",
113 | type=str,
114 | help="Path to IDA",
115 | required=False,
116 | )
117 |
118 | parser.add_argument(
119 | "-ghidra", "--ghidra_path",
120 | type=str,
121 | help="Path to Ghidra",
122 | required=False,
123 | )
124 |
125 | parser.add_argument(
126 | "-joern", "--joern_dir",
127 | type=str,
128 | help="Path to Joern",
129 | required=False,
130 | )
131 |
132 | # TODO: if language not given, maybe detect it
133 | parser.add_argument(
134 | "-lang", "--corpus_language",
135 | type=str,
136 | help="Corpus language",
137 | required=False,
138 | )
139 |
140 | parser.add_argument(
141 | "--language_map",
142 | type=str,
143 | help="Path to JSON mapping of binary name to language, e.g. {\"bin1\": \"c\", \"bin2\": \"cpp\"}",
144 | required=False,
145 | )
146 |
147 |
148 | parser.add_argument(
149 | "-w", "--WORKERS",
150 | type=int,
151 | help="Number of workers",
152 | default=2,
153 | required=False
154 | )
155 |
156 | parser.add_argument(
157 | "--DEBUG",
158 | help="Turn on debug logging mode",
159 | action='store_true',
160 | required=False
161 | )
162 | parser.add_argument(
163 | "--splits",
164 | help="Create test and train split",
165 | action='store_true',
166 | required=False
167 | )
168 | args = parser.parse_args()
169 | from log import setup_logging
170 | setup_logging(args.tmpdir, args.DEBUG)
171 | l = logging.getLogger('main')
172 | main(args)
173 |
174 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VarBERT
2 | VarBERT is a BERT-based model which predicts meaningful variable names and variable origins in decompiled code. Leveraging the power of transfer learning, VarBERT can help you in software reverse engineering tasks. VarBERT is pre-trained on 5M human-written source code functions, and then it is fine-tuned on decompiled code from IDA and Ghidra, spanning four compiler optimizations (*O0*, *O1*, *O2*, *O3*).
3 | We built two data sets: (a) Human Source Code data set (HSC) and (b) VarCorpus (for IDA and Ghidra).
4 | This work is developed for IEEE S&P 2024 paper ["Len or index or count, anything but v1": Predicting Variable Names in Decompilation Output with Transfer Learning](https://www.atipriya.com/files/papers/varbert_oakland24.pdf)
5 |
6 | Key Features
7 |
8 | - Pre-trained on 5.2M human-written source code functions.
9 | - Fine-tuned on decompiled code from IDA and Ghidra.
10 | - Supports four compiler optimizations: O0, O1, O2, O3.
11 | - Achieves an accuracy of 54.43% for IDA and 54.49% for Ghidra on O2 optimized binaries.
12 | - A total of 16 models are available, covering two decompilers, four optimizations, and two splitting strategies.
13 |
14 | ### Table of Contents
15 | - [Overview](#overview)
16 | - [VarBERT Model](#varbert-model)
17 | - [Using VarBERT](#use-varbert)
18 | - [Training and Inference](#training-and-inference)
19 | - [Data sets](#data-sets)
20 | - [Installation Instructions](#installation)
21 | - [Cite](#citing)
22 |
23 | ### Overview
24 | This repository contains details on generating a new dataset, and training and running inference on existing VarBERT models from the paper. To use VarBERT models in your day-to-day reverse engineering tasks, please refer to [Use VarBERT](#use-varbert).
25 |
26 |
27 | ### VarBERT Model
28 | We take inspiration for VARBERT from the concepts of transfer learning generally and specifically Bidirectional Encoder Representations from Transformers (BERT).
29 |
30 | - **Pre-training**: VarBERT is pre-trained on HSC functions using Masked Language Modeling (MLM) and Constrained Masked Language Modeling (CMLM).
31 | - **Fine-tuning**: VarBERT is then further fine-tuned on top of the previously pre-trained model using VarCorpus (decompilation output of IDA and Ghidra). It can be further extended to any other decompiler capable of generating C-Style decompilation output.
32 |
33 | ### Use VarBERT
34 | - The VarBERT API is a Python library to access and use the latest models. It can be used in three ways:
35 | 1. From the CLI, directly on decompiled text (without an attached decompiler).
36 | 2. As a scripting library.
37 | 3. As a decompiler plugin with [DAILA](https://github.com/mahaloz/DAILA) for enhanced decompiling experience.
38 |
39 | For a step-by-step guide and a demo on how to get started with the VarBERT API, please visit [VarBERT API](https://github.com/binsync/varbert_api/tree/main).
40 |
41 | ### Training and Inference
42 | For training a new model or running inference on existing models, see our detailed guide at [Training VarBERT](./varbert/README.md)
43 |
44 | Models available for download:
45 | - [Pre-trained models](https://www.dropbox.com/scl/fo/anibfmk6j8xkzi4nqk55f/h?rlkey=fw6ops1q3pqvsbdy5tl00brpw&dl=0)
46 | - [Fine-tuned models](https://www.dropbox.com/scl/fo/socl7rd5lsv926whylqpn/h?rlkey=i0x74bdipj41hys5rorflxawo&dl=0)
47 |
48 | (A [README](https://www.dropbox.com/scl/fi/13s9z5z08u245jqdgfsdc/readme.md?rlkey=yjo33al04j1d5jrwc5pz2hhpz&dl=0) containing all the necessary links for the model is also available.)
49 |
50 | ### Data sets
51 | - **HSC**: Collected from C source files from the Debian APT repository, totaling 5.2M functions.
52 |
53 | - **VarCorpus**: Decompiled functions from C and C++ binaries, built from Gentoo package repository for four compiler optimizations: O0, O1, O2, and O3.
54 |
55 | Additionally, we have two splits: (a) Function Split (b) Binary Split.
56 | - Function Split: Functions are randomly distributed between the test and train sets.
57 | - Binary Split: All functions from a single binary are exclusively present in either the test set or the train set.
58 | To create a new data, follow detailed instuctions at [Building VarCorpus](./varcorpus/README.md)
59 |
60 | Data sets available at:
61 | - [HSC](https://www.dropbox.com/scl/fo/4cu2fmuh10c4wp7xt53tu/h?rlkey=mlsnkyed35m4rl512ipuocwtt&dl=0)
62 | - [VarCorpus](https://www.dropbox.com/scl/fo/3thmg8xoq2ugtjwjcgjsm/h?rlkey=azgjeq513g4semc1qdi5xyroj&dl=0)
63 |
64 |
65 | The fine-tuned models and their corresponding datasets are named `IDA-O0-Function` and `IDA-O0`, respectively. This naming convention indicates that the models and data set are based on functions decompiled from O0 binaries using the IDA decompiler.
66 |
67 | > [!NOTE]
68 | > Our existing data sets have been generated using IDA Pro 7.6 and Ghidra 10.4.
69 |
70 | ### Gentoo Binaries ###
71 | You can access the Gentoo binaries used to create VarCorpus here: https://www.dropbox.com/scl/fo/awtitjnc48k224373vcrx/h?rlkey=muj6t1watc6vn2ds6du7egoha&e=1&st=eicpyqln&dl=0
72 |
73 | ### Installation
74 | Prerequisites for training model or generating data set
75 |
76 | Linux with Python 3.8 or higher
77 | torch ≥ 1.9.0
78 | transformers ≥ 4.10.0
79 |
80 | #### Docker
81 |
82 | ```
83 | docker build -t . varbert
84 | ```
85 |
86 | #### Without Docker
87 | ```bash
88 | pip install -r requirements.txt
89 |
90 | # joern requires Java 11
91 | sudo apt-get install openjdk-11-jdk
92 |
93 | # Ghidra 10.4 requires Java 17+
94 | sudo apt-get install openjdk-17-jdk
95 |
96 | git clone git@github.com:rhelmot/dwarfwrite.git
97 | cd dwarfwrite
98 | pip install .
99 | ```
100 | Note: Ensure you install the correct Java version required by your specific Ghidra version.
101 |
102 |
103 |
104 | ### Citing
105 |
106 | Please cite our paper if you use this in your research:
107 |
108 | ```
109 | @inproceedings{pal2024len,
110 | title={" Len or index or count, anything but v1": Predicting Variable Names in Decompilation Output with Transfer Learning},
111 | author={Pal, Kuntal Kumar and Bajaj, Ati Priya and Banerjee, Pratyay and Dutcher, Audrey and Nakamura, Mutsumi and Basque, Zion Leonahenahe and Gupta, Himanshu and Sawant, Saurabh Arjun and Anantheswaran, Ujjwala and Shoshitaishvili, Yan and others},
112 | booktitle={2024 IEEE Symposium on Security and Privacy (SP)},
113 | pages={4069--4087},
114 | year={2024},
115 | organization={IEEE}
116 | }
117 | ```
118 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import time
4 | import psutil
5 | import logging
6 | import subprocess
7 | from typing import List
8 | from pathlib import Path
9 | from concurrent.futures import process
10 | from elftools.elf.elffile import ELFFile
11 |
12 | # log = get_logger(sys._getframe().f_code.co_filename)
13 |
14 | def read_text(file_name):
15 | with open(file_name, 'r') as r:
16 | data = r.read().strip()
17 | return data
18 |
19 | def read_lines(file_name):
20 | with open(file_name, 'r') as r:
21 | data = r.read().split()
22 | return data
23 |
24 | def read_json(file_name):
25 | with open(file_name, 'r') as r:
26 | data = json.loads(r.read())
27 | return data
28 |
29 | def write_json(file_name, data):
30 | with open(f'{file_name}.json', 'w') as w:
31 | w.write(json.dumps(data))
32 |
33 | def subprocess_(command, timeout=None, shell=False):
34 | try:
35 | subprocess.check_call(command, timeout=timeout, shell=shell)
36 | except Exception as e:
37 | l.error(f"error in binary! :: {command} :: {e}")
38 | return e
39 |
40 | def is_elf(binary_path: str):
41 | '''
42 | if elf return elffile obj
43 | '''
44 | with open(binary_path, 'rb') as rb:
45 | bytes = rb.read(4)
46 | if bytes == b"\x7fELF":
47 | rb.seek(0)
48 | elffile = ELFFile(rb)
49 | if elffile and elffile.has_dwarf_info():
50 | dwarf_info = elffile.get_dwarf_info()
51 |
52 | return elffile, dwarf_info
53 |
54 | def create_dirs(dir_names, tmpdir, ty):
55 | """
56 | Create directories within the specified data directory.
57 |
58 | :param dir_names: List of directory names to create.
59 | :param tmpdir: Base data directory where directories will be created.
60 | :param ty: Subdirectory type, such as 'strip' or 'type-strip'. If None, only base directories are created.
61 | """
62 | # Path(os.path.join(tmpdir, 'dwarf')).mkdir(parents=True, exist_ok=True)
63 | for name in dir_names:
64 | if ty:
65 | target_path = os.path.join(tmpdir, name, ty)
66 | else:
67 | target_path = os.path.join(tmpdir, name)
68 | Path(target_path).mkdir(parents=True, exist_ok=True)
69 |
70 |
71 | def set_up_data_dir(tmpdir, workdir, decompiler):
72 | """
73 | Set up the data directory with required subdirectories for the given decompiler.
74 |
75 | :param tmpdir: The base data directory to set up.
76 | :param workdir: The working directory where some data directories will be created.
77 | :param decompiler: Name of the decompiler to create specific subdirectories.
78 | """
79 |
80 | base_dirs = ['binary', 'failed']
81 | dc_joern_dirs = ['dc', 'joern']
82 | map_dirs = [f"map/{decompiler}", f"dc/{decompiler}/type_strip-addrs", f"dc/{decompiler}/type_strip-names"]
83 | workdir_dirs = [f'{decompiler}_data', 'tmp_joern']
84 |
85 | # Create directories for 'strip' and 'type-strip'
86 | for ty in ['strip', 'type_strip']:
87 | create_dirs([f'{d}/{decompiler}' for d in dc_joern_dirs] + base_dirs, tmpdir, ty)
88 |
89 | # Create additional directories
90 | create_dirs(map_dirs, tmpdir, None)
91 |
92 | # Create workdir directories
93 | create_dirs(workdir_dirs, workdir, None)
94 |
95 | # copy binary
96 | create_dirs(['dwarf'], tmpdir, None)
97 |
98 |
99 | ### JOERN
100 |
101 |
102 | l = logging.getLogger('main')
103 |
104 | class JoernServer:
105 | def __init__(self, joern_path, port):
106 | self.joern_path = joern_path
107 | self.port = port
108 | self.process = None
109 | self.java_process = None
110 |
111 | def start(self):
112 | if self.is_server_running():
113 | l.error(f"Joern server already running on port {self.port}")
114 | self.stop()
115 | return
116 |
117 | # hack: joern can't find the shell script fuzzyc2cpg.sh (CPG generator via the shell in this version)
118 | current_dir = os.getcwd()
119 | try:
120 | os.chdir(self.joern_path)
121 | joern_cmd = [os.path.join(self.joern_path, 'joern'), '--server', '--server-port', str(self.port)]
122 | self.process = subprocess.Popen(joern_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
123 |
124 | for _ in range(20):
125 | self.java_process = self.is_server_running()
126 | if self.java_process:
127 | l.debug(f"Joern server started on port {self.port}")
128 | return
129 | time.sleep(4)
130 | l.debug(f"Retrying to start Joern server on port {self.port}")
131 | except Exception as e:
132 | l.error(f"Failed to start Joern server on port {self.port} :: {e}")
133 | return
134 |
135 | finally:
136 | os.chdir(current_dir)
137 |
138 | def stop(self):
139 | if self.java_process is not None:
140 | self.java_process.kill()
141 | time.sleep(3)
142 | if self.is_server_running():
143 | l.warning("Joern server did not terminate gracefully, forcing termination")
144 | self.java_process.kill()
145 |
146 | def is_server_running(self):
147 | try:
148 | for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
149 | try:
150 | if proc.info['name'] == 'java' and 'io.shiftleft.joern.console.AmmoniteBridge' in proc.info['cmdline'] and '--server-port' in proc.info['cmdline'] and str(self.port) in proc.info['cmdline']:
151 | return psutil.Process(proc.info['pid'])
152 | except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
153 | continue
154 | except Exception as e:
155 | l.error(f"Error while checking if server is running: {e}")
156 |
157 | return False
158 |
159 | def __enter__(self):
160 | self.start()
161 | return self
162 |
163 | def __exit__(self, exc_type, exc_value, traceback):
164 | self.stop()
165 |
166 | def exit(self):
167 | self.stop()
168 |
169 | def restart(self):
170 | self.server.stop()
171 | time.sleep(10)
172 | self.server.start()
--------------------------------------------------------------------------------
/varbert/resize_model.py:
--------------------------------------------------------------------------------
1 | import json
2 | from transformers import (
3 | WEIGHTS_NAME,
4 | AdamW,
5 | BertConfig,
6 | BertForMaskedLM,
7 | BertTokenizer,
8 | CamembertConfig,
9 | CamembertForMaskedLM,
10 | CamembertTokenizer,
11 | DistilBertConfig,
12 | DistilBertForMaskedLM,
13 | DistilBertTokenizer,
14 | GPT2Config,
15 | GPT2LMHeadModel,
16 | GPT2Tokenizer,
17 | OpenAIGPTConfig,
18 | OpenAIGPTLMHeadModel,
19 | OpenAIGPTTokenizer,
20 | PreTrainedModel,
21 | PreTrainedTokenizer,
22 | RobertaConfig,
23 | RobertaForMaskedLM,
24 | RobertaTokenizer,
25 | RobertaTokenizerFast,
26 | get_linear_schedule_with_warmup,
27 | )
28 | import argparse
29 | import os
30 | import shutil
31 | from typing import Dict, List, Tuple
32 |
33 | import numpy as np
34 | import torch
35 | import torch.nn as nn
36 | from torch.nn import CrossEntropyLoss, MSELoss
37 | from torch.nn.utils.rnn import pad_sequence
38 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
39 | from torch.utils.data.distributed import DistributedSampler
40 | from transformers.activations import ACT2FN, gelu
41 |
42 | vocab_size = 50001
43 | class RobertaLMHead2(nn.Module):
44 |
45 | def __init__(self,config):
46 | super().__init__()
47 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
48 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
49 | self.decoder = nn.Linear(config.hidden_size, vocab_size, bias=False)
50 | self.bias = nn.Parameter(torch.zeros(vocab_size))
51 |
52 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
53 | self.decoder.bias = self.bias
54 |
55 | def forward(self, features, **kwargs):
56 | x = self.dense(features)
57 | x = gelu(x)
58 | x = self.layer_norm(x)
59 | # project back to size of vocabulary with bias
60 | x = self.decoder(x)
61 | return x
62 |
63 | class RobertaForMaskedLMv2(RobertaForMaskedLM):
64 |
65 | def __init__(self, config):
66 | super().__init__(config)
67 | self.lm_head2 = RobertaLMHead2(config)
68 | self.init_weights()
69 |
70 | def forward(
71 | self,
72 | input_ids=None,
73 | attention_mask=None,
74 | token_type_ids=None,
75 | position_ids=None,
76 | head_mask=None,
77 | inputs_embeds=None,
78 | encoder_hidden_states=None,
79 | encoder_attention_mask=None,
80 | labels=None,
81 | output_attentions=None,
82 | output_hidden_states=None,
83 | return_dict=None,
84 | type_label=None
85 | ):
86 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
87 |
88 |
89 | outputs = self.roberta(
90 | input_ids,
91 | attention_mask=attention_mask,
92 | token_type_ids=token_type_ids,
93 | position_ids=position_ids,
94 | head_mask=head_mask,
95 | inputs_embeds=inputs_embeds,
96 | encoder_hidden_states=encoder_hidden_states,
97 | encoder_attention_mask=encoder_attention_mask,
98 | output_attentions=output_attentions,
99 | output_hidden_states=output_hidden_states,
100 | return_dict=return_dict,
101 | )
102 | sequence_output = outputs[0]
103 | prediction_scores = self.lm_head2(sequence_output)
104 |
105 | masked_lm_loss = None
106 | if labels is not None:
107 | loss_fct = CrossEntropyLoss()
108 | masked_lm_loss = loss_fct(prediction_scores.view(-1, vocab_size), labels.view(-1))
109 |
110 | masked_loss = masked_lm_loss
111 |
112 | output = (prediction_scores,) + outputs[2:]
113 | return ((masked_loss,) + output) if masked_loss is not None else output
114 |
115 |
116 | def get_resized_model(oldmodel, newvocab_len):
117 |
118 | ## Change the bias dimensions
119 | def _get_resized_bias(old_bias, new_size):
120 | old_num_tokens = old_bias.data.size()[0]
121 | if old_num_tokens == new_size:
122 | return old_bias
123 |
124 | # Create new biases
125 | new_bias = nn.Parameter(torch.zeros(new_size))
126 | new_bias.to(old_bias.device)
127 |
128 | # Copy from the previous weights
129 | num_tokens_to_copy = min(old_num_tokens, new_size)
130 | new_bias.data[:num_tokens_to_copy] = old_bias.data[:num_tokens_to_copy]
131 | return new_bias
132 |
133 | ## Change the decoder dimensions
134 |
135 | cls_layer_oldmodel = oldmodel.lm_head2.decoder
136 | oldvocab_len, old_embedding_dim = cls_layer_oldmodel.weight.size()
137 | print(f"Old Vocab Size: {oldvocab_len} \t Old Embedding Dim: {old_embedding_dim}")
138 | if oldvocab_len == newvocab_len:
139 | return oldmodel
140 |
141 | # Create new weights
142 | cls_layer_newmodel = nn.Linear(in_features=old_embedding_dim, out_features=newvocab_len)
143 | cls_layer_newmodel.to(cls_layer_oldmodel.weight.device)
144 |
145 | # initialize all weights (in particular added tokens)
146 | oldmodel._init_weights(cls_layer_newmodel)
147 |
148 | # Copy from the previous weights
149 | num_tokens_to_copy = min(oldvocab_len, newvocab_len)
150 | cls_layer_newmodel.weight.data[:num_tokens_to_copy, :] = cls_layer_oldmodel.weight.data[:num_tokens_to_copy, :]
151 | oldmodel.lm_head2.decoder = cls_layer_newmodel
152 | # Change the bias
153 | old_bias = oldmodel.lm_head2.bias
154 | oldmodel.lm_head2.bias = _get_resized_bias(old_bias, newvocab_len)
155 |
156 | return oldmodel
157 |
158 |
159 |
160 | def main(args):
161 |
162 | model = RobertaForMaskedLMv2.from_pretrained(args.old_model)
163 | vocab = json.load(open(args.vocab_path))
164 | print(f"New Vocab Size : {len(vocab)}")
165 | newmodel = get_resized_model(model, len(vocab)+1)
166 | print(f"Final Cls Layer of New model : {newmodel.lm_head2.decoder}")
167 | newmodel.save_pretrained(args.out_model_path)
168 |
169 | # Move the other files
170 | files = ['tokenizer_config.json','vocab.json','training_args.bin','special_tokens_map.json','merges.txt']
171 | for file in files:
172 | shutil.copyfile(os.path.join(args.old_model,file), os.path.join(args.out_model_path,file))
173 |
174 |
175 | if __name__ == "__main__":
176 |
177 | parser = argparse.ArgumentParser()
178 | parser.add_argument('--old_model', type=str, help='name of the train file')
179 | parser.add_argument('--vocab_path', type=str, help='path to the out vocab, the size of output layer you need')
180 | parser.add_argument('--out_model_path', type=str, help='path to the new modified model')
181 | args = parser.parse_args()
182 |
183 | main(args)
--------------------------------------------------------------------------------
/varcorpus/README.md:
--------------------------------------------------------------------------------
1 | ### Building VarCorpus
2 |
3 | To build VarCorpus, we collected C and C++ packages from Gentoo and built them across four compiler optimizations (O0, O1, O2 and O3) with debugging information (-g enabled), using [Bintoo](https://github.com/sefcom/bintoo).
4 |
5 | The script reads binaries from a directory, generates type-stripped and stripped binaries, decompiles and parses them using Joern to match variables, and saves deduplicated training and testing sets for both function and binary split in a data directory. This pipeline supports two decompilers, IDA and Ghidra; it can be extended to any decompiler that generates C-style decompilation output.
6 |
7 | To generate the dataset, choose one of the following language selection modes.
8 |
9 | #### Single-language mode (Either C or CPP)
10 |
11 | ```bash
12 | python3 generate.py \
13 | -b \
14 | -d \
15 | --decompiler \
16 | -lang \
17 | -ida or -ghidra \
18 | -w \
19 | -joern \
20 | --splits
21 | ```
22 |
23 | #### Mixed-language mode (Both C and CPP)
24 |
25 | Provide a JSON mapping of binary basenames to languages and omit `-lang`:
26 |
27 | ```json
28 | {
29 | "bin_one": "c",
30 | "bin_two": "cpp"
31 | }
32 | ```
33 |
34 | ```bash
35 | python3 generate.py \
36 | -b \
37 | -d \
38 | --decompiler \
39 | -ida or -ghidra \
40 | --language_map \
41 | -w \
42 | -joern \
43 | --splits
44 | ```
45 |
46 | Notes:
47 | - Exactly one of `-lang/--corpus_language` or `--language_map` must be provided (mutually exclusive).
48 | - Allowed languages are `c` or `cpp` (lowercase).
49 |
50 | The current implementation relies on a specific modified version of [Joern](https://github.com/joernio/joern). We made this modification to expedite the data set creation process. Please download the compatible Joern version from [Joern](https://www.dropbox.com/scl/fi/toh6087y5t5xyln47i5ih/modified_joern.tar.gz?rlkey=lfvjn1u7zvtp9a4cu8z8vgsof&dl=0) and save it.
51 |
52 | ```
53 | wget -O joern.tar.gz https://www.dropbox.com/scl/fi/toh6087y5t5xyln47i5ih/modified_joern.tar.gz?rlkey=lfvjn1u7zvtp9a4cu8z8vgsof&dl=0
54 | tar xf joern.tar.gz
55 | ```
56 |
57 | For `-joern`, provide the path to directory with joern executable from the downloaded version.
58 | ```
59 | -joern /joern/
60 | ```
61 |
62 | ### Use Docker
63 |
64 | You can skip the setup and use our Dockerfile directly to build data set.
65 |
66 | Have your binaries directory and a output directory ready on your host machine. These will be mounted into the container so you can easily provide binaries and retrieve train and test sets.
67 |
68 | #### For Ghidra:
69 |
70 | ```
71 | docker build -t varbert -f ../../Dockerfile ..
72 |
73 | docker run -it \
74 | -v $PWD/:/varbert_workdir/data/binaries \
75 | -v $PWD/:/varbert_workdir/data/sets \
76 | varbert \
77 | python3 /varbert_workdir/VarBERT/varcorpus/dataset-gen/generate.py \
78 | -b /varbert_workdir/data/binaries \
79 | -d /varbert_workdir/data/sets \
80 | --decompiler ghidra \
81 | -lang \
82 | -ghidra /varbert_workdir/ghidra_10.4_PUBLIC/support/analyzeHeadless \
83 | -w \
84 | -joern /varbert_workdir/joern \
85 | --splits
86 | ```
87 |
88 | To enable debug mode add `--DEBUG` arg and mount a tmp directory from host to see intermediate files:
89 |
90 | ```
91 | docker run -it \
92 | -v $PWD/:/varbert_workdir/data/binaries \
93 | -v $PWD/:/varbert_workdir/data/sets \
94 | -v $PWD/:/tmp/varbert_tmpdir \
95 | varbert \
96 | python3 /varbert_workdir/VarBERT/varcorpus/dataset-gen/generate.py \
97 | -b /varbert_workdir/data/binaries \
98 | -d /varbert_workdir/data/sets \
99 | --decompiler ghidra \
100 | -lang \
101 | -ghidra /varbert_workdir/ghidra_10.4_PUBLIC/support/analyzeHeadless \
102 | -w \
103 | --DEBUG \
104 | -joern /varbert_workdir/joern \
105 | --splits
106 | ```
107 |
108 | What this does:
109 |
110 | - `-v $PWD/:/varbert_workdir/data/binaries`: Mounts your local binaries directory into the container.
111 | - `-v $PWD/:/varbert_workdir/data/sets`: Mounts the directory where you want to save the generated train/test sets.
112 |
113 | Inside the container, your binaries are accessible at `/varbert_workdir/data/binaries`, resulting data sets will be saved to `/varbert_workdir/data/sets` and intermediate files are available at `/tmp/varbert_tmpdir`.
114 |
115 |
116 |
117 | #### For IDA:
118 |
119 | Please update Dockerfile to include your IDA and run.
120 |
121 | ```
122 | docker run -it \
123 | -v $PWD/:/varbert_workdir/data/binaries \
124 | -v $PWD/:/varbert_workdir/data/sets \
125 | varbert \
126 | python3 /varbert_workdir/VarBERT/varcorpus/dataset-gen/generate.py \
127 | -b /varbert_workdir/data/binaries \
128 | -d /varbert_workdir/data/sets \
129 | --decompiler ida \
130 | -lang \
131 | -ida \
132 | -w \
133 | -joern /varbert_workdir/joern \
134 | --splits
135 | ```
136 |
137 |
138 | #### Notes:
139 |
140 | - The train and test sets are split in an 80:20 ratio. If there aren't enough functions (or binaries) to meet this ratio, you may end up with no train or test sets after the run.
141 | - We built the dataset using **Ghidra 10.4**. If you wish to use a different version of Ghidra, please update the Ghidra download link in the Dockerfile accordingly.
142 | - In some cases there is a license popup which should be accepted before you can successfully run IDA in docker.
143 | - Disable type casts for more efficient variable matching. (we disabled it while building VarCorpus).
144 |
145 | ### Debug and temporary directories
146 |
147 | - `--DEBUG`: preserves the temporary working directory so you can inspect intermediates and logs.
148 | - `--tmpdir `: sets a custom working directory (default: `/tmp/varbert_tmpdir`). Important subdirectories during a run:
149 | - `binary/strip`, `binary/type_strip`: modified binaries
150 | - `dc//`: decompiled code
151 | - `dwarf`: DWARF info per binary
152 | - `joern//`: Joern JSON
153 | - `map//`: variables mapped decompiled functions
154 | - `splits`: train and test sets
155 |
156 |
157 | Sample Function:
158 | ```json
159 | {
160 | "id": 5,
161 | "language": "C",
162 | "func": "__int64 __fastcall sub_13DF(_QWORD *@@var_1@@a1@@, _QWORD *@@var_0@@Ancestors@@)\n{\n while ( @@var_0@@Ancestors@@ )\n {\n if ( @@var_0@@Ancestors@@[1] == @@var_1@@a1@@[1] && @@var_0@@Ancestors@@[2] == *@@var_1@@a1@@ )\n return 1LL;\n @@var_0@@Ancestors@@ = *@@var_0@@Ancestors@@;\n }\n return 0LL;\n}",
163 | "type_stripped_vars": {"Ancestors": "dwarf", "a1": "ida"},
164 | "stripped_vars": ["a2", "a1"],
165 | "mapped_vars": {"a2": "Ancestors", "a1": "a1"},
166 | "func_name_dwarf": "is_ancestor",
167 | "hash": "2998f4a10a8f052257122c23897d10b7",
168 | "func_name": "8140277b36ef8461df62b160fed946cb_(00000000000013DF)",
169 | "norm_func": "__int64 __fastcall sub_13DF(_QWORD *@@var_1@@a1@@, _QWORD *@@var_0@@ancestors@@)\n{\n while ( @@var_0@@ancestors@@ )\n {\n if ( @@var_0@@ancestors@@[1] == @@var_1@@a1@@[1] && @@var_0@@ancestors@@[2] == *@@var_1@@a1@@ )\n return 1LL;\n @@var_0@@ancestors@@ = *@@var_0@@ancestors@@;\n }\n return 0LL;\n}",
170 | "vars_map": [["Ancestors", "ancestors"]],
171 | "fid": "5-8140277b36ef8461df62b160fed946cb-(00000000000013DF)"
172 | }
173 | ```
174 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/strip_types.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import random
3 | import string
4 | import logging
5 | from elftools.elf.elffile import ELFFile
6 | from elftools.dwarf import constants
7 | import dwarfwrite
8 | from dwarfwrite.restructure import ReStructurer, VOID
9 | from cle.backends.elf.elf import ELF
10 |
11 | l = logging.getLogger('main')
12 | DWARF_VERSION = 4
13 | VOID_STAR_ID = 127
14 |
15 | class TypeStubRewriter(ReStructurer):
16 | def __init__(self, fp, cu_data=None):
17 | super().__init__(fp)
18 | fp.seek(0)
19 | self.rng = random.Random(fp.read())
20 | self.cu_data = cu_data
21 |
22 | def unit_get_functions(self, unit):
23 | result = list(super().unit_get_functions(unit))
24 |
25 | for die in unit.iter_DIEs():
26 | if die.tag not in ('DW_TAG_class_type', 'DW_TAG_struct_type'):
27 | continue
28 | for member in die.iter_children():
29 | if member.tag != 'DW_TAG_subprogram':
30 | continue
31 | result.append(member)
32 |
33 | return result
34 |
35 | def type_basic_encoding(self, handler):
36 | return constants.DW_ATE_unsigned
37 |
38 | def type_basic_name(self, handler):
39 | return {
40 | 1: "unsigned char",
41 | 2: "unsigned short",
42 | 4: "unsigned int",
43 | 8: "unsigned long long",
44 | 16: "unsigned __int128",
45 | 32: "unsigned __int256",
46 | }[handler]
47 |
48 | def function_get_linkage_name(self, handler):
49 | return None
50 |
51 | def parameter_get_name(self, handler):
52 | return super().parameter_get_name(handler)
53 |
54 |
55 | def get_type_size(cu, offset, wordsize):
56 | # returns type_size, base_type_offset, is_exemplar[true=yes, false=maybe, none=no]
57 | type_die = cu.get_DIE_from_refaddr(offset + cu.cu_offset)
58 | base_type_die = type_die
59 | base_type_offset = offset
60 |
61 | while True:
62 | if base_type_die.tag == 'DW_TAG_pointer_type':
63 | return wordsize, base_type_offset, None
64 | elif base_type_die.tag == 'DW_TAG_structure_type':
65 | try:
66 | first_member = next(base_type_die.iter_children())
67 | assert first_member.tag == 'DW_TAG_member'
68 | assert first_member.attributes['DW_AT_data_member_location'].value == 0
69 | except (KeyError, StopIteration, AssertionError):
70 | return wordsize, base_type_offset, None
71 | base_type_offset = first_member.attributes['DW_AT_type'].value
72 | base_type_die = cu.get_DIE_from_refaddr(base_type_offset + cu.cu_offset)
73 | continue
74 | # TODO unions
75 | elif base_type_die.tag == 'DW_TAG_base_type':
76 | size = base_type_die.attributes['DW_AT_byte_size'].value
77 | if any(ty in base_type_die.attributes['DW_AT_name'].value for ty in (b'char', b'int', b'long')):
78 | return size, base_type_offset, True
79 | else:
80 | return size, base_type_offset, None
81 | elif 'DW_AT_type' in base_type_die.attributes:
82 | base_type_offset = base_type_die.attributes['DW_AT_type'].value
83 | base_type_die = cu.get_DIE_from_refaddr(base_type_offset + cu.cu_offset)
84 | continue
85 | else:
86 | # afaik this case is only reached for void types
87 | return wordsize, base_type_offset, None
88 |
89 | def build_type_to_size(ipath):
90 | with open(ipath, 'rb') as ifp:
91 | elf = ELFFile(ifp)
92 | arch = ELF.extract_arch(elf)
93 |
94 | ifp.seek(0)
95 | dwarf = elf.get_dwarf_info()
96 |
97 | cu_data = {}
98 | for cu in dwarf.iter_CUs():
99 | type_to_size = {}
100 | # import ipdb; ipdb.set_trace()
101 | cu_data[cu.cu_offset] = type_to_size
102 | for die in cu.iter_DIEs():
103 | attr = die.attributes.get("DW_AT_type", None)
104 | if attr is None:
105 | continue
106 |
107 | type_size, _, _ = get_type_size(cu, attr.value, arch.bytes)
108 | type_to_size[attr.value] = type_size
109 |
110 | return cu_data
111 |
112 | class IdaStubRewriter(TypeStubRewriter):
113 |
114 | def get_attribute(self, die, name):
115 | r = super().get_attribute(die, name)
116 | if name != 'DW_AT_type' or r is None:
117 | return r
118 | size = self.cu_data[r.cu.cu_offset][r.offset - r.cu.cu_offset]
119 | return size
120 |
121 | def type_basic_size(self, handler):
122 | return handler
123 |
124 | def function_get_name(self, handler):
125 | r = super().function_get_name(handler)
126 | return r
127 |
128 | def parameter_get_location(self, handler):
129 | return None
130 |
131 | class GhidraStubRewriter(TypeStubRewriter):
132 |
133 | def __init__(self, fp, cu_data=None, low_pc_to_funcname=None):
134 | super().__init__(fp, cu_data)
135 | self.low_pc_to_funcname = low_pc_to_funcname
136 |
137 | def get_attribute(self, die, name):
138 | r = super().get_attribute(die, name)
139 | if name != 'DW_AT_type' or r is None:
140 | return r
141 | if die.tag == "DW_TAG_formal_parameter":
142 | varname = self.get_attribute(die, "DW_AT_name")
143 | if varname == b"this":
144 | return VOID_STAR_ID
145 | size = self.cu_data[r.cu.cu_offset][r.offset - r.cu.cu_offset]
146 | return size
147 |
148 | def type_basic_size(self, handler):
149 | if handler == VOID_STAR_ID:
150 | return 8 # assuming the binary is 64-bit
151 | return handler
152 |
153 | def function_get_name(self, handler):
154 | # import ipdb; ipdb.set_trace()
155 | low_pc = self.get_attribute(handler, "DW_AT_low_pc")
156 | if low_pc is not None and str(low_pc) in self.low_pc_to_funcname:
157 | return self.low_pc_to_funcname[str(low_pc)]
158 | return super().function_get_name(handler)
159 |
160 | def type_ptr_of(self, handler):
161 | if handler == VOID_STAR_ID:
162 | return VOID
163 | return None
164 |
165 | def parameter_get_artificial(self, handler):
166 | return None
167 |
168 | def parameter_get_location(self, handler):
169 | return super().parameter_get_location(handler)
170 |
171 |
172 | def get_spec_offsets_and_names(ipath):
173 |
174 | elf = ELFFile(open(ipath, "rb"))
175 | all_spec_offsets = [ ]
176 | low_pc_to_funcname = {}
177 | spec_offset_to_low_pc = {}
178 |
179 | dwarf = elf.get_dwarf_info()
180 | for cu in dwarf.iter_CUs():
181 | cu_offset = cu.cu_offset
182 | for die in cu.iter_DIEs():
183 | for subdie in cu.iter_DIE_children(die):
184 | if subdie.tag == "DW_TAG_subprogram":
185 | if "DW_AT_low_pc" in subdie.attributes:
186 | low_pc = str(subdie.attributes.get("DW_AT_low_pc").value)
187 | if "DW_AT_specification" in subdie.attributes:
188 | spec_offset = subdie.attributes["DW_AT_specification"].value
189 | global_spec_offset = spec_offset + cu_offset
190 | all_spec_offsets.append(global_spec_offset)
191 | spec_offset_to_low_pc[str(global_spec_offset)] = low_pc
192 |
193 | for cu in dwarf.iter_CUs():
194 | for die in cu.iter_DIEs():
195 | for subdie in cu.iter_DIE_children(die):
196 | if subdie.offset in all_spec_offsets:
197 | if "DW_AT_name" in subdie.attributes:
198 | low_pc_to_funcname[spec_offset_to_low_pc[str(subdie.offset)]] = subdie.attributes.get("DW_AT_name").value
199 |
200 | return low_pc_to_funcname
201 |
202 | def type_strip_target_binary(ipath, opath, decompiler):
203 | try:
204 | cu_data = build_type_to_size(ipath)
205 | if decompiler == "ghidra":
206 | low_pc_to_funcname = get_spec_offsets_and_names(ipath)
207 | GhidraStubRewriter.rewrite_dwarf(in_path=ipath, out_path=opath, cu_data=cu_data, low_pc_to_funcname=low_pc_to_funcname)
208 | elif decompiler == "ida":
209 | IdaStubRewriter.rewrite_dwarf(in_path=ipath, out_path=opath, cu_data=cu_data)
210 | else:
211 | l.error("Unsupported Decompiler. Please choose from IDA or Ghidra")
212 | except Exception as e:
213 | l.error(f"Error occured while creating a type-strip binary: {e}")
214 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/decompiler/ida_analysis.py:
--------------------------------------------------------------------------------
1 | import idaapi
2 | import idc
3 | import ida_funcs
4 | import ida_hexrays
5 | import ida_kernwin
6 | import ida_loader
7 |
8 | from collections import namedtuple, defaultdict
9 | from sortedcontainers import SortedDict
10 |
11 | from elftools.dwarf.descriptions import describe_reg_name
12 | from elftools.elf.elffile import ELFFile
13 | from elftools.dwarf.dwarf_expr import DWARFExprParser
14 | from elftools.dwarf import locationlists
15 |
16 | import json
17 |
18 | LocationEntry = namedtuple("LocationEntry", ("begin_offset", "end_offset", "location"))
19 | NameResult = namedtuple("NameResult", ("name", "size"))
20 |
21 | class RegVarAnalysis:
22 | def __init__(self, fname):
23 | a = ELFFile(open(fname, 'rb'))
24 | b = a.get_dwarf_info()
25 |
26 | self.result = defaultdict(SortedDict)
27 |
28 | self.loc_parser = b.location_lists()
29 | self.expr_parser = DWARFExprParser(b.structs)
30 | self.range_lists = b.range_lists()
31 |
32 | for c in b.iter_CUs():
33 | for x in c.iter_DIEs():
34 | if x.tag != 'DW_TAG_variable' or 'DW_AT_name' not in x.attributes or x.get_parent() is x.cu.get_top_DIE() or 'DW_AT_location' not in x.attributes:
35 | continue
36 | # ???
37 | if x.attributes['DW_AT_location'].form == 'DW_FORM_exprloc':
38 | loclist = self.get_single_loc(x)
39 | elif x.attributes['DW_AT_location'].form != 'DW_FORM_sec_offset':
40 | assert False
41 | else:
42 | loclist = self.get_loclist(x)
43 | for loc in loclist:
44 | if len(loc.location) != 1 or not loc.location[0].op_name.startswith('DW_OP_reg'):
45 | # discard complicated variables
46 | continue
47 | expr = loc.location[0]
48 | if expr.op_name == 'DW_OP_regx':
49 | reg_name = describe_reg_name(expr.args[0], a.get_machine_arch())
50 | else:
51 | reg_name = describe_reg_name(int(expr.op_name[9:]), a.get_machine_arch())
52 | self.result[reg_name][loc.begin_offset] = NameResult(x.attributes['DW_AT_name'].value.decode(), loc.end_offset - loc.begin_offset)
53 |
54 | def get_single_loc(self, f):
55 | base_addr = 0
56 | low_pc = f.cu.get_top_DIE().attributes.get("DW_AT_low_pc", None)
57 | if low_pc is not None:
58 | base_addr = low_pc.value
59 | parent = f.get_parent()
60 | ranges = []
61 | while parent is not f.cu.get_top_DIE():
62 | if 'DW_AT_low_pc' in parent.attributes and 'DW_AT_high_pc' in parent.attributes:
63 | ranges.append((
64 | parent.attributes['DW_AT_low_pc'].value + base_addr,
65 | parent.attributes['DW_AT_high_pc'].value + base_addr,
66 | ))
67 | break
68 | if 'DW_AT_ranges' in parent.attributes:
69 | rlist = self.range_lists.get_range_list_at_offset(parent.attributes['DW_AT_ranges'].value)
70 | ranges = [
71 | (rentry.begin_offset + base_addr, rentry.end_offset + base_addr) for rentry in rlist
72 | ]
73 | break
74 | parent = parent.get_parent()
75 | else:
76 | return []
77 |
78 | return [LocationEntry(
79 | location=self.expr_parser.parse_expr(f.attributes['DW_AT_location'].value),
80 | begin_offset=begin,
81 | end_offset=end
82 | ) for begin, end in ranges]
83 |
84 |
85 | def get_loclist(self, f):
86 | base_addr = 0
87 | low_pc = f.cu.get_top_DIE().attributes.get("DW_AT_low_pc", None)
88 | if low_pc is not None:
89 | base_addr = low_pc.value
90 |
91 | loc_list = self.loc_parser.get_location_list_at_offset(f.attributes['DW_AT_location'].value)
92 | result = []
93 | for item in loc_list:
94 | if type(item) is locationlists.LocationEntry:
95 | try:
96 | result.append(LocationEntry(
97 | base_addr + item.begin_offset,
98 | base_addr + item.end_offset,
99 | self.expr_parser.parse_expr(item.loc_expr)))
100 | except KeyError as e:
101 | if e.args[0] == 249: # gnu extension dwarf expr ops
102 | continue
103 | else:
104 | raise
105 | elif type(item) is locationlists.BaseAddressEntry:
106 | base_addr = item.base_address
107 | else:
108 | raise TypeError("What kind of loclist entry is this?")
109 | return result
110 |
111 | def lookup(self, reg, addr):
112 | try:
113 | key = next(self.result[reg].irange(maximum = addr, reverse=True))
114 | except StopIteration:
115 | return None
116 | else:
117 | val = self.result[reg][key]
118 | if key + val.size <= addr:
119 | return None
120 | return val.name
121 |
122 | analysis: RegVarAnalysis = None
123 |
124 | def setup():
125 | global analysis
126 | if analysis is not None:
127 | return
128 | path = ida_loader.get_path(ida_loader.PATH_TYPE_CMD)
129 | analysis = RegVarAnalysis(path)
130 |
131 | def dump_list(list_, filename):
132 | with open(filename, 'w') as w:
133 | w.write("\n".join(list_))
134 |
135 | def write_json(data, filename):
136 | with open(filename, 'w') as w:
137 | w.write(json.dumps(data))
138 |
139 | def go():
140 | setup()
141 |
142 | ea = 0
143 | collect_addrs, mangled_names_to_demangled_names = [], {}
144 | filename = idc.ARGV[1]
145 | while True:
146 | func = ida_funcs.get_next_func(ea)
147 | if func is None:
148 | break
149 | ea = func.start_ea
150 | seg = idc.get_segm_name(ea)
151 | if seg != ".text":
152 | continue
153 | collect_addrs.append(str(ea))
154 | typ = idc.get_type(ea)
155 | # void sometimes introduce extra variables, updating return type helps in variable matching for type-strip binary
156 | if 'void' in str(typ):
157 | newtype = str(typ).replace("void", f"__int64 {str(ida_funcs.get_func_name(ea))}") + ";"
158 | res = idc.SetType(ea, newtype)
159 |
160 | print("analyzing" , ida_funcs.get_func_name(ea))
161 | analyze_func(func)
162 | # # Demangle the name
163 | mangled_name = ida_funcs.get_func_name(ea)
164 | demangled_name = idc.demangle_name(mangled_name, idc.get_inf_attr(idc.INF_SHORT_DN))
165 | if demangled_name:
166 | mangled_names_to_demangled_names[mangled_name] = demangled_name
167 | else:
168 | mangled_names_to_demangled_names[mangled_name] = mangled_name
169 |
170 | dump_list(collect_addrs, filename)
171 | write_json(mangled_names_to_demangled_names, f'{filename}_names')
172 |
173 | def analyze_func(func):
174 | cfunc = ida_hexrays.decompile_func(func, None, 0)
175 | if cfunc is None:
176 | return
177 | v = Visitor(func.start_ea, cfunc)
178 | v.apply_to(cfunc.body, None)
179 | return v
180 |
181 | class Visitor(idaapi.ctree_visitor_t):
182 | def __init__(self, ea, cfunc):
183 | super().__init__(idaapi.CV_FAST)
184 | self.ea = ea
185 | self.cfunc = cfunc
186 | self.vars = []
187 | self.already_used = {lvar.name for lvar in cfunc.lvars if lvar.has_user_name}
188 | self.already_fixed = set(self.already_used)
189 |
190 | def visit_expr(self, expr):
191 | if expr.op == ida_hexrays.cot_var:
192 | lvar = expr.get_v().getv()
193 | old_name = lvar.name
194 | if expr.ea == idc.BADADDR:
195 | pass
196 | else:
197 | if old_name not in self.already_fixed:
198 | if lvar.location.is_reg1():
199 | reg_name = ida_hexrays.print_vdloc(lvar.location, 8)
200 | if reg_name in analysis.result:
201 | var_name = analysis.lookup(reg_name, expr.ea)
202 | if var_name:
203 | nonce_int = 0
204 | nonce = ''
205 | while var_name + nonce in self.already_used:
206 | nonce = '_' + str(nonce_int)
207 | nonce_int += 1
208 | name = var_name + nonce
209 | ida_hexrays.rename_lvar(self.ea, old_name, name)
210 | self.already_used.add(name)
211 | self.already_fixed.add(old_name)
212 |
213 | return 0
214 |
215 | idaapi.auto_wait()
216 | go()
217 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/dwarf_info.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from elftools.elf.elffile import ELFFile
3 | from elftools.dwarf.locationlists import (
4 | LocationEntry, LocationExpr, LocationParser)
5 | from elftools.dwarf.descriptions import (
6 | describe_DWARF_expr, _import_extra, describe_attr_value,set_global_machine_arch)
7 | from collections import defaultdict
8 |
9 | l = logging.getLogger('main')
10 |
11 | def is_elf(binary_path: str):
12 |
13 | with open(binary_path, 'rb') as rb:
14 | if rb.read(4) != b"\x7fELF":
15 | l.error(f"File is not an ELF format: {binary_path}")
16 | return None
17 | rb.seek(0)
18 | elffile = ELFFile(rb)
19 | if not elffile.has_dwarf_info():
20 | l.error(f"No DWARF info found in: {binary_path}")
21 | return None
22 | return elffile.get_dwarf_info()
23 |
24 | class Dwarf:
25 | def __init__(self, binary_path, binary_name, decompiler) -> None:
26 | self.binary_path = binary_path
27 | self.binary_name = binary_name
28 | self.decompiler = decompiler
29 | self.dwarf_info = is_elf(self.binary_path)
30 | self.vars_in_each_func = defaultdict(list)
31 | self.spec_offset_var_names = defaultdict(set) if decompiler == 'ghidra' else None
32 | self.linkage_name_to_func_name = {}
33 | self._load()
34 |
35 | def _load(self):
36 | if not self.dwarf_info:
37 | l.error(f"Failed to load DWARF information for: {self.binary_path}")
38 | return
39 | if self.decompiler == 'ida':
40 | self.spec_offset_var_names = self.collect_spec_and_names()
41 | self.vars_in_each_func, self.linkage_name_to_func_name = self.read_dwarf()
42 |
43 | # for Ghidra only
44 | def collect_spec_and_names(self):
45 | # 1. get CU's offset 2. get spec offset, if any function resolves use that function name with vars
46 | spec_offset_var_names = defaultdict(set)
47 | for CU in self.dwarf_info.iter_CUs():
48 | try:
49 | current_cu_offset = CU.cu_offset
50 | top_DIE = CU.get_top_DIE()
51 | self.die_info_rec(top_DIE)
52 | current_spec_offset = 0
53 |
54 | for i in range(0, len(CU._dielist)):
55 | if CU._dielist[i].tag == 'DW_TAG_subprogram':
56 |
57 | spec_offset, current_spec_offset = 0, 0
58 | spec = CU._dielist[i].attributes.get("DW_AT_specification", None)
59 |
60 | # specification value + CU offset = defination of function
61 | # 15407 + 0xbd1 = 0x8800
62 | # <8800>: Abbrev Number: 74 (DW_TAG_subprogram)
63 | # <8801> DW_AT_external
64 |
65 | if spec:
66 | spec_offset = spec.value
67 | current_spec_offset = spec_offset + current_cu_offset
68 |
69 | if CU._dielist[i].tag == 'DW_TAG_formal_parameter' or CU._dielist[i].tag == 'DW_TAG_variable':
70 | if self.is_artifical(CU._dielist[i]):
71 | continue
72 | name = self.get_name(CU._dielist[i])
73 | if not name:
74 | continue
75 | if current_spec_offset:
76 | spec_offset_var_names[str(current_spec_offset)].add(name)
77 |
78 | except Exception as e:
79 | l.error(f"Error in collect_spec_and_names :: {self.binary_path} :: {e}")
80 | return spec_offset_var_names
81 |
82 | # for global vars - consider only variables those have locations.
83 | # otherwise variables like stderr, optind, make it to global variables list and
84 | # sometimes these can be marked as valid variables for a particular binary
85 |
86 | def read_dwarf(self):
87 | # for each binary
88 | if not self.dwarf_info:
89 | return defaultdict(list), {}
90 | vars_in_each_func = defaultdict(list)
91 | previous_func_offset = 0
92 | l.debug(f"Reading DWARF info from :: {self.binary_path}")
93 | for CU in self.dwarf_info.iter_CUs():
94 | try:
95 | current_func_name, ch, func_name = 'global_vars', '', ''
96 | func_offset = 0
97 | tmp_list, global_vars_list, local_vars_and_args_list = [], [], []
98 |
99 | # caches die info
100 | top_DIE = CU.get_top_DIE()
101 | self.die_info_rec(top_DIE)
102 |
103 | # 1. global vars 2. vars and args in subprogram
104 | for i in range(0, len(CU._dielist)):
105 | try:
106 | if CU._dielist[i].tag == 'DW_TAG_variable':
107 | if not CU._dielist[i].attributes.get('DW_AT_location'):
108 | continue
109 | if self.is_artifical(CU._dielist[i]):
110 | continue
111 | var_name = self.get_name(CU._dielist[i])
112 | if not var_name:
113 | continue
114 | tmp_list.append(var_name)
115 |
116 | # shouldn't check location for parameters because we don't add their location in DWARF
117 | if CU._dielist[i].tag == 'DW_TAG_formal_parameter':
118 | param_name = self.get_name(CU._dielist[i])
119 | if not param_name:
120 | continue
121 | tmp_list.append(param_name)
122 |
123 | # for last func and it's var
124 | local_vars_and_args_list = tmp_list
125 | if CU._dielist[i].tag == 'DW_TAG_subprogram':
126 | if self.is_artifical(CU._dielist[i]):
127 | continue
128 | # IDA
129 | if self.decompiler == 'ghidra':
130 | # func addr is func name for now
131 | low_pc = CU._dielist[i].attributes.get('DW_AT_low_pc', None)
132 | addr = None
133 | if low_pc is not None:
134 | addr = str(low_pc.value)
135 |
136 | ranges = CU._dielist[i].attributes.get('DW_AT_ranges', None)
137 | if ranges is not None:
138 | addr = self.dwarf_info.range_lists().get_range_list_at_offset(ranges.value)[0].begin_offset
139 |
140 | if not addr:
141 | continue
142 | func_name = addr
143 | # Ghidra
144 | if self.decompiler == 'ida':
145 | func_name = self.get_name(CU._dielist[i])
146 | if not func_name:
147 | continue
148 |
149 | func_linkage_name = self.get_linkage_name(CU._dielist[i])
150 | func_offset = CU._dielist[i].offset
151 | # because func name from dwarf is without class name but IDA gives us funcname with classname
152 | # so we match them using linkage_name
153 | if func_linkage_name:
154 | self.linkage_name_to_func_name[func_linkage_name] = func_name
155 | func_name = func_linkage_name
156 | else:
157 | # need it later for matching
158 | self.linkage_name_to_func_name[func_name] = func_name
159 | func_name = func_name
160 |
161 | # because DIE's are serialized and subprogram comes before vars and params
162 | vars_from_specification_subprogram = []
163 | if previous_func_offset in self.spec_offset_var_names:
164 | vars_from_specification_subprogram = self.spec_offset_var_names[str(previous_func_offset)]
165 | previous_func_offset = str(func_offset)
166 |
167 | if current_func_name != func_name:
168 | if current_func_name == 'global_vars':
169 | global_vars_list.extend(tmp_list)
170 | vars_in_each_func[current_func_name].extend(global_vars_list)
171 | else:
172 | if self.decompiler == 'ida':
173 | if vars_from_specification_subprogram:
174 | tmp_list.extend(vars_from_specification_subprogram)
175 | vars_in_each_func[current_func_name].extend(tmp_list)
176 | ch = current_func_name
177 | current_func_name = func_name
178 | tmp_list = []
179 | except Exception as e:
180 | l.error(f"Error in reading DWARF {e}")
181 |
182 | if current_func_name != ch and func_name:
183 | vars_in_each_func[func_name].extend(local_vars_and_args_list)
184 |
185 | except Exception as e:
186 | l.error(f"Error in read_dwarf :: {self.binary_name} :: {e}")
187 | l.debug(f"Number of functions in {self.binary_name}: {str(len(vars_in_each_func))}")
188 | return vars_in_each_func, self.linkage_name_to_func_name
189 |
190 | def get_name(self, die):
191 | name = die.attributes.get('DW_AT_name', None)
192 | if name:
193 | return name.value.decode('ascii')
194 |
195 | def get_linkage_name(self, die):
196 | name = die.attributes.get('DW_AT_linkage_name', None)
197 | if name:
198 | return name.value.decode('ascii')
199 |
200 | # compiler generated variables or functions (destructors, ...)
201 | def is_artifical(self, die):
202 | return die.attributes.get('DW_AT_artificial', None)
203 |
204 | def die_info_rec(self, die, indent_level=' '):
205 | """ A recursive function for showing information about a DIE and its
206 | children.
207 | """
208 | child_indent = indent_level + ' '
209 | for child in die.iter_children():
210 | self.die_info_rec(child, child_indent)
211 |
212 | @classmethod
213 | def get_vars_in_each_func(cls, binary_path):
214 | dwarf = cls(binary_path)
215 | return dwarf.vars_in_each_func
216 |
217 | @classmethod
218 | def get_vars_for_func(cls, binary_path, func_name):
219 | dwarf = cls(binary_path)
220 | return dwarf.vars_in_each_func[func_name] + dwarf.vars_in_each_func['global_vars']
--------------------------------------------------------------------------------
/varbert/cmlm/preprocess.py:
--------------------------------------------------------------------------------
1 | # from pymongo import MongoClient
2 | from tqdm import tqdm
3 | from transformers import RobertaTokenizerFast
4 | import jsonlines
5 | import sys
6 | import json
7 | import logging
8 | import traceback
9 | import random
10 | import os
11 | import re
12 | import argparse
13 | from multiprocessing import Process, Manager, cpu_count, Pool
14 | from itertools import repeat
15 | from collections import defaultdict
16 |
17 | l = logging.getLogger('model_main')
18 |
19 | def read_input_files(filename):
20 |
21 | samples, sample_ids = [], set()
22 | with jsonlines.open(filename,'r') as f:
23 | for each in tqdm(f):
24 | samples.append(each)
25 | sample_ids.add(str(each['mongo_id']))
26 | return sample_ids, samples
27 |
28 |
29 | def prep_input_files(input_file, num_processes, tokenizer, word_to_idx, max_sample_chunk, input_file_ids, output_file_ids, preprocessed_outfile):
30 |
31 | #------------------------ INPUT FILES ------------------------
32 | output_data = Manager().list()
33 | pool = Pool(processes=num_processes) # Instantiate the pool here
34 | each_alloc = len(input_file) // (num_processes-1)
35 | input_data = [input_file[i*each_alloc:(i+1)*each_alloc] for i in range(0,num_processes)]
36 | x = [len(each) for each in input_data]
37 | print(f"Allocation samples for each worker: {len(input_data)}, {x}")
38 |
39 | pool.starmap(generate_id_files,zip(input_data,
40 | repeat(output_data),
41 | repeat(tokenizer),
42 | repeat(word_to_idx),
43 | repeat(max_sample_chunk)
44 | ))
45 | pool.close()
46 | pool.join()
47 |
48 | # Write to Output file
49 | with jsonlines.open(preprocessed_outfile,'w') as f:
50 | for each in tqdm(output_data):
51 | f.write(each)
52 |
53 | # check : #source ids == target_ids after parallel processing
54 | print(f"src_tgt_intersection:", len(input_file_ids - output_file_ids), len(output_file_ids-input_file_ids))
55 |
56 | for each in tqdm(output_data):
57 | output_file_ids.add(str(each['_id']).split("_")[0])
58 | print("src_tgt_intersection:",len(input_file_ids - output_file_ids), len(output_file_ids-input_file_ids))
59 | print(len(output_data))
60 |
61 | # validate : vocab_check
62 | vocab_check = defaultdict(int)
63 | total = 0
64 | for each in tqdm(output_data):
65 | variables = each['orig_vars']
66 | for var in variables:
67 | total += 1
68 | normvar = normalize(var)
69 | _, vocab_stat = get_var_token(normvar,word_to_idx)
70 | if "in_vocab" in vocab_stat:
71 | vocab_check['in_vocab']+=1
72 | if "not_in_vocab" in vocab_stat:
73 | vocab_check['not_in_vocab']+=1
74 | if "part_in_vocab" in vocab_stat:
75 | vocab_check['part_in_vocab']+=1
76 |
77 | print(vocab_check, round(vocab_check['in_vocab']*100/total,2), round(vocab_check['not_in_vocab']*100/total,2), round(vocab_check['part_in_vocab']*100/total,2))
78 |
79 |
80 | def generate_id_files(data, output_data, tokenizer, word_to_idx, n):
81 |
82 | for d in tqdm(data):
83 | try:
84 | ppw = preprocess_word_mask(d,tokenizer,word_to_idx)
85 | outrow = {"words":ppw[4],"mod_words":ppw[6],"inputids":ppw[0],"labels":ppw[1],"gold_texts":ppw[2],"gold_texts_id":ppw[3],"meta":[],"orig_vars":ppw[5], "_id":ppw[7]}
86 | # if input length is more than max possible 1024 then split and make more sample found by tracing _id
87 | if len(outrow['inputids']) > n:
88 | for i in range(0, len(outrow['inputids']), n):
89 | sample = {"words": outrow['words'][i:i+n],
90 | "mod_words":outrow['mod_words'][i:i+n],
91 | "inputids":outrow['inputids'][i:i+n],
92 | "labels":outrow["labels"][i:i+n],
93 | "gold_texts":outrow["gold_texts"],
94 | "gold_texts_id":outrow["gold_texts_id"],
95 | "orig_vars":outrow["orig_vars"],
96 | "meta":outrow["meta"],
97 | "_id":str(outrow['_id'])+"_"+str((i)//n),}
98 | output_data.append(sample)
99 | else:
100 | output_data.append(outrow)
101 | except:
102 | print("Unexpected error:", sys.exc_info()[0])
103 | traceback.print_exception(*sys.exc_info())
104 |
105 | def change_case(strt):
106 | return ''.join(['_'+i.lower() if i.isupper() else i for i in strt]).lstrip('_')
107 | def is_camel_case(s):
108 | return s != s.lower() and s != s.upper() and "_" not in s
109 |
110 | def normalize(k):
111 | if is_camel_case(k): k=change_case(k)
112 | else: k=k.lower()
113 | return k
114 |
115 | def get_var_token(norm_variable_word,word_to_idx):
116 | vocab_check = defaultdict(int)
117 | token = word_to_idx.get(norm_variable_word,args.vocab_size)
118 | if token == args.vocab_size:
119 | vocab_check['not_in_vocab']+=1
120 | if "_" in norm_variable_word:
121 | word_splits=norm_variable_word.split("_")
122 | word_splits = [ee for ee in word_splits if ee ]
123 | for x in word_splits:
124 | ptoken=word_to_idx.get(x,args.vocab_size)
125 | if ptoken!=args.vocab_size:
126 | token=ptoken
127 | vocab_check['part_in_vocab']+=1
128 | break
129 | else:
130 | vocab_check['in_vocab']+=1
131 | return [token], vocab_check
132 |
133 |
134 | def canonicalize_code(code):
135 | code = re.sub('//.*?\\n|/\\*.*?\\*/', '\\n', code, flags=re.S)
136 | lines = [l.rstrip() for l in code.split('\\n')]
137 | code = '\\n'.join(lines)
138 | # code = re.sub('@@\\w+@@(\\w+)@@\\w+', '\\g<1>', code)
139 | return code
140 |
141 | def preprocess_word_mask(text,tokenizer, word_to_idx):
142 | # vars_map = text['vars_map'] #needed for vartype detection
143 | # vars_map = dict([[ee[1],ee[0]] for ee in vars_map])
144 | _id = text['mongo_id']
145 | ftext = canonicalize_code(text['norm_func'])
146 | words = ftext.replace("\n"," ").split(" ")
147 | pwords =[]
148 | tpwords =[]
149 | owords =[]
150 | towords =[]
151 | pos=0
152 | masked_pos=[]
153 | var_words =[]
154 | var_toks = []
155 | mod_words = []
156 | orig_vars = []
157 |
158 | vocab=tokenizer.get_vocab()
159 |
160 | for word in words:
161 | if re.search(args.var_loc_pattern, word):
162 | idx = 0
163 | for each_var in list(re.finditer(args.var_loc_pattern,word)):
164 | s = each_var.start()
165 | e = each_var.end()
166 | prefix = word[idx:s]
167 | var = word[s:e]
168 | orig_var = var.split("@@")[-2]
169 |
170 | # Somethings attached before the variables
171 | if prefix:
172 | toks = tokenizer.tokenize(prefix)
173 | for t in toks:
174 | mod_words.append(t)
175 | tpwords.append(vocab[t])
176 | towords.append(vocab[t])
177 |
178 | # Original variable handling
179 | # ---- IF dwarf : inputs: labels:tokenized,normalized var
180 | # ---- NOT dwarf: inputs: tokenized,orig_vars labels:tokenized,orig_vars
181 |
182 | norm_variable_word = normalize(orig_var)
183 | var_tokens, _ = get_var_token(norm_variable_word,word_to_idx)
184 | var_toks.append(var_tokens)
185 | var_words.append(norm_variable_word) #Gold_texts (gold labels)
186 | mod_words.append(orig_var)
187 | orig_vars.append(orig_var)
188 |
189 | tpwords.append(vocab[""])
190 | towords.append(var_tokens[0])
191 |
192 | idx = e
193 |
194 | # Postfix if any
195 | postfix = word[idx:]
196 | if postfix:
197 | toks = tokenizer.tokenize(postfix)
198 | for t in toks:
199 | mod_words.append(t)
200 | tpwords.append(vocab[t])
201 | towords.append(vocab[t])
202 |
203 | else:
204 | # When there is no variables
205 | toks = tokenizer.tokenize(word)
206 | for t in toks:
207 | if t == "":
208 | continue #skip adding the token if code has keyword
209 | mod_words.append(t)
210 | tpwords.append(vocab[t])
211 | towords.append(vocab[t])
212 |
213 | assert len(tpwords) == len(towords)
214 | assert len(var_toks) == len(var_words)
215 |
216 | # 0 1 2 3 4 5 6 7
217 | return tpwords,towords,var_words, var_toks, words, orig_vars, mod_words, _id
218 |
219 |
220 | def main(args):
221 |
222 | tokenizer = RobertaTokenizerFast.from_pretrained(args.tokenizer)
223 | word_to_idx = json.load(open(args.vocab_word_to_idx))
224 | idx_to_word = json.load(open(args.vocab_idx_to_word))
225 | max_sample_chunk = args.max_chunk_size-2
226 | args.vocab_size = len(word_to_idx) #OOV as UNK[assinged id=vocab_size]
227 | print("Vocab_size:",args.vocab_size)
228 |
229 | train, test = [], []
230 | src_train_ids, srctest_ids = set(), set()
231 | tgt_train_ids, tgttest_ids = set(), set()
232 |
233 | src_train_ids, train = read_input_files(args.train_file)
234 | srctest_ids, test = read_input_files(args.test_file)
235 | print(f"Data size Train: {len(train)} \t Test: {len(test)}")
236 |
237 | num_processes = min(args.workers, cpu_count())
238 | print("Running with #workers : ",num_processes)
239 |
240 | prep_input_files(train, num_processes, tokenizer, word_to_idx, max_sample_chunk, src_train_ids, tgt_train_ids, args.out_train_file )
241 | prep_input_files(test, num_processes, tokenizer, word_to_idx, max_sample_chunk, srctest_ids, tgttest_ids, args.out_test_file)
242 |
243 |
244 | if __name__ == "__main__":
245 |
246 | parser = argparse.ArgumentParser()
247 | parser.add_argument('--train_file', type=str, help='name of the train file')
248 | parser.add_argument('--test_file', type=str, help='name of name of the test file')
249 |
250 | parser.add_argument('--tokenizer', type=str, help='path to the tokenizer')
251 |
252 | parser.add_argument('--vocab_word_to_idx', type=str, help='Output Vocab Word to index file')
253 | parser.add_argument('--vocab_idx_to_word', type=str, help='Output Vocab Index to Word file')
254 |
255 | parser.add_argument('--vocab_size', type=int, default=50001, help='size of output vocabulary')
256 | parser.add_argument('--var_loc_pattern', type=str, default="@@\w+@@\w+@@", help='pattern representing variable location')
257 | parser.add_argument('--max_chunk_size', type=int, default=1024, help='size of maximum chunk of input for the model')
258 | parser.add_argument('--workers', type=int, default=20, help='number of parallel workers you need')
259 |
260 | parser.add_argument('--out_train_file', type=str, help='name of the output train file')
261 | parser.add_argument('--out_test_file', type=str, help='name of name of the output test file')
262 | args = parser.parse_args()
263 |
264 | main(args)
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/parse_decompiled_code.py:
--------------------------------------------------------------------------------
1 | import re
2 | import logging
3 |
4 |
5 | l = logging.getLogger('main')
6 | #
7 | # Raw IDA decompiled code pre-processing
8 | #
9 | class IDAParser:
10 | def __init__(self, decompiled_code, binary_name, linkage_name_to_func_name) -> None:
11 | self.decompiled_code = decompiled_code
12 | self.functions = {}
13 | self.func_addr_to_name = {}
14 | self.func_name_to_addr = {}
15 | self.linkage_name_to_func_name = linkage_name_to_func_name
16 | self.func_name_to_linkage_name = self.create_dict_fn_to_ln()
17 | self.binary_name = binary_name
18 | self.func_name_wo_line = {}
19 | self.preprocess_ida_raw_code()
20 |
21 | def create_dict_fn_to_ln(self):
22 |
23 | # linkage name to func name (it is the signature abc(void), so split at '(')
24 | func_name_to_linkage_name = {}
25 | if self.linkage_name_to_func_name:
26 | for ln, fn in self.linkage_name_to_func_name.items():
27 | func_name_to_linkage_name[fn.split('(')[0]] = ln
28 | return func_name_to_linkage_name
29 |
30 | def preprocess_ida_raw_code(self):
31 |
32 | data = self.decompiled_code
33 | func_addr_to_line_count = {}
34 | functions, self.func_name_to_addr, self.func_addr_to_name, self.func_name_wo_line = self.split_ida_c_file_into_funcs(data)
35 | for addr, func in functions.items():
36 | try:
37 | func_sign = func.split('{')[0].strip()
38 | func_body = '{'.join(func.split('{')[1:])
39 |
40 | if not addr in self.func_addr_to_name:
41 | continue
42 | func_name = self.func_addr_to_name[addr]
43 |
44 | # find local variables
45 | varlines_bodylines = func_body.strip("\n").split('\n\n')
46 | if len(varlines_bodylines) >= 2:
47 | var_dec_lines = varlines_bodylines[0]
48 | local_vars = self.find_local_vars(var_dec_lines)
49 | else:
50 | local_vars = []
51 | self.functions[addr] = {'func_name': func_name,
52 | 'func_name_no_line': '_'.join(func_name.split('_')[:-1]), # refer to comments in split_ida_c_file_into_funcs
53 | 'func': func,
54 | 'func_prototype': func_sign,
55 | 'func_body': func_body,
56 | 'local_vars': local_vars,
57 | 'addr': addr
58 | }
59 |
60 | except Exception as e:
61 | l.error(f'Error in {self.binary_name}:{func_name}:{addr} = {e}')
62 | l.info(f'Functions after IDA parsing {self.binary_name} :: {len(self.functions)}')
63 |
64 | def split_ida_c_file_into_funcs(self, data: str):
65 |
66 | line_count = 1
67 | chunks = data.split('//----- ')
68 | data_declarations = chunks[0].split('//-------------------------------------------------------------------------')[2]
69 | func_dict, func_name_to_addr, func_addr_to_name, func_name_with_linenum_linkage_name = {}, {}, {}, {}
70 | func_name_wo_line = {}
71 | for chunk in chunks:
72 | lines = chunk.split('\n')
73 | line_count = line_count
74 | if not lines:
75 | continue
76 | if '-----------------------------------' in lines[0]:
77 | name = ''
78 | func_addr = lines[0].strip('-').strip()
79 | func = '\n'.join(lines[1:])
80 | func_dict[func_addr] = func
81 |
82 | # func name and line number
83 | all_func_lines = func.splitlines()
84 | first_line = all_func_lines[0]
85 | sec_line = all_func_lines[1]
86 | if '(' in first_line:
87 | name = first_line.split('(')[0].split(' ')[-1].strip('*') + '_' + str(line_count + 1)
88 | elif '//' in first_line and '(' in sec_line:
89 | name = sec_line.split('(')[0].split(' ')[-1].strip('*') + '_' + str(line_count + 2)
90 |
91 | if name:
92 | # for cpp we use linkage name/mangled names instead of original function names
93 | # 1. dwarf info gives func name w/o class name but IDA has class_name::func_name. 2. to help with function overloading or same func name in different classes
94 | # if there is no mangled name, original function is copied to dict so we have all the func names.
95 | # so in case of C, we can use same dict. it is essentially func name
96 |
97 | tmp_wo_line = name.split('_')
98 | if len(tmp_wo_line) <=1 :
99 | continue
100 | name_wo_line, line_num = '_'.join(tmp_wo_line[:-1]), tmp_wo_line[-1]
101 |
102 | # replace demangled name with mangled name (cpp) or simply a func name (c)
103 | if name_wo_line in self.func_name_to_linkage_name: # it is not func from source
104 | # type-strip
105 | name = self.func_name_to_linkage_name[name_wo_line]
106 | name = f'{name}_{line_num}'
107 | elif name_wo_line.strip().split('::')[-1] in self.func_name_to_linkage_name:
108 | name = f'{name}_{line_num}'
109 |
110 | # strip
111 | func_name_to_addr[name] = func_addr
112 | func_addr_to_name[func_addr] = name
113 | func_name_wo_line[name_wo_line] = func_addr
114 |
115 | line_count += len(lines) - 1
116 | return func_dict, func_name_to_addr, func_addr_to_name, func_name_wo_line
117 |
118 | def find_local_vars(self, lines):
119 | # use regex
120 | local_vars = []
121 | regex = r"(\w+(\[\d+\]|\d{0,6}));"
122 | matches = re.finditer(regex, lines)
123 | if matches:
124 | for m in matches:
125 | tmpvar = m.group(1)
126 | if not tmpvar:
127 | continue
128 | lv = tmpvar.split('[')[0]
129 | local_vars.append(lv)
130 | return local_vars
131 |
132 |
133 | #
134 | # Raw Ghidra decompiled code pre-processing
135 | #
136 |
137 | class GhidraParser:
138 | def __init__(self, decompiled_code_path, binary_name) -> None:
139 | self.decompiled_code_path = decompiled_code_path
140 | self.functions = {}
141 | self.func_addr_to_name = {}
142 | self.func_name_to_addr = {}
143 | self.linkage_name_to_func_name = None
144 | self.func_name_to_linkage_name = None
145 | self.binary_name = binary_name
146 | self.func_name_wo_line = {}
147 | self.preprocess_ghidra_raw_code()
148 |
149 | def preprocess_ghidra_raw_code(self):
150 | data = self.decompiled_code_path
151 | functions, self.func_name_to_addr, self.func_addr_to_name, self.func_name_wo_line = self.split_ghidra_c_file_into_funcs(data)
152 | if not functions:
153 | return
154 | for addr, func in functions.items():
155 | try:
156 | func_sign = func.split('{')[0].strip()
157 | func_body = '{'.join(func.split('{')[1:])
158 | if not addr in self.func_addr_to_name:
159 | continue
160 | func_name = self.func_addr_to_name[addr]
161 |
162 | varlines_bodylines = func_body.strip("\n").split('\n\n')
163 | if len(varlines_bodylines) >= 2:
164 | var_dec_lines = varlines_bodylines[0]
165 | local_vars = self.find_local_vars(varlines_bodylines)
166 | else:
167 | local_vars = []
168 |
169 | self.functions[addr] = {'func_name': func_name,
170 | 'func_name_no_line': '_'.join(func_name.split('_')[:-1]),
171 | 'func': func,
172 | 'func_prototype': func_sign,
173 | 'func_body': func_body,
174 | 'local_vars': local_vars,
175 | # 'arguments': func_args,
176 | 'addr': addr
177 | }
178 | except Exception as e:
179 | l.error(f'Error in {self.binary_name}:{func_name}:{addr} = {e}')
180 |
181 | def split_ghidra_c_file_into_funcs(self, data: str):
182 |
183 | chunks = data.split('//----- ')
184 |
185 | func_dict, func_name_to_addr, func_addr_to_name = {}, {}, {}
186 | func_name_wo_line = {}
187 |
188 | line_count = 1
189 | for chunk in chunks[1:]:
190 | line_count = line_count
191 | lines = chunk.split('\n')
192 | if not lines:
193 | continue
194 | if '-----------------------------------' in lines[0]:
195 | func_addr = lines[0].strip('-').strip()
196 | func = '\n'.join(lines[1:])
197 | func_dict[func_addr] = func
198 |
199 | # get func name and line number TODO: Ghidra had a func split in two lines - maybe use regex for it
200 | all_func_lines = func.split('\n\n')
201 | first_line = all_func_lines[0]
202 | if '(' in first_line:
203 | t_name = first_line.split('(')[0]
204 | if "\n" in t_name:
205 | name = t_name.split('\n')[-1].split(' ')[-1].strip('*') + '_' + str(line_count + 1)
206 | else:
207 | name = t_name.split(' ')[-1].strip('*') + '_' + str(line_count + 1)
208 |
209 | if name:
210 | name_wo_line = '_'.join(name[:-1])
211 | func_name_to_addr[name] = func_addr
212 | func_addr_to_name[func_addr] = name
213 | func_name_wo_line[name_wo_line] = func_addr
214 |
215 | line_count += len(lines) - 1
216 | return func_dict, func_name_to_addr, func_addr_to_name, func_name_wo_line
217 |
218 | def find_local_vars(self, varlines_bodylines):
219 |
220 | all_vars = []
221 | try:
222 | first_curly = ''
223 | sep = ''
224 | dec_end = 0
225 | for elem in range(len(varlines_bodylines)):
226 | if varlines_bodylines[elem] == '{' and first_curly == '':
227 | first_curly = '{'
228 | dec_start = elem + 1
229 | if first_curly == '{' and sep == '' and varlines_bodylines[elem] == ' ':
230 | dec_end = elem - 1
231 |
232 | for index in range(dec_start, dec_end + 1):
233 | if index < len(varlines_bodylines):
234 | f_sp = varlines_bodylines[index].split(' ')
235 | if f_sp:
236 | tmp_var = f_sp[-1][:-1].strip('*')
237 |
238 | if tmp_var.strip().startswith('['):
239 | up_var = f_sp[-2:-1][0].strip('*')
240 | all_vars.append(up_var)
241 | else:
242 | all_vars.append(tmp_var)
243 |
244 | except Exception as e:
245 | l.error(f'Error in finding local vars {self.binary_name} :: {e}')
246 |
247 | return all_vars
248 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/runner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import glob
4 | import shutil
5 | import logging
6 | from tqdm import tqdm
7 | from pathlib import Path
8 | import concurrent.futures
9 | from multiprocessing import Manager
10 |
11 | from binary import Binary
12 | from joern_parser import JoernParser
13 | from variable_matching import DataLoader
14 | from decompiler.run_decompilers import Decompiler
15 | from utils import set_up_data_dir, JoernServer, read_json
16 | from create_dataset_splits import create_train_and_test_sets
17 |
18 | l = logging.getLogger('main')
19 |
20 | class Runner:
21 | def __init__(self, decompiler, target_binaries, WORKERS, path_manager, PORT, language, DEBUG, language_map=None) -> None:
22 | self.decompiler = decompiler
23 | self.target_binaries = target_binaries
24 | self.WORKERS = WORKERS
25 | self.path_manager = path_manager
26 | self.binary = None
27 | self.PORT = PORT
28 | self.language = language
29 | self.DEBUG = DEBUG
30 | self.collect_workdir = Manager().dict()
31 | self.language_map = language_map or {}
32 |
33 | def _setup_environment_for_ghidra(self):
34 | java_home = "/usr/lib/jvm/java-17-openjdk-amd64"
35 | os.environ["JAVA_HOME"] = java_home
36 | if os.environ.get('JAVA_HOME') != java_home:
37 | l.error(f"INCORRECT JAVA VERSION FOR GHIDRA: {os.environ.get('JAVA_HOME')}")
38 |
39 | def decompile_runner(self, binary_path: str):
40 | """
41 | Runner method to handle the decompilation of a binary.
42 | """
43 | try:
44 | decompile_results = {}
45 | binary_name = Path(binary_path).name
46 | l.info(f"Processing target binary: {binary_name}")
47 | workdir = f'{self.path_manager.tmpdir}/workdir/{binary_name}'
48 | l.debug(f"{binary_name} in workdir {workdir}")
49 | set_up_data_dir(self.path_manager.tmpdir, workdir, self.decompiler)
50 |
51 | # collect dwarf info and create strip and type-strip binaries
52 | self.binary = Binary(binary_path, binary_name, self.path_manager, self.language, self.decompiler, True)
53 | if not (self.binary.strip_binary and self.binary.type_strip_binary):
54 | l.error(f"Strip/Type-strip binary not found for {self.binary.binary_name}")
55 | return
56 | l.debug(f"Strip/Type-strip binaries created for {self.binary.binary_name}")
57 | decomp_workdir = os.path.join(workdir, f"{self.decompiler}_data")
58 | decompiler_path = getattr(self.path_manager, f"{self.decompiler}_path")
59 | if self.decompiler == "ghidra":
60 | self._setup_environment_for_ghidra()
61 | for binary_type in ["type_strip", "strip"]:
62 | binary_path = getattr(self.binary, f"{binary_type}_binary")
63 | l.debug(f"Decompiling {self.binary.binary_name} :: {binary_type} with {self.decompiler} :: {binary_path}")
64 |
65 | dec = Decompiler(
66 | decompiler=self.decompiler,
67 | decompiler_path=Path(decompiler_path),
68 | decompiler_workdir=decomp_workdir,
69 | binary_name=self.binary.binary_name,
70 | binary_path=binary_path,
71 | binary_type=binary_type,
72 | decompiled_binary_code_path=os.path.join(self.path_manager.dc_path, str(self.decompiler), binary_type, f'{binary_name}.c'),
73 | failed_path=os.path.join(self.path_manager.failed_path_ida, binary_type),
74 | type_strip_addrs=os.path.join(self.path_manager.dc_path, self.decompiler, f"type_strip-addrs", binary_name),
75 | type_strip_mangled_names = os.path.join(self.path_manager.dc_path, self.decompiler, f"type_strip-names", binary_name)
76 | )
77 | dec_path = dec
78 |
79 | if not dec_path:
80 | l.error(f"Decompilation failed for {self.binary.binary_name} :: {binary_type} ")
81 | l.info(f"Decompilation succesful for {self.binary.binary_name} :: {binary_type}!")
82 | decompile_results[binary_type] = dec_path
83 |
84 | except Exception as e:
85 | l.error(f"Error in decompiling {self.binary.binary_name} with {self.decompiler}: {e}")
86 | return
87 |
88 | self.collect_workdir[self.binary.binary_name] = workdir
89 | return self.binary.dwarf_dict, decompile_results
90 |
91 | def joern_runner(self, dwarf_data, binary_name, decompilation_results, PARSE):
92 |
93 | workdir = self.collect_workdir[binary_name]
94 | joern_data_strip, joern_data_type_strip = '', ''
95 | try:
96 | decompiled_code_strip_path = decompilation_results['strip'].decompiled_binary_code_path
97 | decompiled_code_type_strip_path = decompilation_results['type_strip'].decompiled_binary_code_path
98 | decompiled_code_type_strip_names = decompilation_results['type_strip'].type_strip_mangled_names
99 | dwarf_info_path = dwarf_data
100 | data_map_dump = os.path.join(self.path_manager.match_path, self.decompiler, binary_name)
101 |
102 | if not (os.path.exists(decompiled_code_strip_path) and os.path.exists(decompiled_code_type_strip_path)):
103 | l.error(f"Decompiled code not found for {binary_name} :: {self.decompiler}")
104 | return
105 |
106 | # paths for joern
107 | joern_strip_path = os.path.join(self.path_manager.joern_data_path, self.decompiler, 'strip', (binary_name + '.json'))
108 | joern_type_strip_path = os.path.join(self.path_manager.joern_data_path, self.decompiler, 'type_strip', (binary_name + '.json'))
109 |
110 | if PARSE:
111 | # JOERN
112 | l.info(f'Joern parsing for {binary_name} :: strip :: {self.decompiler} :: in {workdir}')
113 | joern_data_strip = JoernParser(binary_name=binary_name, binary_type='strip', decompiler=self.decompiler,
114 | decompiled_code=decompiled_code_strip_path, workdir=workdir, port=self.PORT,
115 | outpath=joern_strip_path).joern_data
116 |
117 | l.info(f'Joern parsing for {binary_name} :: type-strip :: {self.decompiler} :: in {workdir}')
118 | joern_data_type_strip = JoernParser(binary_name=binary_name, binary_type='type-strip', decompiler=self.decompiler,
119 | decompiled_code=decompiled_code_type_strip_path, workdir=workdir, port=self.PORT,
120 | outpath=joern_type_strip_path).joern_data
121 |
122 |
123 | else:
124 | joern_data_strip = read_json(joern_strip_path)
125 | joern_data_type_strip = read_json(joern_type_strip_path)
126 |
127 | l.info(f'Joern parsing completed for {binary_name} :: strip :: {self.decompiler} :: {len(joern_data_strip)}')
128 | l.info(f'Joern parsing completed for {binary_name} :: type-strip :: {self.decompiler} :: {len(joern_data_type_strip)}')
129 |
130 |
131 | if not (joern_data_strip and joern_data_type_strip):
132 | l.info(f"oops! strip/ty-strip joern not found! {binary_name}")
133 | return
134 |
135 | l.info(f'Start mapping for {binary_name} :: {self.decompiler}')
136 |
137 | # Select language strictly from map if provided, else fall back to default
138 | language_to_use = self.language_map.get(binary_name, self.language) if self.language_map else self.language
139 | matchvariable = DataLoader(binary_name, dwarf_info_path, self.decompiler, decompiled_code_strip_path,
140 | decompiled_code_type_strip_path, joern_data_strip, joern_data_type_strip, data_map_dump, language_to_use, decompiled_code_type_strip_names)
141 |
142 |
143 | except Exception as e:
144 | l.info(f"Error! :: {binary_name} :: {self.decompiler} :: {e} ")
145 |
146 |
147 | def run(self, PARSE, splits):
148 |
149 | if PARSE:
150 | try:
151 | joern_progress = tqdm(total=0, desc="Joern Processing")
152 | with concurrent.futures.ProcessPoolExecutor(max_workers=self.WORKERS) as executor:
153 | future_to_binary = {executor.submit(self.decompile_runner, binary): binary for binary in self.target_binaries}
154 | # As each future completes, process it with joern_runner
155 | # for future in concurrent.futures.as_completed(future_to_binary):
156 | for future in tqdm(concurrent.futures.as_completed(future_to_binary), total=len(future_to_binary), desc="Decompiling Binaries"):
157 | binary = future_to_binary[future]
158 | binary_name = Path(binary).name
159 | binary_info, decompilation_results = future.result()
160 | if not decompilation_results:
161 | l.error(f"No Decompilation results for {binary_name}")
162 | return
163 | l.info(f"Start Joern for :: {binary_name}")
164 | try:
165 | joern_progress.total += 1
166 | # Setup JAVA_HOME for Joern
167 | os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
168 | java_env = os.environ.get('JAVA_HOME', None)
169 | if java_env != '/usr/lib/jvm/java-11-openjdk-amd64':
170 | l.error(f"INCORRECT JAVA VERSION JOERN: {java_env}")
171 | with JoernServer(self.path_manager.joern_dir, self.PORT) as joern_server:
172 | self.joern_runner(binary_info, binary_name, decompilation_results, PARSE)
173 | joern_progress.update(1)
174 | except Exception as e:
175 | l.error(f"Error in starting Joern server for {binary_name} :: {e}")
176 | finally:
177 | if joern_server.is_server_running():
178 | joern_server.stop()
179 | l.debug(f"Stop Joern for :: {binary_name}")
180 | time.sleep(10)
181 |
182 | # dedup functions -> create train and test sets
183 | joern_progress.close()
184 | if splits:
185 | create_train_and_test_sets( f'{self.path_manager.tmpdir}', self.decompiler)
186 |
187 | except Exception as e:
188 | l.error(f"Error during parallel decompilation: {e}")
189 |
190 | else:
191 | create_train_and_test_sets( f'{self.path_manager.tmpdir}', self.decompiler)
192 |
193 | # if not debug, delete tmpdir and copy splits to data dir
194 | source_dir = f'{self.path_manager.tmpdir}/splits/'
195 | # Get all file paths
196 | file_paths = glob.glob(os.path.join(source_dir, '**', 'final*.jsonl'), recursive=True)
197 |
198 | for file_path in file_paths:
199 | # Check if the filename contains 'final'
200 | if 'final' in os.path.basename(file_path):
201 | relative_path = os.path.relpath(file_path, start=source_dir)
202 | new_destination_path = os.path.join(self.path_manager.data_dir, relative_path)
203 | # Create directories if they don't exist
204 | os.makedirs(os.path.dirname(new_destination_path), exist_ok=True)
205 | # Copy the file
206 | shutil.copy(file_path, new_destination_path)
207 | l.info(f"Copied {file_path} to {new_destination_path}")
208 | if self.path_manager.tmpdir and not self.DEBUG:
209 | shutil.rmtree(self.path_manager.tmpdir)
210 | l.info(f"Deleted {self.path_manager.tmpdir}")
--------------------------------------------------------------------------------
/varbert/fine-tune/preprocess.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from transformers import RobertaTokenizerFast
3 | import jsonlines
4 | import sys
5 | import json
6 | import traceback
7 | import random
8 | import os
9 | import re
10 | import argparse
11 | from collections import defaultdict
12 | from multiprocessing import Process, Manager, cpu_count, Pool
13 | from itertools import repeat, islice
14 |
15 | def read_input_files(filename):
16 |
17 | samples, sample_ids = [], set()
18 | with jsonlines.open(filename,'r') as f:
19 | for each in tqdm(f):
20 | if each.get('fid') is None:
21 | new_fid = str(each['id']) + "-" + each['func_name'].replace("_", "-")
22 | each['fid'] = new_fid
23 | if each.get('type_stripped_norm_vars') is None:
24 | vars_map = each.get('vars_map')
25 | norm_var_type = {}
26 | if vars_map:
27 | for pair in vars_map:
28 | norm_var = pair[1]
29 | var = pair[0]
30 | if norm_var in norm_var_type and each["type_stripped_vars"][var] != 'dwarf':
31 | norm_var_type[norm_var] = 'dwarf'
32 | else:
33 | norm_var_type[norm_var] = each["type_stripped_vars"][var]
34 | each['type_stripped_norm_vars'] = norm_var_type
35 |
36 | sample_ids.add(each['fid'])
37 | samples.append(each)
38 | return sample_ids, samples
39 |
40 | def prep_input_files(input_file, num_processes, tokenizer, word_to_idx, max_sample_chunk, input_file_ids, output_file_ids, preprocessed_outfile):
41 |
42 | #------------------------ INPUT FILES ------------------------
43 | output_data = Manager().list()
44 | pool = Pool(processes=num_processes) # Instantiate the pool here
45 | each_alloc = len(input_file) // (num_processes-1)
46 | input_data = [input_file[i*each_alloc:(i+1)*each_alloc] for i in range(0,num_processes)]
47 | x = [len(each) for each in input_data]
48 | print(f"Allocation samples for each worker: {len(input_data)}, {x}")
49 |
50 | pool.starmap(generate_id_files,zip(input_data,
51 | repeat(output_data),
52 | repeat(tokenizer),
53 | repeat(word_to_idx),
54 | repeat(max_sample_chunk)
55 | ))
56 | pool.close()
57 | pool.join()
58 |
59 | # Write to Output file
60 | with jsonlines.open(preprocessed_outfile,'w') as f:
61 | for each in tqdm(output_data):
62 | output_file_ids.add(str(each['fid']).split("_")[0])
63 | f.write(each)
64 |
65 | # validate : # source ids == target_ids all ids are present after parallel processing
66 | print(len(input_file_ids), len(output_file_ids))
67 | print(f"src_tgt_intersection:", len(input_file_ids - output_file_ids), len(output_file_ids-input_file_ids))
68 |
69 | # validate : vocab_check
70 | vocab_check = defaultdict(int)
71 | total = 0
72 | for each in tqdm(output_data):
73 | variables = list(each['type_dict'].keys())
74 | for var in variables:
75 | total += 1
76 | _, vocab_stat = get_var_token(var, word_to_idx)
77 | if "in_vocab" in vocab_stat:
78 | vocab_check['in_vocab']+=1
79 | if "not_in_vocab" in vocab_stat:
80 | vocab_check['not_in_vocab']+=1
81 | if "part_in_vocab" in vocab_stat:
82 | vocab_check['part_in_vocab']+=1
83 |
84 | print(vocab_check, round(vocab_check['in_vocab']*100/total,2), round(vocab_check['not_in_vocab']*100/total,2))
85 |
86 | def get_var_token(norm_variable_word,word_to_idx):
87 | vocab_check = defaultdict(int)
88 | token = word_to_idx.get(norm_variable_word,args.vocab_size)
89 | if token == args.vocab_size:
90 | vocab_check['not_in_vocab']+=1
91 | else:
92 | vocab_check['in_vocab']+=1
93 | return [token], vocab_check
94 |
95 |
96 | def preprocess_word_mask(text,tokenizer, word_to_idx):
97 | type_dict = text['type_stripped_norm_vars']
98 | _id = text['_id'] if "_id" in text.keys() else text['fid']
99 | ftext = text['norm_func']
100 | words = ftext.replace("\n"," ").split(" ")
101 | pwords =[]
102 | tpwords =[]
103 | owords =[]
104 | towords =[]
105 | pos=0
106 | masked_pos=[]
107 | var_words =[]
108 | var_toks = []
109 | mod_words = []
110 | labels_typ = []
111 | orig_vars = []
112 | varmap_position = defaultdict(list)
113 | mask_one_ida = False
114 |
115 | vocab=tokenizer.get_vocab()
116 |
117 | for word in words:
118 |
119 |
120 | if re.search(args.var_loc_pattern, word):
121 | idx = 0
122 | for each_var in list(re.finditer(args.var_loc_pattern,word)):
123 | s = each_var.start()
124 | e = each_var.end()
125 | prefix = word[idx:s]
126 | var = word[s:e]
127 | orig_var = var.split("@@")[-2]
128 |
129 | # Somethings attached before the variables
130 | if prefix:
131 | toks = tokenizer.tokenize(prefix)
132 | for t in toks:
133 | mod_words.append(t)
134 | tpwords.append(vocab[t])
135 | towords.append(vocab[t])
136 | labels_typ.append(-100)
137 |
138 | var_tokens, _ = get_var_token(orig_var,word_to_idx)
139 | var_toks.append(var_tokens)
140 | var_words.append(orig_var) # Gold_texts (gold labels)
141 | mod_words.append(orig_var)
142 | orig_vars.append(orig_var)
143 |
144 | # Second label generation variable Type (decomp vs dwarf)
145 | if orig_var not in type_dict or type_dict[orig_var] == args.decompiler: # if it is not present consider ida
146 | labels_typ.append(1)
147 | tpwords.append(vocab[""])
148 | towords.append(-100)
149 | varmap_position[-100].append(pos)
150 | elif type_dict[orig_var] == "dwarf":
151 | labels_typ.append(0)
152 | tpwords.append(vocab[""])
153 | towords.append(var_tokens[0])
154 | varmap_position[orig_var].append(pos)
155 | else:
156 | print("ERROR: CHECK LABEL TYPE IN STRIPPED BIN DICTIONARY")
157 | exit(0)
158 | pos += 1
159 |
160 | idx = e
161 |
162 | # Postfix if any
163 | postfix = word[idx:]
164 | if postfix:
165 | toks = tokenizer.tokenize(postfix)
166 | for t in toks:
167 | mod_words.append(t)
168 | tpwords.append(vocab[t])
169 | towords.append(vocab[t])
170 | labels_typ.append(-100)
171 |
172 | else:
173 | # When there are no variables
174 | toks = tokenizer.tokenize(word)
175 | for t in toks:
176 | if t == "":
177 | continue # skip adding the token if code has keyword
178 | mod_words.append(t)
179 | tpwords.append(vocab[t])
180 | towords.append(vocab[t])
181 | labels_typ.append(-100)
182 |
183 | assert len(tpwords) == len(towords)
184 | assert len(labels_typ) == len(towords)
185 | assert len(var_toks) == len(var_words)
186 |
187 | # 0 1 2 3 4 5 6 7 8 9
188 | return tpwords,towords,var_words, var_toks, labels_typ, words, orig_vars, mod_words, type_dict, _id, varmap_position
189 |
190 |
191 | def generate_id_files(data, output_data, tokenizer, word_to_idx, n):
192 |
193 | for d in tqdm(data):
194 | try:
195 | ppw = preprocess_word_mask(d,tokenizer,word_to_idx)
196 | outrow = {"words":ppw[5],"mod_words":ppw[7],"inputids":ppw[0],"labels":ppw[1],"gold_texts":ppw[2],"gold_texts_id":ppw[3],"labels_type":ppw[4],"meta":[],"orig_vars":ppw[6], "type_dict":ppw[8], "fid":ppw[9],'varmap_position':ppw[10]}
197 | # if input length is more than max possible 1024 then split and make more sample found by tracing _id
198 | if len(outrow['inputids']) > n:
199 | for i in range(0, len(outrow['inputids']), n):
200 | sample = {"words": outrow['words'][i:i+n],
201 | "mod_words":outrow['mod_words'][i:i+n],
202 | "inputids":outrow['inputids'][i:i+n],
203 | "labels":outrow["labels"][i:i+n],
204 | "labels_type":outrow["labels_type"][i:i+n],
205 | "gold_texts":outrow["gold_texts"],
206 | "gold_texts_id":outrow["gold_texts_id"],
207 | "orig_vars":outrow["orig_vars"],
208 | "type_dict":outrow["type_dict"],
209 | "meta":outrow["meta"],
210 | "fid":outrow['fid']+"_"+str((i)//n),
211 | "varmap_position":outrow["varmap_position"],
212 | }
213 | output_data.append(sample)
214 | else:
215 | output_data.append(outrow)
216 | except:
217 | print("Unexpected error:", sys.exc_info()[0])
218 | traceback.print_exception(*sys.exc_info())
219 |
220 |
221 |
222 | def main(args):
223 |
224 | tokenizer = RobertaTokenizerFast.from_pretrained(args.tokenizer)
225 | word_to_idx = json.load(open(args.vocab_word_to_idx))
226 | idx_to_word = json.load(open(args.vocab_idx_to_word))
227 | max_sample_chunk = args.max_chunk_size-2
228 | args.vocab_size = len(word_to_idx) # OOV as UNK[assinged id=vocab_size]
229 | print(f"Vocab_size: {args.vocab_size}")
230 |
231 | train, test = [], []
232 | src_train_ids, srctest_ids = set(), set()
233 | tgt_train_ids, tgttest_ids = set(), set()
234 |
235 | src_train_ids, train = read_input_files(args.train_file)
236 | srctest_ids, test = read_input_files(args.test_file)
237 | print(f"Data size Train: {len(train)} \t Test: {len(test)}")
238 |
239 | num_processes = min(args.workers, cpu_count())
240 | print(f"Running with #workers : {num_processes}")
241 |
242 | prep_input_files(train, num_processes, tokenizer, word_to_idx, max_sample_chunk, src_train_ids, tgt_train_ids, args.out_train_file )
243 | prep_input_files(test, num_processes, tokenizer, word_to_idx, max_sample_chunk, srctest_ids, tgttest_ids, args.out_test_file)
244 |
245 | if __name__ == "__main__":
246 |
247 | parser = argparse.ArgumentParser()
248 | parser.add_argument('--train_file', type=str, help='name of the train file')
249 | parser.add_argument('--test_file', type=str, help='name of name of the test file')
250 | parser.add_argument('--test_notintrain_file', type=str, help='name of the test not in train file')
251 |
252 | parser.add_argument('--tokenizer', type=str, help='path to the tokenizer')
253 |
254 | parser.add_argument('--vocab_word_to_idx', type=str, help='Output Vocab Word to index file')
255 | parser.add_argument('--vocab_idx_to_word', type=str, help='Output Vocab Index to Word file')
256 |
257 | parser.add_argument('--vocab_size', type=int, default=150001, help='size of output vocabulary')
258 | parser.add_argument('--var_loc_pattern', type=str, default="@@\w+@@\w+@@", help='pattern representing variable location')
259 | parser.add_argument('--decompiler', type=str, default="ida", help='decompiler for type prediction; ida or ghidra')
260 | parser.add_argument('--max_chunk_size', type=int, default=1024, help='size of maximum chunk of input for the model')
261 | parser.add_argument('--workers', type=int, default=30, help='number of parallel workers you need')
262 |
263 | parser.add_argument('--out_train_file', type=str, help='name of the output train file')
264 | parser.add_argument('--out_test_file', type=str, help='name of name of the output test file')
265 | # parser.add_argument('--out_test_notintrain_file', type=str, help='name of the output test not in train file')
266 |
267 |
268 | args = parser.parse_args()
269 |
270 | main(args)
271 |
--------------------------------------------------------------------------------
/varbert/README.md:
--------------------------------------------------------------------------------
1 | ### Table of Contents
2 |
3 | 1. [Training VarBERT](#training-varbert)
4 | 2. [Fine-tuning VarBERT](#fine-tune)
5 | 3. [Vocab Files](#vocab-files)
6 | 4. [Tokenizer](#tokenizer)
7 | 5. [Masked Language Modeling (MLM)](#masked-language-modeling-mlm)
8 | 6. [Constrained Masked Language Modeling (CMLM) in VarBERT](#constrained-masked-language-modeling-cmlm-in-varbert)
9 | 7. [Resize Model](#resize-model)
10 |
11 | ### Training VarBERT
12 |
13 | In our paper, we follow a two-step training process:
14 | - **Pre-training**: VarBERT is pre-trained on source code functions (HSC data set) using Masked Language Modeling (MLM) followed by Constrained Masked Language Modeling (CMLM).
15 | - **Fine-tuning**: Subsequently, VarBERT is fine-tuned on top of the pre-trained model using VarCorpus (decompilation output of IDA and Ghidra) to predict variable names and variable origins (i.e., whether a variable originates from source code or is decompiler-generated).
16 |
17 | Training Process Overview (from the paper):
18 |
19 | BERT-Base → MLM → CMLM → Fine-tune
20 |
21 |
22 | This approach can be adapted for use with any other decompiler capable of generating C-Style decompilation output. Use the pre-trained model from step one and fine-tune with a new (or existing) decompiler. jump to []
23 |
24 | **Essential Components for Training:**
25 | 1. Base Model: Choose from BERT-Base, MLM Model or CMLM Model. Remember to resize the Base Model if the vocab size on subsequent model is different ([How to Resize Model](#resize-model)).
26 | 2. [Tokenizer](#tokenizer)
27 | 3. Train and Test sets
28 | 4. [Vocab Files](#vocab-files): Contains the most frequent variable names from your training dataset.
29 |
30 |
31 | ### Fine-tune
32 | Fine-tuning is generally performed on top of a pre-trained model (MLM + CMLM Model in our case). However, fine-tuning can also be directly applied to an MLM model or a BERT-Base model.
33 | During this phase, the model learns to predict variable names and their origins.
34 | VarBERT predicts the `Top-N` variables (where N can be 1, 3, 5, 10) for variable names. For their origin, the output dwarf indicates a source code origin, while `ida` or `ghidra` suggest that the variable is decompiler-generated.
35 |
36 | Access our fine-tuned model from our paper: [Models](https://www.dropbox.com/scl/fo/socl7rd5lsv926whylqpn/h?rlkey=i0x74bdipj41hys5rorflxawo&dl=0)
37 |
38 | To train a new model follow these steps:
39 |
40 | - **Base Model**: [CMLM Model](https://www.dropbox.com/scl/fi/72ku0tf3o93kn67k60d7d/CMLM_MODEL.tar.gz?rlkey=8kwlfwc87uwcsab86np4bhub0&dl=0)
41 | - **Tokenizer**: Refer to [Tokenizer](#tokenizer)
42 | - **Train and Test sets**: Each VarCorpus data set tarball has both pre-processed and non-processed train and test sets. If you prefer to use pre-processed sets please jump to step 2 (i.e. training a model), otherwise, start with step 1. [VarCorpus](https://www.dropbox.com/scl/fo/3thmg8xoq2ugtjwjcgjsm/h?rlkey=azgjeq513g4semc1qdi5xyroj&dl=0). Alternatively, if you created a new data set using `generate.py` in [Building VarCorpus](./varcorpus/README.md) use files saved in `data dir`
43 | - **Vocab Files**: You can find our vocab files in tarball of each trained model available at [link](https://www.dropbox.com/scl/fo/socl7rd5lsv926whylqpn/h?rlkey=i0x74bdipj41hys5rorflxawo&dl=0) or refer to [Vocab Files](#vocab-files) to create new vocab files.
44 |
45 |
46 | 1. Preprocess data set for training
47 |
48 | ```python
49 | python3 preprocess.py \
50 | --train_file \
51 | --test_file \
52 | --tokenizer \
53 | --vocab_word_to_idx \
54 | --vocab_idx_to_word \
55 | --decompiler \
56 | --max_chunk_size 800 \
57 | --workers 2 \
58 | --out_train_file \
59 | --out_test_file
60 | ```
61 |
62 | 2. Fine-tune a model
63 |
64 |
65 | ```python
66 | python3 -m torch.distributed.launch --nproc_per_node=2 training.py \
67 | --overwrite_output_dir \
68 | --train_data_file \
69 | --output_dir \
70 | --block_size 800 \
71 | --tokenizer_name \
72 | --model_type roberta \
73 | --model_name_or_path \
74 | --vocab_path \
75 | --do_train \
76 | --num_train_epochs 4 \
77 | --save_steps 10000 \
78 | --logging_steps 1000 \
79 | --per_gpu_train_batch_size 4 \
80 | --mlm;
81 | ```
82 | This base model path can be any model. Either previously trained CMLM model or MLM model or BERT-base model.
83 |
84 | Note: To run training without distributed, use `python3 training.py`, with the same arguments.
85 |
86 | 3. Run Inference
87 |
88 | ```python
89 | python3 eval.py \
90 | --model_name \
91 | --tokenizer_name \
92 | --block_size 800 \
93 | --data_file \
94 | --prefix ft_bintoo_test \
95 | --batch_size 16 \
96 | --pred_path resultdir \
97 | --out_vocab_map
98 | ```
99 |
100 |
101 | ### Vocab Files
102 | Vocab files consist of the top N most frequently occurring variable names from the chosen training dataset. Specifically, we use the top 50K variable names from the Human Source Code (HSC) dataset and the top 150K variable names from the VarCorpus dataset. These are variables our model learns upon.
103 |
104 | (We use `50001` and `150001` as vocab size for CMLM model and Fine-tuning respectively. top 50K variable names + 1 for UNK)
105 |
106 | Use existing vocab files from our paper:
107 |
108 | Please note that we have different vocab files for each model. You can find our vocab files in tarball of each trained model available at [link](https://www.dropbox.com/scl/fo/socl7rd5lsv926whylqpn/h?rlkey=i0x74bdipj41hys5rorflxawo&dl=0).
109 |
110 | To create new vocab files:
111 |
112 | ```python
113 | python3 generate_vocab.py \
114 | --dataset_type \
115 | --train_file \
116 | --test_file \
117 | --vocab_size \
118 | --output_dir \
119 | ```
120 |
121 |
122 | ### Tokenizer
123 | To adapt model to the unique nature of the new data (source code), we use a Byte-Pair Encoding (BPE) tokenizer to learn a new source vocabulary. We train a tokenizer with a vocabulary size of 50K, similar to RoBERTa's 50,265.
124 |
125 | Use tokenizer trained on HSC data set (train set): [Tokenizer](https://www.dropbox.com/scl/fi/i8seayujpqdc0egavks18/tokenizer.tar.gz?rlkey=fnhorh3uo2diqv0v1qaymzo2r&dl=0)
126 | ```bash
127 | wget -O tokenizer.tar.gz https://www.dropbox.com/scl/fi/i8seayujpqdc0egavks18/tokenizer.tar.gz?rlkey=fnhorh3uo2diqv0v1qaymzo2r&dl=0
128 | ```
129 |
130 | Training a new Tokenizer:
131 | 1. Prepare train set: The input to tokenizer should be in text format.
132 | (If using [HSC data set](https://www.dropbox.com/scl/fi/1eekwcsg7wr7cux6y34xb/hsc_data.tar.gz?rlkey=s3kjroqt7a27hoeoc56mfyljw&dl=0))
133 |
134 | ```python
135 | python3 preprocess.py \
136 | --input_file \
137 | --output_file
138 | ```
139 |
140 | 2. Training Tokenizer
141 |
142 | ```python
143 | python3 train_bpe_tokenizer.py \
144 | --input_path \
145 | --vocab_size 50265 \
146 | --min_frequency 2 \
147 | --output_path
148 | ```
149 |
150 |
151 | **To pre-train model from scratch**
152 |
153 | ### Masked Language Modeling (MLM)
154 |
155 | Learn the representation of code tokens using BERT from scratch through a Masked Language Modeling approach similar to the one used in RoBERTa. In this process, some tokens are randomly masked, and the model learns to predict these masked tokens, thereby gaining a deeper understanding of code-token representations.
156 |
157 | #### Using a Pre-trained MLM Model
158 |
159 | Access our pre-trained MLM model, trained on 5.2M functions, from our paper: [MLM Model](https://www.dropbox.com/scl/fi/72ku0tf3o93kn67k60d7d/CMLM_MODEL.tar.gz?rlkey=8kwlfwc87uwcsab86np4bhub0&dl=0)
160 |
161 | #### To train a new model, follow these steps:
162 |
163 | 1. Preprocess data set for training
164 | [HSC data set](https://www.dropbox.com/scl/fi/1eekwcsg7wr7cux6y34xb/hsc_data.tar.gz?rlkey=s3kjroqt7a27hoeoc56mfyljw&dl=0)
165 |
166 | ```python
167 | python3 preprocess.py \
168 | --train_file \
169 | --test_file \
170 | --output_train_file \
171 | --output_test_file
172 | ```
173 |
174 | 2. Train MLM Model
175 |
176 | - The MLM model is trained on top of the BERT-Base model using pre-processed HSC train and test files from step 1.
177 | - For the BERT-Base model, refer to [BERT-Base Model](https://www.dropbox.com/scl/fi/18p37f5drph8pekv8kcj2/BERT_Base.tar.gz?rlkey=3x4mpr4hmkyndunhg9fpu0p3b&dl=0)
178 |
179 |
180 | ```python
181 | python3 training.py \
182 | --model_name_or_path \
183 | --model_type roberta \
184 | --tokenizer_name \
185 | --train_file \
186 | --validation_file \
187 | --max_seq_length 800 \
188 | --mlm_probability 0.15 \
189 | --num_train_epochs 40 \
190 | --do_train \
191 | --do_eval \
192 | --per_device_train_batch_size 44 \
193 | --per_device_eval_batch_size 44 \
194 | --output_dir \
195 | --save_steps 10000 \
196 | --logging_steps 5000 \
197 | --overwrite_output_dir
198 | ```
199 |
200 | ### Constrained Masked Language Modeling (CMLM)
201 | Constrained MLM is a variation of MLM. In this approach, tokens are not randomly masked; instead, we specifically mask certain tokens, which in our case are variable names in source code functions.
202 |
203 | #### Using a Pre-trained CMLM Model:
204 | Access our pre-trained CMLM model from the paper: [CMLM Model](https://www.dropbox.com/scl/fi/72ku0tf3o93kn67k60d7d/CMLM_MODEL.tar.gz?rlkey=8kwlfwc87uwcsab86np4bhub0&dl=0)
205 |
206 | To train a new model follow these steps:
207 | - **Base Model**: [MLM Model](https://www.dropbox.com/scl/fi/a0i61xeij0bogkusr4yf7/MLM_MODEL.tar.gz?rlkey=pqenu7f851sgdn6ofcfp6dxoa&dl=0)
208 | - **Tokenizer**: Refer to [Tokenizer](#tokenizer)
209 | - **Train and Test sets**: [CMLM Data set](https://www.dropbox.com/scl/fi/q0itko6fitpxx3dx71qhv/cmlm_dataset.tar.gz?rlkey=51j9iagvg8u3rak79euqjocml&dl=0)
210 | - **Vocab Files**: To generate new vocab refer to [Vocab Files](#vocab-files) or use exisiting at [link](https://www.dropbox.com/scl/fi/yot7urpeem53dttditg7p/cmlm_dataset.tar.gz?rlkey=cned7sgijladr1pu5ery82z8a&dl=0)
211 |
212 |
213 | 1. Preprocess data set for training
214 | [HSC CMLM data set](https://www.dropbox.com/scl/fi/q0itko6fitpxx3dx71qhv/cmlm_dataset.tar.gz?rlkey=51j9iagvg8u3rak79euqjocml&dl=0)
215 |
216 | ```python
217 | python3 preprocess.py \
218 | --train_file \
219 | --test_file \
220 | --tokenizer \
221 | --vocab_word_to_idx \
222 | --vocab_idx_to_word \
223 | --max_chunk_size 800 \
224 | --out_train_file \
225 | --out_test_file \
226 | --workers 4 \
227 | --vocab_size 50001
228 | ```
229 |
230 | 2. Train CMLM Model
231 | The CMLM model is trained on top of the MLM model using pre-processed HSC train and test files from step 1.
232 | [MLM Model](https://www.dropbox.com/scl/fi/a0i61xeij0bogkusr4yf7/MLM_MODEL.tar.gz?rlkey=pqenu7f851sgdn6ofcfp6dxoa&dl=0)
233 |
234 | ```python
235 | python3 training.py \
236 | --overwrite_output_dir \
237 | --train_data_file \
238 | --output_dir \
239 | --block_size 800 \
240 | --model_type roberta \
241 | --model_name_or_path \
242 | --tokenizer_name \
243 | --do_train \
244 | --num_train_epochs 30 \
245 | --save_steps 50000 \
246 | --logging_steps 5000 \
247 | --per_gpu_train_batch_size 32 \
248 | --mlm;
249 | ```
250 |
251 | 3. Run Inference
252 |
253 | ```python
254 | python run_cmlm_scoring.py \
255 | --model_name \
256 | --data_file \
257 | --prefix cmlm_hsc_5M_50K \
258 | --pred_path \
259 | --batch_size 40;
260 | ```
261 |
262 | ### Resize Model
263 |
264 | It's necessary to resize the model when the vocab size changes. For example, if you're fine-tuning over a CMLM model initially trained with a 50K vocab size using a 150K vocab size for the fine-tuned, the CMLM model needs to be resized.
265 |
266 | ```python
267 | python3 resize_model.py \
268 | --old_model \
269 | --vocab_path \
270 | --out_model_path
271 | ```
--------------------------------------------------------------------------------
/varbert/cmlm/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import pprint
5 | import logging
6 | import argparse
7 | import jsonlines
8 | import numpy as np
9 | import pandas as pd
10 | from typing import Dict, List, Tuple
11 | from collections import defaultdict
12 | from tqdm import tqdm, trange
13 |
14 | import torch
15 | from torch.nn.utils.rnn import pad_sequence
16 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
17 | from torch.utils.data.distributed import DistributedSampler
18 |
19 | from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score
20 |
21 | from transformers import (
22 | WEIGHTS_NAME,
23 | AdamW,
24 | BertConfig,
25 | BertForMaskedLM,
26 | BertTokenizer,
27 | CamembertConfig,
28 | CamembertForMaskedLM,
29 | CamembertTokenizer,
30 | DistilBertConfig,
31 | DistilBertForMaskedLM,
32 | DistilBertTokenizer,
33 | GPT2Config,
34 | GPT2LMHeadModel,
35 | GPT2Tokenizer,
36 | OpenAIGPTConfig,
37 | OpenAIGPTLMHeadModel,
38 | OpenAIGPTTokenizer,
39 | PreTrainedModel,
40 | PreTrainedTokenizer,
41 | RobertaConfig,
42 | RobertaForMaskedLM,
43 | RobertaTokenizer,
44 | get_linear_schedule_with_warmup,
45 | )
46 |
47 | import torch.nn as nn
48 | from torch.nn import CrossEntropyLoss, MSELoss
49 |
50 | from transformers.activations import ACT2FN, gelu
51 |
52 |
53 | from transformers import RobertaConfig
54 | from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel, BertPreTrainedModel
55 | from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaLMHead
56 |
57 | ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
58 | "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin",
59 | "roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin",
60 | "roberta-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin",
61 | "distilroberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/distilroberta-base-pytorch_model.bin",
62 | "roberta-base-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-openai-detector-pytorch_model.bin",
63 | "roberta-large-openai-detector": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-openai-detector-pytorch_model.bin",
64 | }
65 |
66 | l = logging.getLogger('model_main')
67 | vocab_size = 50001
68 |
69 | class RobertaLMHead2(nn.Module):
70 |
71 | def __init__(self,config):
72 | super().__init__()
73 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
74 | # self.dense = nn.Linear(config.hidden_size, config.hidden_size*8)
75 |
76 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
77 | # self.layer_norm = nn.LayerNorm(8*config.hidden_size, eps=config.layer_norm_eps)
78 |
79 | self.decoder = nn.Linear(config.hidden_size, vocab_size, bias=False)
80 | # self.decoder = nn.Linear(8*config.hidden_size, vocab_size, bias=False)
81 |
82 | self.bias = nn.Parameter(torch.zeros(vocab_size))
83 |
84 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
85 | self.decoder.bias = self.bias
86 |
87 | def forward(self, features, **kwargs):
88 | x = self.dense(features)
89 | x = gelu(x)
90 | x = self.layer_norm(x)
91 |
92 | # project back to size of vocabulary with bias
93 | x = self.decoder(x)
94 |
95 | return x
96 |
97 |
98 | class RobertaForMaskedLMv2(RobertaForMaskedLM):
99 |
100 | def __init__(self, config):
101 | super().__init__(config)
102 | self.lm_head2 = RobertaLMHead2(config)
103 | self.init_weights()
104 |
105 | def forward(
106 | self,
107 | input_ids=None,
108 | attention_mask=None,
109 | token_type_ids=None,
110 | position_ids=None,
111 | head_mask=None,
112 | inputs_embeds=None,
113 | encoder_hidden_states=None,
114 | encoder_attention_mask=None,
115 | masked_lm_labels=None,
116 | output_attentions=None,
117 | output_hidden_states=None,
118 | return_dict=None,
119 | ):
120 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
121 |
122 | # print("inputs:",input_ids,"labels:",labels)
123 | outputs = self.roberta(
124 | input_ids,
125 | attention_mask=attention_mask,
126 | token_type_ids=token_type_ids,
127 | position_ids=position_ids,
128 | head_mask=head_mask,
129 | inputs_embeds=inputs_embeds,
130 | encoder_hidden_states=encoder_hidden_states,
131 | encoder_attention_mask=encoder_attention_mask,
132 | output_attentions=output_attentions,
133 | output_hidden_states=output_hidden_states,
134 | return_dict=return_dict,
135 | )
136 | sequence_output = outputs[0]
137 | prediction_scores = self.lm_head2(sequence_output)
138 | output_pred_scores = torch.topk(prediction_scores,k=20,dim=-1)
139 | outputs = (output_pred_scores,) # Add hidden states and attention if they are here
140 |
141 | masked_lm_loss = None
142 | if labels is not None:
143 | loss_fct = CrossEntropyLoss()
144 | masked_lm_loss = loss_fct(prediction_scores.view(-1, vocab_size), masked_lm_labels.view(-1))
145 | outputs = (masked_lm_loss,) + outputs
146 |
147 | return outputs
148 | # output = (prediction_scores,) + outputs[2:]
149 | # return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
150 |
151 |
152 | class RobertaForCMLM(BertPreTrainedModel):
153 | config_class = RobertaConfig
154 | pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
155 | base_model_prefix = "roberta"
156 |
157 | def __init__(self, config):
158 | super().__init__(config)
159 |
160 | self.roberta = RobertaModel(config)
161 | self.lm_head = RobertaLMHead(config)
162 |
163 | self.init_weights()
164 |
165 | def get_output_embeddings(self):
166 | return self.lm_head.decoder
167 |
168 | def forward(self,input_ids=None,attention_mask=None,token_type_ids=None,
169 | position_ids=None,head_mask=None,inputs_embeds=None,masked_lm_labels=None):
170 | outputs = self.roberta(
171 | input_ids,
172 | attention_mask=attention_mask,
173 | token_type_ids=token_type_ids,
174 | position_ids=position_ids,
175 | head_mask=head_mask,
176 | inputs_embeds=inputs_embeds,
177 | )
178 | sequence_output = outputs[0]
179 | prediction_scores = self.lm_head(sequence_output)
180 | output_pred_scores = torch.topk(prediction_scores,k=20,dim=-1)
181 | outputs = (output_pred_scores,) # Add hidden states and attention if they are here
182 |
183 | if masked_lm_labels is not None:
184 | loss_fct = CrossEntropyLoss()
185 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
186 | outputs = (masked_lm_loss,) + outputs
187 |
188 | return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
189 |
190 |
191 | try:
192 | from torch.utils.tensorboard import SummaryWriter
193 | except ImportError:
194 | from tensorboardX import SummaryWriter
195 |
196 |
197 | logger = logging.getLogger(__name__)
198 |
199 |
200 | MODEL_CLASSES = {
201 | "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
202 | "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
203 | "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
204 | "roberta": (RobertaConfig, RobertaForMaskedLMv2, RobertaTokenizer),
205 | "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
206 | "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
207 | }
208 |
209 |
210 | class CMLDataset(Dataset):
211 | def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size=512,limit=None):
212 | assert os.path.isfile(file_path)
213 | logger.info("Creating features from dataset file at %s", file_path)
214 |
215 | self.examples=[]
216 | with jsonlines.open(file_path, 'r') as f:
217 | for ix,line in tqdm(enumerate(f),desc="Reading Jsonlines",ascii=True):
218 | if limit is not None and ix>limit:
219 | continue
220 | if (None in line["inputids"]) or (None in line["labels"]):
221 | print("LineNum:",ix,line)
222 | continue
223 | else:
224 | self.examples.append(line)
225 |
226 | if limit is not None:
227 | self.examples = self.examples[0:limit]
228 | self.block_size = int(block_size)
229 | self.tokenizer = tokenizer
230 | self.truncated={}
231 |
232 | def __len__(self):
233 | return len(self.examples)
234 |
235 | def __getitem__(self, i):
236 | block_size = self.block_size
237 | tokenizer = self.tokenizer
238 |
239 | item = self.examples[i]
240 |
241 | input_ids = item["inputids"]
242 | labels = item["labels"]
243 | assert len(input_ids) == len(labels)
244 |
245 | if len(input_ids) > block_size-2:
246 | self.truncated[i]=1
247 |
248 | if len(input_ids) >= block_size-2:
249 | input_ids = input_ids[0:block_size-2]
250 | labels = labels[0:block_size-2]
251 | elif len(input_ids) < block_size-2:
252 | input_ids = input_ids+[tokenizer.pad_token_id]*(self.block_size-2-len(input_ids))
253 | labels = labels + [tokenizer.pad_token_id]*(self.block_size-2-len(labels))
254 |
255 | input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)
256 | labels = tokenizer.build_inputs_with_special_tokens(labels)
257 |
258 | assert len(input_ids) == len(labels)
259 | assert len(input_ids) == block_size
260 | try:
261 | input_ids = torch.tensor(input_ids, dtype=torch.long)
262 | labels = torch.tensor(labels, dtype=torch.long)
263 | mask_idxs = (input_ids==tokenizer.mask_token_id).bool()
264 | labels[~mask_idxs]=-100
265 | except:
266 | l.error(f"Unexpected error at index {i}: {sys.exc_info()[0]}")
267 | raise
268 |
269 | return input_ids , labels
270 |
271 | parser = argparse.ArgumentParser()
272 | parser.add_argument(
273 | "--model_name",
274 | type=str,
275 | help="The model checkpoint for weights initialization.",
276 | )
277 | parser.add_argument(
278 | "--data_file",
279 | type=str,
280 | help="Input Data File to Score",
281 | )
282 | parser.add_argument(
283 | "--meta_file",
284 | type=str,
285 | help="Input Meta File to Score",
286 | )
287 |
288 | parser.add_argument(
289 | "--prefix",
290 | default="test",
291 | type=str,
292 | help="prefix to separate the output files",
293 | )
294 |
295 | parser.add_argument(
296 | "--pred_path",
297 | default="outputs",
298 | type=str,
299 | help="path where the predictions will be stored",
300 | )
301 |
302 | parser.add_argument(
303 | "--batch_size",
304 | default=20,
305 | type=int,
306 | help="Eval Batch Size",
307 | )
308 | args = parser.parse_args()
309 |
310 | device = torch.device("cuda")
311 | n_gpu = torch.cuda.device_count()
312 | config_class, model_class, tokenizer_class = MODEL_CLASSES["roberta"]
313 |
314 | config = config_class.from_pretrained(args.model_name)
315 | tokenizer = tokenizer_class.from_pretrained(args.model_name)
316 | model = model_class.from_pretrained(
317 | args.model_name,
318 | from_tf=bool(".ckpt" in args.model_name),
319 | config=config,
320 | )
321 |
322 | model.to(device)
323 | tiny_dataset = CMLDataset(tokenizer,file_path=args.data_file,block_size=1024)
324 | eval_sampler = SequentialSampler(tiny_dataset)
325 | eval_dataloader = DataLoader(tiny_dataset, sampler=eval_sampler, batch_size=args.batch_size)
326 |
327 | model.eval()
328 | eval_loss = 0.0
329 | nb_eval_steps = 0
330 |
331 | matched={1:0,3:0,5:0,10:0}
332 | totalmasked={1:0,3:0,5:0,10:0}
333 |
334 | pred_list={
335 | 1 : [],
336 | 3 : [],
337 | 5 : [],
338 | 10: []
339 | }
340 | gold_list=[]
341 |
342 | for batch in tqdm(eval_dataloader, desc="Evaluating"):
343 | inputs, labels = batch[0], batch[1]
344 | only_masked = inputs==tokenizer.mask_token_id
345 | masked_gold = labels[only_masked]
346 |
347 | inputs = inputs.to(device)
348 | labels = labels.to(device)
349 |
350 | gold_list.append(masked_gold.tolist())
351 |
352 | with torch.no_grad():
353 | outputs = model(inputs, masked_lm_labels=labels)
354 | lm_loss = outputs[0]
355 | inference = outputs[1].indices.cpu()
356 | eval_loss += lm_loss.mean().item()
357 |
358 | # TopK Calculation
359 | masked_predict = inference[only_masked]
360 | for k in [1,3,5,10]:
361 | topked = masked_predict[:,0:k]
362 | pred_list[k].append(topked.tolist())
363 | for i,cur_gold_tok in enumerate(masked_gold):
364 | totalmasked[k]+=1
365 | cur_predict_scores = topked[i]
366 | if cur_gold_tok in cur_predict_scores:
367 | matched[k]+=1
368 |
369 | nb_eval_steps += 1
370 |
371 | eval_loss = eval_loss / nb_eval_steps
372 | perplexity = torch.exp(torch.tensor(eval_loss))
373 | print(f"Perplexity: ", perplexity)
374 | for i in [1,3,5,10]:
375 | print("TopK:", i, matched[i]/totalmasked[i])
376 |
377 | print("Truncated:", len(tiny_dataset.truncated))
378 |
379 | os.makedirs(args.pred_path, exist_ok=True)
380 | json.dump(pred_list,open(os.path.join(args.pred_path, args.prefix+"_pred_list.json"),"w"))
381 | json.dump(gold_list,open(os.path.join(args.pred_path, args.prefix+"_gold_list.json"),"w"))
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/create_dataset_splits.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import json
4 | import sys
5 | from tqdm import tqdm
6 | import concurrent.futures
7 | from multiprocessing import Manager
8 | from itertools import islice
9 | from functools import partial
10 | import time
11 | import re
12 | import hashlib
13 | import numpy as np
14 | import itertools
15 | import random
16 | import logging
17 | from collections import OrderedDict, Counter
18 | from preprocess_vars import normalize_variables
19 |
20 | l = logging.getLogger('main')
21 |
22 | manager = Manager()
23 | md5_to_funcline = manager.dict()
24 | funcline_to_data = manager.dict()
25 | cfuncs_md5 = manager.dict()
26 | cppfuncs_md5 = manager.dict()
27 | dwarf_vars = manager.list()
28 | decompiler_vars = manager.list()
29 | new_funcs = manager.list()
30 |
31 | def read_json(filename):
32 | with open(filename, 'r') as r:
33 | data = json.loads(r.read())
34 | if data:
35 | return data
36 |
37 | def write_json(filename, data):
38 | with open(filename, 'w') as w:
39 | w.write(json.dumps(data))
40 |
41 | def write_list(filename, data, distribution=False, save_list=True):
42 | if distribution:
43 | c = Counter(data)
44 | sorted_dict = dict(sorted(c.items(), key=lambda item: item[1], reverse = True))
45 | write_json(filename + '_distribution.json', sorted_dict)
46 |
47 | if save_list:
48 | with open(filename, 'w') as w:
49 | w.write('\n'.join(data))
50 |
51 |
52 | def write_jsonlines(path, data):
53 | with open(path, 'a+') as w:
54 | w.write(json.dumps(data) + "\n")
55 | return True
56 |
57 | # collect hash(func_body)
58 | def hash_funcs(workdir, decomp, file_):
59 | global md5_to_funcline
60 | global funcline_to_data
61 | global cfuncs_md5
62 | global cppfuncs_md5
63 |
64 | var_regex = r"@@(var_\d+)@@(\w+)@@"
65 | json_data = read_json(os.path.join(workdir, "map", decomp, file_))
66 | if not json_data:
67 | return
68 |
69 | id_ = 0
70 | for name, data in islice(json_data.items(), 0, len(json_data)):
71 | func = data['func']
72 | up_func = re.sub(var_regex, "\\2", func)
73 | # dedup func body
74 | func_body = '{'.join(up_func.split('{')[1:])
75 | # for Ghidra
76 | # if 'anon_var' in func_body or 'anon_func' in func_body:
77 | # continue
78 | func_body = func_body.replace('\t', '').replace('\n', '').replace(' ','')
79 | md5 = hashlib.md5(func_body.encode('utf-8')).hexdigest()
80 | id_ += 1
81 | data['id'] = id_
82 | data['func_name'] = name
83 | data['md5'] = md5
84 | md5_to_funcline[md5] = name
85 | funcline_to_data[name] = data
86 | if data['language'] == 'c':
87 | cfuncs_md5[md5] = name
88 | else:
89 | cppfuncs_md5[md5] = name
90 |
91 | # de-dup functions before you split!
92 | def func_dedup( workdir, decomp, start: int, batch_size: int, input):
93 | end = start + batch_size
94 | for hash_, funcline in islice(md5_to_funcline.items(), start, end):
95 | try:
96 | binary_name = funcline.split('_')[0]
97 | ret = write_jsonlines(os.path.join(workdir, "dedup_func", decomp, f'{binary_name}.jsonl'), funcline_to_data[funcline])
98 | if not ret:
99 | l.error(f"error in dumping {hash_} | {funcline}")
100 | except Exception as e:
101 | l.error(f"Error in deduplicating functions{e}")
102 |
103 | def find_c_and_cpp_files(dedup_dir, decomp):
104 |
105 | files = os.listdir(os.path.join(dedup_dir, decomp))
106 | random.seed(5)
107 | random.shuffle(files)
108 |
109 | c_name_and_lines = OrderedDict()
110 | cpp_name_and_lines = OrderedDict()
111 | # find c and cpp files
112 | c_funcs_per_binary, cpp_funcs_per_binary = [], []
113 | for f in files:
114 | with open(os.path.join(dedup_dir, decomp, f), 'r') as r:
115 | lines = r.readlines()
116 | if lines:
117 | if json.loads(lines[0])['language'] == 'c':
118 | c_funcs_per_binary.append(len(lines))
119 | c_name_and_lines[f] = len(lines)
120 | else:
121 | cpp_funcs_per_binary.append(len(lines))
122 | cpp_name_and_lines[f] = len(lines)
123 |
124 | l.info(f"c files: {len(c_funcs_per_binary)} \tcpp files: {len(cpp_funcs_per_binary)}")
125 | return c_funcs_per_binary, cpp_funcs_per_binary, c_name_and_lines, cpp_name_and_lines
126 |
127 | # decide splits
128 | def decide_per_binary_split(dedup_dir, c_funcs_per_binary, cpp_funcs_per_binary, c_name_and_lines, cpp_name_and_lines, workdir, decomp):
129 | l.debug("Deciding binary splits...")
130 | def split_files(list_, name_and_lines):
131 | list_ = list(name_and_lines.items())
132 | random.shuffle(list_)
133 | train_, test_ = 0, 0
134 | # sum_ = sum(list_)
135 | sum_ = sum([lines for _, lines in list_])
136 | # if number of total functions are fewer, update the limits
137 | t1 = round(sum_ * 0.709, 2)
138 | t2 = round(sum_ * 0.805, 2)
139 |
140 | l.debug(f"total functions!: {sum_} \tlower limit: {t1} \tupper limit: {t2}")
141 | got = []
142 | found = False
143 | for elem in range(0, len(list_)):
144 | # train_ += list_[elem]
145 | _, lines = list_[elem]
146 | train_ += lines
147 | got.append(list_[elem])
148 | if t1 <= train_ <= t2:
149 | l.debug(f"functions in train split!: {train_}")
150 | l.debug(f"found the split! pick files until: {elem}")
151 | break
152 | l.debug(f"functions in test split: {sum_ - train_}")
153 | # train_files = dict(islice(name_and_lines.items(), 0, elem+1))
154 | # test_files = dict(islice(name_and_lines.items(), elem+1, len(list_)))
155 | train_files = dict(list_[:elem+1])
156 | test_files = dict(list_[elem+1:])
157 | return train_files, test_files
158 |
159 | if len(c_funcs_per_binary):
160 | c_train_files, c_test_files = split_files(c_funcs_per_binary, c_name_and_lines)
161 | while True:
162 | if len(c_test_files) == 0:
163 | l.debug("retrying for c...")
164 | c_train_files, c_test_files = split_files(c_funcs_per_binary, c_name_and_lines)
165 | else:
166 | break
167 | if len(cpp_funcs_per_binary):
168 | cpp_train_files, cpp_test_files = split_files(cpp_funcs_per_binary, cpp_name_and_lines)
169 | while True:
170 | if len(cpp_test_files) == 0:
171 | l.debug("retrying for cpp...")
172 | cpp_train_files, cpp_test_files = split_files(cpp_funcs_per_binary, cpp_name_and_lines)
173 | else:
174 | break
175 |
176 | if len(cpp_funcs_per_binary) and len(c_funcs_per_binary):
177 | train_files = {**c_train_files, **cpp_train_files}
178 | test_files = {**c_test_files, **cpp_test_files}
179 | elif len(cpp_funcs_per_binary):
180 | train_files = cpp_train_files
181 | test_files = cpp_test_files
182 | elif len(c_funcs_per_binary):
183 | train_files = c_train_files
184 | test_files = c_test_files
185 |
186 | with open(os.path.join(workdir, "splits", decomp, "train_files.json"), 'w') as w:
187 | w.write(json.dumps(train_files))
188 |
189 | with open(os.path.join(workdir, "splits", decomp, "test_files.json"), 'w') as w:
190 | w.write(json.dumps(test_files))
191 | return list(train_files.keys()), list(test_files.keys())
192 |
193 | # decide splits
194 | def decide_per_func_split(workdir, decomp):
195 | l.debug("Deciding function splits...")
196 | def split_funcs(all_funcs):
197 |
198 | shuffled_funcnames = list(all_funcs.values())
199 | random.seed(5)
200 | random.shuffle(shuffled_funcnames)
201 |
202 | total = len(all_funcs)
203 | train_ = int(total * 0.80)
204 | train_funcs = shuffled_funcnames[:train_]
205 | test_funcs = shuffled_funcnames[train_:]
206 | return train_funcs, test_funcs
207 |
208 | l.info(f"Total C funcs {len(cfuncs_md5)}")
209 | c_train_funcs, c_test_funcs = split_funcs(cfuncs_md5)
210 | l.info(f"Total CPP funcs {len(cppfuncs_md5)}")
211 | cpp_train_funcs, cpp_test_funcs = split_funcs(cppfuncs_md5)
212 |
213 | # total funcs now
214 | c_train_funcs.extend(cpp_train_funcs)
215 | c_test_funcs.extend(cpp_test_funcs)
216 |
217 | with open(os.path.join(workdir, "splits", decomp, "train_funcs.txt"), 'w') as w:
218 | w.write("\n".join(c_train_funcs))
219 |
220 | with open(os.path.join(workdir, "splits", decomp, "test_funcs.txt"), 'w') as w:
221 | w.write("\n".join(c_test_funcs))
222 |
223 | return c_train_funcs, c_test_funcs
224 |
225 | # create binary splits
226 | def create_per_binary_split(files, filename, workdir, decomp):
227 |
228 | with open(os.path.join(workdir, "splits", decomp,'per-binary', f'{filename}.jsonl'), 'w') as w:
229 | for file_ in files:
230 | with open(os.path.join(workdir, "dedup_func", decomp, file_), 'r') as r:
231 | lines = r.readlines()
232 | for line in lines:
233 | w.write(json.dumps(json.loads(line)) + '\n')
234 |
235 | # create func splits
236 | def create_per_func_split(funcs, filename, workdir, decomp):
237 |
238 | with open(os.path.join(workdir, "splits", decomp,'per-func', f'{filename}.jsonl'), 'w') as w:
239 | for func in funcs:
240 | if func not in funcline_to_data:
241 | print(func)
242 | w.write( json.dumps(funcline_to_data[func])+ "\n")
243 |
244 | def get_variables(decomp, func_):
245 |
246 | global dwarf_vars
247 | global decompiler_vars
248 |
249 | var_regex = r"@@(var_\d+)@@(\w+)@@"
250 | data = json.loads(func_)
251 | vars_ = data['type_stripped_vars']
252 | for var, ty in vars_.items():
253 | if ty == "dwarf":
254 | dwarf_vars.append(var)
255 | elif ty == decomp:
256 | decompiler_vars.append(var)
257 | else:
258 | l.error(f"Error in getting variables: {var} {ty}")
259 |
260 | def substitute_norm(lookup, func):
261 |
262 | global new_funcs
263 | try:
264 | data = json.loads(func)
265 | func = data['func']
266 | tyvars_= data['type_stripped_vars']
267 | # if the variable is removed in cleaning invalid vars
268 | if not tyvars_:
269 | return
270 | vars_ = []
271 | for v in tyvars_:
272 | if tyvars_[v] == 'dwarf':
273 | vars_.append(v)
274 | # vars_map can be empty if there are no dwarf variable names in the func
275 | vars_map = []
276 | norm_func = func
277 | for var in vars_:
278 | og_var = rf"(@@var_\d+@@)({var})(@@)"
279 | if var in lookup:
280 | if lookup[var] == '':
281 | continue
282 | else:
283 | norm_func = re.sub(og_var, rf'\1{lookup[var]}\3', norm_func)
284 | vars_map.append([var, lookup[var]])
285 | else:
286 | l.error(f"something fishy! {var} {data['id']}")
287 |
288 | up_data = data
289 | up_data['norm_func'] = norm_func
290 | up_data['vars_map'] = vars_map
291 | up_data['fid'] = str(data['id']) + "-" + data['func_name'].replace("_","-")
292 |
293 | # for preprocessing for vocab (vars_map and type_stripped_norm_vars same thing)
294 | norm_var_type = {}
295 | for pair in vars_map:
296 | norm_var = pair[1]
297 | var = pair[0]
298 | if norm_var in norm_var_type and up_data["type_stripped_vars"][var] != 'dwarf':
299 | norm_var_type[norm_var] = 'dwarf'
300 | else:
301 | norm_var_type[norm_var] = up_data["type_stripped_vars"][var]
302 |
303 | up_data['type_stripped_norm_vars'] = norm_var_type
304 | new_funcs.append(up_data)
305 | except Exception as e:
306 | l.error(f'Error in updating function in substitute_norm: {e}')
307 |
308 | def write_norm(filename, funcs):
309 | with open(filename, 'w') as w:
310 | for func in funcs:
311 | w.write(json.dumps(func) + "\n")
312 |
313 | def create_train_and_test_sets(workdir, decomp, WORKERS=4):
314 |
315 | total_steps = 4
316 | progress = tqdm(total=total_steps, desc="Initializing")
317 |
318 | # create dirs
319 | dirs = [f"dedup_func/{decomp}", "lookup", f"splits/{decomp}/per-func", f"splits/{decomp}/per-binary", f"vars/{decomp}/per-func", f"vars/{decomp}/per-binary" ]
320 | for d in dirs:
321 | os.makedirs(os.path.join(workdir, d), exist_ok=True)
322 | files = os.listdir(os.path.join(workdir, 'map', decomp))
323 |
324 | progress.set_description("Deduplicating functions")
325 | # create hash of funcs for de-dup
326 | with concurrent.futures.ProcessPoolExecutor(max_workers=WORKERS) as executor:
327 | # executor.map(hash_funcs, files)
328 | partial_func = partial(hash_funcs, workdir, decomp)
329 | executor.map(partial_func, files)
330 |
331 | l.info(f"functions before de-dup: {len(funcline_to_data)} \tfunctions after de-dup: {len(md5_to_funcline)}")
332 |
333 | batch = len(md5_to_funcline) // (WORKERS - 1)
334 | if os.listdir(os.path.join(workdir, "dedup_func", decomp)):
335 | print("files in dedup dir!, remove them")
336 | l.error("files in dedup dir!, remove them")
337 | exit(1)
338 |
339 | l.debug("now saving dedup code!")
340 | with concurrent.futures.ProcessPoolExecutor(max_workers=WORKERS) as executor:
341 | cal_partial = partial(func_dedup, workdir, decomp, batch_size=batch, input=input,)
342 | executor.map(cal_partial, [batch * i for i in range(WORKERS)])
343 |
344 | l.info(f'Binaries in dedup func: {len(os.listdir(os.path.join(workdir, "dedup_func", decomp)))}')
345 | progress.update(1)
346 | # find c and cpp files
347 | c_funcs_per_binary, cpp_funcs_per_binary, c_name_and_lines, cpp_name_and_lines = find_c_and_cpp_files(os.path.join(workdir, "dedup_func"), decomp)
348 |
349 | # # per-binary splits
350 | progress.set_description("Creating Binary Split")
351 | train_b, test_b = decide_per_binary_split(os.path.join(workdir, "dedup_func"), c_funcs_per_binary, cpp_funcs_per_binary, c_name_and_lines, cpp_name_and_lines, workdir, decomp)
352 | l.debug(f"per binary numbers: Train: {len(train_b)} \tTest: {len(test_b)}")
353 | create_per_binary_split(train_b, "train", workdir, decomp)
354 | create_per_binary_split(test_b, "test", workdir, decomp)
355 | train_b.extend(test_b)
356 | all_b = train_b
357 | l.debug(f"total samples in binary split: {len(all_b)}")
358 | progress.update(1)
359 |
360 | # per-func splits
361 | progress.set_description("Creating Function Split")
362 | train_fn, test_fn = decide_per_func_split(workdir, decomp)
363 | l.debug(f"per func numbers: Train: {len(train_fn)} \tTest: {len(test_fn)}")
364 | create_per_func_split(train_fn, "train", workdir, decomp)
365 | create_per_func_split(test_fn, "test", workdir, decomp)
366 | train_fn.extend(test_fn)
367 | all_fn = train_fn
368 | l.debug(f"total samples in func split: {len(all_fn)}")
369 | progress.update(1)
370 |
371 | progress.set_description("Updating variables and saving splits")
372 | clean_up_variables(WORKERS, workdir, decomp)
373 | progress.update(1)
374 | progress.set_description("Train and Test splits created!")
375 | progress.close()
376 |
377 | # TODO: Fix later
378 | def clean_up_variables(WORKERS, workdir, decomp):
379 | # replace variables
380 | import glob
381 | files = glob.glob(f'{workdir}/splits/{decomp}/**/*', recursive=True)
382 | for f in files:
383 | new_funcs[:] = []
384 | if os.path.isfile(f) and f.endswith('jsonl') and 'all' not in f and 'final' not in f:
385 | tmp = f.split('/')
386 | ty = tmp[-2]
387 | file_ = tmp[-1]
388 | with open(f, 'r') as r:
389 | funclines = r.readlines()
390 | with concurrent.futures.ProcessPoolExecutor(max_workers=WORKERS) as executor:
391 | partial_func = partial(get_variables, decomp)
392 | executor.map(partial_func, funclines)
393 |
394 | # create variable lists
395 | var_out_file = os.path.join(workdir, "vars", decomp, f"{ty}", f"{file_[:-6]}_dwarf_vars")
396 | write_list(var_out_file, list(dwarf_vars), True, False)
397 |
398 | write_list(os.path.join(workdir, "vars", decomp, f"{ty}", f"{file_[:-6]}_decomp_vars"), list(decompiler_vars), True, False)
399 |
400 | clean_var_out = os.path.join(workdir, "vars", decomp, f"{ty}")
401 | # subprocess.call(['python3.8', 'preprocess_vars.py', f'{var_out_file}_distribution.json', file_[:-6], clean_var_out, os.path.join(workdir, 'lookup')])
402 | normalize_variables(f'{var_out_file}_distribution.json', file_[:-6], clean_var_out, os.path.join(workdir, 'lookup'), WORKERS)
403 | import time; time.sleep(10)
404 | with open(os.path.join(workdir, 'lookup', 'universal_lookup.json'), 'r') as r:
405 | lookup = json.loads(r.read())
406 |
407 | with concurrent.futures.ProcessPoolExecutor(max_workers=WORKERS) as executor:
408 | partial_func = partial(substitute_norm, lookup)
409 | executor.map(partial_func, funclines)
410 | # how so many funcs?? # FIXME
411 | write_norm(os.path.join(workdir, "splits", decomp, ty, f"final_{file_[:-6]}.jsonl"), new_funcs)
412 | l.info(f"Functions in file {f}: {str(len(new_funcs))}")
413 | # create_train_and_test_sets('/tmp/varbert_tmpdir', 'ghidra')
414 |
--------------------------------------------------------------------------------
/varcorpus/dataset-gen/variable_matching.py:
--------------------------------------------------------------------------------
1 | from binary import Binary
2 | from utils import read_text, read_json, write_json
3 | import os
4 | import string
5 | import re
6 | import json
7 | from collections import defaultdict, Counter, OrderedDict
8 | import sys
9 | import hashlib
10 | import logging
11 | from parse_decompiled_code import IDAParser, GhidraParser
12 | from typing import Any, List, Dict
13 |
14 | l = logging.getLogger('main')
15 |
16 | class DataLoader:
17 | def __init__(self, binary_name, dwarf_info_path, decompiler, decompiled_code_strip_path,
18 | decompiled_code_type_strip_path, joern_strip_data, joern_type_strip_data, data_map_dump, language, decompiled_code_type_strip_ln_fn_names) -> None:
19 | self.binary_name = binary_name
20 | self.decompiler = decompiler
21 | self.decompiled_code_strip_path = decompiled_code_strip_path
22 | self.decompiled_code_type_strip_path = decompiled_code_type_strip_path
23 | self.joern_strip_data = joern_strip_data
24 | self.joern_type_strip_data = joern_type_strip_data
25 | self.dwarf_vars, self.dwarf_funcs, self.linkage_name_to_func_name, self.decompiled_code_type_strip_ln_fn_names = self._update_values((dwarf_info_path), decompiled_code_type_strip_ln_fn_names)
26 | self.data_map_dump = data_map_dump
27 | self.language = language
28 | self.sample = {}
29 | self.run()
30 |
31 | def _update_values(self, path, decompiled_code_type_strip_ln_fn_names):
32 | if isinstance(path, str):
33 | tmp = read_json(path)
34 | tmp = path
35 | linkage_name_to_func_name = {}
36 | if 'vars_per_func' in tmp.keys():
37 | dwarf_vars = tmp['vars_per_func']
38 | dwarf_funcs = dwarf_vars.keys()
39 | if 'linkage_name_to_func_name' in tmp.keys():
40 | if tmp['linkage_name_to_func_name']:
41 | linkage_name_to_func_name = tmp['linkage_name_to_func_name']
42 | else:
43 | for funcname, _ in tmp['vars_per_func'].items():
44 | linkage_name_to_func_name[funcname] = funcname
45 | # TODO: if setting corpus language to c and cpp then enable this
46 | # else:
47 | # # C does not have linkage names, also at the time of building data set, we did not have cpp implementation in, so hack it
48 | # for funcname, _ in tmp['vars_per_func'].items():
49 | # linkage_name_to_func_name[funcname] = funcname
50 | # print("linkage", linkage_name_to_func_name)
51 |
52 | # cpp
53 | if os.path.exists(decompiled_code_type_strip_ln_fn_names):
54 | decompiled_code_type_strip_ln_fn_names = read_json(decompiled_code_type_strip_ln_fn_names)
55 | else:
56 | # c
57 | decompiled_code_type_strip_ln_fn_names = linkage_name_to_func_name
58 | return dwarf_vars, dwarf_funcs, linkage_name_to_func_name, decompiled_code_type_strip_ln_fn_names
59 |
60 | def run(self):
61 | if self.decompiler.lower() == 'ida':
62 | parser_strip = IDAParser(read_text(self.decompiled_code_strip_path), self.binary_name, None)
63 | parser_type_strip = IDAParser(read_text(self.decompiled_code_type_strip_path), self.binary_name, self.decompiled_code_type_strip_ln_fn_names)
64 | elif self.decompiler.lower() == 'ghidra':
65 | parser_strip = GhidraParser(read_text(self.decompiled_code_strip_path), self.binary_name)
66 | parser_type_strip = GhidraParser(read_text(self.decompiled_code_type_strip_path), self.binary_name)
67 |
68 | # load joern data
69 | joern_data_strip = JoernDataLoader(self.joern_strip_data, parser_strip.func_name_to_addr, None)
70 | joern_data_type_strip = JoernDataLoader(self.joern_type_strip_data, parser_type_strip.func_name_to_addr, parser_type_strip.func_name_to_linkage_name)
71 | # filter out functions that are not common
72 | common_func_addr_decompiled_code = set(parser_strip.func_addr_to_name.keys()).intersection(parser_type_strip.func_addr_to_name.keys())
73 | common_func_addr_from_joern = set(joern_data_strip.joern_addr_to_name.keys()).intersection(joern_data_type_strip.joern_addr_to_name.keys())
74 |
75 | # joern names to addr
76 | # remove comments at the last because Joern's numbers are with comments
77 | for addr in common_func_addr_from_joern:
78 | try:
79 | if addr not in common_func_addr_decompiled_code:
80 | continue
81 | # remove decompiler/compiler generated functions - we get dwarf names from type-strip
82 | if addr not in parser_type_strip.func_addr_to_name:
83 | continue
84 | l.debug(f'Trying for func addr {self.binary_name} :: {self.decompiler} {addr}')
85 |
86 | func_name_type_strip = parser_type_strip.func_addr_to_name[addr]
87 | func_name_strip = parser_strip.func_addr_to_name[addr]
88 |
89 | # if function is compiler-generated (artifical) we do not have collect DWARF info fot that function (we need only source functions)
90 | # For Ghidra: we use addr as funcname in dwarf info that is collected from binary
91 | dwarf_addr_compatible_addr = None
92 | if self.decompiler.lower() == 'ghidra':
93 | dwarf_addr_compatible_addr = str(int(addr[1:-1], 16) - 0x00100000)
94 | if dwarf_addr_compatible_addr not in self.dwarf_funcs:
95 | l.debug(f'Func not in DWARF / Source! {self.binary_name} :: {self.decompiler} {dwarf_addr_compatible_addr}')
96 | continue
97 |
98 | # For IDA func name collected from binary is w/o line number
99 | elif self.decompiler.lower() == 'ida':
100 | # funcnames from cpp have '::' in type-strip dc and not in parsed dwarf data - get linkage names so we are good
101 | funcname_from_type_strip = parser_type_strip.functions[addr]['func_name_no_line']
102 |
103 | if funcname_from_type_strip not in self.dwarf_funcs and funcname_from_type_strip.strip().split('::')[-1] not in self.dwarf_funcs:
104 | l.debug(f'Func not in DWARF / Source! {self.binary_name} :: {self.decompiler} {funcname_from_type_strip}')
105 | continue
106 |
107 | dc_functions_strip = parser_strip.functions[addr]
108 | dc_functions_type_strip = parser_type_strip.functions[addr]
109 |
110 | dwarf_func_name = '_'.join(func_name_type_strip.split('_')[:-1])
111 | cfunc_strip = CFunc(self, dc_functions_strip, joern_data_strip, "strip", func_name_strip, dwarf_func_name, addr, None)
112 | cfunc_type_strip = CFunc(self, dc_functions_type_strip, joern_data_type_strip, "type-strip", func_name_type_strip, dwarf_func_name, addr, self.linkage_name_to_func_name)
113 |
114 | # unequal number of joern vars! do not proceed
115 | if len(cfunc_strip.joern_vars) != len(cfunc_type_strip.joern_vars):
116 | l.debug(f'Return! unequal number of joern vars! {self.binary_name} :: {self.decompiler} {addr}')
117 | continue
118 |
119 | # unequal number of local vars!
120 | if len(cfunc_strip.local_vars_dc) == len(cfunc_type_strip.local_vars_dc):
121 | # identify missing var and update lines to current lines (I n don't need for type_strip? but doing it for uniformity)
122 | l.debug(f'Identify missing vars! {self.binary_name} :: {self.decompiler} :: {addr}')
123 | cfunc_strip.joern_vars = self.identify_missing_vars(cfunc_strip)
124 | cfunc_type_strip.joern_vars = self.identify_missing_vars(cfunc_type_strip)
125 |
126 | l.debug(f'Removing invalid variable names! {self.binary_name} :: {self.decompiler} :: {addr}')
127 |
128 | # remove invalid variable names
129 | valid_strip_2_type_strip_varnames, cfunc_strip.joern_vars, cfunc_type_strip.joern_vars = self.remove_invalid_variable_names(cfunc_strip, cfunc_type_strip,
130 | parser_strip, parser_type_strip)
131 |
132 | l.debug(f'Start matching variable names! {self.binary_name} :: {self.decompiler} :: {addr}')
133 | matchvar = MatchVariables(cfunc_strip, cfunc_type_strip, valid_strip_2_type_strip_varnames, self.dwarf_vars, self.decompiler, cfunc_type_strip.dwarf_func_name, self.binary_name, self.language, dwarf_addr_compatible_addr)
134 |
135 | sample_name, sample_data = matchvar.dump_sample()
136 | if sample_data:
137 | l.debug(f'Dump sample! {self.binary_name} :: {self.decompiler} :: {addr}')
138 | self.sample[sample_name] = sample_data
139 |
140 | except Exception as e:
141 | l.error(f"ERRR {self.binary_name} :: {e} {addr}")
142 | l.info(f'Variable Matching complete for {self.binary_name} :: {self.decompiler} :: samples :: {len(self.sample)}')
143 | write_json(self.data_map_dump, dict(self.sample))
144 |
145 | def identify_missing_vars(self, cfunc):
146 | # 1. update file lines to func lines
147 | # 2. joern may miss some occurences of variables without declaration (i.e. variables from .data, .bss section. find them and update corresponding lines)
148 | joern_vars = cfunc.joern_vars
149 | tmp_dict = {}
150 | for var, var_lines in joern_vars.items():
151 | updated_lines = []
152 | updated_lines[:] = [int(number) - int(cfunc.joern_func_start) +1 for number in var_lines]
153 | var = var.strip('~')
154 | find_all = rf'([^\d\w_-])({var})([^\d\w_-])'
155 | for i, line in enumerate(cfunc.func_lines, 1):
156 | matches = re.search(find_all, line)
157 | if matches:
158 | if i not in updated_lines and not line.startswith('//'):
159 | updated_lines.append(i)
160 |
161 | updated_lines = list(set(updated_lines))
162 |
163 | tmp_dict[var] = updated_lines
164 | return tmp_dict
165 |
166 |
167 | def remove_invalid_variable_names(self, cfunc_strip, cfunc_type_strip, parser_strip, parser_type_strip):
168 |
169 | # https://www.hex-rays.com/products/ida/support/idadoc/1361.shtml
170 | if self.decompiler.lower() == 'ida':
171 | data_types = ['_BOOL1', '_BOOL2', '_BOOL4', '__int8', '__int16', '__int32', '__int64', '__int128', '_BYTE', '_WORD', '_DWORD', '_QWORD', '_OWORD', '_TBYTE', '_UNKNOWN', '__pure', '__noreturn', '__usercall', '__userpurge', '__spoils', '__hidden', '__return_ptr', '__struct_ptr', '__array_ptr', '__unused', '__cppobj', '__ptr32', '__ptr64', '__shifted', '__high']
172 |
173 | # anymore undefined?
174 | if self.decompiler.lower() == 'ghidra':
175 | data_types = ['ulong', 'uint', 'ushort', 'ulonglong', 'bool', 'char', 'int', 'long', 'undefined', 'undefined1', 'undefined2', 'undefined4', 'undefined8', 'byte', 'FILE', 'size_t']
176 |
177 | invalid_vars = ['null', 'true', 'false', 'True', 'False', 'NULL', 'char', 'int']
178 | # also func names which may have been identified as var names
179 |
180 | zip_joern_vars = dict(zip(cfunc_strip.joern_vars, cfunc_type_strip.joern_vars))
181 | # strip: type-strip
182 | valid_strip_2_type_strip_varnames = {}
183 | valid_strip_vars = {}
184 | valid_type_strip_vars = {}
185 |
186 | for strip_var, type_strip_var in zip_joern_vars.items():
187 | try:
188 | strip_var_woc, type_strip_var_woc = strip_var, type_strip_var
189 | strip_var = strip_var.strip('~')
190 | type_strip_var = type_strip_var.strip('~')
191 |
192 | # remove func name
193 | if strip_var in parser_strip.func_name_wo_line.keys() or type_strip_var in parser_type_strip.func_name_wo_line.keys():
194 | continue
195 | # remove func name with ::
196 | if strip_var.strip().split('::')[-1] in parser_strip.func_name_wo_line.keys() or type_strip_var.strip().split('::')[-1] in parser_type_strip.func_name_wo_line.keys():
197 | continue
198 | # (rn: sub_2540_475)
199 | if strip_var in parser_strip.func_name_to_addr.keys() or type_strip_var in parser_type_strip.func_name_to_addr.keys():
200 | continue
201 | # remove types
202 | if strip_var in data_types or type_strip_var in data_types:
203 | continue
204 | # remove invalid name
205 | if strip_var in invalid_vars or type_strip_var in invalid_vars:
206 | continue
207 |
208 | if '::' in strip_var:
209 | strip_var_woc = strip_var.strip().split('::')[-1]
210 | if '::' in type_strip_var:
211 | type_strip_var_woc = type_strip_var.strip().split('::')[-1]
212 |
213 | # create valid mapping
214 | valid_strip_2_type_strip_varnames[strip_var_woc] = type_strip_var_woc
215 |
216 | # update joern vars
217 | if strip_var in cfunc_strip.joern_vars:
218 | valid_strip_vars[strip_var_woc] = cfunc_strip.joern_vars[strip_var]
219 |
220 | elif f'~{strip_var}' in cfunc_strip.joern_vars:
221 | # if strip_var_woc:
222 | valid_strip_vars[strip_var_woc] = cfunc_strip.joern_vars[f'~{strip_var}']
223 |
224 | if type_strip_var in cfunc_type_strip.joern_vars:
225 | # if type_strip_var_woc:
226 | valid_type_strip_vars[type_strip_var_woc] = cfunc_type_strip.joern_vars[type_strip_var]
227 | # else:
228 | # valid_type_strip_vars[type_strip_var] = cfunc_type_strip.joern_vars[type_strip_var]
229 |
230 | if f'~{type_strip_var}' in cfunc_type_strip.joern_vars:
231 | # if type_strip_var_woc:
232 | valid_type_strip_vars[type_strip_var_woc] = cfunc_type_strip.joern_vars[f'~{type_strip_var}']
233 | # else:
234 | # valid_type_strip_vars[type_strip_var] = cfunc_type_strip.joern_vars[f'~{type_strip_var}']
235 | except Exception as e:
236 | l.error(f'Error in removing invalid variable names! {self.binary_name} :: {self.decompiler} :: {e}')
237 | return valid_strip_2_type_strip_varnames, valid_strip_vars, valid_type_strip_vars
238 |
239 |
240 | class JoernDataLoader:
241 | def __init__(self, joern_data, decompiled_code_name_to_addr, func_name_to_linkage_name) -> None:
242 | self.functions = {}
243 | self.joern_data = joern_data
244 | self.decompiled_code_name_to_addr = decompiled_code_name_to_addr
245 | self.joern_name_to_addr = {}
246 | self.joern_addr_to_name = {}
247 | self.joern_start_line_to_funcname = {}
248 | self._load(func_name_to_linkage_name)
249 |
250 | def _load(self, func_name_to_linkage_name):
251 | # key: func name, val: {var: lines}
252 | func_line_num_to_addr = {}
253 | for k, v in self.decompiled_code_name_to_addr.items():
254 | line_num = str(k.split('_')[-1])
255 | func_line_num_to_addr[line_num] = v
256 | try:
257 | counter = 0
258 | for func_name, v in self.joern_data.items():
259 | try:
260 | # check if func name in decompiled code funcs and get addr mapping
261 | # type-strip - replace demangled name with mangled name (cpp) or simply a func name (c)
262 | # some func names from joern are different from parsed decompiled code (IDA: YaSkkServ::`anonymous namespace'::signal_dictionary_update_handler | Joern: signal_dictionary_update_handler )
263 | if func_name_to_linkage_name:
264 | tmp_wo_line = func_name.split('_')
265 | if len(tmp_wo_line) <=1 :
266 | continue
267 | name_wo_line, line_num = '_'.join(tmp_wo_line[:-1]), tmp_wo_line[-1]
268 |
269 | if name_wo_line in func_name_to_linkage_name:
270 | func_name = func_name_to_linkage_name[name_wo_line]
271 | func_name = f'{func_name}_{line_num}'
272 |
273 | if func_name in self.decompiled_code_name_to_addr:
274 | addr = self.decompiled_code_name_to_addr[func_name]
275 | elif line_num in func_line_num_to_addr:
276 | addr = func_line_num_to_addr[line_num]
277 | else:
278 | l.warning("joern and IDA func name did not match")
279 |
280 | self.joern_addr_to_name[addr] = func_name
281 | self.joern_name_to_addr[func_name] = addr
282 | self.joern_start_line_to_funcname[v['func_start']] = func_name
283 |
284 | joern_vars = v['variable']
285 | start = v['func_start']
286 | end = v['func_end']
287 | if start and end:
288 | self.functions[func_name] = {'func_name': func_name,
289 | 'variables': joern_vars,
290 | # 'var_lines': joern_var_lines,
291 | 'start': start,
292 | 'end': end}
293 | counter += 1
294 | except Exception as e:
295 | l.error(f'Error in loading joern data! {e}')
296 | except Exception as e:
297 | l.error(f' error in getting joern vars! {e}')
298 |
299 |
300 | class CFunc:
301 |
302 | def __init__(self, dcl, dc_func, joern_data, binary_type, func_name, dwarf_func_name, func_addr, linkage_name_to_func_name ) -> None:
303 |
304 | self.local_vars_dc = dc_func['local_vars']
305 |
306 | self.func = dc_func['func']
307 | self.func_addr = func_addr
308 | self.func_prototype = None
309 | self.func_body = None
310 | self.func_lines = None
311 |
312 | self.func_name = func_name
313 | self.func_name_no_line = '_'.join(func_name.split('_')[:-1])
314 | self.line = func_name.split('_')[-1]
315 | self.binary_name = dcl.binary_name
316 | self.binary_type = binary_type
317 |
318 | self.decompiler = dcl.decompiler
319 | self.dwarf_func_name = dwarf_func_name
320 | self.dwarf_mangled_func_name = linkage_name_to_func_name
321 |
322 | # start
323 | if self.func_name in joern_data.functions:
324 | self.joern_func_start = joern_data.functions[self.func_name]['start']
325 | elif self.line in joern_data.joern_start_line_to_funcname:
326 | tmp_func_name = joern_data.joern_start_line_to_funcname[self.line]
327 | self.joern_func_start = joern_data.functions[tmp_func_name]['start']
328 |
329 | # variables
330 | if self.func_name in joern_data.functions:
331 | self.all_vars_joern = joern_data.functions[self.func_name]['variables']
332 | elif self.line in joern_data.joern_start_line_to_funcname:
333 | tmp_func_name = joern_data.joern_start_line_to_funcname[self.line]
334 | self.all_vars_joern = joern_data.functions[tmp_func_name]['variables']
335 |
336 | self.set_func_details()
337 |
338 | def __repr__(self) -> str:
339 | pass
340 |
341 | @property
342 | def joern_vars(self):
343 | return self.all_vars_joern
344 |
345 | @joern_vars.setter
346 | def joern_vars(self, new_value):
347 | self.all_vars_joern = new_value
348 |
349 | def set_func_details(self) -> None:
350 | if self.func:
351 | self.func_lines = self.func.split('\n')
352 | self.func_prototype = self.func.split('{')[0]
353 | self.func_body = '{'.join(self.func.split('{')[1:])
354 |
355 |
356 | class MatchVariables:
357 |
358 | def __init__(self, strip: CFunc, type_strip: CFunc, valid_strip_2_type_strip_varnames, dwarf_vars, decompiler, dwarf_func_name, binary_name, language, dwarf_addr_compatible_addr=None) -> None:
359 | self.cfunc_strip = strip
360 | self.cfunc_type_strip = type_strip
361 | self.mapped_vars = valid_strip_2_type_strip_varnames
362 | self.dwarf_vars = dwarf_vars
363 | self.labelled_vars = {}
364 | self.modified_func = None
365 | self.decompiler = decompiler
366 | self.dwarf_func_name = dwarf_func_name
367 | self.md5_hash = None
368 | self.binary_name = binary_name
369 | self.language = language
370 | self.dwarf_addr_compatible_addr = dwarf_addr_compatible_addr
371 | self.match()
372 |
373 | def match(self):
374 | if len(self.cfunc_strip.joern_vars) != len(self.cfunc_type_strip.joern_vars):
375 | return
376 | if len(self.mapped_vars) == 0:
377 | return
378 |
379 | self.modified_func = self.update_func()
380 | # label DWARF and decompiler-gen
381 | self.labelled_vars = self.label_vars()
382 | self.md5_hash = self.func_hash()
383 |
384 | def update_func(self):
385 | def rm_comments(func):
386 | cm_regex = r'// .*'
387 | cm_func = re.sub(cm_regex, ' ', func).strip()
388 | return cm_func
389 |
390 | # pre-process variables and replace them with "@@dwarf_var_name@@var_id@@"
391 | varname2token = {}
392 | for i, varname in enumerate(self.mapped_vars, 0):
393 | varname2token[varname] = f"@@var_{i}@@{self.mapped_vars[varname]}@@"
394 | new_func = self.cfunc_strip.func
395 |
396 | # if no line numbers available
397 | allowed_prefixes = [" ", "&", "(", "*", "++", "--", "!"]
398 | allowed_suffixes = [" ", ")", ",", ";", "[", "++", "--"]
399 | for varname, newname in varname2token.items():
400 | for p in allowed_prefixes:
401 | for s in allowed_suffixes:
402 | new_func = new_func.replace(f"{p}{varname}{s}", f"{p}{newname}{s}")
403 |
404 | # no var is labelled as stderr. it is nether in global vars nor in vars for this function, so we do not add it
405 | if '@@' not in new_func:
406 | return None
407 | return rm_comments(new_func)
408 |
409 |
410 | def label_vars(self):
411 | labelled_vars = {}
412 | for strip_var, type_strip_var in self.mapped_vars.items():
413 | if self.decompiler == 'ida':
414 | check_value = self.cfunc_type_strip.dwarf_func_name
415 | elif self.decompiler == 'ghidra':
416 | check_value = self.dwarf_addr_compatible_addr
417 | if type_strip_var in self.dwarf_vars[check_value] or type_strip_var in self.dwarf_vars['global_vars']:
418 | labelled_vars[type_strip_var] = 'dwarf'
419 | else:
420 | labelled_vars[type_strip_var] = self.decompiler
421 | return labelled_vars
422 |
423 |
424 | def func_hash(self) -> None:
425 | var_regex = r"@@(var_\d+)@@(\w+)@@"
426 | up_func = re.sub(var_regex, "\\2", self.modified_func)
427 | func_body = '{'.join(up_func.split('{')[1:])
428 | md5 = hashlib.md5(func_body.encode('utf-8')).hexdigest()
429 | return md5
430 |
431 |
432 | def dump_sample(self):
433 | if self.decompiler == 'ida':
434 | func_name_dwarf = str(self.cfunc_type_strip.dwarf_mangled_func_name[str(self.dwarf_func_name)])
435 | elif self.decompiler == 'ghidra':
436 | func_name_dwarf = str(self.cfunc_type_strip.dwarf_func_name)
437 | if self.modified_func:
438 | name = f'{self.binary_name}_{self.cfunc_strip.func_addr}'
439 | data = { 'func': self.modified_func,
440 | 'type_stripped_vars': dict(self.labelled_vars),
441 | 'stripped_vars': list(self.mapped_vars.keys()),
442 | 'mapped_vars': dict(self.mapped_vars),
443 | 'func_name_dwarf':func_name_dwarf,
444 | 'dwarf_mangled_func_name': str(self.dwarf_func_name),
445 | 'hash': self.md5_hash,
446 | 'language': self.language
447 | }
448 |
449 | return name, data
450 | else:
451 | return None, None
452 |
--------------------------------------------------------------------------------