├── requirements.txt ├── asm2vec ├── __init__.py ├── model.py ├── utils.py └── datatype.py ├── setup.py ├── LICENSE ├── scripts ├── compare.py ├── test.py ├── train.py └── bin2asm.py ├── .gitignore └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7,<2 2 | click>=7.1,<8 3 | r2pipe>=1.5,<2 4 | -------------------------------------------------------------------------------- /asm2vec/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | __all__ = ['model', 'datatype', 'utils'] 4 | 5 | for module in __all__: 6 | importlib.import_module(f'.{module}', 'asm2vec') 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='asm2vec', 5 | version='1.0.0', 6 | description='Unofficial implementation of asm2vec using pytorch', 7 | install_requires=['torch>=1.7,<2' 8 | 'click>=7.1,<8' 9 | 'r2pipe>=1.5,<2'], 10 | author='oalieno', 11 | author_email='jeffrey6910@gmail.com', 12 | license='MIT License', 13 | packages = find_packages(), 14 | ) 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 oalieno 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/compare.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import click 4 | import asm2vec 5 | 6 | def cosine_similarity(v1, v2): 7 | return (v1 @ v2 / (v1.norm() * v2.norm())).item() 8 | 9 | @click.command() 10 | @click.option('-i1', '--input1', 'ipath1', help='target function 1', required=True) 11 | @click.option('-i2', '--input2', 'ipath2', help='target function 2', required=True) 12 | @click.option('-m', '--model', 'mpath', help='model path', required=True) 13 | @click.option('-e', '--epochs', default=10, help='training epochs', show_default=True) 14 | @click.option('-c', '--device', default='auto', help='hardware device to be used: cpu / cuda / auto', show_default=True) 15 | @click.option('-lr', '--learning-rate', 'lr', default=0.02, help="learning rate", show_default=True) 16 | def cli(ipath1, ipath2, mpath, epochs, device, lr): 17 | if device == 'auto': 18 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | 20 | # load model, tokens 21 | model, tokens = asm2vec.utils.load_model(mpath, device=device) 22 | functions, tokens_new = asm2vec.utils.load_data([ipath1, ipath2]) 23 | tokens.update(tokens_new) 24 | model.update(2, tokens.size()) 25 | model = model.to(device) 26 | 27 | # train function embedding 28 | model = asm2vec.utils.train( 29 | functions, 30 | tokens, 31 | model=model, 32 | epochs=epochs, 33 | device=device, 34 | mode='test', 35 | learning_rate=lr 36 | ) 37 | 38 | # compare 2 function vectors 39 | v1, v2 = model.to('cpu').embeddings_f(torch.tensor([0, 1])) 40 | 41 | print(f'cosine similarity : {cosine_similarity(v1, v2):.6f}') 42 | 43 | if __name__ == '__main__': 44 | cli() 45 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import click 4 | import asm2vec 5 | 6 | @click.command() 7 | @click.option('-i', '--input', 'ipath', help='target function', required=True) 8 | @click.option('-m', '--model', 'mpath', help='model path', required=True) 9 | @click.option('-e', '--epochs', default=10, help='training epochs', show_default=True) 10 | @click.option('-n', '--neg-sample-num', 'neg_sample_num', default=25, help='negative sampling amount', show_default=True) 11 | @click.option('-l', '--limit', help='limit the amount of output probability result', type=int) 12 | @click.option('-c', '--device', default='auto', help='hardware device to be used: cpu / cuda / auto', show_default=True) 13 | @click.option('-lr', '--learning-rate', 'lr', default=0.02, help="learning rate", show_default=True) 14 | @click.option('-p', '--pretty', default=False, help='pretty print table', show_default=True, is_flag=True) 15 | def cli(ipath, mpath, epochs, neg_sample_num, limit, device, lr, pretty): 16 | if device == 'auto': 17 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | 19 | # load model, tokens 20 | model, tokens = asm2vec.utils.load_model(mpath, device=device) 21 | functions, tokens_new = asm2vec.utils.load_data(ipath) 22 | tokens.update(tokens_new) 23 | model.update(1, tokens.size()) 24 | model = model.to(device) 25 | 26 | # train function embedding 27 | model = asm2vec.utils.train( 28 | functions, 29 | tokens, 30 | model=model, 31 | epochs=epochs, 32 | neg_sample_num=neg_sample_num, 33 | device=device, 34 | mode='test', 35 | learning_rate=lr 36 | ) 37 | 38 | # show predicted probability results 39 | x, y = asm2vec.utils.preprocess(functions, tokens) 40 | probs = model.predict(x.to(device), y.to(device)) 41 | asm2vec.utils.show_probs(x, y, probs, tokens, limit=limit, pretty=pretty) 42 | 43 | if __name__ == '__main__': 44 | cli() 45 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import click 3 | import asm2vec 4 | 5 | @click.command() 6 | @click.option('-i', '--input', 'ipath', help='training data folder', required=True) 7 | @click.option('-o', '--output', 'opath', default='model.pt', help='output model path', show_default=True) 8 | @click.option('-m', '--model', 'mpath', help='load previous trained model path', type=str) 9 | @click.option('-l', '--limit', help='limit the number of functions to be loaded', show_default=True, type=int) 10 | @click.option('-d', '--ebedding-dimension', 'embedding_size', default=100, help='embedding dimension', show_default=True) 11 | @click.option('-b', '--batch-size', 'batch_size', default=1024, help='batch size', show_default=True) 12 | @click.option('-e', '--epochs', default=10, help='training epochs', show_default=True) 13 | @click.option('-n', '--neg-sample-num', 'neg_sample_num', default=25, help='negative sampling amount', show_default=True) 14 | @click.option('-a', '--calculate-accuracy', 'calc_acc', help='whether calculate accuracy ( will be significantly slower )', is_flag=True) 15 | @click.option('-c', '--device', default='auto', help='hardware device to be used: cpu / cuda / auto', show_default=True) 16 | @click.option('-lr', '--learning-rate', 'lr', default=0.02, help="learning rate", show_default=True) 17 | def cli(ipath, opath, mpath, limit, embedding_size, batch_size, epochs, neg_sample_num, calc_acc, device, lr): 18 | if device == 'auto': 19 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 20 | 21 | if mpath: 22 | model, tokens = asm2vec.utils.load_model(mpath, device=device) 23 | functions, tokens_new = asm2vec.utils.load_data(ipath, limit=limit) 24 | tokens.update(tokens_new) 25 | model.update(len(functions), tokens.size()) 26 | else: 27 | model = None 28 | functions, tokens = asm2vec.utils.load_data(ipath, limit=limit) 29 | 30 | def callback(context): 31 | progress = f'{context["epoch"]} | time = {context["time"]:.2f}, loss = {context["loss"]:.4f}' 32 | if context["accuracy"]: 33 | progress += f', accuracy = {context["accuracy"]:.4f}' 34 | print(progress) 35 | asm2vec.utils.save_model(opath, context["model"], context["tokens"]) 36 | 37 | model = asm2vec.utils.train( 38 | functions, 39 | tokens, 40 | model=model, 41 | embedding_size=embedding_size, 42 | batch_size=batch_size, 43 | epochs=epochs, 44 | neg_sample_num=neg_sample_num, 45 | calc_acc=calc_acc, 46 | device=device, 47 | callback=callback, 48 | learning_rate=lr 49 | ) 50 | 51 | if __name__ == '__main__': 52 | cli() 53 | -------------------------------------------------------------------------------- /asm2vec/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | bce, sigmoid, softmax = nn.BCELoss(), nn.Sigmoid(), nn.Softmax(dim=1) 5 | 6 | class ASM2VEC(nn.Module): 7 | def __init__(self, vocab_size, function_size, embedding_size): 8 | super(ASM2VEC, self).__init__() 9 | self.embeddings = nn.Embedding(vocab_size, embedding_size, _weight=torch.zeros(vocab_size, embedding_size)) 10 | self.embeddings_f = nn.Embedding(function_size, 2 * embedding_size, _weight=(torch.rand(function_size, 2 * embedding_size)-0.5)/embedding_size/2) 11 | self.embeddings_r = nn.Embedding(vocab_size, 2 * embedding_size, _weight=(torch.rand(vocab_size, 2 * embedding_size)-0.5)/embedding_size/2) 12 | 13 | def update(self, function_size_new, vocab_size_new): 14 | device = self.embeddings.weight.device 15 | vocab_size, function_size, embedding_size = self.embeddings.num_embeddings, self.embeddings_f.num_embeddings, self.embeddings.embedding_dim 16 | if vocab_size_new != vocab_size: 17 | weight = torch.cat([self.embeddings.weight, torch.zeros(vocab_size_new - vocab_size, embedding_size).to(device)]) 18 | self.embeddings = nn.Embedding(vocab_size_new, embedding_size, _weight=weight) 19 | weight_r = torch.cat([self.embeddings_r.weight, ((torch.rand(vocab_size_new - vocab_size, 2 * embedding_size)-0.5)/embedding_size/2).to(device)]) 20 | self.embeddings_r = nn.Embedding(vocab_size_new, 2 * embedding_size, _weight=weight_r) 21 | self.embeddings_f = nn.Embedding(function_size_new, 2 * embedding_size, _weight=((torch.rand(function_size_new, 2 * embedding_size)-0.5)/embedding_size/2).to(device)) 22 | 23 | def v(self, inp): 24 | e = self.embeddings(inp[:,1:]) 25 | v_f = self.embeddings_f(inp[:,0]) 26 | v_prev = torch.cat([e[:,0], (e[:,1] + e[:,2]) / 2], dim=1) 27 | v_next = torch.cat([e[:,3], (e[:,4] + e[:,5]) / 2], dim=1) 28 | v = ((v_f + v_prev + v_next) / 3).unsqueeze(2) 29 | return v 30 | 31 | def forward(self, inp, pos, neg): 32 | device, batch_size = inp.device, inp.shape[0] 33 | v = self.v(inp) 34 | # negative sampling loss 35 | pred = torch.bmm(self.embeddings_r(torch.cat([pos, neg], dim=1)), v).squeeze() 36 | label = torch.cat([torch.ones(batch_size, 3), torch.zeros(batch_size, neg.shape[1])], dim=1).to(device) 37 | return bce(sigmoid(pred), label) 38 | 39 | def predict(self, inp, pos): 40 | device, batch_size = inp.device, inp.shape[0] 41 | v = self.v(inp) 42 | probs = torch.bmm(self.embeddings_r(torch.arange(self.embeddings_r.num_embeddings).repeat(batch_size, 1).to(device)), v).squeeze(dim=2) 43 | return softmax(probs) 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/vim,python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=vim,python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | pytestdebug.log 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | doc/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | pythonenv* 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # pytype static type analyzer 139 | .pytype/ 140 | 141 | # profiling data 142 | .prof 143 | 144 | ### Vim ### 145 | # Swap 146 | [._]*.s[a-v][a-z] 147 | !*.svg # comment out if you don't need vector files 148 | [._]*.sw[a-p] 149 | [._]s[a-rt-v][a-z] 150 | [._]ss[a-gi-z] 151 | [._]sw[a-p] 152 | 153 | # Session 154 | Session.vim 155 | Sessionx.vim 156 | 157 | # Temporary 158 | .netrwhist 159 | *~ 160 | # Auto-generated tag files 161 | tags 162 | # Persistent undo 163 | [._]*.un~ 164 | 165 | # End of https://www.toptal.com/developers/gitignore/api/vim,python 166 | -------------------------------------------------------------------------------- /scripts/bin2asm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import re 3 | import os 4 | import click 5 | import r2pipe 6 | import hashlib 7 | from pathlib import Path 8 | 9 | def sha3(data): 10 | return hashlib.sha3_256(data.encode()).hexdigest() 11 | 12 | def validEXE(filename): 13 | magics = [bytes.fromhex('7f454c46')] 14 | with open(filename, 'rb') as f: 15 | header = f.read(4) 16 | return header in magics 17 | 18 | def normalize(opcode): 19 | opcode = opcode.replace(' - ', ' + ') 20 | opcode = re.sub(r'0x[0-9a-f]+', 'CONST', opcode) 21 | opcode = re.sub(r'\*[0-9]', '*CONST', opcode) 22 | opcode = re.sub(r' [0-9]', ' CONST', opcode) 23 | return opcode 24 | 25 | def fn2asm(pdf, minlen): 26 | # check 27 | if pdf is None: 28 | return 29 | if len(pdf['ops']) < minlen: 30 | return 31 | if 'invalid' in [op['type'] for op in pdf['ops']]: 32 | return 33 | 34 | ops = pdf['ops'] 35 | 36 | # set label 37 | labels, scope = {}, [op['offset'] for op in ops] 38 | assert(None not in scope) 39 | for i, op in enumerate(ops): 40 | if op.get('jump') in scope: 41 | labels.setdefault(op.get('jump'), i) 42 | 43 | # dump output 44 | output = '' 45 | for op in ops: 46 | # add label 47 | if labels.get(op.get('offset')) is not None: 48 | output += f'LABEL{labels[op["offset"]]}:\n' 49 | # add instruction 50 | if labels.get(op.get('jump')) is not None: 51 | output += f' {op["type"]} LABEL{labels[op["jump"]]}\n' 52 | else: 53 | output += f' {normalize(op["opcode"])}\n' 54 | 55 | return output 56 | 57 | def bin2asm(filename, opath, minlen): 58 | # check 59 | if not validEXE(filename): 60 | return 0 61 | 62 | r = r2pipe.open(str(filename)) 63 | r.cmd('aaaa') 64 | 65 | count = 0 66 | 67 | for fn in r.cmdj('aflj'): 68 | r.cmd(f's {fn["offset"]}') 69 | asm = fn2asm(r.cmdj('pdfj'), minlen) 70 | if asm: 71 | uid = sha3(asm) 72 | asm = f''' .name {fn["name"]} 73 | .offset {fn["offset"]:016x} 74 | .file {filename.name} 75 | ''' + asm 76 | with open(opath / uid, 'w') as f: 77 | f.write(asm) 78 | count += 1 79 | 80 | print(f'[+] {filename}') 81 | 82 | return count 83 | 84 | @click.command() 85 | @click.option('-i', '--input', 'ipath', help='input directory / file', required=True) 86 | @click.option('-o', '--output', 'opath', default='asm', help='output directory') 87 | @click.option('-l', '--len', 'minlen', default=10, help='ignore assembly code with instructions amount smaller than minlen') 88 | def cli(ipath, opath, minlen): 89 | ''' 90 | Extract assembly functions from binary executable 91 | ''' 92 | ipath = Path(ipath) 93 | opath = Path(opath) 94 | 95 | # create output directory 96 | if not os.path.exists(opath): 97 | os.mkdir(opath) 98 | 99 | fcount, bcount = 0, 0 100 | 101 | # directory 102 | if os.path.isdir(ipath): 103 | for f in os.listdir(ipath): 104 | if not os.path.islink(ipath / f) and not os.path.isdir(ipath / f): 105 | fcount += bin2asm(ipath / f, opath, minlen) 106 | bcount += 1 107 | # file 108 | elif os.path.exists(ipath): 109 | fcount += bin2asm(ipath, opath, minlen) 110 | bcount += 1 111 | else: 112 | print(f'[Error] No such file or directory: {ipath}') 113 | 114 | print(f'[+] Total scan binary: {bcount} => Total generated assembly functions: {fcount}') 115 | 116 | if __name__ == '__main__': 117 | cli() 118 | -------------------------------------------------------------------------------- /asm2vec/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from torch.utils.data import DataLoader, Dataset 5 | from pathlib import Path 6 | from .datatype import Tokens, Function, Instruction 7 | from .model import ASM2VEC 8 | 9 | class AsmDataset(Dataset): 10 | def __init__(self, x, y): 11 | self.x = x 12 | self.y = y 13 | def __len__(self): 14 | return len(self.x) 15 | def __getitem__(self, index): 16 | return self.x[index], self.y[index] 17 | 18 | def load_data(paths, limit=None): 19 | if type(paths) is not list: 20 | paths = [paths] 21 | 22 | filenames = [] 23 | for path in paths: 24 | if os.path.isdir(path): 25 | filenames += [Path(path) / filename for filename in sorted(os.listdir(path)) if os.path.isfile(Path(path) / filename)] 26 | else: 27 | filenames += [Path(path)] 28 | 29 | functions, tokens = [], Tokens() 30 | for i, filename in enumerate(filenames): 31 | if limit and i >= limit: 32 | break 33 | with open(filename) as f: 34 | fn = Function.load(f.read()) 35 | functions.append(fn) 36 | tokens.add(fn.tokens()) 37 | 38 | return functions, tokens 39 | 40 | def preprocess(functions, tokens): 41 | x, y = [], [] 42 | for i, fn in enumerate(functions): 43 | for seq in fn.random_walk(): 44 | for j in range(1, len(seq) - 1): 45 | x.append([i] + [tokens[token].index for token in seq[j-1].tokens() + seq[j+1].tokens()]) 46 | y.append([tokens[token].index for token in seq[j].tokens()]) 47 | return torch.tensor(x), torch.tensor(y) 48 | 49 | def train( 50 | functions, 51 | tokens, 52 | model=None, 53 | embedding_size=100, 54 | batch_size=1024, 55 | epochs=10, 56 | neg_sample_num=25, 57 | calc_acc=False, 58 | device='cpu', 59 | mode='train', 60 | callback=None, 61 | learning_rate=0.02 62 | ): 63 | if mode == 'train': 64 | if model is None: 65 | model = ASM2VEC(tokens.size(), function_size=len(functions), embedding_size=embedding_size).to(device) 66 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 67 | elif mode == 'test': 68 | if model is None: 69 | raise ValueError("test mode required pretrained model") 70 | optimizer = torch.optim.Adam(model.embeddings_f.parameters(), lr=learning_rate) 71 | else: 72 | raise ValueError("Unknown mode") 73 | 74 | loader = DataLoader(AsmDataset(*preprocess(functions, tokens)), batch_size=batch_size, shuffle=True) 75 | for epoch in range(epochs): 76 | start = time.time() 77 | loss_sum, loss_count, accs = 0.0, 0, [] 78 | 79 | model.train() 80 | for i, (inp, pos) in enumerate(loader): 81 | neg = tokens.sample(inp.shape[0], neg_sample_num) 82 | loss = model(inp.to(device), pos.to(device), neg.to(device)) 83 | loss_sum, loss_count = loss_sum + loss, loss_count + 1 84 | 85 | optimizer.zero_grad() 86 | loss.backward() 87 | optimizer.step() 88 | 89 | if i == 0 and calc_acc: 90 | probs = model.predict(inp.to(device), pos.to(device)) 91 | accs.append(accuracy(pos, probs)) 92 | 93 | if callback: 94 | callback({ 95 | 'model': model, 96 | 'tokens': tokens, 97 | 'epoch': epoch, 98 | 'time': time.time() - start, 99 | 'loss': loss_sum / loss_count, 100 | 'accuracy': torch.tensor(accs).mean() if calc_acc else None 101 | }) 102 | 103 | return model 104 | 105 | def save_model(path, model, tokens): 106 | torch.save({ 107 | 'model_params': ( 108 | model.embeddings.num_embeddings, 109 | model.embeddings_f.num_embeddings, 110 | model.embeddings.embedding_dim 111 | ), 112 | 'model': model.state_dict(), 113 | 'tokens': tokens.state_dict(), 114 | }, path) 115 | 116 | def load_model(path, device='cpu'): 117 | checkpoint = torch.load(path, map_location=device) 118 | tokens = Tokens() 119 | tokens.load_state_dict(checkpoint['tokens']) 120 | model = ASM2VEC(*checkpoint['model_params']) 121 | model.load_state_dict(checkpoint['model']) 122 | model = model.to(device) 123 | return model, tokens 124 | 125 | def show_probs(x, y, probs, tokens, limit=None, pretty=False): 126 | if pretty: 127 | TL, TR, BL, BR = '┌', '┐', '└', '┘' 128 | LM, RM, TM, BM = '├', '┤', '┬', '┴' 129 | H, V = '─', '│' 130 | arrow = ' ➔' 131 | else: 132 | TL = TR = BL = BR = '+' 133 | LM = RM = TM = BM = '+' 134 | H, V = '-', '|' 135 | arrow = '->' 136 | top = probs.topk(5) 137 | for i, (xi, yi) in enumerate(zip(x, y)): 138 | if limit and i >= limit: 139 | break 140 | xi, yi = xi.tolist(), yi.tolist() 141 | print(TL + H * 42 + TR) 142 | print(f'{V} {str(Instruction(tokens[xi[1]], tokens[xi[2:4]])):37} {V}') 143 | print(f'{V} {arrow} {str(Instruction(tokens[yi[0]], tokens[yi[1:3]])):37} {V}') 144 | print(f'{V} {str(Instruction(tokens[xi[4]], tokens[xi[5:7]])):37} {V}') 145 | print(LM + H * 8 + TM + H * 33 + RM) 146 | for value, index in zip(top.values[i], top.indices[i]): 147 | if index in yi: 148 | colorbegin, colorclear = '\033[92m', '\033[0m' 149 | else: 150 | colorbegin, colorclear = '', '' 151 | print(f'{V} {colorbegin}{value*100:05.2f}%{colorclear} {V} {colorbegin}{tokens[index.item()].name:31}{colorclear} {V}') 152 | print(BL + H * 8 + BM + H * 33 + BR) 153 | 154 | def accuracy(y, probs): 155 | return torch.mean(torch.tensor([torch.sum(probs[i][yi]) for i, yi in enumerate(y)])) 156 | 157 | -------------------------------------------------------------------------------- /asm2vec/datatype.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import warnings 4 | 5 | class Token: 6 | def __init__(self, name, index): 7 | self.name = name 8 | self.index = index 9 | self.count = 1 10 | def __str__(self): 11 | return self.name 12 | 13 | class Tokens: 14 | def __init__(self, name_to_index=None, tokens=None): 15 | self.name_to_index = name_to_index or {} 16 | self.tokens = tokens or [] 17 | self._weights = None 18 | def __getitem__(self, key): 19 | if type(key) is str: 20 | if self.name_to_index.get(key) is None: 21 | warnings.warn("Unknown token in training dataset") 22 | return self.tokens[self.name_to_index[""]] 23 | return self.tokens[self.name_to_index[key]] 24 | elif type(key) is int: 25 | return self.tokens[key] 26 | else: 27 | try: 28 | return [self[k] for k in key] 29 | except: 30 | raise ValueError 31 | def load_state_dict(self, sd): 32 | self.name_to_index = sd['name_to_index'] 33 | self.tokens = sd['tokens'] 34 | def state_dict(self): 35 | return {'name_to_index': self.name_to_index, 'tokens': self.tokens} 36 | def size(self): 37 | return len(self.tokens) 38 | def add(self, names): 39 | self._weights = None 40 | if type(names) is not list: 41 | names = [names] 42 | for name in names: 43 | if name not in self.name_to_index: 44 | token = Token(name, len(self.tokens)) 45 | self.name_to_index[name] = token.index 46 | self.tokens.append(token) 47 | else: 48 | self.tokens[self.name_to_index[name]].count += 1 49 | def update(self, tokens_new): 50 | for token in tokens_new: 51 | if token.name not in self.name_to_index: 52 | token.index = len(self.tokens) 53 | self.name_to_index[token.name] = token.index 54 | self.tokens.append(token) 55 | else: 56 | self.tokens[self.name_to_index[token.name]].count += token.count 57 | def weights(self): 58 | # if no cache, calculate 59 | if self._weights is None: 60 | total = sum([token.count for token in self.tokens]) 61 | self._weights = torch.zeros(len(self.tokens)) 62 | for token in self.tokens: 63 | self._weights[token.index] = (token.count / total) ** 0.75 64 | return self._weights 65 | def sample(self, batch_size, num=5): 66 | return torch.multinomial(self.weights(), num * batch_size, replacement=True).view(batch_size, num) 67 | 68 | class Function: 69 | def __init__(self, insts, blocks, meta): 70 | self.insts = insts 71 | self.blocks = blocks 72 | self.meta = meta 73 | @classmethod 74 | def load(cls, text): 75 | ''' 76 | gcc -S format compatiable 77 | ''' 78 | label, labels, insts, blocks, meta = None, {}, [], [], {} 79 | for line in text.strip('\n').split('\n'): 80 | if line[0] in [' ', '\t']: 81 | line = line.strip() 82 | # meta data 83 | if line[0] == '.': 84 | key, _, value = line[1:].strip().partition(' ') 85 | meta[key] = value 86 | # instruction 87 | else: 88 | inst = Instruction.load(line) 89 | insts.append(inst) 90 | if len(blocks) == 0 or blocks[-1].end(): 91 | blocks.append(BasicBlock()) 92 | # link prev and next block 93 | if len(blocks) > 1: 94 | blocks[-2].successors.add(blocks[-1]) 95 | if label: 96 | labels[label], label = blocks[-1], None 97 | blocks[-1].add(inst) 98 | # label 99 | else: 100 | label = line.partition(':')[0] 101 | # link label 102 | for block in blocks: 103 | inst = block.insts[-1] 104 | if inst.is_jmp() and labels.get(inst.args[0]): 105 | block.successors.add(labels[inst.args[0]]) 106 | # replace label with CONST 107 | for inst in insts: 108 | for i, arg in enumerate(inst.args): 109 | if labels.get(arg): 110 | inst.args[i] = 'CONST' 111 | return cls(insts, blocks, meta) 112 | def tokens(self): 113 | return [token for inst in self.insts for token in inst.tokens()] 114 | def random_walk(self, num=3): 115 | return [self._random_walk() for _ in range(num)] 116 | def _random_walk(self): 117 | current, visited, seq = self.blocks[0], [], [] 118 | while current not in visited: 119 | visited.append(current) 120 | seq += current.insts 121 | # no following block / hit return 122 | if len(current.successors) == 0 or current.insts[-1].op == 'ret': 123 | break 124 | current = random.choice(list(current.successors)) 125 | return seq 126 | 127 | class BasicBlock: 128 | def __init__(self): 129 | self.insts = [] 130 | self.successors = set() 131 | def add(self, inst): 132 | self.insts.append(inst) 133 | def end(self): 134 | inst = self.insts[-1] 135 | return inst.is_jmp() or inst.op == 'ret' 136 | 137 | class Instruction: 138 | def __init__(self, op, args): 139 | self.op = op 140 | self.args = args 141 | def __str__(self): 142 | return f'{self.op} {", ".join([str(arg) for arg in self.args if str(arg)])}' 143 | @classmethod 144 | def load(cls, text): 145 | text = text.strip().strip('bnd').strip() # get rid of BND prefix 146 | op, _, args = text.strip().partition(' ') 147 | if args: 148 | args = [arg.strip() for arg in args.split(',')] 149 | else: 150 | args = [] 151 | args = (args + ['', ''])[:2] 152 | return cls(op, args) 153 | def tokens(self): 154 | return [self.op] + self.args 155 | def is_jmp(self): 156 | return 'jmp' in self.op or self.op[0] == 'j' 157 | def is_call(self): 158 | return self.op == 'call' 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # asm2vec-pytorch 2 | 3 | release 1.0.0 4 | mit 5 | python 6 | 7 | Unofficial implementation of `asm2vec` using pytorch ( with GPU acceleration ) 8 | The details of the model can be found in the original paper: [(sp'19) Asm2Vec: Boosting Static Representation Robustness for Binary Clone Search against Code Obfuscation and Compiler Optimization](https://www.computer.org/csdl/proceedings-article/sp/2019/666000a038/19skfc3ZfKo) 9 | 10 | ## Requirements 11 | 12 | python >= 3.6 13 | 14 | | packages | for | 15 | | --- | --- | 16 | | r2pipe | `scripts/bin2asm.py` | 17 | | click | `scripts/*` | 18 | | torch | almost all code need it | 19 | 20 | You also need to install `radare2` to run `scripts/bin2asm.py`. `r2pipe` is just the python interface to `radare2` 21 | 22 | If you only want to use the library code, you just need to install `torch` 23 | 24 | ## Install 25 | 26 | ``` 27 | python setup.py install 28 | ``` 29 | 30 | or 31 | 32 | ``` 33 | pip install git+https://github.com/oalieno/asm2vec-pytorch.git 34 | ``` 35 | 36 | ## Benchmark 37 | 38 | An implementation already exists here: [Lancern/asm2vec](https://github.com/Lancern/asm2vec) 39 | Following is the benchmark of training 1000 functions in 1 epoch. 40 | 41 | | Implementation | Time (s) | 42 | | :-: | :-: | 43 | | [Lancern/asm2vec](https://github.com/Lancern/asm2vec) | 202.23 | 44 | | [oalieno/asm2vec-pytorch](https://github.com/oalieno/asm2vec-pytorch) (with CPU) | 9.11 | 45 | | [oalieno/asm2vec-pytorch](https://github.com/oalieno/asm2vec-pytorch) (with GPU) | 0.97 | 46 | 47 | ## Get Started 48 | 49 | ```bash 50 | python scripts/bin2asm.py -i /bin/ -o asm/ 51 | ``` 52 | 53 | First generate asm files from binarys under `/bin/`. 54 | You can hit `Ctrl+C` anytime when there is enough data. 55 | 56 | ```bash 57 | python scripts/train.py -i asm/ -l 100 -o model.pt --epochs 100 58 | ``` 59 | 60 | Try to train the model using only 100 functions and 100 epochs for a taste. 61 | Then you can use more data if you want. 62 | 63 | ```bash 64 | python scripts/test.py -i asm/123456 -m model.pt 65 | ``` 66 | 67 | After you train your model, try to grab an assembly function and see the result. 68 | This script will show you how the model perform. 69 | Once you satisfied, you can take out the embedding vector of the function and do whatever you want with it. 70 | 71 | ## Usage 72 | 73 | ### bin2asm.py 74 | 75 | ``` 76 | Usage: bin2asm.py [OPTIONS] 77 | 78 | Extract assembly functions from binary executable 79 | 80 | Options: 81 | -i, --input TEXT input directory / file [required] 82 | -o, --output TEXT output directory 83 | -l, --len INTEGER ignore assembly code with instructions amount smaller 84 | than minlen 85 | 86 | --help Show this message and exit. 87 | ``` 88 | 89 | ```bash 90 | # Example 91 | python bin2asm.py -i /bin/ -o asm/ 92 | ``` 93 | 94 | ### train.py 95 | 96 | ``` 97 | Usage: train.py [OPTIONS] 98 | 99 | Options: 100 | -i, --input TEXT training data folder [required] 101 | -o, --output TEXT output model path [default: model.pt] 102 | -m, --model TEXT load previous trained model path 103 | -l, --limit INTEGER limit the number of functions to be loaded 104 | -d, --ebedding-dimension INTEGER 105 | embedding dimension [default: 100] 106 | -b, --batch-size INTEGER batch size [default: 1024] 107 | -e, --epochs INTEGER training epochs [default: 10] 108 | -n, --neg-sample-num INTEGER negative sampling amount [default: 25] 109 | -a, --calculate-accuracy whether calculate accuracy ( will be 110 | significantly slower ) 111 | 112 | -c, --device TEXT hardware device to be used: cpu / cuda / 113 | auto [default: auto] 114 | 115 | -lr, --learning-rate FLOAT learning rate [default: 0.02] 116 | --help Show this message and exit. 117 | ``` 118 | 119 | ```bash 120 | # Example 121 | python train.py -i asm/ -o model.pt --epochs 100 122 | ``` 123 | 124 | ### test.py 125 | 126 | ``` 127 | Usage: test.py [OPTIONS] 128 | 129 | Options: 130 | -i, --input TEXT target function [required] 131 | -m, --model TEXT model path [required] 132 | -e, --epochs INTEGER training epochs [default: 10] 133 | -n, --neg-sample-num INTEGER negative sampling amount [default: 25] 134 | -l, --limit INTEGER limit the amount of output probability result 135 | -c, --device TEXT hardware device to be used: cpu / cuda / auto 136 | [default: auto] 137 | 138 | -lr, --learning-rate FLOAT learning rate [default: 0.02] 139 | -p, --pretty pretty print table [default: False] 140 | --help Show this message and exit. 141 | ``` 142 | 143 | ```bash 144 | # Example 145 | python test.py -i asm/123456 -m model.pt 146 | ``` 147 | 148 | ``` 149 | ┌──────────────────────────────────────────┐ 150 | │ endbr64 │ 151 | │ ➔ push r15 │ 152 | │ push r14 │ 153 | ├────────┬─────────────────────────────────┤ 154 | │ 34.68% │ [rdx + rsi*CONST + CONST] │ 155 | │ 20.29% │ push │ 156 | │ 16.22% │ r15 │ 157 | │ 04.36% │ r14 │ 158 | │ 03.55% │ r11d │ 159 | └────────┴─────────────────────────────────┘ 160 | ``` 161 | 162 | ### compare.py 163 | 164 | ``` 165 | Usage: compare.py [OPTIONS] 166 | 167 | Options: 168 | -i1, --input1 TEXT target function 1 [required] 169 | -i2, --input2 TEXT target function 2 [required] 170 | -m, --model TEXT model path [required] 171 | -e, --epochs INTEGER training epochs [default: 10] 172 | -c, --device TEXT hardware device to be used: cpu / cuda / auto 173 | [default: auto] 174 | 175 | -lr, --learning-rate FLOAT learning rate [default: 0.02] 176 | --help Show this message and exit. 177 | ``` 178 | 179 | ```bash 180 | # Example 181 | python compare.py -i1 asm/123456 -i2 asm/654321 -m model.pt -e 30 182 | ``` 183 | 184 | ``` 185 | cosine similarity : 0.873684 186 | ``` 187 | --------------------------------------------------------------------------------