├── .gitignore ├── LICENSE ├── README.md ├── data.py ├── datautils ├── README.md ├── clean.sh ├── dataset │ ├── libcap-git-setcap-O0-00ac9f3e2f406891116850f0ddffd008 │ ├── libcap-git-setcap-O1-90826146c17034375256122b115dc849 │ ├── libcap-git-setcap-O2-8dc43f20ea80b7703f6973a1ea86e8b8 │ ├── libcap-git-setcap-O3-e8e5b8752b5ba79d135c254273edadae │ └── libcap-git-setcap-Os-1ac8a004dfd1d92cb18bf4aa36d817e3 ├── playdata.py ├── process.py ├── run.py └── util │ ├── base.py │ └── pairdata.py ├── eval_save.py ├── fasteval.py ├── figures └── poolsizecompare.png ├── finetune.py ├── jtrans_tokenizer ├── special_tokens_map.json ├── tokenizer_config.json └── vocab.txt ├── readidadata.py └── tokenizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | experiments/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 VUL337 Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jTrans 2 | This repo is the official code of **jTrans: Jump-Aware Transformer for Binary Code Similarity Detection**. 3 | 4 | ![Illustrating the performance of the proposed jTrans](/figures/poolsizecompare.png) 5 | 6 | ## News 7 | * \[2023/3/2\] Update an writeup on using jTrans for binary diffing in [HackerGame2022](https://github.com/USTC-Hackergame/hackergame2022-writeups/tree/master/official/%E7%81%AB%E7%9C%BC%E9%87%91%E7%9D%9B%E7%9A%84%E5%B0%8F%20E). 8 | * \[2022/7/7\] We update BinaryCorp with the original [binaries](https://cloud.vul337.team:8443/s/W57ZWXxn7zSKG4q). 9 | * \[2022/6/18\] We release the code and models of jTrans. 10 | * \[2022/6/9\] We release the preprocessing code and [BinaryCorp](https://cloud.vul337.team:8443/s/cxnH8DfZTADLKCs), the dataset we used in our paper. 11 | * \[2022/5/26\] jTrans is now on [ArXiv](https://arxiv.org/pdf/2205.12713.pdf). 12 | 13 | ## Writeups 14 | * [Binary Diffing](https://github.com/USTC-Hackergame/hackergame2022-writeups/tree/master/official/%E7%81%AB%E7%9C%BC%E9%87%91%E7%9D%9B%E7%9A%84%E5%B0%8F%20E#jtrans) 15 | * Welcome PRs for more writeups :) 16 | 17 | ## Get Started 18 | ### Prerequisites 19 | - Linux (MacOS and Windows are not currently officially supported) 20 | - Python 3.8+ 21 | - PyTorch 1.10+ 22 | - CUDA 10.2+ 23 | - IDA pro 7.5+ (only used for dataset processing) 24 | 25 | ### Quick Start 26 | 27 | **a. Create a conda virtual environment and activate it.** 28 | ``` 29 | conda create -n jtrans python=3.8 pandas tqdm -y 30 | conda activate jtrans 31 | ``` 32 | 33 | **b. Install PyTorch and other packages.** 34 | ``` 35 | conda install pytorch cudatoolkit=11.0 -c pytorch 36 | python -m pip install simpletransformers networkx pyelftools 37 | ``` 38 | 39 | **c. Get code and models of jTrans.** 40 | ``` 41 | git clone https://github.com/vul337/jTrans.git && cd jTrans 42 | ``` 43 | Download [experiments.tar.gz](https://cloud.vul337.team:8443/s/wmqzYFyJnSEfEgm) and [models.tar.gz](https://cloud.vul337.team:8443/s/tM5qGQPJa6iynCf) and extract them. 44 | ``` 45 | tar -xzvf experiments.tar.gz && tar -xzvf models.tar.gz 46 | ``` 47 | 48 | **d. Get the BinaryCorp dataset 49 | Download the processed dataset from this [link](https://cloud.vul337.team:8443/s/cxnH8DfZTADLKCs)** 50 | 51 | **e. Finetune new models on the BinaryCorp** 52 | ``` 53 | python finetune.py -h 54 | ``` 55 | 56 | **d. Evaluation** 57 | ``` 58 | python eval_save.py -h 59 | python fasteval.py -h 60 | ``` 61 | try to evaluate jTrans on BinaryCorp-3M after extracting experiments.tar.gz 62 | ``` 63 | python fasteval.py 64 | ``` 65 | 66 | **f. Try jTrans on your own binaries** 67 | 68 | Make sure you have IDA pro 7.5+ and following the instructions at [datautils](datautils/README.md). After extracting features of your binaries, you can try jTrans on them such as the usage at [eval_save.py](./eval_save.py). 69 | 70 | ## Dataset 71 | - We present a new large-scale and diversified dataset, [BinaryCorp](https://cloud.vul337.team:8443/s/cxnH8DfZTADLKCs), for the task of binary code similarity detection. 72 | - The description of the dataset can be found at [here](datautils/README.md) and we give an [example](datautils/playdata.py) for using BinaryCorp. 73 | - If you need to use features that we do not provide in advance, such as call graphs, you can download the raw binaries from [here](https://cloud.vul337.team:8443/s/W57ZWXxn7zSKG4q). 74 | 75 | ## Acknowledgement 76 | This project is not possible without multiple great open-sourced code bases. We list some notable examples below. 77 | 78 | * [transformers](https://github.com/huggingface/transformers) 79 | * [simpletransformers](https://github.com/ThilinaRajapakse/simpletransformers) 80 | 81 | ## Bibtex 82 | If this work or BinaryCorp dataset are helpful for your research, please consider citing the following BibTeX entry. 83 | 84 | ``` 85 | @inproceedings{10.1145/3533767.3534367, 86 | author = {Wang, Hao and Qu, Wenjie and Katz, Gilad and Zhu, Wenyu and Gao, Zeyu and Qiu, Han and Zhuge, Jianwei and Zhang, Chao}, 87 | title = {JTrans: Jump-Aware Transformer for Binary Code Similarity Detection}, 88 | year = {2022}, 89 | isbn = {9781450393799}, 90 | publisher = {Association for Computing Machinery}, 91 | address = {New York, NY, USA}, 92 | url = {https://doi.org/10.1145/3533767.3534367}, 93 | doi = {10.1145/3533767.3534367}, 94 | abstract = {Binary code similarity detection (BCSD) has important applications in various fields such as vulnerabilities detection, software component analysis, and reverse engineering. Recent studies have shown that deep neural networks (DNNs) can comprehend instructions or control-flow graphs (CFG) of binary code and support BCSD. In this study, we propose a novel Transformer-based approach, namely jTrans, to learn representations of binary code. It is the first solution that embeds control flow information of binary code into Transformer-based language models, by using a novel jump-aware representation of the analyzed binaries and a newly-designed pre-training task. Additionally, we release to the community a newly-created large dataset of binaries, BinaryCorp, which is the most diverse to date. Evaluation results show that jTrans outperforms state-of-the-art (SOTA) approaches on this more challenging dataset by 30.5% (i.e., from 32.0% to 62.5%). In a real-world task of known vulnerability searching, jTrans achieves a recall that is 2X higher than existing SOTA baselines.}, 95 | booktitle = {Proceedings of the 31st ACM SIGSOFT International Symposium on Software Testing and Analysis}, 96 | pages = {1–13}, 97 | numpages = {13}, 98 | keywords = {Binary Analysis, Similarity Detection, Neural Networks, Datasets}, 99 | location = {Virtual, South Korea}, 100 | series = {ISSTA 2022} 101 | } 102 | 103 | @article{wang2022jtrans, 104 | title={jTrans: Jump-Aware Transformer for Binary Code Similarity}, 105 | author={Wang, Hao and Qu, Wenjie and Katz, Gilad and Zhu, Wenyu and Gao, Zeyu and Qiu, Han and Zhuge, Jianwei and Zhang, Chao}, 106 | journal={arXiv preprint arXiv:2205.12713}, 107 | year={2022} 108 | } 109 | ``` 110 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from datautils.playdata import DatasetBase as DatasetBase 3 | import networkx 4 | import os 5 | import networkx as nx 6 | from collections import defaultdict 7 | from tqdm import tqdm 8 | import pickle 9 | import argparse 10 | import re 11 | import readidadata 12 | import torch 13 | import random 14 | import time 15 | MAXLEN=512 16 | 17 | vocab_data = open("./jtrans_tokenizer/vocab.txt").read().strip().split("\n") + ["[SEP]", "[PAD]", "[CLS]", "[MASK]"] 18 | my_vocab = defaultdict(lambda: 512, {vocab_data[i] : i for i in range(len(vocab_data))}) 19 | 20 | def help_tokenize(line): 21 | global my_vocab 22 | ret = {} 23 | split_line = line.strip().split(' ') 24 | split_line_len = len(split_line) 25 | if split_line_len <= 509: 26 | split_line = ['[CLS]']+split_line+['[SEP]'] 27 | attention_mask = [1] * len(split_line) + [0] * (512 - len(split_line)) 28 | split_line = split_line + (512-len(split_line))*['[PAD]'] 29 | else: 30 | split_line = ['[CLS]'] + split_line[:510] + ['[SEP]'] 31 | attention_mask = [1]*512 32 | input_ids = [my_vocab[e] for e in split_line] 33 | ret['input_ids'] = torch.tensor(input_ids, dtype=torch.long) 34 | ret['attention_mask'] = torch.tensor(attention_mask, dtype=torch.long) 35 | return ret 36 | 37 | def gen_funcstr(f,convert_jump): 38 | cfg=f[3] 39 | #print(hex(f[0])) 40 | bb_ls,code_lst,map_id=[],[],{} 41 | for bb in cfg.nodes: 42 | bb_ls.append(bb) 43 | bb_ls.sort() 44 | for bx in range(len(bb_ls)): 45 | bb=bb_ls[bx] 46 | asm=cfg.nodes[bb]['asm'] 47 | map_id[bb]=len(code_lst) 48 | for code in asm: 49 | operator,operand1,operand2,operand3,annotation=readidadata.parse_asm(code) 50 | code_lst.append(operator) 51 | if operand1!=None: 52 | code_lst.append(operand1) 53 | if operand2!=None: 54 | code_lst.append(operand2) 55 | if operand3!=None: 56 | code_lst.append(operand3) 57 | for c in range(len(code_lst)): 58 | op=code_lst[c] 59 | if op.startswith('hex_'): 60 | jumpaddr=int(op[4:],base=16) 61 | if map_id.get(jumpaddr): 62 | jumpid=map_id[jumpaddr] 63 | if jumpid < MAXLEN: 64 | code_lst[c]='JUMP_ADDR_{}'.format(jumpid) 65 | else: 66 | code_lst[c]='JUMP_ADDR_EXCEEDED' 67 | else: 68 | code_lst[c]='UNK_JUMP_ADDR' 69 | if not convert_jump: 70 | code_lst[c]='CONST' 71 | func_str=' '.join(code_lst) 72 | return func_str 73 | 74 | def load_unpair_data(datapath,filt=None,alldata=True,convert_jump=True,opt=None, fp=None): 75 | dataset = DatasetBase(datapath,filt, alldata) 76 | dataset.load_unpair_data() 77 | functions=[] 78 | for i in dataset.get_unpaird_data(): #proj, func_name, func_addr, asm_list, rawbytes_list, cfg, bai_featrue 79 | f = (i[2], i[3], i[4], i[5], i[6]) 80 | func_str=gen_funcstr(f,convert_jump) 81 | if len(func_str) > 0: 82 | fp.write(func_str+"\n") 83 | 84 | def load_paired_data(datapath,filt=None,alldata=True,convert_jump=True,opt=None,add_ebd=False): 85 | 86 | dataset = DatasetBase(datapath,filt,alldata, opt=opt) 87 | functions=[] 88 | func_emb_data=[] 89 | SUM=0 90 | for i in dataset.get_paired_data_iter(): #proj, func_name, func_addr, asm_list, rawbytes_list, cfg, bai_featrue 91 | functions.append([]) 92 | if add_ebd: 93 | func_emb_data.append({'proj':i[0],'funcname':i[1]}) 94 | for o in opt: 95 | if i[2].get(o): 96 | f=i[2][o] 97 | func_str=gen_funcstr(f,convert_jump) 98 | if len(func_str)>0: 99 | if add_ebd: 100 | func_emb_data[-1][o]=len(functions[-1]) 101 | functions[-1].append(func_str) 102 | SUM+=1 103 | 104 | print('TOTAL ',SUM) 105 | return functions,func_emb_data 106 | 107 | class FunctionDataset_CL(torch.utils.data.Dataset): #binary version dataset 108 | def __init__(self,tokenizer,path='../BinaryCorp/extract',filt=None,alldata=True,convert_jump_addr=True,opt=None,add_ebd=True): #random visit 109 | functions,ebds=load_paired_data(datapath=path,filt=filt,alldata=alldata,convert_jump=convert_jump_addr,opt=opt,add_ebd=add_ebd) 110 | self.datas=functions 111 | self.ebds=ebds 112 | self.tokenizer=tokenizer 113 | self.opt=opt 114 | self.convert_jump_addr=True 115 | def __getitem__(self, idx): #also return bad pair 116 | 117 | pairs=self.datas[idx] 118 | if self.opt==None: 119 | pos=random.randint(0,len(pairs)-1) 120 | pos2=random.randint(0,len(pairs)-1) 121 | while pos2==pos: 122 | pos2=random.randint(0,len(pairs)-1) 123 | f1=pairs[pos] #give three pairs 124 | f2=pairs[pos2] 125 | else: 126 | pos=0 127 | pos2=1 128 | f1=pairs[pos] 129 | f2=pairs[pos2] 130 | ftype=random.randint(0,len(self.datas)-1) 131 | while ftype==idx: 132 | ftype=random.randint(0,len(self.datas)-1) 133 | pair_opp=self.datas[ftype] 134 | pos3=random.randint(0,len(pair_opp)-1) 135 | f3=pair_opp[pos3] 136 | ret1 = help_tokenize(f1) 137 | token_seq1=ret1['input_ids'] 138 | mask1=ret1['attention_mask'] 139 | 140 | ret2 = help_tokenize(f2) 141 | token_seq2=ret2['input_ids'] 142 | mask2=ret2['attention_mask'] 143 | 144 | ret3 = help_tokenize(f3) 145 | token_seq3=ret3['input_ids'] 146 | mask3=ret3['attention_mask'] 147 | 148 | return token_seq1,token_seq2,token_seq3,mask1,mask2,mask3 149 | def __len__(self): 150 | return len(self.datas) 151 | 152 | class FunctionDataset_CL_Load(torch.utils.data.Dataset): #binary version dataset 153 | def __init__(self,tokenizer,path='../BinaryCorp/extract',filt=None,alldata=True,convert_jump_addr=True,opt=None,add_ebd=True, load=None): #random visit 154 | if load: 155 | start = time.time() 156 | self.datas = pickle.load(open(load, 'rb')) 157 | print('load time:', time.time() - start) 158 | self.tokenizer=tokenizer 159 | self.opt=opt 160 | self.convert_jump_addr=True 161 | else: 162 | functions,ebds=load_paired_data(datapath=path,filt=filt,alldata=alldata,convert_jump=convert_jump_addr,opt=opt,add_ebd=add_ebd) 163 | self.datas=[] 164 | for func_list in functions: 165 | tmp = [] 166 | for f in func_list: 167 | tmp.append(help_tokenize(f)) 168 | self.datas.append(tmp) 169 | self.ebds=ebds 170 | self.tokenizer=tokenizer 171 | self.opt=opt 172 | self.convert_jump_addr=True 173 | def __getitem__(self, idx): #also return bad pair 174 | 175 | pairs=self.datas[idx] 176 | if self.opt!=None: 177 | pos=random.randint(0,len(pairs)-1) 178 | pos2=random.randint(0,len(pairs)-1) 179 | while pos2==pos: 180 | pos2=random.randint(0,len(pairs)-1) 181 | f1=pairs[pos] #give three pairs 182 | f2=pairs[pos2] 183 | else: 184 | pos=0 185 | pos2=1 186 | f1=pairs[pos] 187 | f2=pairs[pos2] 188 | ftype=random.randint(0,len(self.datas)-1) 189 | while ftype==idx: 190 | ftype=random.randint(0,len(self.datas)-1) 191 | pair_opp=self.datas[ftype] 192 | pos3=random.randint(0,len(pair_opp)-1) 193 | f3=pair_opp[pos3] 194 | 195 | token_seq1=f1['input_ids'] 196 | mask1=f1['attention_mask'] 197 | 198 | token_seq2=f2['input_ids'] 199 | mask2=f2['attention_mask'] 200 | 201 | token_seq3=f3['input_ids'] 202 | mask3=f3['attention_mask'] 203 | 204 | return token_seq1,token_seq2,token_seq3,mask1,mask2,mask3 205 | def __len__(self): 206 | return len(self.datas) 207 | 208 | def load_filter_list(name): 209 | import csv 210 | f=csv.reader(open(name,'r')) 211 | S=set() 212 | for i in f: 213 | S.add(i[1]) 214 | return list(S) -------------------------------------------------------------------------------- /datautils/README.md: -------------------------------------------------------------------------------- 1 | # Prerequests 2 | - linux ida 3 | - IDA Python (python3) with networkx, pyelftools, binaryai 4 | ```bash 5 | python3 -m pip install pyelftools binaryai networkx 6 | ``` 7 | 8 | # Quick Start 9 | 10 | ## Directory Description 11 | - dataset (original binaries) 12 | - dataset_strip (temp directory for strip binary) 13 | - extract (extracted feature) 14 | - ida (linux ida directory) 15 | - idb (idb files for orignial binaries) 16 | - log (processing log) 17 | - util (scripts utilities) 18 | - base.py (binary process base class) 19 | - pairdata (pair the groudtruth for functions with different optimization) 20 | - process.py (IDA Python scripts for extrating features of binaries) 21 | - playdata.py (play with the extracted features) 22 | - run.py (parallel run) 23 | 24 | # Usage 25 | ## Extracting features for binary similarity task 26 | - copy all the compiled binaries with symbol table to dataset/ 27 | - change config.py for the suitable parameters 28 | - run the following commands 29 | ```bash 30 | ./ida/idapyswitch # switch to system python3 31 | python3 run.py 32 | ``` 33 | 34 | ## Use the extracted features 35 | - Have a look at util/playdata.py 36 | - There are two types of processed datasets, one for unsupervised learning (unpair_data) and another for supervised learning (pair_data), which are stored in .pickle files 37 | - unpair data 38 | ```python 39 | unpair_data = { 40 | 'foo': [ 41 | 0x400000, # function_addr 42 | ['sub rbp, rsp', 'ret'], # asm_list 43 | b"\x48\x29\xe5\xc3", # raw bytes 44 | cfg, # networkx DiGraph 45 | binaryai_feature 46 | ], 47 | 'bar': [ 48 | ... 49 | ] 50 | } 51 | # cfg traverse node 52 | def traverse_cfg_node(self, cfg): 53 | for node in cfg.nodes(): 54 | yield cfg.nodes[node]['asm'], cfg.nodes[node]['raw'] 55 | 56 | # cfg create code 57 | def get_cfg(self, func): 58 | 59 | def get_attr(block): 60 | asm,raw=[],b"" 61 | curr_addr = block.start_ea 62 | while curr_addr < block.end_ea: 63 | asm.append(idc.GetDisasm(curr_addr)) 64 | raw+=idc.get_bytes(curr_addr, idc.get_item_size(curr_addr)) 65 | curr_addr = idc.next_head(curr_addr, block.end_ea) 66 | return asm, raw 67 | 68 | nx_graph = nx.DiGraph() 69 | flowchart = idaapi.FlowChart(idaapi.get_func(func), flags=idaapi.FC_PREDS) 70 | for block in flowchart: 71 | # Make sure all nodes are added (including edge-less nodes) 72 | attr = get_attr(block) 73 | nx_graph.add_node(block.start_ea, asm=attr[0], raw=attr[1]) 74 | 75 | for pred in block.preds(): 76 | nx_graph.add_edge(pred.start_ea, block.start_ea) 77 | for succ in block.succs(): 78 | nx_graph.add_edge(block.start_ea, succ.start_ea) 79 | return nx_graph 80 | ``` 81 | - pair data 82 | the pair data is organized by the groundtruth (paired functions compiled by diferent optimization) 83 | ```python 84 | pair_data = { 85 | 'foo': [ 86 | unpair_foo_O0, # unpair_func_foo_O0 87 | unpair_foo_O1, # unpair_func_foo_O1 88 | unpair_foo_O2, # unpair_func_foo_O2 89 | ... 90 | ], 91 | 'bar': [ 92 | ... 93 | ] 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /datautils/clean.sh: -------------------------------------------------------------------------------- 1 | rm log/* 2 | rm idb/* 3 | rm -rf extract/* 4 | rm dataset_strip/* 5 | mkdir log idb extract dataset_strip -------------------------------------------------------------------------------- /datautils/dataset/libcap-git-setcap-O0-00ac9f3e2f406891116850f0ddffd008: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vul337/jTrans/1d405156865e3c582e4f183cbaa623b700ea0e9e/datautils/dataset/libcap-git-setcap-O0-00ac9f3e2f406891116850f0ddffd008 -------------------------------------------------------------------------------- /datautils/dataset/libcap-git-setcap-O1-90826146c17034375256122b115dc849: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vul337/jTrans/1d405156865e3c582e4f183cbaa623b700ea0e9e/datautils/dataset/libcap-git-setcap-O1-90826146c17034375256122b115dc849 -------------------------------------------------------------------------------- /datautils/dataset/libcap-git-setcap-O2-8dc43f20ea80b7703f6973a1ea86e8b8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vul337/jTrans/1d405156865e3c582e4f183cbaa623b700ea0e9e/datautils/dataset/libcap-git-setcap-O2-8dc43f20ea80b7703f6973a1ea86e8b8 -------------------------------------------------------------------------------- /datautils/dataset/libcap-git-setcap-O3-e8e5b8752b5ba79d135c254273edadae: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vul337/jTrans/1d405156865e3c582e4f183cbaa623b700ea0e9e/datautils/dataset/libcap-git-setcap-O3-e8e5b8752b5ba79d135c254273edadae -------------------------------------------------------------------------------- /datautils/dataset/libcap-git-setcap-Os-1ac8a004dfd1d92cb18bf4aa36d817e3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vul337/jTrans/1d405156865e3c582e4f183cbaa623b700ea0e9e/datautils/dataset/libcap-git-setcap-Os-1ac8a004dfd1d92cb18bf4aa36d817e3 -------------------------------------------------------------------------------- /datautils/playdata.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #!/usr/bin/env python3 3 | import os 4 | import networkx as nx 5 | from collections import defaultdict 6 | from tqdm import tqdm 7 | import pickle 8 | import argparse 9 | from functools import reduce 10 | 11 | class DatasetBase(object): 12 | def __init__(self, path, prefixfilter=None, all_data=True, opt=None): 13 | self.path = path 14 | self.prefixfilter = prefixfilter 15 | self.all_data = all_data 16 | self.unpaired = defaultdict(list) 17 | self.opt = opt 18 | if self.opt is not None: 19 | # assert len(self.opt) == 2, "set len(opt) != 2" 20 | self.paired = defaultdict(defaultdict) 21 | else: 22 | self.paired = defaultdict(list) 23 | assert os.path.exists(self.path), "Dataset Path Not Exists" 24 | assert (self.prefixfilter is not None) != self.all_data, "You should set prefixfilter with all_data = False" 25 | 26 | def traverse_file(self): 27 | for root, dirs, _ in os.walk(self.path): 28 | for dir in dirs: 29 | if self.all_data: 30 | for file in os.listdir(os.path.join(root, dir)): 31 | yield dir, file, os.path.join(root, dir, file) 32 | else: 33 | for filter in self.prefixfilter: 34 | if dir.startswith(filter): 35 | for file in os.listdir(os.path.join(root, dir)): 36 | yield dir, file, os.path.join(root, dir, file) 37 | 38 | def load_pickle(self, file): 39 | with open(file, 'rb') as f: 40 | return pickle.load(f) 41 | 42 | def load_unpair_data(self): 43 | for proj, filename, pkl_path in self.traverse_file(): 44 | if filename != 'saved_index.pkl': 45 | pickle_data = self.load_pickle(pkl_path) 46 | self.unpaired[proj].append(pickle_data) 47 | 48 | def load_pair_data(self): 49 | if self.opt is None: 50 | for proj, filename, pkl_path in self.traverse_file(): 51 | if filename == 'saved_index.pkl': 52 | pickle_data = self.load_pickle(pkl_path) 53 | self.paired[proj].append(pickle_data) 54 | else: 55 | for proj, filename, pkl_path in self.traverse_file(): 56 | if filename == 'saved_index.pkl': 57 | continue 58 | opt = filename.split('-')[-2] 59 | if opt in self.opt: 60 | print(filename) 61 | pickle_data = self.load_pickle(pkl_path) 62 | self.paired[proj][opt] = pickle_data 63 | 64 | def get_paired_data_iter(self): 65 | proj2pickle = defaultdict(defaultdict) 66 | for proj, filename, pkl_path in self.traverse_file(): 67 | if filename == 'saved_index.pkl': 68 | continue 69 | opt = filename.split('-')[-2] 70 | proj2pickle[proj][opt] = pkl_path 71 | 72 | for proj, pickle_path_dict in proj2pickle.items(): 73 | if len(pickle_path_dict) < 2: 74 | continue 75 | function_list = [] 76 | tmp_pickle_dict = {} 77 | for opt, pkl_path in pickle_path_dict.items(): 78 | pkl = pickle.load(open(pkl_path, 'rb')) 79 | function_list.append(list(pkl.keys())) 80 | tmp_pickle_dict[opt] = pkl 81 | function_set = reduce(lambda x,y : set(x) & set(y), function_list) 82 | for func_name in function_set: 83 | ret_func_data = defaultdict() 84 | for opt, pkl in tmp_pickle_dict.items(): 85 | ret_func_data[opt] = pkl[func_name] 86 | yield proj, func_name, ret_func_data 87 | 88 | 89 | def get_unpaird_data_iter(self): 90 | for proj, filename, pkl_path in self.traverse_file(): 91 | if filename != 'saved_index.pkl': 92 | pickle_data = self.load_pickle(pkl_path) 93 | for func_name, func_data in pickle_data.items(): 94 | func_addr, asm_list, rawbytes_list, cfg, biai_featrue = func_data 95 | yield proj, func_name, func_addr, asm_list, rawbytes_list, cfg, biai_featrue 96 | 97 | def get_unpaird_data(self): 98 | for proj, pkl_list in self.unpaired.items(): 99 | for pkl in pkl_list: 100 | for func_name, func_data in pkl.items(): 101 | func_addr, asm_list, rawbytes_list, cfg, biai_featrue = func_data 102 | yield proj, func_name, func_addr, asm_list, rawbytes_list, cfg, biai_featrue 103 | 104 | def get_paired_data(self): 105 | if self.opt is None: 106 | for proj, pkl_list in self.paired.items(): 107 | for pkl in pkl_list: 108 | for func_name, func_data_list in pkl.items(): 109 | yield proj, func_name, func_data_list 110 | # for func_data in func_data_list: 111 | # func_addr, asm_list, rawbytes_list, cfg, biai_featrue = func_data 112 | else: 113 | for proj, pkl_dict in self.paired.items(): 114 | if len(pkl_dict) < 2: 115 | continue 116 | function_list = [] 117 | for opt, pkl in pkl_dict.items(): 118 | function_list.append(list(pkl.keys())) 119 | function_set = reduce(lambda x,y : set(x) & set(y), function_list) 120 | for func_name in function_set: 121 | ret_func_data = defaultdict() 122 | for opt, pkl in pkl_dict.items(): 123 | ret_func_data[opt] = pkl[func_name] 124 | yield proj, func_name, ret_func_data 125 | 126 | def traverse_cfg_node(self, cfg): 127 | for node in cfg.nodes(): 128 | yield cfg.nodes[node]['asm'], cfg.nodes[node]['raw'] 129 | 130 | class DataBaseCrossCompiler(DatasetBase): 131 | def __init__(self, path, prefixfilter=None, all_data=True, opt=None): 132 | super(DataBaseCrossCompiler, self).__init__(path, prefixfilter, all_data, opt) 133 | 134 | def load_pair_data(self): 135 | if self.opt is not None: 136 | for proj, filename, pkl_path in self.traverse_file(): 137 | if filename == 'saved_index.pkl': 138 | continue 139 | opt = filename.split('-')[-2] 140 | compiler = filename.split('-')[-3] 141 | final_opt = compiler+opt 142 | if opt in self.opt: 143 | print(filename) 144 | pickle_data = self.load_pickle(pkl_path) 145 | self.paired[proj][final_opt] = pickle_data 146 | else: 147 | print("opt is None") 148 | exit(1) 149 | 150 | def get_paired_data(self): 151 | # return proj, func_name, ret_func_data 152 | # ret_func_data = { 153 | # opt: { 154 | # compiler : (func_addr, asm_list, rawbytes_list, cfg, biai_featrue) 155 | # } 156 | # } 157 | if self.opt is not None: 158 | for proj, pkl_dict in self.paired.items(): 159 | if len(pkl_dict) < 2: 160 | continue 161 | function_list = [] 162 | for opt, pkl in pkl_dict.items(): 163 | function_list.append(list(pkl.keys())) 164 | function_set = reduce(lambda x,y : set(x) & set(y), function_list) 165 | for func_name in function_set: 166 | ret_func_data = defaultdict() 167 | for opt, pkl in pkl_dict.items(): 168 | ret_func_data[opt] = pkl[func_name] 169 | yield proj, func_name, ret_func_data 170 | else: 171 | print("opt is None") 172 | exit(1) 173 | 174 | if __name__ == '__main__': 175 | parser = argparse.ArgumentParser() 176 | parser.add_argument('--dataset_path', type=str, default='../extract') 177 | parser.add_argument('--prefixfilter', type=str, default=None) 178 | parser.add_argument('--all_data', type=bool, default=True) 179 | args = parser.parse_args() 180 | dataset = DatasetBase(args.dataset_path, args.prefixfilter, args.all_data) 181 | # used for pretrain 182 | dataset.load_unpair_data() 183 | # used for contrastive learning 184 | # dataset.load_pair_data() 185 | pretrain_dataset = dataset.get_unpaird_data() 186 | cnt = 0 187 | for proj, func_name, func_addr, asm_list, rawbytes_list, cfg, biai_featrue in tqdm(pretrain_dataset): 188 | # print(proj, func_name, func_addr, asm_list, rawbytes_list, cfg, biai_featrue) 189 | pass 190 | 191 | # demo for contrastive learning dataset in different optimization level 192 | dataset = DatasetBase('./extract', ["arenatracker-git-ArenaTracker"], False, ['O0', 'O1']) 193 | dataset.load_pair_data() 194 | ft_dataset = dataset.get_paired_data() 195 | for proj, func_name, func_data in ft_dataset: 196 | for opt in ['O0', 'O1']: 197 | func_addr, asm_list, rawbytes_list, cfg, biai_featrue = func_data[opt] 198 | print(func_name, hex(func_addr)) 199 | 200 | # demo for cross compiler dataset 201 | dataset = DataBaseCrossCompiler('../extractDataset/coreutils', ["coreutils-b2sum"], False, ['O0', 'Os']) 202 | dataset.load_pair_data() 203 | cnt = 0 204 | functions = [] 205 | 206 | for proj, func_name, func_data in dataset.get_paired_data(): 207 | for opt in ['O0', 'Os']: 208 | for compiler in ['gcc', 'clang']: 209 | print('opt: ', opt, 'compiler', compiler) 210 | func_addr, asm_list, rawbytes_list, cfg, biai_featrue = func_data[compiler+opt] 211 | print(func_name, hex(func_addr)) 212 | cnt += 1 213 | if cnt > 5: 214 | break -------------------------------------------------------------------------------- /datautils/process.py: -------------------------------------------------------------------------------- 1 | import idc 2 | import idautils 3 | import idaapi 4 | import pickle 5 | import binaryai 6 | import networkx as nx 7 | from util.base import Binarybase 8 | 9 | SAVEROOT = "./extract" # dir of pickle files saved by IDA 10 | DATAROOT = "./dataset" # dir of binaries (not stripped) 11 | 12 | class BinaryData(Binarybase): 13 | def __init__(self, unstrip_path): 14 | super(BinaryData, self).__init__(unstrip_path) 15 | self.fix_up() 16 | 17 | def fix_up(self): 18 | for addr in self.addr2name: 19 | # incase some functions' instructions are not recognized by IDA 20 | idc.create_insn(addr) 21 | idc.add_func(addr) 22 | 23 | def get_asm(self, func): 24 | instGenerator = idautils.FuncItems(func) 25 | asm_list = [] 26 | for inst in instGenerator: 27 | asm_list.append(idc.GetDisasm(inst)) 28 | return asm_list 29 | 30 | def get_rawbytes(self, func): 31 | instGenerator = idautils.FuncItems(func) 32 | rawbytes_list = b"" 33 | for inst in instGenerator: 34 | rawbytes_list += idc.get_bytes(inst, idc.get_item_size(inst)) 35 | return rawbytes_list 36 | 37 | def get_cfg(self, func): 38 | 39 | def get_attr(block, func_addr_set): 40 | asm,raw=[],b"" 41 | curr_addr = block.start_ea 42 | if curr_addr not in func_addr_set: 43 | return -1 44 | # print(f"[*] cur: {hex(curr_addr)}, block_end: {hex(block.end_ea)}") 45 | while curr_addr <= block.end_ea: 46 | asm.append(idc.GetDisasm(curr_addr)) 47 | raw+=idc.get_bytes(curr_addr, idc.get_item_size(curr_addr)) 48 | curr_addr = idc.next_head(curr_addr, block.end_ea) 49 | return asm, raw 50 | 51 | nx_graph = nx.DiGraph() 52 | flowchart = idaapi.FlowChart(idaapi.get_func(func), flags=idaapi.FC_PREDS) 53 | func_addr_set = set([addr for addr in idautils.FuncItems(func)]) 54 | for block in flowchart: 55 | # Make sure all nodes are added (including edge-less nodes) 56 | attr = get_attr(block, func_addr_set) 57 | if attr == -1: 58 | continue 59 | nx_graph.add_node(block.start_ea, asm=attr[0], raw=attr[1]) 60 | # print(f"[*] bb: {hex(block.start_ea)}, asm: {attr[0]}") 61 | for pred in block.preds(): 62 | if pred.start_ea not in func_addr_set: 63 | continue 64 | nx_graph.add_edge(pred.start_ea, block.start_ea) 65 | for succ in block.succs(): 66 | if succ.start_ea not in func_addr_set: 67 | continue 68 | nx_graph.add_edge(block.start_ea, succ.start_ea) 69 | return nx_graph 70 | 71 | def get_binai_feature(self, func): 72 | return binaryai.ida.get_func_feature(func) 73 | 74 | def extract_all(self): 75 | for func in idautils.Functions(): 76 | if idc.get_segm_name(func) in ['.plt','extern','.init','.fini']: 77 | continue 78 | print("[+] %s" % idc.get_func_name(func)) 79 | asm_list = self.get_asm(func) 80 | rawbytes_list = self.get_rawbytes(func) 81 | cfg = self.get_cfg(func) 82 | bai_feature = self.get_binai_feature(func) 83 | yield (self.addr2name[func], func, asm_list, rawbytes_list, cfg, bai_feature) 84 | 85 | if __name__ == "__main__": 86 | import os 87 | from collections import defaultdict 88 | 89 | assert os.path.exists(DATAROOT), "DATAROOT does not exist" 90 | assert os.path.exists(SAVEROOT), "SAVEROOT does not exist" 91 | 92 | binary_abs_path = idc.get_input_file_path() 93 | filename = binary_abs_path.split('/')[-1][:-6] 94 | unstrip_path = os.path.join(DATAROOT, filename) 95 | idc.auto_wait() 96 | binary_data = BinaryData(unstrip_path) 97 | 98 | saved_dict = defaultdict(lambda: list) 99 | saved_path = os.path.join(SAVEROOT, filename + "_extract.pkl") # unpair data 100 | with open(saved_path, 'wb') as f: 101 | for func_name, func, asm_list, rawbytes_list, cfg, bai_feature in binary_data.extract_all(): 102 | saved_dict[func_name] = [func, asm_list, rawbytes_list, cfg, bai_feature] 103 | pickle.dump(dict(saved_dict), f) 104 | idc.qexit(0) # exit IDA -------------------------------------------------------------------------------- /datautils/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import multiprocessing 4 | import time 5 | from util.pairdata import pairdata 6 | 7 | ida_path = "./ida/idat64" 8 | work_dir = os.path.abspath('.') 9 | dataset_dir = './dataset/' 10 | strip_path = "./dataset_strip/" 11 | script_path = f"./process.py" 12 | SAVE_ROOT = "./extract" 13 | 14 | def getTarget(path, prefixfilter=None): 15 | target = [] 16 | for root, dirs, files in os.walk(path): 17 | for file in files: 18 | if prefixfilter is None: 19 | target.append(os.path.join(root, file)) 20 | else: 21 | for prefix in prefixfilter: 22 | if file.startswith(prefix): 23 | target.append(os.path.join(root, file)) 24 | return target 25 | 26 | if __name__ == '__main__': 27 | # prefixfilter = ['libcap-git-setcap'] 28 | start = time.time() 29 | target_list = getTarget(dataset_dir) 30 | 31 | pool = multiprocessing.Pool(processes=8) 32 | for target in target_list: 33 | filename = target.split('/')[-1] 34 | filename_strip = filename + '.strip' 35 | ida_input = os.path.join(strip_path, filename_strip) 36 | os.system(f"strip -s {target} -o {ida_input}") 37 | print(f"strip -s {target} -o {ida_input}") 38 | 39 | cmd_str = f'{ida_path} -Llog/{filename}.log -c -A -S{script_path} -oidb/{filename}.idb {ida_input}' 40 | print(cmd_str) 41 | cmd = [ida_path, f'-Llog/{filename}.log', '-c', '-A', f'-S{script_path}', f'-oidb/{filename}.idb', f'{ida_input}'] 42 | pool.apply_async(subprocess.call, args=(cmd,)) 43 | pool.close() 44 | pool.join() 45 | print('[*] Features Extracting Done') 46 | pairdata(SAVE_ROOT) 47 | end = time.time() 48 | print(f"[*] Time Cost: {end - start} seconds") 49 | -------------------------------------------------------------------------------- /datautils/util/base.py: -------------------------------------------------------------------------------- 1 | from elftools.elf.elffile import ELFFile 2 | import elftools.elf.elffile as elffile 3 | from elftools.elf.sections import SymbolTableSection 4 | from collections import defaultdict 5 | import os 6 | 7 | class Binarybase(object): 8 | def __init__(self, unstrip_path): 9 | self.unstrip_path = unstrip_path 10 | assert os.path.exists(unstrip_path), f'{unstrip_path} not exists' 11 | self.addr2name = self.extract_addr2name(self.unstrip_path) 12 | 13 | def get_func_name(self, name, functions): 14 | 15 | if name not in functions: 16 | return name 17 | 18 | i = 0 19 | while True: 20 | 21 | new_name = name+'_'+str(i) 22 | if new_name not in functions: 23 | return new_name 24 | 25 | i += 1 26 | 27 | def scan_section(self, functions, section): 28 | """ 29 | Function to extract function names from a shared library file. 30 | """ 31 | if not section or not isinstance(section, SymbolTableSection) or section['sh_entsize'] == 0: 32 | return 0 33 | 34 | count = 0 35 | for nsym, symbol in enumerate(section.iter_symbols()): 36 | 37 | if symbol['st_info']['type'] == 'STT_FUNC' and symbol['st_shndx'] != 'SHN_UNDEF': 38 | 39 | func = symbol.name 40 | 41 | name = self.get_func_name(func, functions) 42 | 43 | if not name in functions: 44 | 45 | functions[name] = {} 46 | 47 | functions[name]['begin'] = symbol.entry['st_value'] 48 | 49 | 50 | def extract_addr2name(self, path): 51 | ''' 52 | return: 53 | ''' 54 | functions = {} 55 | with open(path, 'rb') as stream: 56 | 57 | elffile = ELFFile(stream) 58 | 59 | self.scan_section(functions, elffile.get_section_by_name('.symtab')) 60 | 61 | self.scan_section(functions, elffile.get_section_by_name('.dynsym')) 62 | 63 | addr2name = {func['begin']: name for (name, func) in functions.items()} 64 | return defaultdict(lambda:-1, addr2name) 65 | 66 | -------------------------------------------------------------------------------- /datautils/util/pairdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from tqdm import tqdm 4 | from shutil import move 5 | import pickle 6 | from functools import reduce 7 | import networkx as nx 8 | 9 | def pairdata(data_dir): 10 | def get_prefix(path): # get proj name 11 | l = path.split('-') 12 | prefix = '-'.join(l[:-2]) 13 | return prefix.split('/')[-1] 14 | 15 | proj2file = defaultdict(list) # proj to filename list 16 | for root, dirs, files in os.walk(data_dir, topdown=False): 17 | for name in tqdm(files): 18 | pickle_path = os.path.join(root, name) 19 | prefix = get_prefix(pickle_path) 20 | proj2file[prefix].append(name) 21 | 22 | for proj, filelist in proj2file.items(): 23 | if not os.path.exists(os.path.join(data_dir, proj)): 24 | os.mkdir(os.path.join(data_dir, proj)) 25 | 26 | binary_func_list = [] 27 | pkl_list = [] 28 | for name in filelist: 29 | src = os.path.join(data_dir, name) 30 | dst = os.path.join(data_dir, proj, name) 31 | pkl = pickle.load(open(src, 'rb')) 32 | pkl_list.append(pkl) 33 | func_list = [] 34 | for func_name in pkl: 35 | func_list.append(func_name) 36 | print(name, len(func_list)) 37 | binary_func_list.append(func_list) 38 | move(src, dst) # move file into proj dir 39 | 40 | final_index = reduce(lambda x,y : set(x) & set(y), binary_func_list) 41 | print('all', len(final_index)) 42 | 43 | saved_index = defaultdict(list) 44 | for func_name in final_index: 45 | for pkl in pkl_list: 46 | saved_index[func_name].append(pkl[func_name]) 47 | 48 | saved_pickle_name = os.path.join(data_dir, proj, 'saved_index.pkl') # pari data 49 | pickle.dump(dict(saved_index), open(saved_pickle_name, 'wb')) -------------------------------------------------------------------------------- /eval_save.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, BertForMaskedLM, BertModel 2 | from tokenizer import * 3 | import pickle 4 | from torch.utils.data import DataLoader 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | from tqdm import tqdm 10 | from data import help_tokenize, load_paired_data,FunctionDataset_CL 11 | from transformers import AdamW 12 | import torch.nn.functional as F 13 | import argparse 14 | import wandb 15 | import logging 16 | import sys 17 | import time 18 | import data 19 | WANDB = True 20 | 21 | def get_logger(name): 22 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filename=name) 23 | logger = logging.getLogger(__name__) 24 | s_handle = logging.StreamHandler(sys.stdout) 25 | s_handle.setLevel(logging.INFO) 26 | s_handle.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(filename)s[:%(lineno)d] - %(message)s")) 27 | logger.addHandler(s_handle) 28 | return logger 29 | 30 | def eval(model, args, valid_set, logger): 31 | 32 | if WANDB: 33 | wandb.init(project=f'jTrans-finetune') 34 | wandb.config.update(args) 35 | logger.info("Initializing Model...") 36 | device = torch.device("cuda") 37 | model.to(device) 38 | logger.info("Finished Initialization...") 39 | valid_dataloader = DataLoader(valid_set, batch_size=args.eval_batch_size, num_workers=24, shuffle=True) 40 | global_steps = 0 41 | etc=0 42 | logger.info(f"Doing Evaluation ...") 43 | mrr = finetune_eval(model, valid_dataloader) 44 | logger.info(f"Evaluate: mrr={mrr}") 45 | if WANDB: 46 | wandb.log({ 47 | 'mrr': mrr 48 | }) 49 | 50 | def finetune_eval(net, data_loader): 51 | net.eval() 52 | print(net) 53 | with torch.no_grad(): 54 | avg=[] 55 | gt=[] 56 | cons=[] 57 | eval_iterator = tqdm(data_loader) 58 | for i, (seq1,seq2,seq3,mask1,mask2,mask3) in enumerate(eval_iterator): 59 | input_ids1, attention_mask1= seq1.cuda(),mask1.cuda() 60 | input_ids2, attention_mask2= seq2.cuda(),mask2.cuda() 61 | print(input_ids1.shape) 62 | print(attention_mask1.shape) 63 | anchor,pos=0,0 64 | 65 | output=net(input_ids=input_ids1,attention_mask=attention_mask1) 66 | #anchor=output.last_hidden_state[:,0:1,:] 67 | anchor=output.pooler_output 68 | output=net(input_ids=input_ids2,attention_mask=attention_mask2) 69 | #pos=output.last_hidden_state[:,0:1,:] 70 | pos=output.pooler_output 71 | ans=0 72 | for k in range(len(anchor)): # check every vector of (vA,vB) 73 | vA=anchor[k:k+1].cpu() 74 | sim=[] 75 | for j in range(len(pos)): 76 | vB=pos[j:j+1].cpu() 77 | #vB=vB[0] 78 | AB_sim=F.cosine_similarity(vA, vB).item() 79 | sim.append(AB_sim) 80 | if j!=k: 81 | cons.append(AB_sim) 82 | sim=np.array(sim) 83 | y=np.argsort(-sim) 84 | posi=0 85 | for j in range(len(pos)): 86 | if y[j]==k: 87 | posi=j+1 88 | 89 | gt.append(sim[k]) 90 | 91 | ans+=1/posi 92 | 93 | ans=ans/len(anchor) 94 | avg.append(ans) 95 | print("now mrr ",np.mean(np.array(avg))) 96 | fi=open("logft.txt","a") 97 | print("MRR ",np.mean(np.array(avg)),file=fi) 98 | print("FINAL MRR ",np.mean(np.array(avg))) 99 | fi.close() 100 | return np.mean(np.array(avg)) 101 | class BinBertModel(BertModel): 102 | def __init__(self, config, add_pooling_layer=True): 103 | super().__init__(config) 104 | self.config = config 105 | self.embeddings.position_embeddings=self.embeddings.word_embeddings 106 | from datautils.playdata import DatasetBase as DatasetBase 107 | 108 | if __name__ == '__main__': 109 | 110 | parser = argparse.ArgumentParser(description="jTrans-EvalSave") 111 | parser.add_argument("--model_path", type=str, default='./models/jTrans-finetune', help="Path to the model") 112 | parser.add_argument("--dataset_path", type=str, default='./BinaryCorp/small_test', help="Path to the dataset") 113 | parser.add_argument("--experiment_path", type=str, default='./experiments/BinaryCorp-3M/jTrans.pkl', help="Path to the experiment") 114 | parser.add_argument("--tokenizer", type=str, default='./jtrans_tokenizer/') 115 | 116 | args = parser.parse_args() 117 | 118 | from datetime import datetime 119 | now = datetime.now() # current date and time 120 | TIMESTAMP="%Y%m%d%H%M" 121 | tim = now.strftime(TIMESTAMP) 122 | logger = get_logger(f"jTrans-{args.model_path}-eval-{args.dataset_path}_savename_{args.experiment_path}_{tim}") 123 | logger.info(f"Loading Pretrained Model from {args.model_path} ...") 124 | model = BinBertModel.from_pretrained(args.model_path) 125 | 126 | model.eval() 127 | device = torch.device("cuda") 128 | model.to(device) 129 | 130 | logger.info("Done ...") 131 | tokenizer = BertTokenizer.from_pretrained(args.tokenizer) 132 | logger.info("Tokenizer Done ...") 133 | 134 | logger.info("Preparing Datasets ...") 135 | ft_valid_dataset=FunctionDataset_CL(tokenizer,args.dataset_path,None,True,opt=['O0', 'O1', 'O2', 'O3', 'Os'], add_ebd=True, convert_jump_addr=True) 136 | for i in tqdm(range(len(ft_valid_dataset.datas))): 137 | pairs=ft_valid_dataset.datas[i] 138 | for j in ['O0','O1','O2','O3','Os']: 139 | if ft_valid_dataset.ebds[i].get(j) is not None: 140 | idx=ft_valid_dataset.ebds[i][j] 141 | ret1=tokenizer([pairs[idx]], add_special_tokens=True,max_length=512,padding='max_length',truncation=True,return_tensors='pt') #tokenize them 142 | seq1=ret1['input_ids'] 143 | mask1=ret1['attention_mask'] 144 | input_ids1, attention_mask1= seq1.cuda(),mask1.cuda() 145 | output=model(input_ids=input_ids1,attention_mask=attention_mask1) 146 | anchor=output.pooler_output 147 | ft_valid_dataset.ebds[i][j]=anchor.detach().cpu() 148 | 149 | logger.info("ebds start writing") 150 | fi=open(args.experiment_path,'wb') 151 | pickle.dump(ft_valid_dataset.ebds,fi) 152 | fi.close() 153 | 154 | -------------------------------------------------------------------------------- /fasteval.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import sys 3 | from datautils.playdata import DatasetBase as DatasetBase 4 | from torch.utils.data import DataLoader 5 | import torch 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from tqdm import tqdm 9 | import argparse 10 | 11 | def eval_O(ebds,TYPE1,TYPE2): 12 | funcarr1=[] 13 | funcarr2=[] 14 | 15 | for i in range(len(ebds)): 16 | if ebds[i].get(TYPE1) is not None and type(ebds[i][TYPE1]) is not int: 17 | if ebds[i].get(TYPE2) is not None and type(ebds[i][TYPE2]) is not int: 18 | ebd1,ebd2=ebds[i][TYPE1],ebds[i][TYPE2] 19 | funcarr1.append(ebd1 / ebd1.norm()) 20 | funcarr2.append(ebd2 / ebd2.norm()) 21 | else: 22 | continue 23 | 24 | ft_valid_dataset=FunctionDataset_Fast(funcarr1,funcarr2) 25 | dataloader = DataLoader(ft_valid_dataset, batch_size=POOLSIZE, num_workers=24, shuffle=True) 26 | SIMS=[] 27 | Recall_AT_1=[] 28 | 29 | for idx, (anchor,pos) in enumerate(tqdm(dataloader)): 30 | anchor = anchor.cuda() 31 | pos =pos.cuda() 32 | if anchor.shape[0]==POOLSIZE: 33 | for i in range(len(anchor)): # check every vector of (vA,vB) 34 | vA=anchor[i:i+1] #pos[i] 35 | sim = np.array(torch.mm(vA, pos.T).cpu().squeeze()) 36 | y=np.argsort(-sim) 37 | posi=0 38 | for j in range(len(pos)): 39 | if y[j]==i: 40 | posi=j+1 41 | break 42 | if posi==1: 43 | Recall_AT_1.append(1) 44 | else: 45 | Recall_AT_1.append(0) 46 | SIMS.append(1.0/posi) 47 | print(TYPE1,TYPE2,'MRR{}: '.format(POOLSIZE),np.array(SIMS).mean()) 48 | print(TYPE1,TYPE2,'Recall@1: ', np.array(Recall_AT_1).mean()) 49 | return np.array(Recall_AT_1).mean() 50 | 51 | class FunctionDataset_Fast(torch.utils.data.Dataset): 52 | def __init__(self,arr1,arr2): 53 | self.arr1=arr1 54 | self.arr2=arr2 55 | assert(len(arr1)==len(arr2)) 56 | def __getitem__(self, idx): 57 | return self.arr1[idx].squeeze(0),self.arr2[idx].squeeze(0) 58 | def __len__(self): 59 | return len(self.arr1) 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser(description="jTrans-FastEval") 63 | parser.add_argument("--experiment_path", type=str, default='./experiments/BinaryCorp-3M/jTrans.pkl', help="experiment to be evaluated") 64 | parser.add_argument("--poolsize", type=int, default=32, help="size of the function pool") 65 | args = parser.parse_args() 66 | 67 | POOLSIZE=args.poolsize 68 | ff=open(args.experiment_path,'rb') 69 | ebds=pickle.load(ff) 70 | ff.close() 71 | 72 | print(f'evaluating...poolsize={POOLSIZE}') 73 | 74 | eval_O(ebds,'O0','O3') 75 | eval_O(ebds,'O0','Os') 76 | eval_O(ebds,'O1','Os') 77 | eval_O(ebds,'O1','O3') 78 | eval_O(ebds,'O2','Os') 79 | eval_O(ebds,'O2','O3') -------------------------------------------------------------------------------- /figures/poolsizecompare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vul337/jTrans/1d405156865e3c582e4f183cbaa623b700ea0e9e/figures/poolsizecompare.png -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | from unicodedata import name 2 | from transformers import BertTokenizer, BertForMaskedLM, BertModel 3 | import torch.multiprocessing 4 | from torch.utils.data import DataLoader 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | from tqdm import tqdm 10 | from data import load_paired_data, FunctionDataset_CL, FunctionDataset_CL_Load 11 | from transformers import AdamW 12 | import torch.nn.functional as F 13 | import argparse 14 | import wandb 15 | import logging 16 | import sys 17 | import time 18 | import data 19 | import pickle 20 | WANDB = True 21 | 22 | def get_logger(name): 23 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filename=name) 24 | logger = logging.getLogger(__name__) 25 | s_handle = logging.StreamHandler(sys.stdout) 26 | s_handle.setLevel(logging.INFO) 27 | s_handle.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(filename)s[:%(lineno)d] - %(message)s")) 28 | logger.addHandler(s_handle) 29 | return logger 30 | 31 | def train_dp(model, args, train_set, valid_set, logger): 32 | 33 | class Triplet_COS_Loss(nn.Module): 34 | def __init__(self,margin): 35 | super(Triplet_COS_Loss, self).__init__() 36 | self.margin=margin 37 | 38 | def forward(self, repr, good_code_repr, bad_code_repr): 39 | good_sim=F.cosine_similarity(repr, good_code_repr) 40 | bad_sim=F.cosine_similarity(repr, bad_code_repr) 41 | #print("simm ",good_sim.shape) 42 | loss=(self.margin-(good_sim-bad_sim)).clamp(min=1e-6).mean() 43 | return loss 44 | 45 | if WANDB: 46 | wandb.init(project=f'jTrans-finetune', name="jTrans_Freeze_10_Train_Test") 47 | wandb.config.update(args) 48 | 49 | logger.info("Initializing Model...") 50 | device = torch.device("cuda") 51 | model.to(device) 52 | logger.info("Finished Initialization...") 53 | train_dataloader = DataLoader(train_set, batch_size=args.batch_size, num_workers=48, shuffle=True, prefetch_factor=4) 54 | valid_dataloader = DataLoader(valid_set, batch_size=args.eval_batch_size, num_workers=48, shuffle=True, prefetch_factor=4) 55 | 56 | no_decay = ["bias", "LayerNorm.weight"] 57 | optimizer_grouped_parameters = [] 58 | 59 | optimizer_grouped_parameters.extend( 60 | [ 61 | { 62 | "params": [ 63 | p 64 | for n, p in model.named_parameters() 65 | if not any(nd in n for nd in no_decay) 66 | ], 67 | "weight_decay": args.weight_decay, 68 | }, 69 | { 70 | "params": [ 71 | p 72 | for n, p in model.named_parameters() 73 | if any(nd in n for nd in no_decay) 74 | ], 75 | "weight_decay": 0.0, 76 | }, 77 | ] 78 | ) 79 | 80 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.lr) 81 | 82 | model = nn.DataParallel(model) 83 | global_steps = 0 84 | etc=0 85 | for epoch in range(args.epoch): 86 | model.train() 87 | triplet_loss=Triplet_COS_Loss(margin=0.2) 88 | train_iterator = tqdm(train_dataloader) 89 | loss_list = [] 90 | for i, (seq1,seq2,seq3,mask1,mask2,mask3) in enumerate(train_iterator): 91 | t1=time.time() 92 | input_ids1, attention_mask1 = seq1.cuda(),mask1.cuda() 93 | input_ids2, attention_mask2 = seq2.cuda(),mask2.cuda() 94 | input_ids3, attention_mask3 = seq3.cuda(),mask3.cuda() 95 | 96 | optimizer.zero_grad() 97 | anchor,pos,neg=0,0,0 98 | 99 | output1 = model(input_ids=input_ids1, attention_mask=attention_mask1) 100 | anchor = output1.pooler_output 101 | 102 | output2 = model(input_ids=input_ids2, attention_mask=attention_mask2) 103 | pos = output2.pooler_output 104 | 105 | output3 = model(input_ids=input_ids3, attention_mask=attention_mask3) 106 | neg = output3.pooler_output 107 | 108 | optimizer.zero_grad() 109 | loss = triplet_loss(anchor, pos, neg) 110 | 111 | loss.backward() 112 | loss_list.append(loss) 113 | 114 | optimizer.step() 115 | if (i+1) % args.log_every == 0: 116 | global_steps += 1 117 | tmp_lr = optimizer.param_groups[0]["lr"] 118 | # logger.info(f"[*] epoch: [{epoch}/{args.epoch+1}], steps: [{i}/{len(train_iterator)}], lr={tmp_lr}, loss={loss}") 119 | train_iterator.set_description(f"[*] epoch: [{epoch}/{args.epoch+1}], steps: [{i}/{len(train_iterator)}], lr={tmp_lr}, loss={loss}") 120 | if WANDB: 121 | wandb.log({ 122 | 'triplet loss' : loss, 123 | 'lr' : tmp_lr, 124 | 'global_step' : global_steps, 125 | }) 126 | 127 | if (epoch+1) % args.eval_every == 0: 128 | logger.info(f"Doing Evaluation ...") 129 | mrr = finetune_eval(model, valid_dataloader) 130 | logger.info(f"[*] epoch: [{epoch}/{args.epoch+1}], mrr={mrr}") 131 | if WANDB: 132 | wandb.log({ 133 | 'mrr': mrr 134 | }) 135 | if (epoch+1) % args.save_every == 0: 136 | logger.info(f"Saving Model ...") 137 | model.module.save_pretrained(os.path.join(args.output_path, f"finetune_epoch_{epoch+1}")) 138 | logger.info(f"Done") 139 | 140 | 141 | def finetune_eval(net, data_loader): 142 | net.eval() 143 | with torch.no_grad(): 144 | avg=[] 145 | gt=[] 146 | cons=[] 147 | eval_iterator = tqdm(data_loader) 148 | for i, (seq1,seq2,_,mask1,mask2,_) in enumerate(eval_iterator): 149 | input_ids1, attention_mask1= seq1.cuda(),mask1.cuda() 150 | input_ids2, attention_mask2= seq2.cuda(),mask2.cuda() 151 | 152 | anchor,pos=0,0 153 | 154 | output1 = model(input_ids=input_ids1, attention_mask=attention_mask1) 155 | anchor = output1.pooler_output 156 | 157 | output2 = model(input_ids=input_ids2, attention_mask=attention_mask2) 158 | pos = output2.pooler_output 159 | 160 | ans=0 161 | for i in range(len(anchor)): # check every vector of (vA,vB) 162 | vA=anchor[i:i+1].cpu() #pos[i] 163 | sim=[] 164 | for j in range(len(pos)): 165 | vB=pos[j:j+1].cpu() # pos[j] 166 | AB_sim=F.cosine_similarity(vA, vB).item() 167 | sim.append(AB_sim) 168 | if j!=i: 169 | cons.append(AB_sim) 170 | sim=np.array(sim) 171 | y=np.argsort(-sim) 172 | posi=0 173 | for j in range(len(pos)): 174 | if y[j]==i: 175 | posi=j+1 176 | 177 | gt.append(sim[i]) 178 | 179 | ans+=1/posi 180 | 181 | ans=ans/len(anchor) 182 | avg.append(ans) 183 | return np.mean(np.array(avg)) 184 | 185 | class BinBertModel(BertModel): 186 | def __init__(self, config, add_pooling_layer=True): 187 | super().__init__(config) 188 | self.config = config 189 | self.embeddings.position_embeddings=self.embeddings.word_embeddings 190 | 191 | if __name__ == '__main__': 192 | torch.multiprocessing.set_sharing_strategy('file_system') 193 | parser = argparse.ArgumentParser(description="jTrans-Finetune") 194 | parser.add_argument("--model_path", type=str, default='./models/jTrans-pretrain', help='the path of pretrain model') 195 | parser.add_argument("--output_path", type=str, default='./models/jTrans-finetune', help='the path where the finetune model be saved') 196 | parser.add_argument("--tokenizer", type=str, default='./jtrans_tokenizer', help='the path of tokenizer') 197 | parser.add_argument("--epoch", type=int, default=10, help='number of training epochs') 198 | parser.add_argument("--lr", type=float, default=1e-5, help='learning rate') 199 | parser.add_argument("--warmup", type=int, default=1000, help='warmup steps') 200 | parser.add_argument("--step_size", type=int, default=40000, help='scheduler step size') 201 | parser.add_argument("--gamma", type=float, default=0.99, help='scheduler gamma') 202 | parser.add_argument("--batch_size", type=int, default = 64, help='training batch size') 203 | parser.add_argument("--eval_batch_size", type=int, default = 256, help='evaluation batch size') 204 | parser.add_argument("--log_every", type=int, default =1, help='logging frequency') 205 | parser.add_argument("--local_rank", type=int, default = 0, help='local rank used for ddp') 206 | parser.add_argument("--freeze_cnt", type=int, default=10, help='number of layers to freeze') 207 | parser.add_argument("--weight_decay", type=int, default = 1e-4, help='regularization weight decay') 208 | parser.add_argument("--eval_every", type=int, default=1, help="evaluate the model every x epochs") 209 | parser.add_argument("--eval_every_step", type=int, default=1000, help="evaluate the model every x epochs") 210 | parser.add_argument("--save_every", type=int, default=1, help="save the model every x epochs") 211 | parser.add_argument("--train_path", type=str, default='./BinaryCorp/small_train', help='the path of training data') 212 | parser.add_argument("--eval_path", type=str, default='./BinaryCorp/small_test', help='the path of evaluation data') 213 | parser.add_argument("--load_path", type=str, default='./experiments/BinaryCorp-3M/', help='load path') 214 | 215 | args = parser.parse_args() 216 | 217 | from datetime import datetime 218 | now = datetime.now() # current date and time 219 | TIMESTAMP="%Y%m%d%H%M" 220 | tim = now.strftime(TIMESTAMP) 221 | logger = get_logger(f"jTrans_{args.lr}_batchsize_{args.batch_size}_weight_decay_{args.weight_decay}_{tim}") 222 | 223 | logger.info(f"Loading Pretrained Model from {args.model_path} ...") 224 | model = BinBertModel.from_pretrained(args.model_path) 225 | 226 | freeze_layer_count = args.freeze_cnt 227 | for param in model.embeddings.parameters(): 228 | param.requires_grad = False 229 | 230 | if freeze_layer_count != -1: 231 | for layer in model.encoder.layer[:freeze_layer_count]: 232 | for param in layer.parameters(): 233 | param.requires_grad = False 234 | print(model) 235 | 236 | logger.info("Done ...") 237 | tokenizer = BertTokenizer.from_pretrained(args.tokenizer) 238 | logger.info("Tokenizer Done ...") 239 | 240 | load_train, load_test = False, False 241 | # load_train = f"{args.load_path}/jTrans-{args.train_path.split('/')[-1]}.pkl" 242 | # load_test = f"{args.load_path}/jTrans-{args.eval_path.split('/')[-1]}.pkl" 243 | ft_train_dataset= FunctionDataset_CL_Load(tokenizer,args.train_path,convert_jump_addr=True, load=load_train, opt=['O0','O1','O2','O3','Os']) 244 | ft_valid_dataset=FunctionDataset_CL_Load(tokenizer,args.eval_path,convert_jump_addr=True, load=load_test, opt=['O0','O1','O2','O3','Os']) 245 | if not load_train: 246 | pickle.dump(ft_train_dataset.datas, open(f"{args.load_path}/jTrans-{args.train_path.split('/')[-1]}.pkl", 'wb')) 247 | pickle.dump(ft_valid_dataset.datas, open(f"{args.load_path}/jTrans-{args.eval_path.split('/')[-1]}.pkl", 'wb')) 248 | logger.info("Done ...") 249 | train_dp(model, args, ft_train_dataset, ft_valid_dataset, logger) 250 | logger.info("Finished Training") 251 | 252 | -------------------------------------------------------------------------------- /jtrans_tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"} -------------------------------------------------------------------------------- /jtrans_tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"do_lower_case": false, "do_basic_tokenize": false, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": false, "strip_accents": null, "tokenizer_class": "BertTokenizer"} -------------------------------------------------------------------------------- /jtrans_tokenizer/vocab.txt: -------------------------------------------------------------------------------- 1 | JUMP_ADDR_0 2 | JUMP_ADDR_1 3 | JUMP_ADDR_2 4 | JUMP_ADDR_3 5 | JUMP_ADDR_4 6 | JUMP_ADDR_5 7 | JUMP_ADDR_6 8 | JUMP_ADDR_7 9 | JUMP_ADDR_8 10 | JUMP_ADDR_9 11 | JUMP_ADDR_10 12 | JUMP_ADDR_11 13 | JUMP_ADDR_12 14 | JUMP_ADDR_13 15 | JUMP_ADDR_14 16 | JUMP_ADDR_15 17 | JUMP_ADDR_16 18 | JUMP_ADDR_17 19 | JUMP_ADDR_18 20 | JUMP_ADDR_19 21 | JUMP_ADDR_20 22 | JUMP_ADDR_21 23 | JUMP_ADDR_22 24 | JUMP_ADDR_23 25 | JUMP_ADDR_24 26 | JUMP_ADDR_25 27 | JUMP_ADDR_26 28 | JUMP_ADDR_27 29 | JUMP_ADDR_28 30 | JUMP_ADDR_29 31 | JUMP_ADDR_30 32 | JUMP_ADDR_31 33 | JUMP_ADDR_32 34 | JUMP_ADDR_33 35 | JUMP_ADDR_34 36 | JUMP_ADDR_35 37 | JUMP_ADDR_36 38 | JUMP_ADDR_37 39 | JUMP_ADDR_38 40 | JUMP_ADDR_39 41 | JUMP_ADDR_40 42 | JUMP_ADDR_41 43 | JUMP_ADDR_42 44 | JUMP_ADDR_43 45 | JUMP_ADDR_44 46 | JUMP_ADDR_45 47 | JUMP_ADDR_46 48 | JUMP_ADDR_47 49 | JUMP_ADDR_48 50 | JUMP_ADDR_49 51 | JUMP_ADDR_50 52 | JUMP_ADDR_51 53 | JUMP_ADDR_52 54 | JUMP_ADDR_53 55 | JUMP_ADDR_54 56 | JUMP_ADDR_55 57 | JUMP_ADDR_56 58 | JUMP_ADDR_57 59 | JUMP_ADDR_58 60 | JUMP_ADDR_59 61 | JUMP_ADDR_60 62 | JUMP_ADDR_61 63 | JUMP_ADDR_62 64 | JUMP_ADDR_63 65 | JUMP_ADDR_64 66 | JUMP_ADDR_65 67 | JUMP_ADDR_66 68 | JUMP_ADDR_67 69 | JUMP_ADDR_68 70 | JUMP_ADDR_69 71 | JUMP_ADDR_70 72 | JUMP_ADDR_71 73 | JUMP_ADDR_72 74 | JUMP_ADDR_73 75 | JUMP_ADDR_74 76 | JUMP_ADDR_75 77 | JUMP_ADDR_76 78 | JUMP_ADDR_77 79 | JUMP_ADDR_78 80 | JUMP_ADDR_79 81 | JUMP_ADDR_80 82 | JUMP_ADDR_81 83 | JUMP_ADDR_82 84 | JUMP_ADDR_83 85 | JUMP_ADDR_84 86 | JUMP_ADDR_85 87 | JUMP_ADDR_86 88 | JUMP_ADDR_87 89 | JUMP_ADDR_88 90 | JUMP_ADDR_89 91 | JUMP_ADDR_90 92 | JUMP_ADDR_91 93 | JUMP_ADDR_92 94 | JUMP_ADDR_93 95 | JUMP_ADDR_94 96 | JUMP_ADDR_95 97 | JUMP_ADDR_96 98 | JUMP_ADDR_97 99 | JUMP_ADDR_98 100 | JUMP_ADDR_99 101 | JUMP_ADDR_100 102 | JUMP_ADDR_101 103 | JUMP_ADDR_102 104 | JUMP_ADDR_103 105 | JUMP_ADDR_104 106 | JUMP_ADDR_105 107 | JUMP_ADDR_106 108 | JUMP_ADDR_107 109 | JUMP_ADDR_108 110 | JUMP_ADDR_109 111 | JUMP_ADDR_110 112 | JUMP_ADDR_111 113 | JUMP_ADDR_112 114 | JUMP_ADDR_113 115 | JUMP_ADDR_114 116 | JUMP_ADDR_115 117 | JUMP_ADDR_116 118 | JUMP_ADDR_117 119 | JUMP_ADDR_118 120 | JUMP_ADDR_119 121 | JUMP_ADDR_120 122 | JUMP_ADDR_121 123 | JUMP_ADDR_122 124 | JUMP_ADDR_123 125 | JUMP_ADDR_124 126 | JUMP_ADDR_125 127 | JUMP_ADDR_126 128 | JUMP_ADDR_127 129 | JUMP_ADDR_128 130 | JUMP_ADDR_129 131 | JUMP_ADDR_130 132 | JUMP_ADDR_131 133 | JUMP_ADDR_132 134 | JUMP_ADDR_133 135 | JUMP_ADDR_134 136 | JUMP_ADDR_135 137 | JUMP_ADDR_136 138 | JUMP_ADDR_137 139 | JUMP_ADDR_138 140 | JUMP_ADDR_139 141 | JUMP_ADDR_140 142 | JUMP_ADDR_141 143 | JUMP_ADDR_142 144 | JUMP_ADDR_143 145 | JUMP_ADDR_144 146 | JUMP_ADDR_145 147 | JUMP_ADDR_146 148 | JUMP_ADDR_147 149 | JUMP_ADDR_148 150 | JUMP_ADDR_149 151 | JUMP_ADDR_150 152 | JUMP_ADDR_151 153 | JUMP_ADDR_152 154 | JUMP_ADDR_153 155 | JUMP_ADDR_154 156 | JUMP_ADDR_155 157 | JUMP_ADDR_156 158 | JUMP_ADDR_157 159 | JUMP_ADDR_158 160 | JUMP_ADDR_159 161 | JUMP_ADDR_160 162 | JUMP_ADDR_161 163 | JUMP_ADDR_162 164 | JUMP_ADDR_163 165 | JUMP_ADDR_164 166 | JUMP_ADDR_165 167 | JUMP_ADDR_166 168 | JUMP_ADDR_167 169 | JUMP_ADDR_168 170 | JUMP_ADDR_169 171 | JUMP_ADDR_170 172 | JUMP_ADDR_171 173 | JUMP_ADDR_172 174 | JUMP_ADDR_173 175 | JUMP_ADDR_174 176 | JUMP_ADDR_175 177 | JUMP_ADDR_176 178 | JUMP_ADDR_177 179 | JUMP_ADDR_178 180 | JUMP_ADDR_179 181 | JUMP_ADDR_180 182 | JUMP_ADDR_181 183 | JUMP_ADDR_182 184 | JUMP_ADDR_183 185 | JUMP_ADDR_184 186 | JUMP_ADDR_185 187 | JUMP_ADDR_186 188 | JUMP_ADDR_187 189 | JUMP_ADDR_188 190 | JUMP_ADDR_189 191 | JUMP_ADDR_190 192 | JUMP_ADDR_191 193 | JUMP_ADDR_192 194 | JUMP_ADDR_193 195 | JUMP_ADDR_194 196 | JUMP_ADDR_195 197 | JUMP_ADDR_196 198 | JUMP_ADDR_197 199 | JUMP_ADDR_198 200 | JUMP_ADDR_199 201 | JUMP_ADDR_200 202 | JUMP_ADDR_201 203 | JUMP_ADDR_202 204 | JUMP_ADDR_203 205 | JUMP_ADDR_204 206 | JUMP_ADDR_205 207 | JUMP_ADDR_206 208 | JUMP_ADDR_207 209 | JUMP_ADDR_208 210 | JUMP_ADDR_209 211 | JUMP_ADDR_210 212 | JUMP_ADDR_211 213 | JUMP_ADDR_212 214 | JUMP_ADDR_213 215 | JUMP_ADDR_214 216 | JUMP_ADDR_215 217 | JUMP_ADDR_216 218 | JUMP_ADDR_217 219 | JUMP_ADDR_218 220 | JUMP_ADDR_219 221 | JUMP_ADDR_220 222 | JUMP_ADDR_221 223 | JUMP_ADDR_222 224 | JUMP_ADDR_223 225 | JUMP_ADDR_224 226 | JUMP_ADDR_225 227 | JUMP_ADDR_226 228 | JUMP_ADDR_227 229 | JUMP_ADDR_228 230 | JUMP_ADDR_229 231 | JUMP_ADDR_230 232 | JUMP_ADDR_231 233 | JUMP_ADDR_232 234 | JUMP_ADDR_233 235 | JUMP_ADDR_234 236 | JUMP_ADDR_235 237 | JUMP_ADDR_236 238 | JUMP_ADDR_237 239 | JUMP_ADDR_238 240 | JUMP_ADDR_239 241 | JUMP_ADDR_240 242 | JUMP_ADDR_241 243 | JUMP_ADDR_242 244 | JUMP_ADDR_243 245 | JUMP_ADDR_244 246 | JUMP_ADDR_245 247 | JUMP_ADDR_246 248 | JUMP_ADDR_247 249 | JUMP_ADDR_248 250 | JUMP_ADDR_249 251 | JUMP_ADDR_250 252 | JUMP_ADDR_251 253 | JUMP_ADDR_252 254 | JUMP_ADDR_253 255 | JUMP_ADDR_254 256 | JUMP_ADDR_255 257 | JUMP_ADDR_256 258 | JUMP_ADDR_257 259 | JUMP_ADDR_258 260 | JUMP_ADDR_259 261 | JUMP_ADDR_260 262 | JUMP_ADDR_261 263 | JUMP_ADDR_262 264 | JUMP_ADDR_263 265 | JUMP_ADDR_264 266 | JUMP_ADDR_265 267 | JUMP_ADDR_266 268 | JUMP_ADDR_267 269 | JUMP_ADDR_268 270 | JUMP_ADDR_269 271 | JUMP_ADDR_270 272 | JUMP_ADDR_271 273 | JUMP_ADDR_272 274 | JUMP_ADDR_273 275 | JUMP_ADDR_274 276 | JUMP_ADDR_275 277 | JUMP_ADDR_276 278 | JUMP_ADDR_277 279 | JUMP_ADDR_278 280 | JUMP_ADDR_279 281 | JUMP_ADDR_280 282 | JUMP_ADDR_281 283 | JUMP_ADDR_282 284 | JUMP_ADDR_283 285 | JUMP_ADDR_284 286 | JUMP_ADDR_285 287 | JUMP_ADDR_286 288 | JUMP_ADDR_287 289 | JUMP_ADDR_288 290 | JUMP_ADDR_289 291 | JUMP_ADDR_290 292 | JUMP_ADDR_291 293 | JUMP_ADDR_292 294 | JUMP_ADDR_293 295 | JUMP_ADDR_294 296 | JUMP_ADDR_295 297 | JUMP_ADDR_296 298 | JUMP_ADDR_297 299 | JUMP_ADDR_298 300 | JUMP_ADDR_299 301 | JUMP_ADDR_300 302 | JUMP_ADDR_301 303 | JUMP_ADDR_302 304 | JUMP_ADDR_303 305 | JUMP_ADDR_304 306 | JUMP_ADDR_305 307 | JUMP_ADDR_306 308 | JUMP_ADDR_307 309 | JUMP_ADDR_308 310 | JUMP_ADDR_309 311 | JUMP_ADDR_310 312 | JUMP_ADDR_311 313 | JUMP_ADDR_312 314 | JUMP_ADDR_313 315 | JUMP_ADDR_314 316 | JUMP_ADDR_315 317 | JUMP_ADDR_316 318 | JUMP_ADDR_317 319 | JUMP_ADDR_318 320 | JUMP_ADDR_319 321 | JUMP_ADDR_320 322 | JUMP_ADDR_321 323 | JUMP_ADDR_322 324 | JUMP_ADDR_323 325 | JUMP_ADDR_324 326 | JUMP_ADDR_325 327 | JUMP_ADDR_326 328 | JUMP_ADDR_327 329 | JUMP_ADDR_328 330 | JUMP_ADDR_329 331 | JUMP_ADDR_330 332 | JUMP_ADDR_331 333 | JUMP_ADDR_332 334 | JUMP_ADDR_333 335 | JUMP_ADDR_334 336 | JUMP_ADDR_335 337 | JUMP_ADDR_336 338 | JUMP_ADDR_337 339 | JUMP_ADDR_338 340 | JUMP_ADDR_339 341 | JUMP_ADDR_340 342 | JUMP_ADDR_341 343 | JUMP_ADDR_342 344 | JUMP_ADDR_343 345 | JUMP_ADDR_344 346 | JUMP_ADDR_345 347 | JUMP_ADDR_346 348 | JUMP_ADDR_347 349 | JUMP_ADDR_348 350 | JUMP_ADDR_349 351 | JUMP_ADDR_350 352 | JUMP_ADDR_351 353 | JUMP_ADDR_352 354 | JUMP_ADDR_353 355 | JUMP_ADDR_354 356 | JUMP_ADDR_355 357 | JUMP_ADDR_356 358 | JUMP_ADDR_357 359 | JUMP_ADDR_358 360 | JUMP_ADDR_359 361 | JUMP_ADDR_360 362 | JUMP_ADDR_361 363 | JUMP_ADDR_362 364 | JUMP_ADDR_363 365 | JUMP_ADDR_364 366 | JUMP_ADDR_365 367 | JUMP_ADDR_366 368 | JUMP_ADDR_367 369 | JUMP_ADDR_368 370 | JUMP_ADDR_369 371 | JUMP_ADDR_370 372 | JUMP_ADDR_371 373 | JUMP_ADDR_372 374 | JUMP_ADDR_373 375 | JUMP_ADDR_374 376 | JUMP_ADDR_375 377 | JUMP_ADDR_376 378 | JUMP_ADDR_377 379 | JUMP_ADDR_378 380 | JUMP_ADDR_379 381 | JUMP_ADDR_380 382 | JUMP_ADDR_381 383 | JUMP_ADDR_382 384 | JUMP_ADDR_383 385 | JUMP_ADDR_384 386 | JUMP_ADDR_385 387 | JUMP_ADDR_386 388 | JUMP_ADDR_387 389 | JUMP_ADDR_388 390 | JUMP_ADDR_389 391 | JUMP_ADDR_390 392 | JUMP_ADDR_391 393 | JUMP_ADDR_392 394 | JUMP_ADDR_393 395 | JUMP_ADDR_394 396 | JUMP_ADDR_395 397 | JUMP_ADDR_396 398 | JUMP_ADDR_397 399 | JUMP_ADDR_398 400 | JUMP_ADDR_399 401 | JUMP_ADDR_400 402 | JUMP_ADDR_401 403 | JUMP_ADDR_402 404 | JUMP_ADDR_403 405 | JUMP_ADDR_404 406 | JUMP_ADDR_405 407 | JUMP_ADDR_406 408 | JUMP_ADDR_407 409 | JUMP_ADDR_408 410 | JUMP_ADDR_409 411 | JUMP_ADDR_410 412 | JUMP_ADDR_411 413 | JUMP_ADDR_412 414 | JUMP_ADDR_413 415 | JUMP_ADDR_414 416 | JUMP_ADDR_415 417 | JUMP_ADDR_416 418 | JUMP_ADDR_417 419 | JUMP_ADDR_418 420 | JUMP_ADDR_419 421 | JUMP_ADDR_420 422 | JUMP_ADDR_421 423 | JUMP_ADDR_422 424 | JUMP_ADDR_423 425 | JUMP_ADDR_424 426 | JUMP_ADDR_425 427 | JUMP_ADDR_426 428 | JUMP_ADDR_427 429 | JUMP_ADDR_428 430 | JUMP_ADDR_429 431 | JUMP_ADDR_430 432 | JUMP_ADDR_431 433 | JUMP_ADDR_432 434 | JUMP_ADDR_433 435 | JUMP_ADDR_434 436 | JUMP_ADDR_435 437 | JUMP_ADDR_436 438 | JUMP_ADDR_437 439 | JUMP_ADDR_438 440 | JUMP_ADDR_439 441 | JUMP_ADDR_440 442 | JUMP_ADDR_441 443 | JUMP_ADDR_442 444 | JUMP_ADDR_443 445 | JUMP_ADDR_444 446 | JUMP_ADDR_445 447 | JUMP_ADDR_446 448 | JUMP_ADDR_447 449 | JUMP_ADDR_448 450 | JUMP_ADDR_449 451 | JUMP_ADDR_450 452 | JUMP_ADDR_451 453 | JUMP_ADDR_452 454 | JUMP_ADDR_453 455 | JUMP_ADDR_454 456 | JUMP_ADDR_455 457 | JUMP_ADDR_456 458 | JUMP_ADDR_457 459 | JUMP_ADDR_458 460 | JUMP_ADDR_459 461 | JUMP_ADDR_460 462 | JUMP_ADDR_461 463 | JUMP_ADDR_462 464 | JUMP_ADDR_463 465 | JUMP_ADDR_464 466 | JUMP_ADDR_465 467 | JUMP_ADDR_466 468 | JUMP_ADDR_467 469 | JUMP_ADDR_468 470 | JUMP_ADDR_469 471 | JUMP_ADDR_470 472 | JUMP_ADDR_471 473 | JUMP_ADDR_472 474 | JUMP_ADDR_473 475 | JUMP_ADDR_474 476 | JUMP_ADDR_475 477 | JUMP_ADDR_476 478 | JUMP_ADDR_477 479 | JUMP_ADDR_478 480 | JUMP_ADDR_479 481 | JUMP_ADDR_480 482 | JUMP_ADDR_481 483 | JUMP_ADDR_482 484 | JUMP_ADDR_483 485 | JUMP_ADDR_484 486 | JUMP_ADDR_485 487 | JUMP_ADDR_486 488 | JUMP_ADDR_487 489 | JUMP_ADDR_488 490 | JUMP_ADDR_489 491 | JUMP_ADDR_490 492 | JUMP_ADDR_491 493 | JUMP_ADDR_492 494 | JUMP_ADDR_493 495 | JUMP_ADDR_494 496 | JUMP_ADDR_495 497 | JUMP_ADDR_496 498 | JUMP_ADDR_497 499 | JUMP_ADDR_498 500 | JUMP_ADDR_499 501 | JUMP_ADDR_500 502 | JUMP_ADDR_501 503 | JUMP_ADDR_502 504 | JUMP_ADDR_503 505 | JUMP_ADDR_504 506 | JUMP_ADDR_505 507 | JUMP_ADDR_506 508 | JUMP_ADDR_507 509 | JUMP_ADDR_508 510 | JUMP_ADDR_509 511 | JUMP_ADDR_510 512 | JUMP_ADDR_511 513 | [UNK] 514 | endbr64 515 | sub 516 | rsp 517 | CONST 518 | mov 519 | rax 520 | cs:xxx 521 | test 522 | jz 523 | call 524 | add 525 | retn 526 | push 527 | jmp 528 | rdi 529 | [rbx+CONST] 530 | callfunc_xxx 531 | rbp 532 | jnz 533 | movapd 534 | xmm0 535 | movups 536 | [r12] 537 | [r12+CONST] 538 | [rsp+arg_xxx] 539 | cmp 540 | r15 541 | r12 542 | rbx 543 | fs:xxx 544 | [rsp+CONST+var_xxx] 545 | xor 546 | eax 547 | lea 548 | [CONST_VAR+CONST] 549 | [CONST_VAR] 550 | rsi 551 | jbe 552 | edx 553 | [rbx] 554 | dl 555 | [rax] 556 | rcx 557 | rep 558 | rdx 559 | [rdx+rax] 560 | GLOBAL_VAR 561 | pop 562 | [rbp+CONST] 563 | edi 564 | r13 565 | [rax+CONST] 566 | sub_xxx 567 | r14 568 | [rsp+CONST] 569 | [rsp+CONST_VAR] 570 | r8 571 | [rsp+rax+arg_xxx] 572 | movzx 573 | [r14] 574 | al 575 | [r13+CONST] 576 | ebp 577 | dil 578 | movsx 579 | bpl 580 | off_xxx 581 | ebx 582 | imul 583 | [r14+CONST] 584 | dec 585 | ja 586 | UNK_ADDR 587 | movsxd 588 | ds:xxx 589 | sil 590 | [rax+rbp+CONST] 591 | esi 592 | or 593 | lock 594 | nop 595 | loc_xxx 596 | unk_xxx 597 | shr 598 | sar 599 | pxor 600 | movdqu 601 | [rdx] 602 | setz 603 | setnz 604 | xmm1 605 | [rax+rax+CONST] 606 | jle 607 | [rax+rcx*8] 608 | jge 609 | jg 610 | jb 611 | [rdx+CONST] 612 | movdqa 613 | xmm2 614 | shl 615 | paddd 616 | and 617 | [rcx+rax*4] 618 | [rcx+CONST] 619 | [rcx] 620 | [rdx+rbx*8] 621 | [rdx+r13*8] 622 | movd 623 | xmm3 624 | ecx 625 | pshufd 626 | [rdx+rax*4] 627 | [rcx+rsi*4] 628 | [rdx+rsi*4] 629 | [rax+r13*8] 630 | jnb 631 | [rdx+rsi*8] 632 | [r8+CONST] 633 | movq 634 | [rax+rbx] 635 | punpcklqdq 636 | [rdx+rbx*4] 637 | [r14+rbx*4] 638 | r14d 639 | r15d 640 | cmovnz 641 | setb 642 | cl 643 | [r15] 644 | [CONST_VAR+rdx+CONST] 645 | [CONST_VAR+rdx] 646 | [r15+r12] 647 | movaps 648 | cmova 649 | cmovnb 650 | r9 651 | [r15+rbx*4] 652 | [rcx+rbx*4] 653 | [r15+rbx*8] 654 | [rcx+rbx*8] 655 | movsd 656 | xmm5 657 | xmm4 658 | comisd 659 | maxsd 660 | xmm6 661 | minsd 662 | xmm7 663 | movhpd 664 | [rsp+var_xxx] 665 | [rsp+var_xxx+CONST] 666 | [rbp+var_xxx+CONST] 667 | [rbp+var_xxx] 668 | movupd 669 | [rax+r12] 670 | [r12+rax+CONST] 671 | [rbp+rax+CONST] 672 | [CONST_VAR+r12] 673 | [rsp+CONST+arg_xxx] 674 | [rsp+CONST+var_xxx+CONST] 675 | psubq 676 | [r15+CONST] 677 | cmovg 678 | r8d 679 | [rax+r14] 680 | [CONST_VAR+r14] 681 | xchg 682 | ax 683 | jns 684 | js 685 | [rsp+CONST+CONST_VAR] 686 | [r9] 687 | cdqe 688 | r12d 689 | bl 690 | r13d 691 | [r13+r14+CONST] 692 | [rax+rax*2] 693 | movss 694 | unpcklps 695 | movlps 696 | [r14+r13] 697 | [r14+r13+CONST] 698 | [r8] 699 | jl 700 | [CONST_VAR+r14+CONST] 701 | r10 702 | cmovb 703 | r11 704 | [r14+r10] 705 | [r11+r12] 706 | [rax+r13] 707 | [rax+rcx] 708 | [CONST_VAR+rax] 709 | [CONST_VAR+r13] 710 | cmovbe 711 | [rcx+rcx*2] 712 | [rax+rcx*4] 713 | [rbx+rbx*2] 714 | [r15+rax*4] 715 | [CONST_VAR+rax*4] 716 | [CONST_VAR+rax*4+CONST] 717 | [r13+rbp+CONST] 718 | [r15+rbp] 719 | comiss 720 | [rsp+rdx*8+CONST+var_xxx] 721 | movhlps 722 | movlhps 723 | [rax+r15] 724 | [CONST_VAR+r15] 725 | sbb 726 | [rdx+r12] 727 | [rdx+rbp] 728 | [r8+rdx] 729 | [r13+rax+CONST] 730 | [r12+rax] 731 | [rbp+rdx+CONST] 732 | [r9+rdx] 733 | unpcklpd 734 | unpckhpd 735 | not 736 | ucomisd 737 | setbe 738 | pandn 739 | andpd 740 | andnpd 741 | orpd 742 | cmplepd 743 | movhps 744 | [rdx+rcx] 745 | [r14+rbp] 746 | [r9+CONST] 747 | [rdx+rax*8] 748 | [rdx+r13] 749 | [r10+CONST] 750 | cmovns 751 | [r11+rax*8] 752 | r10d 753 | [r11] 754 | [rbp+rdx*8+CONST] 755 | cqo 756 | [rax+rdx] 757 | [rcx+rdx*8] 758 | [r15+r14] 759 | r11d 760 | [r10] 761 | [CONST_VAR+rcx*8] 762 | r9d 763 | cmovz 764 | cmovl 765 | setl 766 | setnle 767 | [rsp+CONST+CONST_VAR+CONST] 768 | [rcx+rax] 769 | [rcx+rax+CONST] 770 | [rax+rax] 771 | cdq 772 | idiv 773 | [r8+rbp] 774 | [r8+rax*4] 775 | [CONST_VAR+rcx*4] 776 | [CONST_VAR+rcx*8+CONST] 777 | [rax+CONST_VAR] 778 | [CONST_VAR+rbp] 779 | [rbp+rax*4+CONST] 780 | [r10+r11+CONST] 781 | [rdx+rdx] 782 | [rax+r12*4] 783 | [r9+r11+CONST] 784 | [rax+rbx*4] 785 | [rbp+rax*8+CONST] 786 | [rcx+CONST_VAR] 787 | div 788 | [r11+CONST_VAR+CONST] 789 | [r9+CONST_VAR+CONST] 790 | [rax+rdx*8] 791 | [r15+r13] 792 | [rcx+r12] 793 | [rbp+rbp*2+CONST] 794 | [r14+rdx*8] 795 | [r14+r12] 796 | [rcx+rsi*8] 797 | [r14+rax*8] 798 | [r9+rdx*4] 799 | punpckldq 800 | [r12+rdx*4] 801 | [r12+rax*4] 802 | [CONST_VAR+rdi*4] 803 | [r8+rax] 804 | setnbe 805 | [CONST_VAR+rcx*4+CONST] 806 | [r8+rcx*4] 807 | [r9+rcx*4] 808 | [CONST_VAR+CONST_VAR] 809 | [r9+r10+CONST] 810 | [CONST_VAR+rcx+CONST] 811 | [CONST_VAR+rdx*4] 812 | [r9+r9] 813 | [CONST_VAR+r9*4] 814 | [CONST_VAR+rsi*2] 815 | [CONST_VAR+r10*8] 816 | [CONST_VAR+rdx*8] 817 | [CONST_VAR+r8*8] 818 | [CONST_VAR+rsi*8] 819 | [rdx+rdx+CONST] 820 | [r13+r12+CONST] 821 | r8b 822 | [r13+r13*2+CONST] 823 | [r12+rax*4+CONST] 824 | [r12+rax*8] 825 | [r12+r13*8] 826 | [CONST_VAR+r9+CONST] 827 | [CONST_VAR+r9] 828 | [rbp+r12+CONST] 829 | cvtsi2sd 830 | mulsd 831 | cvttsd2si 832 | [r12+r13] 833 | addsd 834 | subsd 835 | btc 836 | [r8+CONST_VAR] 837 | neg 838 | [r13+r15+CONST] 839 | [rdx+rdx*4] 840 | [rcx+rsi*2+CONST] 841 | [rax+r13*4] 842 | [rax+rdx*4] 843 | cmovge 844 | cmovle 845 | [r8+r13] 846 | [r8+r15] 847 | [rcx+rdx] 848 | [rcx+r15] 849 | [rdx+rcx*4] 850 | [rax+rbx*8] 851 | [r13+rbx+CONST] 852 | adc 853 | [rcx+rdx*4] 854 | [rdx+rbp*8] 855 | [r8+rcx] 856 | [rcx+rdx+CONST] 857 | [rax+rdx+CONST] 858 | [r12+r15] 859 | bsr 860 | [rdx+rcx*8] 861 | [rax+rdi*4] 862 | [rax+rdi*8] 863 | [r15+CONST_VAR] 864 | [r15+CONST_VAR+CONST] 865 | [CONST_VAR+r8] 866 | [r8+rcx+CONST] 867 | [r9+r11*8] 868 | [r8+r12] 869 | [r8+rbx*8] 870 | [CONST_VAR+rbx*8] 871 | [CONST_VAR+CONST_VAR+CONST] 872 | [rax+rax*4] 873 | [CONST_VAR+r15+CONST] 874 | r15b 875 | [rbx+r14*8] 876 | [rbx+rax+CONST] 877 | [r14+CONST_VAR] 878 | [r14+CONST_VAR+CONST] 879 | [r14+rcx+CONST] 880 | r13b 881 | [r12+rbp] 882 | [r8+rdx*4+CONST] 883 | [rax+rbp*4] 884 | [r13+rbp*4+CONST] 885 | [r13+rax*4+CONST] 886 | [rbp+r15*4+CONST] 887 | [rsp+rax+CONST+var_xxx] 888 | [r11+rax*4] 889 | [r9+rax*4] 890 | [r12+rdx*8] 891 | [rsp+rax*4+CONST+var_xxx] 892 | r10b 893 | [rbx+r12*8] 894 | ud2 895 | [rax+rax*2+CONST] 896 | bt 897 | [rbp+rbx+CONST] 898 | [r10+r15] 899 | [r15+r10] 900 | [r15+rcx+CONST] 901 | r9b 902 | [r15+rcx] 903 | [rbp+r15+CONST] 904 | [rax+r10] 905 | r14b 906 | [r10+rdx] 907 | [rdx+r10] 908 | [rbp+r10+CONST] 909 | [r14+rdx] 910 | [r12+r14+CONST] 911 | [r12+rcx] 912 | mul 913 | [r12+CONST_VAR] 914 | [r12+rdx] 915 | [r12+rbx] 916 | [rsp+r8+CONST+var_xxx] 917 | [rsp+CONST_VAR+CONST+var_xxx] 918 | [r14+r9] 919 | [r14+rax] 920 | [r8+r8*2] 921 | [rdx+rcx*2+CONST] 922 | [rcx+rax*8] 923 | [rax+rbp*8] 924 | [rax+r12*8] 925 | [rcx+r12*4] 926 | [r12+r10] 927 | [r11+CONST] 928 | setnb 929 | r12b 930 | movmskpd 931 | xorpd 932 | jp 933 | dx 934 | fld 935 | fstp 936 | fst 937 | fchs 938 | fucomip 939 | st 940 | fabs 941 | fxch 942 | r10w 943 | [r14+rax*4] 944 | [CONST_VAR+r12*8] 945 | [CONST_VAR+rax*8] 946 | [r8+rdx*2+CONST] 947 | [r8+rdx*2] 948 | [CONST_VAR+r10] 949 | r11b 950 | [r8+rcx*2+CONST] 951 | [r8+rcx*2] 952 | [rax+rax*8] 953 | [r14+r14*8] 954 | [CONST_VAR+rdi*8] 955 | [rcx+rcx*8] 956 | [r13+rax*8+CONST] 957 | [r8+r8*8] 958 | [CONST_VAR+rsi*8+CONST] 959 | [rdx+rdx*8] 960 | [CONST_VAR+rbx] 961 | cx 962 | si 963 | jo 964 | mulss 965 | addss 966 | cvtsi2ss 967 | addps 968 | cvtps2pd 969 | cvtss2sd 970 | pcmpgtb 971 | pand 972 | por 973 | psrldq 974 | pcmpeqd 975 | pmaxub 976 | pminub 977 | pshuflw 978 | bh 979 | ah 980 | pmaxsw 981 | pminsw 982 | [CONST_VAR+rax*2] 983 | pextrw 984 | bp 985 | r13w 986 | psubusw 987 | paddw 988 | pcmpeqw 989 | pcmpgtd 990 | [rcx+r9*4] 991 | cvtdq2pd 992 | psubd 993 | minss 994 | maxss 995 | shufps 996 | [CONST_VAR+r11*8] 997 | [rdx+rdx*2] 998 | xmm8 999 | xmm9 1000 | packuswb 1001 | psrlw 1002 | [rbp+rcx*4+CONST] 1003 | [rax+rcx*8+CONST] 1004 | [rax+r10*8] 1005 | [rdx+r9] 1006 | [rax+r11*8] 1007 | [rdx+rdi*8] 1008 | [rbx+rax] 1009 | [r13+r15*4+CONST] 1010 | [r13+r15*8+CONST] 1011 | [r12+r15*8] 1012 | [r12+r14*8] 1013 | [rbx+r14] 1014 | [r14+r15] 1015 | [rbx+r12] 1016 | [rax+r14*8] 1017 | [rax+rcx+CONST] 1018 | [r13+rdx+CONST] 1019 | cvtpd2ps 1020 | cmovs 1021 | subpd 1022 | divpd 1023 | mulpd 1024 | addpd 1025 | divsd 1026 | andps 1027 | ucomiss 1028 | [rax+rsi*4] 1029 | [r8+rcx*8] 1030 | [rcx+r12*8] 1031 | [r12+r12*2] 1032 | [r8+rdi*8] 1033 | [r8+rax*8] 1034 | [rax+r13+CONST] 1035 | subss 1036 | mulps 1037 | [r15+rcx*8] 1038 | [r8+r10*8] 1039 | [r8+rdx*8] 1040 | cvtsd2ss 1041 | [r10+r10*2] 1042 | [CONST_VAR+r9*8] 1043 | [rbx+rcx*8] 1044 | [r11+rcx*4] 1045 | [r8+r11] 1046 | [rax+r10*4] 1047 | [CONST_VAR+r12+CONST] 1048 | divss 1049 | cvttss2si 1050 | [CONST_VAR+r8*8+CONST] 1051 | [r9+r9*2] 1052 | [rax+r8] 1053 | [rax+r8+CONST] 1054 | [rax+rdx*4+CONST] 1055 | [r8+rsi*8+CONST] 1056 | [rcx+rsi*4+CONST] 1057 | [r8+r9*8] 1058 | [r9+r8*4+CONST] 1059 | setp 1060 | sqrtss 1061 | xorps 1062 | [r14+rcx] 1063 | [rax+rdx*8+CONST] 1064 | [rsp+CONST+arg_xxx+CONST] 1065 | [rax+rsi*8] 1066 | extrn 1067 | extrn_xxx 1068 | inc 1069 | [rbx+rax*8] 1070 | dh 1071 | [r8+rsi*8] 1072 | [rdx+rax*8+CONST] 1073 | [rcx+rax*8+CONST] 1074 | [CONST_VAR+rax+CONST] 1075 | [CONST_VAR+rax*8+CONST] 1076 | [CONST_VAR+r10*8+CONST] 1077 | [CONST_VAR+r9*8+CONST] 1078 | [r12+rbx*8] 1079 | [rbp+r14*8+CONST] 1080 | [r10+rax*8] 1081 | [rcx+r13] 1082 | r12w 1083 | r14w 1084 | [CONST_VAR+rcx] 1085 | ch 1086 | [rbx+r9*8] 1087 | [r12+rbp*8+CONST] 1088 | [rsp+rdx+CONST+var_xxx] 1089 | [r8+rax*8+CONST] 1090 | [r8+rbx*8+CONST] 1091 | [r15+rbp*8+CONST] 1092 | [rdx+r13*4] 1093 | [r8+r8*2+CONST] 1094 | [rcx+rbp*4] 1095 | [r14+r13*8+CONST] 1096 | [CONST_VAR+rdx*8+CONST] 1097 | [r9+rcx*8+CONST] 1098 | [r9+rax*8] 1099 | [rcx+rbx] 1100 | [rdx+rbx] 1101 | [r9+rsi*8] 1102 | [rcx+r14] 1103 | [r9+rax*8+CONST] 1104 | [r9+r13*8+CONST] 1105 | [rax+r15*8] 1106 | [rbp+rbx*8+CONST] 1107 | [rsp+rbp+CONST+var_xxx] 1108 | movddup 1109 | [r15+rbx*8+CONST] 1110 | [rsp+rcx+CONST+var_xxx] 1111 | [rdx+r12*4] 1112 | [r12+rax*8+CONST] 1113 | [r12+rbx*8+CONST] 1114 | [rbx+r15] 1115 | [rbx+r15+CONST] 1116 | [rcx+rcx] 1117 | [rbx+r14+CONST] 1118 | [r11+rax] 1119 | [r12+rbx+CONST] 1120 | [r10+r10*4] 1121 | [r9+r10*4] 1122 | [rbx+r12+CONST] 1123 | [rbp+r14+CONST] 1124 | [rbp+rdi*4+CONST] 1125 | [rbx+CONST_VAR] 1126 | paddq 1127 | [rdx+rsi*8+CONST] 1128 | [CONST_VAR+r14*8] 1129 | [rcx+rbp*8] 1130 | btr 1131 | [r15+r14+CONST] 1132 | [r15+rax*8] 1133 | [rbx+rdx] 1134 | [r12+r14*8+CONST] 1135 | [r10+rax*8+CONST] 1136 | [r10+rdx*8+CONST] 1137 | [r8+rcx*8+CONST] 1138 | [r14+rbp*8] 1139 | [CONST_VAR+rbp*8] 1140 | [r14+rbx*8] 1141 | [r13+r12*8+CONST] 1142 | [r14+rbx*8+CONST] 1143 | [CONST_VAR+r11*8+CONST] 1144 | [rcx+rdx*8+CONST] 1145 | [rbp+r15*8+CONST] 1146 | [r14+r15*8] 1147 | setnl 1148 | cmpnltsd 1149 | [r14+r15*4] 1150 | cmpnlesd 1151 | [rcx+rcx*2+CONST] 1152 | [rax+r9*8] 1153 | [r15+rdx*4] 1154 | [r12+rdx+CONST] 1155 | sqrtsd 1156 | [rsp+rax*8+CONST+var_xxx] 1157 | [r15+rbx*4+CONST] 1158 | [r12+rbx*4] 1159 | [rbp+rbx*4+CONST] 1160 | [r13+r14*4+CONST] 1161 | [r13+rbx*4+CONST] 1162 | [r12+r15*4+CONST] 1163 | [r12+r14*4+CONST] 1164 | [rbx+rax*4] 1165 | [rbx+rdx*4] 1166 | [r15+rax*4+CONST] 1167 | [r15+rsi*4+CONST] 1168 | [r13+CONST_VAR+CONST] 1169 | [r15+rcx*4+CONST] 1170 | [r13+rsi*4+CONST] 1171 | [r13+r8*4+CONST] 1172 | [r13+rcx*4+CONST] 1173 | cvttps2dq 1174 | cmpltps 1175 | punpckhdq 1176 | [r14+rdx*4] 1177 | [rbp+rsi*4+CONST] 1178 | [rbp+rdx*4+CONST] 1179 | [r11+r11*2] 1180 | [r15+r11+CONST] 1181 | [r11+r15+CONST] 1182 | [rbp+r11*4+CONST] 1183 | [r15+rdx+CONST] 1184 | [rdx+r15+CONST] 1185 | [r15+r15*2] 1186 | [rdx+rax*4+CONST] 1187 | [rcx+r15+CONST] 1188 | [r15+rax+CONST] 1189 | [rax+r15+CONST] 1190 | [rbp+r9*4+CONST] 1191 | [rbp+r8*4+CONST] 1192 | [CONST_VAR+rsi*4+CONST] 1193 | [rcx+rdi*8] 1194 | [r10+rdi*4] 1195 | [rcx+r8+CONST] 1196 | [r10+r9+CONST] 1197 | [rdx+r8+CONST] 1198 | [rax+r9+CONST] 1199 | [rdx+r12*8] 1200 | [rax+rbp] 1201 | [r10+rsi*4] 1202 | [r9+rsi*4] 1203 | [r8+rsi*4] 1204 | [r12+rsi*4] 1205 | [r14+rsi*4] 1206 | [r11+rsi*4] 1207 | [rbx+rsi*4] 1208 | [r15+rsi*4] 1209 | [r11+rbp*8] 1210 | [CONST_VAR+rsi*4] 1211 | [r15+rax] 1212 | [r15+r8] 1213 | [r8+r14] 1214 | [rbx+r12*8+CONST] 1215 | [r15+rax*8+CONST] 1216 | [r14+rax*8+CONST] 1217 | [r12+r9] 1218 | [rbp+rsi*8+CONST] 1219 | [r8+rdx*4] 1220 | [r14+r12*4] 1221 | [rbx+r12*4] 1222 | [r14+r13*4] 1223 | [r13+r12*4+CONST] 1224 | [rbx+r13*4] 1225 | [r15+r14*4] 1226 | [rbp+CONST_VAR] 1227 | [r14+r12*8] 1228 | [r15+r12*4] 1229 | [rcx+r8] 1230 | [rbp+rcx+CONST] 1231 | setle 1232 | [rbp+rdi*8+CONST] 1233 | [r11+rcx] 1234 | [r9+rcx] 1235 | [rbp+rcx*8+CONST] 1236 | subps 1237 | unpckhps 1238 | [rbp+rax*2+CONST] 1239 | [r13+rcx+CONST] 1240 | [rbx+rcx+CONST] 1241 | [rbp+CONST_VAR+CONST] 1242 | [r13+rbx*8+CONST] 1243 | [rax+rbx*8+CONST] 1244 | psrad 1245 | punpcklwd 1246 | punpckhwd 1247 | psubb 1248 | [rbx+r10+CONST] 1249 | [rbx+r10] 1250 | [r9+CONST_VAR] 1251 | [rbx+r11+CONST] 1252 | [rbx+r11] 1253 | [r13+r9+CONST] 1254 | [rax+r13*8+CONST] 1255 | cmpltsd 1256 | [rbx+CONST_VAR+CONST] 1257 | [rax+rax*4+CONST] 1258 | [CONST_VAR+rsi*2+CONST] 1259 | [CONST_VAR+rdi*2] 1260 | [r12+rsi*8] 1261 | [rsp+rax+CONST+CONST_VAR] 1262 | [rax+rbx+CONST] 1263 | shufpd 1264 | [rbp+rax+var_xxx] 1265 | [rbp+rdx+var_xxx] 1266 | [rbx+rax*8+CONST] 1267 | [rax+CONST_VAR+CONST] 1268 | jnp 1269 | [rdx+CONST_VAR] 1270 | setnp 1271 | [rdx+r14*8] 1272 | [rax+r14*4] 1273 | [rdx+r15] 1274 | [rdx+rax+CONST] 1275 | [rbp+r13+CONST] 1276 | [rdx+rcx*8+CONST] 1277 | [r13+r8*8+CONST] 1278 | [r13+rsi*8+CONST] 1279 | [rbx+rbx] 1280 | [rcx+rbx*2] 1281 | [rdx+rdi*4] 1282 | [rax+rdx*2] 1283 | [rax+rbx*2] 1284 | [r14+r14] 1285 | [rbp+rbp+CONST] 1286 | [rcx+rbp*2] 1287 | [rdx+rax*2] 1288 | [r8+r9] 1289 | [r8+r9+CONST] 1290 | [rcx+r9] 1291 | [rax+r9] 1292 | [r11+CONST_VAR] 1293 | [r9+r11] 1294 | [rcx+r11] 1295 | [rax+r11] 1296 | [rcx+r9+CONST] 1297 | [r10+CONST_VAR] 1298 | [rcx+r10] 1299 | [rcx+r10+CONST] 1300 | [rdx+r8] 1301 | [r9+r10] 1302 | [rcx+rax*2] 1303 | r11w 1304 | [r15+rdx] 1305 | r15w 1306 | [CONST_VAR+rbx+CONST] 1307 | [r8+rbx] 1308 | [r10+rax+CONST] 1309 | [rbx+rcx] 1310 | [rdx+rcx*2] 1311 | [r9+r12] 1312 | [rdx+r9+CONST] 1313 | [r11+rdx+CONST] 1314 | [rcx+CONST_VAR+CONST] 1315 | bx 1316 | [r15+rbx] 1317 | [r15+rbp*8] 1318 | [r15+r14*8] 1319 | [r15+rsi*8] 1320 | [r14+r14+CONST] 1321 | [rbx+rbp*8] 1322 | [rax+r14*2] 1323 | [r9+rax*2] 1324 | r8w 1325 | [CONST_VAR+rdx*2] 1326 | [r9+rdx*2] 1327 | [rcx+r8*8] 1328 | [r10+rcx] 1329 | [r8+r10] 1330 | [r8+r10+CONST] 1331 | [r9+rax] 1332 | [r11+r9] 1333 | [rdx+r10+CONST] 1334 | [r8+rax*4+CONST] 1335 | [r10+rbp+CONST] 1336 | [rbx+rbp+CONST] 1337 | [rdx+r9*4] 1338 | [rcx+r14*4] 1339 | [rcx+r13*4] 1340 | [rcx+r15*4] 1341 | [r10+rax*4] 1342 | [r10+rax] 1343 | [r9+rbp+CONST] 1344 | [r9+rbx+CONST] 1345 | [r10+r12+CONST] 1346 | [r8+r13*4] 1347 | [r9+r12*4] 1348 | [r8+r15*4] 1349 | [r9+r15*4] 1350 | [rbx+rax*4+CONST] 1351 | [r11+rax+CONST] 1352 | [CONST_VAR+r10*4] 1353 | [rdx+r14+CONST] 1354 | [r11+r13+CONST] 1355 | [r10+rax*4+CONST] 1356 | [r9+rax*4+CONST] 1357 | [r11+rbp+CONST] 1358 | [rcx+r11*4] 1359 | [CONST_VAR+r11*4] 1360 | [r13+r11*4+CONST] 1361 | [r12+r11*4+CONST] 1362 | [CONST_VAR+rbx*4] 1363 | [r10+rcx+CONST] 1364 | [r12+rsi*4+CONST] 1365 | [r8+r14*4] 1366 | [r8+rbp*4] 1367 | [r9+rbp*4] 1368 | [r9+r13*4] 1369 | [r12+rcx*4+CONST] 1370 | [r14+rax+CONST] 1371 | [r15+rcx*4] 1372 | [rbx+rcx*4] 1373 | [r11+rax*4+CONST] 1374 | [r15+r10*4] 1375 | [rbx+r11*4] 1376 | [r8+rax+CONST] 1377 | [r10+rdx*4] 1378 | [r14+r8*4] 1379 | [rbx+r8*4] 1380 | [r15+r8*4] 1381 | [r15+rdi*4] 1382 | [rdx+r10*4] 1383 | [rcx+rax*4+CONST] 1384 | [r13+r10*4+CONST] 1385 | [rdx+r8*4] 1386 | [rax+r11*4] 1387 | [rdx+r11*4] 1388 | [rcx+r10*4] 1389 | [r15+r13*4+CONST] 1390 | [rax+r9*4] 1391 | [r12+r13*4+CONST] 1392 | [r10+rbx+CONST] 1393 | [r14+rbx] 1394 | [r8+rbx*4] 1395 | [r9+rbx*4] 1396 | [r8+r10*4] 1397 | [r9+rdi*4] 1398 | [r12+rcx*4] 1399 | [r12+r11*4] 1400 | [rdx+r12+CONST] 1401 | [rdx+CONST_VAR+CONST] 1402 | [rdx+r13+CONST] 1403 | [r12+r13*4] 1404 | [r15+r13*4] 1405 | [rdx+rbp*4] 1406 | [r14+rbp*4] 1407 | [r15+rbp*4] 1408 | [rbx+rbp*4] 1409 | [rsp+r8*4+CONST+var_xxx] 1410 | r9w 1411 | di 1412 | [rdx+r14] 1413 | [rax+r14+CONST] 1414 | [rbx+rcx*4+CONST] 1415 | [rbx+rsi*4+CONST] 1416 | cmpnless 1417 | andnps 1418 | orps 1419 | [rbx+r8*4+CONST] 1420 | [rbx+r11*4+CONST] 1421 | [rbx+rdi*4+CONST] 1422 | [rbx+rdx*4+CONST] 1423 | [rbx+r10*4+CONST] 1424 | [r9+rbx] 1425 | [r13+rbp*8+CONST] 1426 | [r13+rcx*8+CONST] 1427 | [CONST_VAR+r8+CONST] 1428 | [r9+r14*8+CONST] 1429 | [rsp+rbx+CONST+var_xxx] 1430 | [rcx+rbp] 1431 | [rbx+r14*8+CONST] 1432 | [rbp+r12*4+CONST] 1433 | [r8+rbp*8] 1434 | [r12+rbp*8] 1435 | [rax+rdx*2+CONST] 1436 | [r14+rbx+CONST] 1437 | [rcx+rdx*2+CONST] 1438 | bswap 1439 | hlt 1440 | [rsp+CONST_VAR+CONST+CONST_VAR] 1441 | [rsp+rdx+CONST+CONST_VAR] 1442 | [r12+rcx*8] 1443 | [r12+rcx*8+CONST] 1444 | [r13+rdx*4+CONST] 1445 | [r12+rbx*4+CONST] 1446 | [CONST_VAR+rax*2+CONST] 1447 | [CONST_VAR+rcx*2+CONST] 1448 | [rbx+rdx+CONST] 1449 | [rax+rbx*4+CONST] 1450 | [rsp+r13*4+CONST+var_xxx] 1451 | pslldq 1452 | [rbx+rbx*4] 1453 | [rcx+rdx*2] 1454 | [r13+r8+CONST] 1455 | [rax+r10+CONST] 1456 | [rbx+r8] 1457 | [r14+rax*2+CONST] 1458 | [rbx+rbp] 1459 | [r8+rdx+CONST] 1460 | [r8+CONST_VAR+CONST] 1461 | [r9+r13+CONST] 1462 | [r8+r8*4] 1463 | [rax+rbp*2+CONST] 1464 | [r8+rax*2+CONST] 1465 | [r8+r9*2+CONST] 1466 | [r12+r12*4] 1467 | [rbx+rax*2+CONST] 1468 | [rbx+rdx*8+CONST] 1469 | [rbx+r12*4+CONST] 1470 | bts 1471 | [rsp+r14*4+CONST+var_xxx] 1472 | [rax+r12+CONST] 1473 | [rsp+rbx+CONST+CONST_VAR] 1474 | [CONST_VAR+rcx*2] 1475 | [r8+rbp*8+CONST] 1476 | [r13+rdx*8+CONST] 1477 | [rcx+rax*2+CONST] 1478 | [rdx+rax*2+CONST] 1479 | [rbx+r13] 1480 | [r12+r15+CONST] 1481 | [rdx+rdi*2] 1482 | [rdx+r11] 1483 | [r12+r14] 1484 | [r9+r9*4] 1485 | [rcx+rcx*4] 1486 | [r11+rdx*8] 1487 | [r15+r13*8] 1488 | [rbx+rcx*8+CONST] 1489 | [rcx+rbx*8+CONST] 1490 | [rbx+rbp*8+CONST] 1491 | [r12+rsi*8+CONST] 1492 | [r12+rdx*8+CONST] 1493 | [rbx+rsi*8+CONST] 1494 | [rax+r14*8+CONST] 1495 | [r9+r8] 1496 | [rdx+r9*8] 1497 | [r8+rdi*8+CONST] 1498 | [r12+r9*4] 1499 | [rbx+r14*4] 1500 | [rbx+r14*4+CONST] 1501 | [rcx+rsi*8+CONST] 1502 | [r13+r13+CONST] 1503 | [rbx+rbx+CONST] 1504 | [rbp+r13*8+CONST] 1505 | [r15+rdx*8] 1506 | [r12+CONST_VAR+CONST] 1507 | [rsp+r14+CONST+var_xxx] 1508 | [r8+r8+CONST] 1509 | [CONST_VAR+rbx*8+CONST] 1510 | [CONST_VAR+r14*8+CONST] 1511 | [rcx+r13*8+CONST] 1512 | [rcx+r14*8+CONST] 1513 | [rdx+r14*4] 1514 | [r15+rdx*8+CONST] 1515 | [rax+rbp*8+CONST] 1516 | [rdx+r14*8+CONST] 1517 | [rcx+r12*8+CONST] 1518 | [rcx+r15*8+CONST] 1519 | [rax+r12*8+CONST] 1520 | [r14+r15*8+CONST] 1521 | movlpd 1522 | [r12+r13*8+CONST] 1523 | [rdx+rbx*8+CONST] 1524 | [rbp+rbp*4+CONST] 1525 | [CONST_VAR+r13*8] 1526 | [rdx+rcx+CONST] 1527 | [rbx+rdx*8] 1528 | [rbx+rsi*8] 1529 | [r12+rdi*8] 1530 | [rdx+r8*8] 1531 | [rdx+r15*8] 1532 | [r15+r15*4] 1533 | [rax+r8*8] 1534 | [r10+rcx*8] 1535 | [r10+rdx*8] 1536 | [r11+r8*8] 1537 | [r11+rcx*8] 1538 | [r11+rsi*8] 1539 | [r8+r8] 1540 | [rcx+r8*4] 1541 | [r11+r10*8+CONST] 1542 | [r13+r9*4+CONST] 1543 | [rbx+r8*8] 1544 | [CONST_VAR+r8*4] 1545 | [r8+r11*8] 1546 | [rbp+r12*8+CONST] 1547 | [rbx+r13*8] 1548 | [rbp+r9*8+CONST] 1549 | [rbx+r10*8] 1550 | [r15+r9*8] 1551 | [r15+r10*8] 1552 | [rcx+r9*8] 1553 | [r12+r9*8] 1554 | [r10+rsi*8] 1555 | [r10+rdi*8] 1556 | [r10+r9*8] 1557 | [r10+r14] 1558 | [r9+rcx*8] 1559 | [r11+r9*8] 1560 | [r9+r10*8] 1561 | [r12+r12] 1562 | [r9+r12*8] 1563 | [r14+r8*8] 1564 | [r9+rdi*8] 1565 | [r9+r8*8] 1566 | [r9+r14] 1567 | [r12+r11*8] 1568 | [r13+rdi*8+CONST] 1569 | [r14+rdi*8] 1570 | [r14+rcx*8] 1571 | [r10+r8*8+CONST] 1572 | [r9+r15] 1573 | [r11+r11*4] 1574 | [r11+rdx*8+CONST] 1575 | [rax+r10*8+CONST] 1576 | [r11+r10*8] 1577 | [r11+rdi*8] 1578 | [r11+rbx*8] 1579 | [r14+r11*8] 1580 | [r11+r13*8] 1581 | [CONST_VAR+rbp*4] 1582 | [CONST_VAR+r15*4] 1583 | [CONST_VAR+r15*8] 1584 | [rcx+r15*8] 1585 | [r9+rdx*8+CONST] 1586 | [r9+r15*8+CONST] 1587 | [r14+rbp*8+CONST] 1588 | [r12+r12+CONST] 1589 | [r10+rcx*8+CONST] 1590 | [r13+r14*8+CONST] 1591 | [r11+rax*8+CONST] 1592 | [rsp+r15+CONST+var_xxx] 1593 | [r9+r12*8+CONST] 1594 | [r11+rbp*8+CONST] 1595 | [r9+r8+CONST] 1596 | [rcx+rdi*2] 1597 | err 1598 | rol 1599 | [r12+rdx*4+CONST] 1600 | [r12+rdi*4+CONST] 1601 | [r15+rbx+CONST] 1602 | [r15+rbp+CONST] 1603 | [r14+r13*8] 1604 | punpckhqdq 1605 | [rax+rsi*8+CONST] 1606 | [CONST_VAR+r12*8+CONST] 1607 | [r12+r13+CONST] 1608 | [rbp+r8*8+CONST] 1609 | [r8+r10*2] 1610 | [rax+rcx*2+CONST] 1611 | [rax+rdi*2+CONST] 1612 | [rdx+rsi*2+CONST] 1613 | [CONST_VAR+rdi*2+CONST] 1614 | [rsp+rax*8+CONST+CONST_VAR] 1615 | [rsp+rcx*8+CONST+CONST_VAR] 1616 | [rcx+r14*8] 1617 | [rbx+r15*8] 1618 | [rax+rbp*2] 1619 | [CONST_VAR+r9*2] 1620 | [r9+r10*2] 1621 | [rax+r9*2] 1622 | [rax+rdi*2] 1623 | [CONST_VAR+r8*2] 1624 | [rax+r8*4] 1625 | [r12+r12*4+CONST] 1626 | [r12+r12*2+CONST] 1627 | [r14+r15+CONST] 1628 | fldz 1629 | fcomip 1630 | fnstcw 1631 | fmul 1632 | fsub 1633 | fldcw 1634 | fistp 1635 | fsubrp 1636 | fld1 1637 | fdivp 1638 | fmulp 1639 | fucomi 1640 | fcomi 1641 | fdiv 1642 | [rsp+rax+CONST+var_xxx+CONST] 1643 | fdivrp 1644 | [r13+r13*4+CONST] 1645 | [rdx+rsi*2] 1646 | [r8+rsi*2+CONST] 1647 | [r11+r8] 1648 | [rsp+arg_xxx+CONST] 1649 | [r8+r15*8+CONST] 1650 | [r15+r14*8+CONST] 1651 | [r12+r15*8+CONST] 1652 | [r10+r14*8+CONST] 1653 | [rbp+arg_xxx] 1654 | [r11+rdx*4] 1655 | [CONST_VAR+r10+CONST] 1656 | [rcx+r12+CONST] 1657 | [rdx+rsi*4+CONST] 1658 | [rax+rcx*4+CONST] 1659 | [r13+rdi*4+CONST] 1660 | [CONST_VAR+r14*4] 1661 | [r11+r13] 1662 | [r10+rbx] 1663 | [r12+r10*4] 1664 | [r14+rcx*4] 1665 | [r12+r8] 1666 | [rbx+rdi*4] 1667 | [rcx+rdx*4+CONST] 1668 | [CONST_VAR+rdx*4+CONST] 1669 | [rdx+rcx*4+CONST] 1670 | [r11+r14*4] 1671 | [rcx+r8*4+CONST] 1672 | [r10+r13] 1673 | [r10+r13*4] 1674 | [r12+rbp*4] 1675 | [rbx+r13*8+CONST] 1676 | jno 1677 | [r10+r13*8+CONST] 1678 | [r15+r13*8+CONST] 1679 | [CONST_VAR+r13*4] 1680 | [r15+r12*8+CONST] 1681 | [r14+rcx*8+CONST] 1682 | [r10+r12*8+CONST] 1683 | [r14+r10*8+CONST] 1684 | [r15+rsi*8+CONST] 1685 | [rsp+r12+CONST+var_xxx] 1686 | [r10+rbp*8+CONST] 1687 | [rsp+r13+CONST+var_xxx] 1688 | locretxxx 1689 | [rbx+r15*8+CONST] 1690 | [r14+rsi*8+CONST] 1691 | [r14+r12*8+CONST] 1692 | [r11+r14*8+CONST] 1693 | [rdx+rbp*8+CONST] 1694 | [CONST_VAR+r13*8+CONST] 1695 | [r9+rbx*8+CONST] 1696 | [r15+rcx*8+CONST] 1697 | [r11+rbx*8+CONST] 1698 | [r8+r13*8+CONST] 1699 | [r14+rdx+CONST] 1700 | [rbx+rdx*2] 1701 | [CONST_VAR+rdx*2+CONST] 1702 | [rax+rbx*2+CONST] 1703 | [r11+r11+CONST] 1704 | [CONST_VAR+r11] 1705 | [rax+rcx*2] 1706 | [r15+r15] 1707 | [rax+r15*2] 1708 | [r14+r12+CONST] 1709 | [r13+rdx*2+CONST] 1710 | [rcx+rcx+CONST] 1711 | [rbx+rcx*2] 1712 | [r13+rax*2+CONST] 1713 | [rbx+rax*2] 1714 | [CONST_VAR+r13+CONST] 1715 | [rcx+rbx+CONST] 1716 | [rax+rsi*2+CONST] 1717 | [rbp+r8+CONST] 1718 | [CONST_VAR+r9*2+CONST] 1719 | [rbx+r9+CONST] 1720 | [r14+r14*4] 1721 | [rax+r8*2+CONST] 1722 | [CONST_VAR+r8*2+CONST] 1723 | [rbx+r8+CONST] 1724 | [rdx+rdx*2+CONST] 1725 | [rcx+r10*8] 1726 | [r10+r14*8] 1727 | [rdx+r11*8] 1728 | [r9+r9+CONST] 1729 | [r9+rdx*8] 1730 | [rax+rsi*2] 1731 | [r8+rax*2] 1732 | [r8+r11*4] 1733 | [rcx+rdi*4] 1734 | [rcx+rdi*4+CONST] 1735 | [rbp+r9+CONST] 1736 | punpcklbw 1737 | pcmpeqb 1738 | paddb 1739 | punpckhbw 1740 | psllw 1741 | psrld 1742 | pslld 1743 | [rcx+r13+CONST] 1744 | [rsp+rsi*8+CONST+var_xxx] 1745 | [rsp+rdx+CONST+CONST_VAR+CONST] 1746 | [rdx+r8*8+CONST] 1747 | [rdx+rbx+CONST] 1748 | [r12+rcx+CONST] 1749 | [r8+r14+CONST] 1750 | [r9+rdx+CONST] 1751 | [r9+rax+CONST] 1752 | [r12+r8*8] 1753 | [r13+r10+CONST] 1754 | [r14+r9*4] 1755 | [r15+r9] 1756 | [r13+r11+CONST] 1757 | [rcx+rbp+CONST] 1758 | [r8+r9*4] 1759 | [rax+r13*2] 1760 | [r14+rsi*8] 1761 | [rbp+rbx*2+CONST] 1762 | ss:xxx 1763 | cwde 1764 | [r14+rsi*2] 1765 | pinsrw 1766 | [r10+r8] 1767 | [r9+rcx+CONST] 1768 | [r14+r8] 1769 | [r13+rbx*2+CONST] 1770 | [rbp+rax*8+var_xxx] 1771 | [rbp+rdx*8+var_xxx] 1772 | [rbp+rsi*8+var_xxx] 1773 | [rbp+rcx*8+var_xxx] 1774 | [rax+r15*4] 1775 | [r14+r14*2] 1776 | [r10+r12] 1777 | [rbx+r13+CONST] 1778 | [r15+r12+CONST] 1779 | [rcx+rcx*4+CONST] 1780 | [CONST_VAR+rbp+CONST] 1781 | [r12+r11] 1782 | [r12+rax*2+CONST] 1783 | [rdx+r13*2+CONST] 1784 | [rdx+r11+CONST] 1785 | [r10+r10] 1786 | [r11+r11] 1787 | [rax+r11+CONST] 1788 | [rax+rsi*4+CONST] 1789 | [r10+CONST_VAR+CONST] 1790 | [r14+r11] 1791 | [rsp+r12*8+CONST+var_xxx] 1792 | [rbp+CONST_VAR+var_xxx] 1793 | [rsp+rbx*8+CONST+var_xxx] 1794 | [rsp+rcx*8+CONST+var_xxx] 1795 | [r15+rdi*8] 1796 | [rbp+rcx+var_xxx] 1797 | [rbp+rax+var_xxx+CONST] 1798 | leave 1799 | rdsspq 1800 | incsspq 1801 | [rbx+r8*8+CONST] 1802 | [r11+r10] 1803 | pmuludq 1804 | psllq 1805 | [r8+r12*4] 1806 | [CONST_VAR+r12*4] 1807 | [rsp+rbp*8+CONST+var_xxx] 1808 | [r10+r11*8] 1809 | [r10+r8*8] 1810 | [rsp+CONST_VAR+CONST+var_xxx+CONST] 1811 | [rsp+r8+CONST+var_xxx+CONST] 1812 | [rsp+r11+CONST+var_xxx] 1813 | [rsp+r11+CONST+var_xxx+CONST] 1814 | [rsp+r12+CONST+var_xxx+CONST] 1815 | [rsp+rdx+CONST+var_xxx+CONST] 1816 | [rsp+rdi*8+CONST+var_xxx+CONST] 1817 | [rsp+r11*8+CONST+var_xxx+CONST] 1818 | [rsp+CONST_VAR+var_xxx] 1819 | [r15+rdx*4+CONST] 1820 | [r8+r12*8] 1821 | [r8+r11*8+CONST] 1822 | [CONST_VAR+r8*4+CONST] 1823 | [rsp+rsi*4+CONST+var_xxx] 1824 | [rsp+rdi*8+CONST+var_xxx] 1825 | [r9+rbx*8] 1826 | [r12+rbp+CONST] 1827 | [r9+r15+CONST] 1828 | [rbp+r11+CONST] 1829 | cmplesd 1830 | [rcx+r13*8] 1831 | [rax+r8*4+CONST] 1832 | [rsp+rdi*4+CONST+var_xxx] 1833 | [rsp+rdx*4+CONST+var_xxx] 1834 | [r14+rax*4+CONST] 1835 | movsb 1836 | [r9+rcx*2] 1837 | [r9+rsi*2] 1838 | [rbx+rbx*2+CONST] 1839 | [r15+rbp*4+CONST] 1840 | cvttpd2dq 1841 | [r14+r14*2+CONST] 1842 | [r8+r10*8+CONST] 1843 | [r14+rdi*8+CONST] 1844 | [rbp+rcx*2+CONST] 1845 | [r12+rcx*2] 1846 | [r10+rsi*4+CONST] 1847 | [r11+rdx] 1848 | [r11+rcx+CONST] 1849 | [rcx+r9*4+CONST] 1850 | [r8+r13+CONST] 1851 | [rsp+rsi*4+CONST+var_xxx+CONST] 1852 | cvtdq2ps 1853 | ror 1854 | [r8+rdx*8+CONST] 1855 | psrlq 1856 | [rcx+r11*8] 1857 | [rdx+r10*8] 1858 | [rax+rdi*4+CONST] 1859 | [rdx+r15*4] 1860 | cmpnltss 1861 | [r15+r13+CONST] 1862 | [rax+r8*8+CONST] 1863 | [r13+r12*2+CONST] 1864 | [r10+rcx*4] 1865 | pcmpgtw 1866 | [rbx+r9*8+CONST] 1867 | [rdx+r8*2+CONST] 1868 | [rsp+r15+CONST+CONST_VAR] 1869 | [rsp+r13+CONST+CONST_VAR] 1870 | [r10+rsi*8+CONST] 1871 | [r14+rax*2] 1872 | [rax+r13*2+CONST] 1873 | [rdx+rdi*2+CONST] 1874 | [rbx+r9] 1875 | [r11+r15] 1876 | [rdx+r13*8+CONST] 1877 | [rbp+r15*2+CONST] 1878 | [r12+r10*8+CONST] 1879 | [rbx+rdi*8] 1880 | [CONST_VAR+rdi*8+CONST] 1881 | [r14+rdx*8+CONST] 1882 | [r8+r12*8+CONST] 1883 | [r10+r12*4] 1884 | [r10+rbp*4] 1885 | [r9+r11*4] 1886 | [r12+rdi*8+CONST] 1887 | [r14+r10*4] 1888 | [rbp+r13*4+CONST] 1889 | [r12+r15*4] 1890 | [r11+r9*4] 1891 | [r11+r11*2+CONST] 1892 | [rax+r12*2+CONST] 1893 | [r15+rax*2] 1894 | [r11+rbp*4] 1895 | [r11+r12*4] 1896 | [rsp+rcx*4+CONST+var_xxx] 1897 | [rcx+r14+CONST] 1898 | [rsp+rbx*4+CONST+var_xxx] 1899 | [rbp+rbx*8+var_xxx] 1900 | [rbp+r12*8+var_xxx] 1901 | [rsp+r13*8+CONST+var_xxx] 1902 | [r9+r11*8+CONST] 1903 | [rsp+r14*8+CONST+var_xxx] 1904 | [r10+rdi*8+CONST] 1905 | [r10+rbx*8+CONST] 1906 | [r9+rbp*8+CONST] 1907 | [r8+r14*8+CONST] 1908 | [r10+r15*8+CONST] 1909 | [r10+rcx*4+CONST] 1910 | [r10+rdx*4+CONST] 1911 | [r11+rbx] 1912 | [r12+rdi*4] 1913 | [r12+r14*4] 1914 | [r15+r12*8] 1915 | [r11+rdx*2] 1916 | [rbx+rdi*8+CONST] 1917 | [r11+r15*8+CONST] 1918 | [r9+r10*8+CONST] 1919 | [r10+r10*8] 1920 | [r9+r9*8] 1921 | [r11+r11*8] 1922 | [r9+r13*8] 1923 | [r12+r10*8] 1924 | [rbx+rbx*8] 1925 | [r15+r11*4] 1926 | divps 1927 | [rbp+r14*4+CONST] 1928 | [rbx+r15*4] 1929 | [r14+r10+CONST] 1930 | [rbx+r10*4] 1931 | [r9+rdx*4+CONST] 1932 | [r15+r11*8] 1933 | [r10+r11*4] 1934 | [r10+r8*4] 1935 | [rsp+r15*8+CONST+arg_xxx] 1936 | [r13+r9*8+CONST] 1937 | [r11+r8*4] 1938 | clc 1939 | [r11+rdi*4] 1940 | [r11+r12*8] 1941 | [CONST_VAR+r12*4+CONST] 1942 | [rdx+r12*4+CONST] 1943 | [rax+r10*4+CONST] 1944 | [rax+r15*4+CONST] 1945 | [r10+rdx+CONST] 1946 | [r12+r12*8] 1947 | [rax+r13*4+CONST] 1948 | [rcx+rbp*4+CONST] 1949 | [CONST_VAR+r14*4+CONST] 1950 | [rcx+rbp*8+CONST] 1951 | [rbp+rbp*8+CONST] 1952 | [r13+r10*8+CONST] 1953 | [r11+r10+CONST] 1954 | [r15+rax*2+CONST] 1955 | [r10+rbp] 1956 | pmullw 1957 | [r10+r9] 1958 | [r12+r8*4] 1959 | [r12+rbx*2] 1960 | [r12+r8*2] 1961 | [r12+r14*2] 1962 | [r12+rdx*2] 1963 | [r13+r8*2+CONST] 1964 | [r13+r14*2+CONST] 1965 | [r12+rax*2] 1966 | [r10+r13*8] 1967 | cmpltpd 1968 | [r10+rdi*4+CONST] 1969 | [rsp+r10*4+CONST+var_xxx] 1970 | [rsp+r9*4+CONST+var_xxx] 1971 | [rsp+r15*4+CONST+var_xxx] 1972 | [rax+r15*8+CONST] 1973 | [r9+rbp] 1974 | [r15+r11] 1975 | [r10+r11] 1976 | [rdx+rbx*2] 1977 | [rdx+r13*2] 1978 | [rdx+r11*2] 1979 | [rdx+r12*2] 1980 | [rdx+r14*2] 1981 | [rdx+rbp*2] 1982 | [r11+rbp*2] 1983 | [r11+rsi*2] 1984 | [r11+r10*2] 1985 | [rbx+r14*2] 1986 | [r15+r10*2] 1987 | [r13+r15*2+CONST] 1988 | [r14+r15*2] 1989 | [rbx+r11*2] 1990 | [CONST_VAR+rbp*2] 1991 | [CONST_VAR+r10*2] 1992 | [CONST_VAR+r11*2] 1993 | [r11+r13*8+CONST] 1994 | [rdx+r12*8+CONST] 1995 | [r9+r8*4] 1996 | [r14+r8+CONST] 1997 | [r14+r9*8] 1998 | [r8+rcx*4+CONST] 1999 | [rsp+rdx*4+CONST+CONST_VAR] 2000 | [r8+r13*8] 2001 | [rax+rdi*8+CONST] 2002 | [r14+rbp+CONST] 2003 | [r15+rbp*2] 2004 | [rdx+r14*4+CONST] 2005 | [rcx+r14*4+CONST] 2006 | [rbx+r13*2] 2007 | [rbp+r10*4+CONST] 2008 | [r8+rsi*2] 2009 | [r8+rdi*4] 2010 | pause 2011 | [r11+rbx*4] 2012 | [r10+rbx*4] 2013 | [rax+r9*4+CONST] 2014 | maxps 2015 | [r14+rbx*4+CONST] 2016 | [r8+r9*8+CONST] 2017 | [r15+r11*4+CONST] 2018 | [r15+rdi*4+CONST] 2019 | [r15+r10*4+CONST] 2020 | [r15+r8*4+CONST] 2021 | [r10+rax*2] 2022 | [r15+rdx*2] 2023 | [r14+r10*8] 2024 | [r15+r9*4] 2025 | [r11+r8+CONST] 2026 | [r8+r11+CONST] 2027 | [r8+rbx+CONST] 2028 | [r13+r11*8+CONST] 2029 | [CONST_VAR+r11+CONST] 2030 | [rax+r8*2] 2031 | [rbx+rbp*2] 2032 | [r15+r15+CONST] 2033 | [r14+rbp*2] 2034 | [rbx+r12*2] 2035 | [rax+r12*2] 2036 | +24h 2037 | +48h 2038 | [rdx+rbp+CONST] 2039 | [r15+r15*8] 2040 | [r13+r13*8+CONST] 2041 | [rbx+r15*2] 2042 | rdtsc 2043 | [rdx+r8*2] 2044 | [rsp] 2045 | [r15+r8*8+CONST] 2046 | [r15+rdi*8+CONST] 2047 | [rcx+r8*8+CONST] 2048 | [r15+r9+CONST] 2049 | [rsp+r14+CONST+CONST_VAR] 2050 | fsubr 2051 | [rdx+rbp*2+CONST] 2052 | [r15+rdx*2+CONST] 2053 | [r14+rbx*2+CONST] 2054 | [r14+r9*8+CONST] 2055 | [r10+r9*8+CONST] 2056 | [r14+r13*2] 2057 | [rdx+rdi*8+CONST] 2058 | [r14+rdx*2] 2059 | [r9+r14*4] 2060 | [r9+rdi*8+CONST] 2061 | psubusb 2062 | [r10+r15*8] 2063 | [rsp+r15*8+CONST+var_xxx] 2064 | [r12+r11*8+CONST] 2065 | [r12+r8*8+CONST] 2066 | [r12+r9*8+CONST] 2067 | [r10+r14+CONST] 2068 | [r8+r15*8] 2069 | [CONST_VAR+r15*8+CONST] 2070 | [rdx+r15*8+CONST] 2071 | [r10+rbp*8] 2072 | cbw 2073 | [rbx+rbp*4+CONST] 2074 | [rbx+r15*4+CONST] 2075 | [r9+rdi*4+CONST] 2076 | [rdx+rbx*2+CONST] 2077 | [rcx+r11+CONST] 2078 | cmpltss 2079 | [r10+r9*4] 2080 | [r9+r14*8] 2081 | esp 2082 | [rdx+rdi*4+CONST] 2083 | [r11+r10*4] 2084 | [r15+r8+CONST] 2085 | [r15+r10+CONST] 2086 | [r15+rcx*2] 2087 | [r14+r11+CONST] 2088 | cmpneqpd 2089 | cmpeqpd 2090 | cmpneqps 2091 | cmpeqps 2092 | [rbx+r9*4] 2093 | [rbx+rsi*2] 2094 | [rbx+rsi*2+CONST] 2095 | [r14+r8*8+CONST] 2096 | [rsp+r9+CONST+var_xxx] 2097 | [r15+r8*8] 2098 | [r11+r8*8+CONST] 2099 | [rsp+rbx*8+CONST+var_xxx+CONST] 2100 | [CONST_VAR+r14*2] 2101 | [rbp+r12*2+CONST] 2102 | [CONST_VAR+r13*2] 2103 | [CONST_VAR+r12*2] 2104 | [r11+r12*2] 2105 | [CONST_VAR+rbx*2] 2106 | [rbx+rdi*2] 2107 | [rbp+r14*2+CONST] 2108 | [r9+r13] 2109 | [r9+rcx*2+CONST] 2110 | [rsp+r12*2+arg_xxx] 2111 | [rsp+rcx*2+arg_xxx] 2112 | [rsp+rsi*2+arg_xxx] 2113 | [rsp+rcx+arg_xxx] 2114 | [rsp+CONST_VAR+arg_xxx] 2115 | [rsp+rbx*2+arg_xxx] 2116 | [rsp+rdx+arg_xxx] 2117 | xgetbv 2118 | cpuid 2119 | [rsp+r10+CONST+var_xxx] 2120 | [rsp+rbx*2+CONST+var_xxx] 2121 | [rcx+r8*2+CONST] 2122 | [r10+rsi*2] 2123 | [rsp+rcx+CONST] 2124 | [rsp+CONST_VAR+CONST] 2125 | [rsp+rax+CONST] 2126 | [rsp+rsi*2+CONST+var_xxx] 2127 | [CONST_VAR+r10*2+CONST] 2128 | [rsp+rdx+CONST] 2129 | [CONST_VAR+r11*4+CONST] 2130 | [rbp+rax+CONST_VAR] 2131 | [rbp+rdx+CONST_VAR] 2132 | [rsp+rcx+CONST+var_xxx+CONST] 2133 | [rax+rax*8+CONST] 2134 | [rsp+rcx+CONST+CONST_VAR] 2135 | [rdx+r14*2+CONST] 2136 | [r15+r13*2] 2137 | [r15+rbp*2+CONST] 2138 | [rax+r15*2+CONST] 2139 | [r15+r13*2+CONST] 2140 | [rax+r9*8+CONST] 2141 | [rax+r14*4+CONST] 2142 | [rdx+r10*8+CONST] 2143 | [r15+rbx*2] 2144 | [r10+rdx*2+CONST] 2145 | [r10+rdx*2] 2146 | [r12+rbp*2+CONST] 2147 | [r12+r8*2+CONST] 2148 | [rbx+r8*2+CONST] 2149 | [rbx+r10*2] 2150 | [rbp+rax+CONST_VAR+CONST] 2151 | [rbp+rdx+CONST_VAR+CONST] 2152 | [r13+rcx*2+CONST] 2153 | [r14+rcx*2] 2154 | [r12+rsi*2] 2155 | [rsp+r15*2+CONST+var_xxx] 2156 | [r8+r12*2+CONST] 2157 | [r8+r14*2+CONST] 2158 | [CONST_VAR+r13*2+CONST] 2159 | [CONST_VAR+rbx*2+CONST] 2160 | [rbx+r13*2+CONST] 2161 | [r11+r13*4] 2162 | psadbw 2163 | [r11+rax*2] 2164 | [r9+r14+CONST] 2165 | [rsp+rax*4+CONST+CONST_VAR] 2166 | [rsp+r15+CONST+var_xxx+CONST] 2167 | [rsp+r13+CONST] 2168 | [rsp+rbp+CONST] 2169 | [rsp+rbx+CONST+var_xxx+CONST] 2170 | [rsp+r12+CONST] 2171 | [rcx+r13*2+CONST] 2172 | [rcx+r13*2] 2173 | [rsp+r14+CONST+var_xxx+CONST] 2174 | [r11+rbp] 2175 | [r11+r9+CONST] 2176 | [r10+r15+CONST] 2177 | [rsp+rbp+CONST+CONST_VAR] 2178 | [rsp+rax+CONST+CONST_VAR+CONST] 2179 | [rsp+r9+CONST+CONST_VAR] 2180 | [rsp+rbx+CONST+CONST_VAR+CONST] 2181 | [r8+rbp+CONST] 2182 | [rsp+rax*2+CONST+var_xxx] 2183 | [rsp+r8*2+CONST+var_xxx] 2184 | [rsp+rcx*2+CONST+var_xxx] 2185 | [rsp+rdi*2+CONST+var_xxx] 2186 | [rsp+rdx*2+CONST+var_xxx] 2187 | [rsp+rbp*2+CONST+var_xxx] 2188 | [rbp+r8+var_xxx] 2189 | [r12+rdx*2+CONST] 2190 | [r8+r12+CONST] 2191 | [rsp+rbp+CONST+var_xxx+CONST] 2192 | [rsp+r8+CONST+CONST_VAR+CONST] 2193 | [r12+r10+CONST] 2194 | [r15+rsi*2] 2195 | [rbx+r14*2+CONST] 2196 | [r14+r9+CONST] 2197 | [rdx+r10*2] 2198 | [rsp+rdx*2+CONST+var_xxx+CONST] 2199 | [rsp+r10*2+CONST+var_xxx] 2200 | [rdx+r9*2] 2201 | [rsp+r12*2+CONST+var_xxx] 2202 | [r10+rbx*2] 2203 | [rsp+r14*2+CONST+var_xxx] 2204 | [rsp+r15*2+CONST+var_xxx+CONST] 2205 | [rsp+r13*2+CONST+var_xxx] 2206 | [r11+r14] 2207 | [r10+r9*2] 2208 | [rsp+r10*8+CONST+var_xxx] 2209 | [CONST_VAR+rdi*4+CONST] 2210 | cmovp 2211 | [rbp+r14*8+var_xxx] 2212 | [CONST_VAR+r10*4+CONST] 2213 | [rbp+r11*8+var_xxx] 2214 | vpcmpeqd 2215 | vmovdqa 2216 | vpxor 2217 | popcnt 2218 | vpor 2219 | vpand 2220 | vpandn 2221 | vmovq 2222 | vpextrq 2223 | ymm1 2224 | vandps 2225 | ymm0 2226 | vzeroupper 2227 | vorps 2228 | ymm2 2229 | ymm4 2230 | ymm5 2231 | ymm3 2232 | ymm6 2233 | ymm7 2234 | ymm9 2235 | ymm8 2236 | vextracti128 2237 | vpternlogd 2238 | zmm0 2239 | vmovdqa32 2240 | zmm1 2241 | zmm2 2242 | zmm5 2243 | vpxord 2244 | zmm4 2245 | vpord 2246 | zmm6 2247 | vpandd 2248 | zmm3 2249 | zmm7 2250 | zmm8 2251 | zmm9 2252 | vpandnd 2253 | [CONST_VAR+rbp*2+CONST] 2254 | [rax+rbp*4+CONST] 2255 | [r14+r13*4+CONST] 2256 | [r8+rsi*4+CONST] 2257 | [rdx+r9*8+CONST] 2258 | [r10+r11*8+CONST] 2259 | [r11+r15*4] 2260 | [r11+rbx+CONST] 2261 | [r12+r11+CONST] 2262 | [r12+r9+CONST] 2263 | [r14+rbx*2] 2264 | [rbp+rdx*2+CONST] 2265 | [rsp+rax*4+var_xxx] 2266 | [r8+rdi*4+CONST] 2267 | [r8+r13*4+CONST] 2268 | [r8+r15*4+CONST] 2269 | [r8+r14*4+CONST] 2270 | [CONST_VAR+rbp*8+CONST] 2271 | [r9+r15*8] 2272 | [r11+r14+CONST] 2273 | [rdx+r11*8+CONST] 2274 | [rax+r11*8+CONST] 2275 | [r15+rdi*2] 2276 | [rbp+rdi*2+CONST] 2277 | [rbp+rsi*2+CONST] 2278 | cmpless 2279 | [r14+r13*2+CONST] 2280 | [rbx+r15*2+CONST] 2281 | [rax+r11*2] 2282 | cmpleps 2283 | [r14+r12*2] 2284 | [r15+r12*2] 2285 | [r9+rbp*8] 2286 | [CONST_VAR+r9*4+CONST] 2287 | [r12+rbx*2+CONST] 2288 | [rsp+rdx*8+CONST+CONST_VAR] 2289 | [r13+rbp*2+CONST] 2290 | [rsp+rdi*8+CONST+CONST_VAR] 2291 | [rsp+r14*8+CONST+CONST_VAR] 2292 | [rdx+rbx*4+CONST] 2293 | [rcx+rdi*8+CONST] 2294 | [r9+r11*2] 2295 | [r9+r8*2] 2296 | [r8+r9*2] 2297 | [r8+r9*4+CONST] 2298 | psubw 2299 | [r10+rcx*2] 2300 | [rsp+r8*8+CONST+var_xxx] 2301 | [rsp+r9*8+CONST+var_xxx] 2302 | [rsp+r11*8+CONST+var_xxx] 2303 | fyl2x 2304 | fldl2e 2305 | frndint 2306 | f2xm1 2307 | faddp 2308 | fscale 2309 | cvtsd2si 2310 | fldln2 2311 | fldlg2 2312 | fsin 2313 | fcos 2314 | fptan 2315 | int 2316 | fpatan 2317 | [r14+rdx*4+CONST] 2318 | tzcnt 2319 | [rdx+r9*4+CONST] 2320 | [r9+r10*4+CONST] 2321 | [edi] 2322 | [rbp+r8*2+CONST] 2323 | [rdx+r15*2] 2324 | [rbx+rdx*2+CONST] 2325 | [CONST_VAR+r14*2+CONST] 2326 | [rax+r14*2+CONST] 2327 | [r15+r9*2] 2328 | [r11+rax*2+CONST] 2329 | [rsp+rdi*2+CONST+var_xxx+CONST] 2330 | [r11+rdi*8+CONST] 2331 | [r11+rsi*8+CONST] 2332 | [rbx+rbp*2+CONST] 2333 | [r12+r8+CONST] 2334 | [r14+rdx*2+CONST] 2335 | [rbx+rcx*2+CONST] 2336 | [esi] 2337 | [rbx+r11*8] 2338 | [r15+r15*2+CONST] 2339 | [r15+r14*4+CONST] 2340 | [rsp+r14*4+CONST+var_xxx+CONST] 2341 | [r11+rdi*4+CONST] 2342 | [r11+rbx*4+CONST] 2343 | [r9+rsi*4+CONST] 2344 | [rdx+r8*4+CONST] 2345 | [rcx+rsi*2] 2346 | cmovnp 2347 | [r10+rax*2+CONST] 2348 | [r15+r14*2+CONST] 2349 | bnd 2350 | [rsp+r8*8+CONST+CONST_VAR] 2351 | [rbp+r9*2+CONST] 2352 | [r11+r9*2] 2353 | [rsp+rbp*4+CONST+var_xxx] 2354 | [r10+rdi*2] 2355 | [r10+rbp*2] 2356 | [r11+r8*2] 2357 | [r10+r8*2] 2358 | [r11+rcx*2] 2359 | [r9+rcx*4+CONST] 2360 | [r14+rcx*4+CONST] 2361 | [rsp+rax+var_xxx] 2362 | [rcx+r8*2] 2363 | [rcx+r12*2] 2364 | [r12+r15*2] 2365 | [r14+r10*4+CONST] 2366 | [rbx+r8*2] 2367 | [rax+r10*2] 2368 | [r10+r10+CONST] 2369 | [r9+rdi*2] 2370 | [r9+rbp*2] 2371 | [r8+r11*2] 2372 | [rcx+r15*2] 2373 | [r8+r15*2] 2374 | [CONST_VAR+r15*4+CONST] 2375 | [rcx+r11*2] 2376 | [r10+rbx*2+CONST] 2377 | [rbx+r9*2] 2378 | [r10+rdi*2+CONST] 2379 | [rcx+r9*2] 2380 | [r10+r9*2+CONST] 2381 | [r8+rdi*2] 2382 | [r10+r14*4] 2383 | [r14+r11*4] 2384 | [r14+rdi*4] 2385 | [r14+rsi*4+CONST] 2386 | [r9+r12+CONST] 2387 | [r8+r15+CONST] 2388 | [r15+rbx*2+CONST] 2389 | [rsp+r12+CONST+CONST_VAR] 2390 | [r10+rbx*8] 2391 | [rdx+rdx*4+CONST] 2392 | fild 2393 | [r8+r8*4+CONST] 2394 | fadd 2395 | [rbx+rbx*4+CONST] 2396 | [r10+r11*2] 2397 | pmaddwd 2398 | [r9+rbx*2] 2399 | [r14+r9*2] 2400 | [r12+r10*2] 2401 | [r12+r9*2] 2402 | [rdx+r9*2+CONST] 2403 | [rsp+rax*8+CONST+arg_xxx] 2404 | [r11+r9*4+CONST] 2405 | [r10+r9*4+CONST] 2406 | [r11+r12*8+CONST] 2407 | [rax+r12*4+CONST] 2408 | [CONST_VAR+r13*4+CONST] 2409 | [r9+rbx*4+CONST] 2410 | +3 2411 | [r12+rbp*4+CONST] 2412 | [r15+r9*8+CONST] 2413 | [r11+r9*8+CONST] 2414 | [rbx+r10*8+CONST] 2415 | [r9+r9*2+CONST] 2416 | [rsp+r11*4+CONST+var_xxx] 2417 | [r9+r8*8+CONST] 2418 | [r14+r11*8+CONST] 2419 | [rdx+rbp*4+CONST] 2420 | [rsp+rax*4+CONST+var_xxx+CONST] 2421 | [rsp+rax*8+CONST+var_xxx+CONST] 2422 | [rsp+rcx*4+CONST+var_xxx+CONST] 2423 | [rsp+rdx*8+CONST+var_xxx+CONST] 2424 | [rsp+rdx*4+CONST+var_xxx+CONST] 2425 | [rsp+rbp*8+CONST+var_xxx+CONST] 2426 | [r10+r10*2+CONST] 2427 | [CONST_VAR+r15*2+CONST] 2428 | [rbp+rdx*4+var_xxx] 2429 | [rbp+rdi*4+var_xxx] 2430 | [rbp+r8*8+var_xxx] 2431 | [rbp+r11+var_xxx] 2432 | [rbp+rax*4+var_xxx] 2433 | [rbp+rsi*4+var_xxx] 2434 | [rbp+r13+var_xxx] 2435 | [r11+r10*4+CONST] 2436 | [r11+r10*2+CONST] 2437 | [r11+r8*4+CONST] 2438 | [r13+r10*2+CONST] 2439 | [r11+r8*2+CONST] 2440 | [CONST_VAR+rbx*4+CONST] 2441 | [rbx+rdi*2+CONST] 2442 | [r10+rcx*2+CONST] 2443 | [r13+rsi*2+CONST] 2444 | [rsp+rbp*8+CONST+CONST_VAR] 2445 | [rsp+rbx*4+CONST+var_xxx+CONST] 2446 | [rsp+r12*4+CONST+var_xxx] 2447 | [rsp+r12*4+CONST+var_xxx+CONST] 2448 | [rcx+r12*4+CONST] 2449 | minpd 2450 | maxpd 2451 | [rbp+r10*8+CONST] 2452 | [rsp+r9*8+CONST+CONST_VAR] 2453 | [rsp+rsi*8+CONST+CONST_VAR] 2454 | +10h 2455 | [r9+rsi*8+CONST] 2456 | [rbp+r13*2+CONST] 2457 | [rdx+r13*4+CONST] 2458 | +0Ch 2459 | bsf 2460 | [rbx+r12*2+CONST] 2461 | [r14+rcx*2+CONST] 2462 | [r9+rax*2+CONST] 2463 | [rcx+r10*4+CONST] 2464 | [r12+r8*4+CONST] 2465 | [r12+r9*4+CONST] 2466 | [rcx+r11*4+CONST] 2467 | [rsp+r9*2+CONST+var_xxx] 2468 | [r14+r9*2+CONST] 2469 | [r14+r8*2] 2470 | [r14+r8*2+CONST] 2471 | [r14+r10*2] 2472 | [r14+r10*2+CONST] 2473 | [r14+r11*2] 2474 | [r14+r11*2+CONST] 2475 | emms 2476 | [r12+rbp*2] 2477 | [r14+rsi*2+CONST] 2478 | [CONST_VAR+r12*2+CONST] 2479 | [r13+r9*2+CONST] 2480 | [r9+rdi*2+CONST] 2481 | [r9+r11*2+CONST] 2482 | [r9+r15*2+CONST] 2483 | [rsp+r11*2+CONST+var_xxx] 2484 | [r9+rbx*2+CONST] 2485 | [rsp+rbx*2+CONST_VAR+var_xxx] 2486 | [rsp+r10*2+CONST_VAR+var_xxx] 2487 | [rsp+r8*2+CONST_VAR+var_xxx] 2488 | [rsp+rax*2+CONST_VAR+var_xxx] 2489 | [r15+r8*2] 2490 | [r8+r10*4+CONST] 2491 | [r11+rcx*8+CONST] 2492 | [r10+r13+CONST] 2493 | [r10+r8+CONST] 2494 | [rcx+rbx*4+CONST] 2495 | [rcx+rdi*2+CONST] 2496 | [r12+rcx*2+CONST] 2497 | [r12+rsi*2+CONST] 2498 | [rsp+r8+CONST+CONST_VAR] 2499 | [r11+r15*8] 2500 | seto 2501 | [rdx+r12*2+CONST] 2502 | [rcx+r12*2+CONST] 2503 | [r10+r15*4] 2504 | [r11+r12+CONST] 2505 | [rsp+CONST_VAR+var_xxx+CONST_VAR] 2506 | [rsp+CONST_VAR+CONST_VAR] 2507 | [r11+r14*8] 2508 | shld 2509 | [rsp+rdx*8+CONST+arg_xxx] 2510 | [r14+r15*4+CONST] 2511 | [r10+r8*4+CONST] 2512 | [rsp+rax*4+CONST+arg_xxx] 2513 | [rsp+rdx*4+CONST+arg_xxx] 2514 | [rsp+rbp*4+CONST+CONST_VAR] 2515 | [rsp+r10+CONST+arg_xxx] 2516 | [rsp+r9*4+CONST+arg_xxx] 2517 | [rsp+rax+CONST+arg_xxx] 2518 | [r8+r14*8] 2519 | [rbp+r9+var_xxx] 2520 | [rbp+r10+var_xxx] 2521 | [r9+r15*4+CONST] 2522 | [r14+r12*4+CONST] 2523 | [r11+rbx*2] 2524 | [r11+rbx*2+CONST] 2525 | [r11+rdi*2] 2526 | [r8+r14*2] 2527 | [CONST_VAR+r15*2] 2528 | [r9+r15*2] 2529 | [rbp+r11*8+CONST] 2530 | [r8+r10*2+CONST] 2531 | [rcx+r11*2+CONST] 2532 | [r10+r8*2+CONST] 2533 | minps 2534 | [r8+r13*2] 2535 | [r8+r13*2+CONST] 2536 | [r11+rdx*2+CONST] 2537 | [r8+rbx*2] 2538 | [rbx+r13*4+CONST] 2539 | [r12+r14*2+CONST] 2540 | [r8+rbp*4+CONST] 2541 | [r12+rdi*2] 2542 | [rbp+r10*2+CONST] 2543 | [r8+rbp*2] 2544 | [r14+rdi*2] 2545 | [r12+r13*2] 2546 | [r9+r9*4+CONST] 2547 | [rdx+r10*4+CONST] 2548 | [r9+r8*2+CONST] 2549 | [rbp+r13*4+var_xxx] 2550 | [r14+r15*2+CONST] 2551 | [rbp+r11*2+CONST] 2552 | [r12+rdi*2+CONST] 2553 | [rbp+rcx*4+var_xxx] 2554 | [r9+r9*8+CONST] 2555 | [rcx+rcx*8+CONST] 2556 | [rbp+r13*8+var_xxx] 2557 | [rbp+r15*4+var_xxx+CONST] 2558 | [r14+rbp*4+CONST] 2559 | spl 2560 | prefetcht0 2561 | [rcx+r11*8+CONST] 2562 | [rdx+r15*4+CONST] 2563 | [rbx+r9*4+CONST] 2564 | vmovdqu 2565 | vpinsrq 2566 | vmovsd 2567 | vpbroadcastq 2568 | vxorps 2569 | vcvtusi2sd 2570 | vdivsd 2571 | vcvtsi2sd 2572 | rorx 2573 | vxorpd 2574 | vmovupd 2575 | vmovss 2576 | vcvtusi2ss 2577 | vmulss 2578 | vcvtss2sd 2579 | vdivss 2580 | vcvttss2usi 2581 | vcvtsd2ss 2582 | vucomisd 2583 | kmovb 2584 | k1 2585 | k2 2586 | kandb 2587 | k0 2588 | kortestb 2589 | vinserti128 2590 | movbe 2591 | vmovd 2592 | vpinsrd 2593 | vmovdqu64 2594 | vcomiss 2595 | [r9+rdx*2+CONST] 2596 | vpmovsxdq 2597 | vmulsd 2598 | vpbroadcastd 2599 | vpaddq 2600 | vpaddd 2601 | vmovdqa64 2602 | vextracti64x4 2603 | vextracti32x8 2604 | vextracti64x2 2605 | vpextrd 2606 | vaddsd 2607 | [r13+CONST_VAR] 2608 | shlx 2609 | vcvttps2dq 2610 | vpermt2w 2611 | vpmovwb 2612 | vinserti64x4 2613 | vmovdqu8 2614 | vpbroadcastw 2615 | vpackuswb 2616 | vpermq 2617 | vcvttss2si 2618 | vbroadcastss 2619 | vaddps 2620 | vpermi2w 2621 | vaddss 2622 | vmovdqu16 2623 | vmovaps 2624 | vmovdqu32 2625 | vsubsd 2626 | vfmadd231ps 2627 | vfmadd231ss 2628 | vfmadd213ss 2629 | vfnmadd213ss 2630 | vfmsub231ss 2631 | vcmpordss 2632 | vblendvps 2633 | vcmpunordss 2634 | vunpcklps 2635 | vmovlps 2636 | vucomiss 2637 | shrx 2638 | vshufpd 2639 | vminps 2640 | vmaxps 2641 | vshufps 2642 | vdivps 2643 | vcvtsi2ss 2644 | vsubps 2645 | vmulps 2646 | vfmsub132ps 2647 | vsqrtps 2648 | vpsubq 2649 | vpsrlq 2650 | vpminuq 2651 | vcomisd 2652 | vcvttsd2usi 2653 | vandpd 2654 | [rbp+rdx+var_xxx+CONST] 2655 | vmovups 2656 | vmovsldup 2657 | vmovshdup 2658 | vfmadd132ss 2659 | vsqrtss 2660 | vsubss 2661 | vpunpcklqdq 2662 | vfmadd132ps 2663 | vpermt2ps 2664 | vpermi2ps 2665 | vpmullq 2666 | vmaxsd 2667 | andn 2668 | [rdx+rdx*8+CONST] 2669 | shrd 2670 | fxam 2671 | fnstsw 2672 | [rsp+r9+CONST+var_xxx+CONST] 2673 | [r13+rdi*2+CONST] 2674 | [rbp+rcx+arg_xxx] 2675 | [rsp+r10+CONST+var_xxx+CONST] 2676 | [r9+r12*4+CONST] 2677 | rdrand 2678 | rdseed 2679 | +28h 2680 | [rax+r11*2+CONST] 2681 | [r12+r10*4+CONST] 2682 | [r14+r11*4+CONST] 2683 | [rdx+r15*2+CONST] 2684 | [rcx+r10*2] 2685 | [rcx+r14*2] 2686 | [r10+rsi*2+CONST] 2687 | [rbp+rax*4+CONST_VAR] 2688 | [rbp+rcx*4+CONST_VAR] 2689 | syscall 2690 | [r13+r11*2+CONST] 2691 | [rsp+r13*8+CONST+CONST_VAR] 2692 | [rsp+rdi*8+CONST+arg_xxx] 2693 | [rsp+rsi*8+CONST+arg_xxx] 2694 | [rsp+r14*8+CONST+arg_xxx] 2695 | [rsp+r8*8+CONST+arg_xxx] 2696 | [rsp+r13*8+CONST+arg_xxx] 2697 | [r15+r15*4+CONST] 2698 | [rax+r9*2+CONST] 2699 | [rbx+rbx*8+CONST] 2700 | [rbp+arg_xxx+CONST] 2701 | [r15+r14*2] 2702 | [CONST_VAR+rbp*4+CONST] 2703 | pmulhw 2704 | psraw 2705 | packssdw 2706 | [r8+r11*2+CONST] 2707 | pmulhuw 2708 | [rbx+r9*2+CONST] 2709 | [rbx+r10*2+CONST] 2710 | [r9+r11*4+CONST] 2711 | [r15+r12*4+CONST] 2712 | [rbp+r15*8+var_xxx] 2713 | [rbp+r15*8+var_xxx+CONST] 2714 | [rbp+rax*4+var_xxx+CONST] 2715 | [rbp+rcx*4+var_xxx+CONST] 2716 | [rsp+rbx+CONST+arg_xxx] 2717 | [rsp+r13+CONST+var_xxx+CONST] 2718 | [rsp+rax*2+arg_xxx] 2719 | [rsp+rdx*2+arg_xxx] 2720 | [r15+r11*2] 2721 | [rsp+rbp*2+CONST+var_xxx+CONST] 2722 | [r8+r15*2+CONST] 2723 | [r14+r12*2+CONST] 2724 | [rsp+rax+CONST+arg_xxx+CONST] 2725 | [rsp+r15+CONST] 2726 | [rsp+r8*4+CONST+CONST_VAR] 2727 | [rsp+r10*4+CONST+CONST_VAR] 2728 | [rsp+rbx+CONST] 2729 | [rcx+r15*4+CONST] 2730 | [r11+rdx*4+CONST] 2731 | [rsp+rcx*8+CONST+var_xxx+CONST] 2732 | [r10+r12*8] 2733 | [rsp+rdi*8+arg_xxx] 2734 | [rsp+rax*8+arg_xxx] 2735 | [r15+r11*8+CONST] 2736 | [rbp+r8*4+var_xxx] 2737 | [rsp+rax*4+CONST+CONST_VAR+CONST] 2738 | [rsp+rdx*4+CONST+CONST_VAR+CONST] 2739 | [rsp+rsi*4+CONST+arg_xxx] 2740 | [rsp+rax*4+CONST+arg_xxx+CONST] 2741 | [rsp+r9*4+CONST+arg_xxx+CONST] 2742 | [rcx+r9*8+CONST] 2743 | [rsp+r10*8+CONST+CONST_VAR] 2744 | [rsp+r11*8+CONST+CONST_VAR] 2745 | [rsp+r15*8+CONST+CONST_VAR] 2746 | [rsp+rbx*8+CONST+CONST_VAR] 2747 | [rsp+r12*8+CONST+CONST_VAR] 2748 | [r15+r10*8+CONST] 2749 | [rbp+r14*4+var_xxx] 2750 | [rcx+r10*8+CONST] 2751 | [rsp+r8*4+CONST+arg_xxx+CONST] 2752 | [rbp+rax*8+CONST_VAR] 2753 | [r8+r8*8+CONST] 2754 | [rbp+rbx+var_xxx+CONST] 2755 | [r10+r11*4+CONST] 2756 | [rsp+rbx*4+CONST+CONST_VAR] 2757 | [rsp+rbp*4+CONST_VAR] 2758 | +50h 2759 | [r15+r9*4+CONST] 2760 | [rcx+rbx*2+CONST] 2761 | [rsp+r10*8+CONST+var_xxx+CONST] 2762 | [rsp+r9*8+CONST+var_xxx+CONST] 2763 | [rsp+rsi*8+CONST+var_xxx+CONST] 2764 | [r15+rcx*2+CONST] 2765 | [r15+rsi*2+CONST] 2766 | [rsp+rbx*8+CONST+arg_xxx] 2767 | [r11+rsi*4+CONST] 2768 | [rsp+rax*2+CONST+var_xxx+CONST] 2769 | [r14+rdi*2+CONST] 2770 | [r15+r12*2+CONST] 2771 | [rbx+r11*8+CONST] 2772 | [r9+rsi*2+CONST] 2773 | [rbp+rax*2+var_xxx] 2774 | [rsp+CONST_VAR+CONST+arg_xxx] 2775 | [rsp+rdx+CONST+arg_xxx] 2776 | [r11+rbp*4+CONST] 2777 | [r11+r13*2] 2778 | [r13+rax+CONST_VAR] 2779 | [r11+r15*2] 2780 | sin 2781 | [r9+r13*4+CONST] 2782 | setns 2783 | std 2784 | [rsp+rax*4+CONST] 2785 | [r11+rcx*2+CONST] 2786 | [rsp+r9*8+CONST+arg_xxx] 2787 | [rsp+r9+CONST+arg_xxx] 2788 | [rbp+rcx+var_xxx+CONST] 2789 | [rbp+r9*8+var_xxx] 2790 | [rbp+r12*4+var_xxx] 2791 | [rsp+rbp+var_xxx] 2792 | [rbp+r12+var_xxx] 2793 | [rsp+rbp*4+CONST+var_xxx+CONST] 2794 | kmovd 2795 | [rsp+r14*2+CONST+var_xxx+CONST] 2796 | korb 2797 | kmovq 2798 | k3 2799 | knotq 2800 | k4 2801 | k5 2802 | k6 2803 | k7 2804 | [rbp+rbx+var_xxx] 2805 | enter 2806 | [r8+rbp*2+CONST] 2807 | [rsp+r13+CONST+arg_xxx] 2808 | [rsp+rdx*8+CONST] 2809 | [rsp+rdx*8+arg_xxx] 2810 | [r12+r15*2+CONST] 2811 | [rax+r11*4+CONST] 2812 | [rbp+rbx*4+var_xxx] 2813 | [rsp+rbx*2+CONST+CONST_VAR] 2814 | [rsp+rbx+CONST+arg_xxx+CONST] 2815 | [r9+r12*2] 2816 | [r9+r14*2+CONST] 2817 | [rsp+rcx*4+CONST+CONST_VAR] 2818 | [rsp+rsi*2+CONST+var_xxx+CONST] 2819 | [rsp+rcx+var_xxx] 2820 | [rsp+rax+var_xxx+CONST] 2821 | [r14+rbp*2+CONST] 2822 | [rbp+rdi*8+var_xxx] 2823 | [rbp+r10*8+var_xxx] 2824 | [r14+rdi*4+CONST] 2825 | [rbp+r15*4+var_xxx] 2826 | [r8+r11*4+CONST] 2827 | [rsp+rbx*8+CONST+CONST_VAR+CONST] 2828 | [r15+r9*2+CONST] 2829 | [rdx+r11*2+CONST] 2830 | [r15+r10*2+CONST] 2831 | [r10+r14*2] 2832 | [rcx+r9*2+CONST] 2833 | [rbp+rbx*8+CONST_VAR] 2834 | fsubp 2835 | [rbp+rdx*8+CONST_VAR] 2836 | [rbp+rdx*2+var_xxx] 2837 | [rbp+rdx*4+CONST_VAR] 2838 | [rbp+rax*2+CONST_VAR] 2839 | +14h 2840 | [rbp+rcx+CONST_VAR] 2841 | [rbp+rdx*4+var_xxx+CONST] 2842 | +0Ah 2843 | [rbp+rax*8+var_xxx+CONST] 2844 | a3 2845 | a46 2846 | vmovlhps 2847 | mm1 2848 | mm5 2849 | cmpnleps 2850 | [rcx+r13*4+CONST] 2851 | [rsp+rcx*2+CONST+var_xxx+CONST] 2852 | [r9+r12*2+CONST] 2853 | [r8+rbx*4+CONST] 2854 | [rbp+rbx*2+var_xxx] 2855 | movsq 2856 | movsw 2857 | [rdx+r11*4+CONST] 2858 | [r15+rdi*2+CONST] 2859 | [r15+r11*2+CONST] 2860 | [rcx+rbp*2+CONST] 2861 | [rcx+r14*2+CONST] 2862 | rcr 2863 | [rsp+r14*8+CONST+var_xxx+CONST] 2864 | [rcx+r10*2+CONST] 2865 | [rbp+r8+var_xxx+CONST] 2866 | [r15+r8*2+CONST] 2867 | [r14+r9*4+CONST] 2868 | [rsp+r15*4+CONST+var_xxx+CONST] 2869 | [rbp+r14+var_xxx] 2870 | [rbp+r14+var_xxx+CONST] 2871 | [rbp+r15+var_xxx] 2872 | [rsp+r8*4+CONST+var_xxx+CONST] 2873 | [r12+r10*2+CONST] 2874 | +8 2875 | xlat 2876 | [r8+rdi*2+CONST] 2877 | [rsp+rdx+var_xxx] 2878 | [rbp+r10*8+var_xxx+CONST] 2879 | [rsp+rbp+CONST+arg_xxx] 2880 | [rbp+rdx*2+var_xxx+CONST] 2881 | [rsp+r13*2+CONST+var_xxx+CONST] 2882 | [r11+rsi*2+CONST] 2883 | [r14+r14*8+CONST] 2884 | [rbp+r9*4+var_xxx] 2885 | dest 2886 | [r12+r11*2] 2887 | [rbx+r11*2+CONST] 2888 | [r11+r12*2+CONST] 2889 | [r11+r9*2+CONST] 2890 | [rcx+r15*2+CONST] 2891 | [r11+rdi*2+CONST] 2892 | [rsp+rax*2+CONST+CONST_VAR] 2893 | [rsp+rdi*2+CONST+CONST_VAR] 2894 | [rsp+rsi*2+CONST+CONST_VAR] 2895 | [r9+r13*2] 2896 | [rbp+rdx*8+var_xxx+CONST] 2897 | JUMP_ADDR_EXCEEDED 2898 | UNK_JUMP_ADDR 2899 | -------------------------------------------------------------------------------- /readidadata.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import networkx 3 | import time 4 | def parse_operand(operator,location,operand1): 5 | operand1=operand1.strip(' ') 6 | operand1=operand1.replace('ptr ','') 7 | operand1=operand1.replace('offset ','') 8 | operand1=operand1.replace('xmmword ','') 9 | operand1=operand1.replace('dword ','') 10 | operand1=operand1.replace('qword ','') 11 | operand1=operand1.replace('word ','') 12 | operand1=operand1.replace('byte ','') 13 | operand1=operand1.replace('short ','') 14 | operand1=operand1.replace('-','+') 15 | 16 | if operand1[0:3]=='cs:' : 17 | operand1='cs:xxx' 18 | return operand1 19 | if operand1[0:3]=='ss:' : 20 | operand1='ss:xxx' 21 | return operand1 22 | if operand1[0:3]=='fs:' : 23 | operand1='fs:xxx' 24 | return operand1 25 | if operand1[0:3]=='ds:' : 26 | operand1='ds:xxx' 27 | return operand1 28 | if operand1[0:3]=='es:' : 29 | operand1='es:xxx' 30 | return operand1 31 | if operand1[0:3]=='gs:' : 32 | operand1='gs:xxx' 33 | return operand1 34 | if operator[0]=='j' and not isregister(operand1): 35 | if operand1[0:4]=='loc_' or operand1[0:7]=='locret_' or operand1[0:4]=='sub_' : 36 | operand1='hex_'+operand1[operand1.find('_')+1:] 37 | return operand1 38 | else: 39 | #print("JUMP ",operand1) 40 | operand1='UNK_ADDR' 41 | return operand1 42 | 43 | if operand1[0:4]=='loc_' : 44 | operand1='loc_xxx' 45 | return operand1 46 | if operand1[0:4]=='off_' : 47 | operand1='off_xxx' 48 | return operand1 49 | if operand1[0:4]=='unk_' : 50 | operand1='unk_xxx' 51 | return operand1 52 | if operand1[0:6]=='locret' : 53 | operand1='locretxxx' 54 | return operand1 55 | if operand1[0:4]=='sub_' : 56 | operand1='sub_xxx' 57 | return operand1 58 | if operand1[0:4]=='arg_' : 59 | operand1='arg_xxx' 60 | return operand1 61 | if operand1[0:4]=='def_' : 62 | operand1='def_xxx' 63 | return operand1 64 | if operand1[0:4]=='var_' : 65 | operand1='var_xxx' 66 | return operand1 67 | if operand1[0]=='(' and operand1[-1]==')': 68 | operand1='CONST' 69 | return operand1 70 | if operator=='lea' and location==2: 71 | if not ishexnumber(operand1) and not isaddr(operand1): #handle some address constants 72 | operand1='GLOBAL_VAR' 73 | return operand1 74 | 75 | if operator=='call' and location==1: 76 | if len(operand1)>3: 77 | operand1='callfunc_xxx' 78 | return operand1 79 | 80 | if operator=='extrn': 81 | operand1='extrn_xxx' 82 | return operand1 83 | if ishexnumber(operand1): 84 | operand1='CONST' 85 | return operand1 86 | elif ispurenumber(operand1): 87 | operand1='CONST' 88 | return operand1 89 | if isaddr(operand1): 90 | params=operand1[1:-1].split('+') 91 | for i in range(len(params)): 92 | if ishexnumber(params[i]): 93 | params[i]='CONST' 94 | elif ispurenumber(params[i]): 95 | params[i]='CONST' 96 | elif params[i][0:4]=='var_': 97 | params[i]='var_xxx' 98 | elif params[i][0:4]=='arg_': 99 | params[i]='arg_xxx' 100 | elif not isregister(params[i]): 101 | if params[i].find('*')==-1: 102 | params[i]='CONST_VAR' 103 | s1='+' 104 | operand1='['+s1.join(params)+']' 105 | return operand1 106 | 107 | if not isregister(operand1) and len(operand1)>4: 108 | operand1='CONST' 109 | return operand1 110 | return operand1 111 | def parse_asm(code): #handle ida code to better quality code for NLP model 112 | annotation=None 113 | operator,operand=None,None 114 | operand1,operand2,operand3=None,None,None 115 | if code.find(';')!=-1: 116 | id=code.find(';') 117 | annotation=code[id+1:] 118 | code=code[0:id] 119 | if code.find(' ')!=-1: 120 | id=code.find(' ') 121 | operand=code[id+1:] 122 | operator=code[0:id] 123 | else: 124 | operator=code 125 | if operand!=None: 126 | if operand.find(',')!=-1: 127 | strs=operand.split(',') 128 | if len(strs)==2: 129 | operand1,operand2=strs[0],strs[1] 130 | else: 131 | operand1,operand2,operand3=strs[0],strs[1],strs[2] 132 | else: 133 | operand1=operand 134 | operand2=None 135 | if operand1!=None: 136 | operand1=parse_operand(operator,1,operand1) 137 | if operand2!=None: 138 | operand2=parse_operand(operator,2,operand2) 139 | if operand3!=None: 140 | operand3=parse_operand(operator,3,operand3) 141 | return operator,operand1,operand2,operand3,annotation 142 | def isregister(x): 143 | registers=['rax','rbx','rcx','rdx','esi','edi','rbp','rsp','r8','r9','r10','r11','r12','r13','r14','r15'] 144 | return x in registers 145 | def ispurenumber(number): 146 | if len(number)==1 and str.isdigit(number): 147 | return True 148 | return False 149 | def isaddr(number): 150 | return number[0]=='[' and number[-1]==']' 151 | def ishexnumber(number): 152 | if number[-1]=='h': 153 | for i in range(len(number)-1): 154 | if str.isdigit(number[i]) or (number[i] >='A' and number[i]<='F'): 155 | continue 156 | else: 157 | return False 158 | else: 159 | return False 160 | return True 161 | 162 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import re 3 | import os 4 | import readidadata 5 | 6 | def tokenizer(): 7 | with open('token_ida.pkl','rb') as f: 8 | token_id=pickle.load(f) 9 | f.close() 10 | return token_id 11 | 12 | def seq_to_token(token_id,seq,UNK): 13 | ret=[] 14 | cnt=0 15 | for str in seq: 16 | re=token_id.get(str) 17 | if re!=None: 18 | ret.append(re) 19 | else: 20 | cnt+=1 21 | ret.append(UNK) 22 | return ret 23 | 24 | def normalize(opcode): 25 | opcode = opcode.replace(' - ', ' + ') 26 | opcode = re.sub(r'0x[0-9a-f]+', 'CONST', opcode) 27 | opcode = re.sub(r'\*[0-9]', '*CONST', opcode) 28 | opcode = re.sub(r' [0-9]', ' CONST', opcode) 29 | return opcode 30 | 31 | def save_tokens(prefixs): 32 | token_id={} 33 | cnts=1 34 | file_cnt=0 35 | binlist=[] 36 | nowdir='../largedata/ourclean' 37 | docs = os.listdir(nowdir) 38 | for i in docs: 39 | pth=os.path.join(nowdir,i) 40 | for fi in os.listdir(pth): 41 | idx=fi.find('.') 42 | fd=False 43 | for pre in prefixs: 44 | if fi.startswith(pre): 45 | fd=True 46 | if fd and fi.endswith('.nod'): 47 | fn=os.path.join(pth,fi[0:idx]) 48 | binlist.append(fn+'.nod') 49 | print(fn) 50 | for fi in binlist: 51 | fii=open(fi,'rb') 52 | try: 53 | asm_seq=pickle.load(fii) 54 | except: 55 | fii.close() 56 | continue 57 | else: 58 | fii.close() 59 | for bbid,addr,bb in asm_seq: 60 | for addr,instructions in bb: 61 | operator,operand1,operand2,operand3,annotation=readidadata.parse_asm(instructions) 62 | if operator!=None: 63 | if token_id.get(operator)==None: 64 | print(operator,cnts," from ",hex(addr),instructions) 65 | token_id[operator]=cnts 66 | cnts+=1 67 | 68 | if operand1!=None: 69 | if not operand1.startswith('hex') and token_id.get(operand1)==None: 70 | print(operand1,cnts," from ",hex(addr),instructions) 71 | token_id[operand1]=cnts 72 | cnts+=1 73 | 74 | if operand2!=None: 75 | if token_id.get(operand2)==None: 76 | print(operand2,cnts," from ",hex(addr),instructions) 77 | token_id[operand2]=cnts 78 | cnts+=1 79 | 80 | if operand3!=None: 81 | if token_id.get(operand3)==None: 82 | print(operand3,cnts," from ",hex(addr),instructions) 83 | token_id[operand3]=cnts 84 | cnts+=1 85 | file_cnt+=1 86 | print("finnish ",fi,file_cnt/len(binlist)) 87 | print("token_number: ",cnts) 88 | with open('token_ida.pkl','wb') as f: 89 | pickle.dump(token_id,f) 90 | f.close() 91 | return token_id 92 | if __name__ == "__main__": 93 | fi=open("vocab.txt","wb") 94 | tokens=save_tokens(['proxmark3','pythonqt','pizmidi','plasma','qbs','qcad','sc3','vice','virtualgl','vtk','onics','odr','opencolorio','owncloud','sagemath','usd','lua','lxc']) 95 | ''' 96 | output=[0 in range(100000)] 97 | for i in tokens: 98 | output[tokens[i]]=i 99 | for i in range(len(tokens)): 100 | print(output[i],file=fi) 101 | 102 | fi.close() 103 | ''' 104 | --------------------------------------------------------------------------------