├── LICENSE ├── README.md ├── SiameseModel.py ├── model_bce_with_pregen_emb_2.pth ├── pred.py ├── preprocess.py ├── run-slice.py ├── store_emb.py ├── util_slice.py └── util_tokenizer.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 VUL337 Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CALLEE 2 | 3 | Official code of CALLEE: Recovering Call Graphs for Binaries with Transfer and Contrastive Learning. 4 | 5 | For ease of use, we have made some changes to the original implementation in the paper. 6 | 7 | **Status: We have substituted the doc2vec model with transformers and released a new dataset.** 8 | 9 | * The new work **kTrans** is [here](https://github.com/Learner0x5a/kTrans-release). 10 | * The new dataset is [here](https://github.com/Learner0x5a/Callee-Dataset). 11 | 12 | We have decided to deprecate the old dataset since it was collected several years ago on older versions of Firefox and the Linux kernel. 13 | 14 | ## Usage 15 | 16 | ### Environment 17 | Tested on Ubuntu 18.04 with 18 | - Python3 (python-magic, gensim, numpy, torch, tqdm, capstone) 19 | - IDA Pro 7.6 20 | - CUDA 10.2 21 | 22 | ### Pipeline 23 | 24 | ***NOTE: This is a single-thread demo, consider multiprocessing for production or batch processing*** 25 | 26 | **a. Slice target binary with IDA** 27 | 28 | ``` 29 | python3 run-slice.py -i /path/to/binary -o /path/to/slices -n --ida_path /path/to/idat64 30 | ``` 31 | 32 | The script invokes IDA Pro to analyze the binary and perform slicing for indirect callsites and candidate callees. 33 | 34 | **b. Tokenize the slices** 35 | 36 | ``` 37 | python3 preprocess.py -i /path/to/slices -o /path/to/tokenized_slices 38 | ``` 39 | 40 | The script tokenizes assembly instructions of slices. 41 | 42 | **c. Generate embeddings with doc2vec** 43 | 44 | ``` 45 | python3 store_emb.py -i /path/to/tokenized_slices -o /path/to/embeddings --doc2vec_model /path/to/doc2vec_model 46 | ``` 47 | 48 | The script transforms slices into embeddings with pretrained doc2vec model. 49 | 50 | **d. Predict with the Siamese network** 51 | 52 | ``` 53 | python3 pred.py -i /path/to/embeddings 54 | ``` 55 | 56 | The script outputs scores for each (indirect callsite, candidate callee). 57 | 58 | ## Tool for collecting indirect call 59 | 60 | Here is a qemu tcg plugin we've modified to collect indirect calls on x86_64: [`ibresolver`](https://github.com/Learner0x5a/ibresolver) 61 | 62 | -------------------------------------------------------------------------------- /SiameseModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FeatureExtractor(torch.nn.Module): 5 | def __init__(self,n_layers,dim_input,dim_hidden,dim_output): 6 | super(FeatureExtractor, self).__init__() 7 | self.n_layers = n_layers 8 | 9 | self.input_layer = torch.nn.Linear(dim_input,dim_hidden) 10 | self.fcn = torch.nn.Linear(dim_hidden,dim_hidden) 11 | self.output_layer = torch.nn.Linear(dim_hidden, dim_output) 12 | 13 | self.input_attention = torch.nn.Linear(dim_input, dim_input) 14 | # self.hidden_attention = torch.nn.Linear(dim_hidden, dim_hidden) 15 | self.activation_function = torch.nn.ReLU() 16 | self.input_norm = torch.nn.LayerNorm(dim_input) 17 | self.hidden_norm = torch.nn.LayerNorm(dim_hidden) 18 | 19 | def forward(self, x): 20 | x = self.input_norm(x) 21 | # att = self.input_attention(x) 22 | # x = torch.multiply(x, torch.softmax(att,dim=-1)) 23 | x = self.input_layer(x) 24 | x = self.hidden_norm(x) 25 | # x = torch.nn.Dropout()(x) 26 | x = self.activation_function(x) 27 | 28 | for i in range(self.n_layers): 29 | # att = self.hidden_attention(x) 30 | x = self.fcn(x) 31 | # x = torch.nn.LayerNorm(x.size()[-1])(x) 32 | # x = torch.nn.Dropout()(x) 33 | x = self.activation_function(x) 34 | 35 | x = self.output_layer(x) 36 | x = self.activation_function(x) 37 | return x 38 | 39 | 40 | 41 | class ContrastiveClassifier(torch.nn.Module): 42 | def __init__(self,n_layers_feature,dim_input_feature,dim_hidden_feature,dim_output_feature,n_layers_cls,dim_hidden_cls,dim_output_cls): 43 | super(ContrastiveClassifier, self).__init__() 44 | self.n_layers = n_layers_cls 45 | self.feature_extractor1 = FeatureExtractor(n_layers=n_layers_feature, dim_input=dim_input_feature, dim_hidden=dim_hidden_feature,dim_output=dim_output_feature) 46 | self.feature_extractor2 = FeatureExtractor(n_layers=n_layers_feature, dim_input=dim_input_feature, dim_hidden=dim_hidden_feature,dim_output=dim_output_feature) 47 | self.input_layer = torch.nn.Linear(2*dim_output_feature, dim_hidden_cls) 48 | self.hidden_layer = torch.nn.Linear(dim_hidden_cls, dim_hidden_cls) 49 | self.output_layer = torch.nn.Linear(dim_hidden_cls, dim_output_cls) 50 | self.batchnorm = torch.nn.LayerNorm(dim_hidden_cls) 51 | self.activation_function = torch.nn.ReLU() 52 | 53 | def forward(self, x1, x2): 54 | emb1 = self.feature_extractor1(x1) 55 | emb2 = self.feature_extractor2(x2) 56 | x = torch.cat((emb1,emb2),dim=-1) 57 | 58 | x = self.input_layer(x) 59 | x = self.batchnorm(x) 60 | # x = torch.nn.Dropout()(x) 61 | x = self.activation_function(x) 62 | 63 | for i in range(self.n_layers): 64 | x = self.hidden_layer(x) 65 | # x = self.batchnorm(x) 66 | # x = torch.nn.Dropout()(x) 67 | x = self.activation_function(x) 68 | 69 | x = self.output_layer(x) 70 | x = torch.sigmoid(x) 71 | 72 | return x -------------------------------------------------------------------------------- /model_bce_with_pregen_emb_2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vul337/Callee/aebcbc3821b5967a01ddc0b2fb0b1a1d1313cac5/model_bce_with_pregen_emb_2.pth -------------------------------------------------------------------------------- /pred.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from torch.utils.data import DataLoader 7 | from SiameseModel import ContrastiveClassifier 8 | import numpy as np 9 | from glob import glob 10 | from argparse import ArgumentParser 11 | 12 | 13 | class AICTPairWithPreGenEmbDataset(torch.utils.data.Dataset): 14 | def __init__(self, dataset_path): 15 | self.dataset_path = dataset_path 16 | print('Loading dataset...') 17 | self.emb_files = [] 18 | self.load_data() 19 | 20 | def __getitem__(self, idx): # per callsite 21 | caller_embs = [] 22 | callee_embs = [] 23 | with open(self.emb_files[idx], 'rb') as f: 24 | call_pairs = pickle.load(f) 25 | for caller_sig, caller_emb, callee_sig, callee_emb in tqdm(call_pairs): 26 | caller_embs.append(caller_emb) 27 | callee_embs.append(callee_emb) 28 | print(self.emb_files[idx]) 29 | return self.emb_files[idx], np.array(caller_embs), np.array(callee_embs) 30 | 31 | def __len__(self): 32 | return len(self.emb_files) 33 | 34 | def load_data(self): 35 | for slice_file in tqdm(glob('{}/*.pkl'.format(self.dataset_path))): 36 | self.emb_files.append(slice_file) 37 | 38 | 39 | if torch.cuda.is_available(): 40 | dev=torch.device('cuda') 41 | else: 42 | dev=torch.device('cpu') 43 | print(dev) 44 | 45 | 46 | 47 | if __name__ == '__main__': 48 | parser = ArgumentParser() 49 | parser.add_argument('-i','--emb_dir', type=str, help='embeddings dir', nargs='?', default='./aict-embeddings') 50 | parser.add_argument('--model', type=str, help='siamese network model', nargs='?', default='./model_bce_with_pregen_emb_2.pth') 51 | 52 | args = parser.parse_args() 53 | 54 | 55 | model = ContrastiveClassifier(3, 100, 256, 128, 1, 256, 1).to(dev) 56 | params_load = torch.load(args.model)['state_dict'] 57 | model.load_state_dict(params_load) 58 | 59 | aict_loader = DataLoader(AICTPairWithPreGenEmbDataset(args.emb_dir), batch_size = 1, num_workers=0, shuffle=True) 60 | model.eval() 61 | icts = {} 62 | with torch.no_grad(): 63 | for i, (binary_name, caller_embs, callee_embs) in tqdm(enumerate(aict_loader)): 64 | binary_name = binary_name[0] 65 | 66 | caller_embs = caller_embs.to(dev) 67 | caller_embs = torch.squeeze(caller_embs) 68 | callee_embs = callee_embs.to(dev) 69 | callee_embs = torch.squeeze(callee_embs) 70 | 71 | preds = model(caller_embs, callee_embs) 72 | 73 | print(f'Callsite {i}, preds:', preds.cpu().numpy()) 74 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | from util_tokenizer import asmTokenizer 4 | from argparse import ArgumentParser 5 | 6 | def get_slice_iter(_dir): 7 | for root,parent,files in os.walk(_dir): 8 | for slice_file in files: 9 | filepath = os.path.join(root, slice_file) 10 | filesize = os.path.getsize(filepath)/(1024*1024) 11 | if filesize > 50: 12 | print('Found large file: {}\t{} MB'.format(slice_file, filesize)) 13 | continue 14 | 15 | slice_paris_tokenized = [] 16 | with open(filepath,'r') as f: 17 | slice_pairs = f.readlines() 18 | for line in slice_pairs: 19 | caller_data, callee_data = line.split(' -> ') 20 | caller_sig, caller = caller_data.split('|') 21 | callee_sig, callee = callee_data.split('|') 22 | caller_insns = caller.strip().split('\t') 23 | callee_insns = callee.strip().split('\t') 24 | tokenized_caller = tokenizer.tokenize_doc(caller_insns) 25 | tokenized_callee = tokenizer.tokenize_doc(callee_insns) 26 | 27 | caller_data_tokenized = '{}|{}'.format(caller_sig, tokenized_caller) 28 | callee_data_tokenized = '{}|{}'.format(callee_sig, tokenized_callee) 29 | slice_paris_tokenized.append((caller_data_tokenized, callee_data_tokenized)) 30 | 31 | yield slice_file, slice_paris_tokenized 32 | 33 | 34 | if __name__ == '__main__': 35 | 36 | parser = ArgumentParser() 37 | parser.add_argument('-i','--slice_dir', type=str, help='slice dir', nargs='?', default='slice') 38 | parser.add_argument('-o','--output_dir', type=str, nargs='?', 39 | help='Output tokenized slice dir', default='./slice-tokenized') 40 | args = parser.parse_args() 41 | 42 | os.makedirs(args.output_dir, exist_ok=True) 43 | tokenizer = asmTokenizer() 44 | for slice_file, slice_paris_tokenized in tqdm(get_slice_iter(args.slice_dir)): 45 | with open(os.path.join(args.output_dir, slice_file),'w') as f: 46 | for caller_data_tokenized, callee_data_tokenized in slice_paris_tokenized: 47 | f.write('{} -> {}\n'.format(caller_data_tokenized, callee_data_tokenized)) 48 | 49 | os.system("awk '!seen[$0]++' {} > {}.uniq".format(os.path.join(args.output_dir, slice_file), os.path.join(args.output_dir, slice_file))) 50 | -------------------------------------------------------------------------------- /run-slice.py: -------------------------------------------------------------------------------- 1 | import os 2 | import magic 3 | from concurrent.futures import ProcessPoolExecutor 4 | import subprocess 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('-i','--bin', type=str, help='binary') 9 | parser.add_argument('-o','--output_dir', type=str, nargs='?', 10 | help='Output dir', default='./output-slice') 11 | parser.add_argument('-n','--workers', type=int, nargs='?', 12 | help='Max Workers', default=1) 13 | parser.add_argument('--ida_path', type=str, nargs='?', 14 | help='idapro dir', default='/workspace/idapro-7.6/idat64') 15 | 16 | args = parser.parse_args() 17 | 18 | 19 | def run(bin_path, output_dir): 20 | filename = bin_path.split(os.path.sep)[-1] 21 | script_cmd = './util_slice.py {}'.format(output_dir) 22 | ida_cmd = 'env TERM=xterm {} -L"log/{}.log" -A -S"{}" {}'.format(args.ida_path, filename, script_cmd, bin_path) 23 | print(ida_cmd) 24 | subprocess.run(ida_cmd, shell=True) 25 | 26 | def main(): 27 | output_dir = os.path.join(args.output_dir,'aict') 28 | os.makedirs(output_dir, exist_ok=True) 29 | 30 | run(args.bin, output_dir) 31 | 32 | if __name__ == '__main__': 33 | main() 34 | -------------------------------------------------------------------------------- /store_emb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | import gensim 5 | import random 6 | import numpy as np 7 | from tqdm import tqdm 8 | from glob import glob 9 | from argparse import ArgumentParser 10 | 11 | class CallPair(object): 12 | def __init__(self, line): 13 | caller_data_tokneized, callee_data_tokneized = line.split(' -> ') 14 | caller_sig, caller_insns = caller_data_tokneized.split('|') 15 | callee_sig, callee_insns = callee_data_tokneized.split('|') 16 | 17 | self.caller_sig = caller_sig 18 | self.caller_insns = caller_insns 19 | self.callee_sig = callee_sig 20 | self.callee_insns = callee_insns 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = ArgumentParser() 25 | parser.add_argument('-i','--input_dir', type=str, help='input slice dir, which contains slice pairs to be embedded into vectors', 26 | nargs='?', default='slice-tokenized') 27 | parser.add_argument('--doc2vec_model', type=str, help='path to the trained doc2vec model', 28 | nargs='?', default='./doc2vec.model.dbow') 29 | parser.add_argument('-o','--output_dir', type=str, nargs='?', 30 | help='Output dir containing the embedded slice pairs', default='./aict-embeddings') 31 | args = parser.parse_args() 32 | 33 | os.makedirs(args.output_dir, exist_ok=True) 34 | 35 | call_pairs = {} 36 | doc2vec_model = gensim.models.Doc2Vec.load(args.doc2vec_model) 37 | for slice_file in tqdm(glob('{}/*.slice.uniq'.format(args.input_dir))): 38 | call_pairs[slice_file] = [] 39 | with open(slice_file, 'r') as f: 40 | for line in f: 41 | call_pairs[slice_file].append(CallPair(line)) 42 | 43 | call_pairs_pkl = {} 44 | 45 | for slice_file in call_pairs: 46 | call_pairs_pkl[slice_file] = [] 47 | for idx, pair in tqdm(enumerate(call_pairs[slice_file])): 48 | caller_emb = doc2vec_model.infer_vector(pair.caller_insns.strip().split()) 49 | callee_emb = doc2vec_model.infer_vector(pair.callee_insns.strip().split()) 50 | 51 | call_pairs_pkl[slice_file].append([pair.caller_sig, caller_emb, pair.callee_sig, callee_emb]) 52 | with open('{}.pkl'.format(os.path.join(args.output_dir, slice_file.split(os.path.sep)[-1])), 'wb') as f: 53 | pickle.dump(call_pairs_pkl[slice_file], f) 54 | 55 | -------------------------------------------------------------------------------- /util_slice.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | from typing import Dict, List, Tuple 5 | import idc 6 | import idaapi 7 | import idautils 8 | import ida_pro 9 | import ida_auto 10 | import ida_nalt 11 | import ida_funcs 12 | ida_auto.auto_wait() 13 | 14 | from capstone import * 15 | md = Cs(CS_ARCH_X86, CS_MODE_64) 16 | md.detail = True 17 | 18 | text_start = idc.get_segm_by_sel(idc.selector_by_name(".text")) 19 | text_end = idc.get_segm_end(text_start) 20 | got_start = idc.get_segm_by_sel(idc.selector_by_name(".got")) 21 | got_end = idc.get_segm_end(got_start) 22 | got_plt_start = idc.get_segm_by_sel(idc.selector_by_name(".got.plt")) 23 | got_plt_end = idc.get_segm_end(got_plt_start) 24 | plt_start = idc.get_segm_by_sel(idc.selector_by_name(".plt")) 25 | plt_end = idc.get_segm_end(plt_start) 26 | plt_got_start = idc.get_segm_by_sel(idc.selector_by_name(".plt.got")) 27 | plt_got_end = idc.get_segm_end(plt_got_start) 28 | 29 | at_blacklist = ['main', '_start', '__do_global_dtors_aux','frame_dummy', '__lib_csu_init', '__lib_csu_fini'] 30 | exit_list = ["exit","_exit","terminate","_terminate"] 31 | init_reg_map = {"rdi":"rdi", "rsi":"rsi", "rdx":"rdx", "rcx":"rcx", "r8":"r8", "r9":"r9", "rax":"rax", "rsp":"rsp", "rbp":"rbp", 32 | "edi":"rdi", "esi":"rsi", "edx":"rdx", "ecx":"rcx", "r8d":"r8", "r9d":"r9", "eax":"rax", "esp":"rsp", "ebp":"rbp", 33 | "di":"rdi", "si":"rsi", "dx":"rdx", "cx":"rcx", "r8w":"r8", "r9w":"r9", "ax":"rax", "sp":"rsp", "bp":"rbp", 34 | "dil":"rdi", "sil":"rsi", "dl":"rdx", "cl":"rcx", "r8b":"r8", "r9b":"r9", "al":"rax", "spl":"rsp", "bpl":"rbp" , 35 | "xmm0":"zmm0","ymm0":"zmm0","zmm0":"zmm0", 36 | "xmm1":"zmm1","ymm1":"zmm1","zmm1":"zmm1", 37 | "xmm2":"zmm2","ymm2":"zmm2","zmm2":"zmm2", 38 | "xmm3":"zmm3","ymm3":"zmm3","zmm3":"zmm3", 39 | } 40 | reversed_reg_map = { 41 | "rdi": ["rdi","edi","di","dil"], 42 | "rsi": ["rsi","esi","si","sil"], 43 | "rdx": ["rdx","edx","dx","dl"], 44 | "rcx": ["rcx","ecx","cx","cl"], 45 | "r8": ["r8","r8d","r8w","r8b"], 46 | "r9": ["r9","r9d","r9w","r9b"], 47 | "rax": ["rax","eax","ax","al"], 48 | "rsp": ["rsp","esp","sp","spl"], 49 | "rbp": ["rbp","ebp","bp","bpl"], 50 | "zmm0": ["zmm0","ymm0","xmm0"], 51 | "zmm1": ["zmm1","ymm1","xmm1"], 52 | "zmm2": ["zmm2","ymm2","xmm2"], 53 | "zmm3": ["zmm3","ymm3","xmm3"], 54 | } 55 | 56 | reg_rw_threshold = 1 57 | call_ins_threshold = 1 58 | j_ins_threshold = 1 59 | 60 | SKIP_THRESHOLD = 5 61 | ALL_FUNCTIONS = list(idautils.Functions()) 62 | 63 | def getRealAddr(addr): 64 | 65 | initAddr = addr 66 | 67 | while initAddr < text_start or initAddr >= text_end: 68 | initAddr = idc.get_operand_value(initAddr, 0) 69 | 70 | xref = idautils.XrefsFrom(initAddr,0) 71 | xreflist = list(xref) 72 | lenlist = len(xreflist) 73 | if lenlist == 1: 74 | initAddr = xreflist[0].to 75 | elif lenlist == 0: 76 | initAddr = 0 77 | break 78 | else: 79 | print("0x%x: more than 1 xref?"%initAddr) 80 | 81 | return initAddr 82 | 83 | def is_tail_call(opcode, addr): 84 | flag = False 85 | if opcode.startswith('j'): 86 | if idc.get_operand_value(addr, 0) in ALL_FUNCTIONS: 87 | flag = True 88 | elif opcode == 'jmp' and (not ' short ' in idc.GetDisasm(addr)): 89 | flag = True 90 | 91 | return flag 92 | 93 | def getDisasmCapstone(addr): 94 | insn = None 95 | r,w = [],[] 96 | code = idc.get_bytes(addr, idc.get_item_size(addr)) 97 | if not code: 98 | return '',[],[] 99 | for i in md.disasm(code, addr): 100 | insn = "%s %s" % (i.mnemonic, i.op_str) 101 | if insn.startswith('nop'): 102 | continue 103 | (regs_read, regs_write) = i.regs_access() 104 | if regs_read: 105 | for reg in regs_read: 106 | # print "\tRead REG: %s" %(i.reg_name(reg)) 107 | r.append("%s"%i.reg_name(reg)) 108 | if regs_write: 109 | for reg in regs_write: 110 | # print "\tWrite REG: %s" %(i.reg_name(reg)) 111 | w.append("%s"%i.reg_name(reg)) 112 | 113 | return insn,r,w 114 | 115 | 116 | def get_num_insns(func_ea): 117 | 118 | if func_ea == idc.BADADDR: 119 | iter = func_ea 120 | backward_count = 0 121 | while backward_count < 100: 122 | backward_count += 1 123 | iter = idc.prev_head(iter) 124 | if iter in ALL_FUNCTIONS or idc.print_insn_mnem(iter) == 'retn': 125 | break 126 | func_start = idc.next_head(iter) 127 | 128 | iter = func_ea 129 | forward_count = 0 130 | while forward_count < 100: 131 | forward_count += 1 132 | iter = idc.next_head(iter) 133 | if iter in ALL_FUNCTIONS or idc.print_insn_mnem(iter) == 'retn': 134 | break 135 | func_end = idc.prev_head(iter) 136 | 137 | num_insns = backward_count + forward_count 138 | 139 | else: 140 | num_insns = len(list(idautils.FuncItems(func_ea))) 141 | 142 | return num_insns 143 | 144 | def get_func_boudary(ea): 145 | func_ea = idc.get_func_attr(ea, idc.FUNCATTR_START) 146 | 147 | if func_ea == idc.BADADDR: 148 | iter = func_ea 149 | backward_count = 0 150 | while backward_count < 100: 151 | backward_count += 1 152 | iter = idc.prev_head(iter) 153 | if iter in ALL_FUNCTIONS or idc.print_insn_mnem(iter) == 'retn': 154 | break 155 | functionStart = idc.next_head(iter) 156 | 157 | iter = func_ea 158 | forward_count = 0 159 | while forward_count < 100: 160 | forward_count += 1 161 | iter = idc.next_head(iter) 162 | if iter in ALL_FUNCTIONS or idc.print_insn_mnem(iter) == 'retn': 163 | break 164 | functionEnd = idc.prev_head(iter) 165 | 166 | else: 167 | functionStart = idc.get_func_attr(ea, idc.FUNCATTR_START) 168 | functionEnd = idc.find_func_end(functionStart) 169 | 170 | return functionStart, functionEnd 171 | 172 | 173 | 174 | class Callee: 175 | def __init__(self, addr) -> None: 176 | self.addr = addr 177 | self.slices = [] 178 | self.signature = [] 179 | self.functionStart = self.addr 180 | self.functionEnd = idc.find_func_end(self.functionStart) 181 | self.functionName = idc.get_func_name(self.addr) 182 | self.num_insns = len(list(idautils.FuncItems(self.functionStart))) 183 | 184 | def _SliceOnRegs(self, reg_count : dict, reg_map : dict) -> Tuple[list, list]: 185 | 186 | addr = self.addr 187 | signature = [] 188 | slices = [] 189 | ret_flag = False 190 | 191 | call_ins_count = 0 192 | j_ins_count = 0 193 | 194 | reg_status = {} 195 | for key in reg_count: 196 | reg_status[key] = "" 197 | 198 | while self.functionStart <= addr < self.functionEnd: 199 | flag = False 200 | 201 | opcode = idc.print_insn_mnem(addr) 202 | if opcode.startswith("nop"): 203 | addr = idc.next_head(addr) 204 | continue 205 | 206 | if opcode.startswith("call"): 207 | if call_ins_count < call_ins_threshold: 208 | call_ins_count += 1 209 | flag = True 210 | 211 | if j_ins_count < j_ins_threshold: 212 | if is_tail_call(opcode, addr): 213 | j_ins_count += 1 214 | flag = True 215 | 216 | insn,r,w = getDisasmCapstone(addr) 217 | for reg in r: 218 | if not reg in reg_map.keys(): 219 | continue 220 | reg_status[reg_map[reg]] += "r" 221 | if reg_map[reg] == "rax": 222 | continue 223 | if reg_count[reg_map[reg]] > reg_rw_threshold: 224 | continue 225 | reg_count[reg_map[reg]] += 1 226 | 227 | flag = True 228 | 229 | for reg in w: 230 | if not reg in reg_map.keys(): 231 | continue 232 | reg_status[reg_map[reg]] += "w" 233 | if reg_map[reg] != "rax": 234 | continue 235 | if reg_count[reg_map[reg]] > reg_rw_threshold: 236 | continue 237 | reg_count[reg_map[reg]] += 1 238 | ret_flag = True 239 | flag = True 240 | 241 | if flag: 242 | # print(insn) 243 | slices.append(insn) 244 | 245 | addr = idc.next_head(addr) 246 | 247 | 248 | float_reg_count = {} 249 | for key in reg_count: 250 | if key.startswith("zmm"): 251 | float_reg_count[key] = reg_count[key] 252 | 253 | for reg in reg_count: 254 | if reg_count[reg] > 0: 255 | signature.append(reg) 256 | 257 | for reg in float_reg_count: 258 | if float_reg_count[reg] > 0 and (reg not in signature): 259 | signature.append(reg) 260 | 261 | if ret_flag and ("rax" not in signature): 262 | signature.append("rax") 263 | 264 | return (signature, slices) 265 | 266 | def calleeSlice(self): 267 | 268 | reg_count = {"rdi":0,"rsi":0,"rdx":0,"rcx":0,"r8":0,"r9":0,"rsp":0,"rbp":0,"rax":0,"zmm0":0,"zmm1":0,"zmm2":0,"zmm3":0} 269 | signature,raw_slices = self._SliceOnRegs(reg_count, init_reg_map) 270 | 271 | new_reg_count = {} 272 | new_reg_map = {} 273 | for reg in signature: 274 | new_reg_count[reg] = 0 275 | for key in reversed_reg_map[reg]: 276 | new_reg_map[key] = reg 277 | 278 | signature,refined_slices = self._SliceOnRegs(new_reg_count, new_reg_map) 279 | 280 | self.signature = signature 281 | self.slices = refined_slices 282 | 283 | 284 | class Callsite: 285 | def __init__(self, addr) -> None: 286 | self.addr = addr 287 | self.functionName = idc.get_func_name(self.addr) 288 | self.functionStart, self.functionEnd = get_func_boudary(self.addr) 289 | self.signature = [] 290 | self.slices = [] 291 | 292 | 293 | def _BackwardSliceOnRegs(self, reg_count: dict, reg_map: dict) -> Tuple[list, list]: 294 | 295 | ret_slice = self._ForwardRetSlice() 296 | 297 | 298 | slices = [] 299 | signature = [] 300 | call_ins_count = 0 301 | j_ins_count = 0 302 | 303 | iter = self.addr 304 | while self.functionStart <= iter: 305 | flag = False 306 | iter = idc.prev_head(iter) 307 | opcode = idc.print_insn_mnem(iter) 308 | if opcode.startswith("nop"): 309 | continue 310 | 311 | insn,r,w = getDisasmCapstone(iter) 312 | 313 | if opcode.startswith("call"): 314 | break 315 | 316 | if j_ins_count < j_ins_threshold: 317 | if is_tail_call(opcode, iter): 318 | j_ins_count += 1 319 | flag = True 320 | else: 321 | break 322 | 323 | for reg in w: 324 | if not reg in reg_map.keys(): 325 | continue 326 | reg = reg_map[reg] 327 | if reg_map[reg] == "rax": 328 | continue 329 | if reg_count[reg_map[reg]] > reg_rw_threshold: 330 | continue 331 | 332 | reg_count[reg_map[reg]] += 1 333 | flag = True 334 | 335 | if flag: 336 | slices.append(insn) 337 | 338 | slices.reverse() 339 | slices.append("callsite callee") 340 | slices.extend(ret_slice) 341 | 342 | float_reg_count = {} 343 | for key in reg_count: 344 | if key.startswith("zmm"): 345 | float_reg_count[key] = reg_count[key] 346 | 347 | for reg in reg_count: 348 | if reg_count[reg] > 0: 349 | signature.append(reg) 350 | for reg in float_reg_count: 351 | if float_reg_count[reg] > 0 and (reg not in signature): 352 | signature.append(reg) 353 | 354 | if len(ret_slice) > 0 and ("rax" not in signature): 355 | signature.append("rax") 356 | 357 | return (signature, slices) 358 | 359 | 360 | def callsiteslice(self): 361 | reg_count = {"rdi":0,"rsi":0,"rdx":0,"rcx":0,"r8":0,"r9":0,"rsp":0,"rbp":0,"rax":0,"zmm0":0,"zmm1":0,"zmm2":0,"zmm3":0} 362 | signature,raw_slices = self._BackwardSliceOnRegs(reg_count, init_reg_map) 363 | 364 | new_reg_count = {} 365 | new_reg_map = {} 366 | for reg in signature: 367 | new_reg_count[reg] = 0 368 | for key in reversed_reg_map[reg]: 369 | new_reg_map[key] = reg 370 | 371 | signature,refined_slices = self._BackwardSliceOnRegs(new_reg_count, new_reg_map) 372 | 373 | self.signature = signature 374 | self.slices = refined_slices 375 | 376 | 377 | 378 | def _ForwardRetSlice(self) -> list: 379 | reg_count = {"rax":0} 380 | 381 | slices = [] 382 | call_ins_count = 0 383 | j_ins_count = 0 384 | addr = self.addr 385 | while self.functionStart <= addr < self.functionEnd: 386 | flag = False 387 | addr = idc.next_head(addr) 388 | 389 | opcode = idc.print_insn_mnem(addr) 390 | if opcode.startswith("nop"): 391 | addr = idc.next_head(addr) 392 | continue 393 | 394 | insn,r,w = getDisasmCapstone(addr) 395 | if opcode.startswith("call"): 396 | break 397 | 398 | for reg in r: 399 | if not reg in init_reg_map.keys(): 400 | continue 401 | if init_reg_map[reg] != "rax": 402 | continue 403 | 404 | reg_count["rax"] += 1 405 | flag = True 406 | 407 | if flag: 408 | slices.append(insn) 409 | 410 | if reg_count["rax"] > reg_rw_threshold: 411 | break 412 | 413 | return slices 414 | 415 | 416 | if __name__ == '__main__': 417 | if len(idc.ARGV) < 2: 418 | print('\n\nGenerating AICT Eval Data') 419 | print('\tNeed to specify the output dir') 420 | print('\tUsage: /path/to/ida -A -Llog/{}.log -S"{} " /path/to/binary\n\n'.format(ida_nalt.get_root_filename(), idc.ARGV[0])) 421 | ida_pro.qexit(1) 422 | 423 | output_dir = idc.ARGV[1] 424 | 425 | AT_FUNCTIONS = [] 426 | ICALLSITES = [] 427 | text_func_count = 0 428 | for func in tqdm(idautils.Functions(), desc="Slicing..."): 429 | 430 | func_name = idc.get_func_name(func) 431 | demangle_name = idc.demangle_name(func_name, idc.get_inf_attr(idc.INF_SHORT_DEMNAMES)) 432 | if demangle_name: 433 | func_name = demangle_name 434 | if (func < plt_end and func >= plt_start) or (func < plt_got_end and func >= plt_got_start): 435 | func = getRealAddr(func) 436 | 437 | if text_start<= func < text_end: 438 | text_func_count += 1 439 | if list(idautils.DataRefsTo(func)): 440 | num_insns = get_num_insns(func) 441 | if num_insns < SKIP_THRESHOLD: 442 | print('Small function:', func_name, num_insns) 443 | continue 444 | if not func_name in at_blacklist: 445 | this_func = Callee(func) 446 | this_func.calleeSlice() 447 | AT_FUNCTIONS.append((set(this_func.signature), this_func.slices)) 448 | 449 | for (startea, endea) in idautils.Chunks(func): 450 | for head in idautils.Heads(startea, endea): 451 | opcode = idc.print_insn_mnem(head) 452 | 453 | if opcode == "call": 454 | optype = idc.get_operand_type(head, 0) 455 | callee_ea = idc.get_operand_value(head, 0) 456 | 457 | if 1<= optype < 5: 458 | callsite = Callsite(head) 459 | callsite.callsiteslice() 460 | callsite_slices = callsite.slices 461 | callsite_sig = set(callsite.signature) 462 | ICALLSITES.append((callsite_sig, callsite_slices)) 463 | 464 | callsite_idx = 0 465 | output_dir = idc.ARGV[1] 466 | for callsite_sig, callsite_slices in tqdm(ICALLSITES, desc="Storing slices..."): 467 | all_pairs = '' 468 | for callee_sig, callee_slices in AT_FUNCTIONS: 469 | all_pairs += '{}|{} -> {}|{}\n'.format( 470 | '.'.join(callsite_sig), 471 | '\t'.join(callsite_slices), 472 | '.'.join(callee_sig), 473 | '\t'.join(callee_slices) 474 | ) 475 | with open(os.path.join(output_dir, '{}_{}.slice'.format(ida_nalt.get_root_filename(), callsite_idx)),'w') as f: 476 | f.write(all_pairs) 477 | callsite_idx += 1 478 | 479 | ida_pro.qexit(0) 480 | -------------------------------------------------------------------------------- /util_tokenizer.py: -------------------------------------------------------------------------------- 1 | class asmTokenizer: 2 | def __init__(self): 3 | pass 4 | 5 | def tokenize_insn(self,insn): 6 | '''mov eax, 1 -> mov eax, imm''' 7 | insn = insn.replace('(', ' ( ') 8 | insn = insn.replace(')', ' ) ') 9 | insn = insn.replace('[', ' [ ') 10 | insn = insn.replace(']', ' ] ') 11 | insn = insn.replace(',', ' , ') 12 | insn = insn.replace('*', ' * ') 13 | insn = insn.replace('+', ' + ') 14 | insn = insn.replace('-', ' - ') 15 | insn = insn.replace(':', ' : ') 16 | ins_split = insn.split() 17 | while '' in ins_split: 18 | ins_split.remove('') 19 | 20 | opcode = ins_split[0] 21 | operands = ins_split[1:] 22 | new_insn = '' 23 | new_insn += opcode + ' ' 24 | 25 | for opnd in operands: 26 | 27 | if opnd.isdigit() or opnd.startswith('0x'):# 处理数字 28 | # print('found digit:',opnd) 29 | if opcode == 'call' or opcode[0] == 'j': 30 | new_opnd = 'addr' 31 | else: 32 | new_opnd = 'num' 33 | else: 34 | new_opnd = opnd 35 | new_insn += new_opnd + ' ' 36 | 37 | new_insn = new_insn.strip() 38 | return new_insn 39 | 40 | def tokenize_doc(self,doc): 41 | tokenized_doc = '' 42 | for insn in doc: 43 | insn = insn.strip() 44 | if insn: 45 | tokenized_doc += self.tokenize_insn(insn) + ' ' 46 | tokenized_doc = tokenized_doc.strip() 47 | return tokenized_doc 48 | --------------------------------------------------------------------------------