├── requirements.txt
├── asm2vec
├── __init__.py
├── model.py
├── utils.py
└── datatype.py
├── setup.py
├── LICENSE
├── scripts
├── compare.py
├── test.py
├── train.py
└── bin2asm.py
├── .gitignore
└── README.md
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.7,<2
2 | click>=7.1,<8
3 | r2pipe>=1.5,<2
4 |
--------------------------------------------------------------------------------
/asm2vec/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | __all__ = ['model', 'datatype', 'utils']
4 |
5 | for module in __all__:
6 | importlib.import_module(f'.{module}', 'asm2vec')
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='asm2vec',
5 | version='1.0.0',
6 | description='Unofficial implementation of asm2vec using pytorch',
7 | install_requires=['torch>=1.7,<2'
8 | 'click>=7.1,<8'
9 | 'r2pipe>=1.5,<2'],
10 | author='oalieno',
11 | author_email='jeffrey6910@gmail.com',
12 | license='MIT License',
13 | packages = find_packages(),
14 | )
15 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 oalieno
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/scripts/compare.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import click
4 | import asm2vec
5 |
6 | def cosine_similarity(v1, v2):
7 | return (v1 @ v2 / (v1.norm() * v2.norm())).item()
8 |
9 | @click.command()
10 | @click.option('-i1', '--input1', 'ipath1', help='target function 1', required=True)
11 | @click.option('-i2', '--input2', 'ipath2', help='target function 2', required=True)
12 | @click.option('-m', '--model', 'mpath', help='model path', required=True)
13 | @click.option('-e', '--epochs', default=10, help='training epochs', show_default=True)
14 | @click.option('-c', '--device', default='auto', help='hardware device to be used: cpu / cuda / auto', show_default=True)
15 | @click.option('-lr', '--learning-rate', 'lr', default=0.02, help="learning rate", show_default=True)
16 | def cli(ipath1, ipath2, mpath, epochs, device, lr):
17 | if device == 'auto':
18 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
19 |
20 | # load model, tokens
21 | model, tokens = asm2vec.utils.load_model(mpath, device=device)
22 | functions, tokens_new = asm2vec.utils.load_data([ipath1, ipath2])
23 | tokens.update(tokens_new)
24 | model.update(2, tokens.size())
25 | model = model.to(device)
26 |
27 | # train function embedding
28 | model = asm2vec.utils.train(
29 | functions,
30 | tokens,
31 | model=model,
32 | epochs=epochs,
33 | device=device,
34 | mode='test',
35 | learning_rate=lr
36 | )
37 |
38 | # compare 2 function vectors
39 | v1, v2 = model.to('cpu').embeddings_f(torch.tensor([0, 1]))
40 |
41 | print(f'cosine similarity : {cosine_similarity(v1, v2):.6f}')
42 |
43 | if __name__ == '__main__':
44 | cli()
45 |
--------------------------------------------------------------------------------
/scripts/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import click
4 | import asm2vec
5 |
6 | @click.command()
7 | @click.option('-i', '--input', 'ipath', help='target function', required=True)
8 | @click.option('-m', '--model', 'mpath', help='model path', required=True)
9 | @click.option('-e', '--epochs', default=10, help='training epochs', show_default=True)
10 | @click.option('-n', '--neg-sample-num', 'neg_sample_num', default=25, help='negative sampling amount', show_default=True)
11 | @click.option('-l', '--limit', help='limit the amount of output probability result', type=int)
12 | @click.option('-c', '--device', default='auto', help='hardware device to be used: cpu / cuda / auto', show_default=True)
13 | @click.option('-lr', '--learning-rate', 'lr', default=0.02, help="learning rate", show_default=True)
14 | @click.option('-p', '--pretty', default=False, help='pretty print table', show_default=True, is_flag=True)
15 | def cli(ipath, mpath, epochs, neg_sample_num, limit, device, lr, pretty):
16 | if device == 'auto':
17 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
18 |
19 | # load model, tokens
20 | model, tokens = asm2vec.utils.load_model(mpath, device=device)
21 | functions, tokens_new = asm2vec.utils.load_data(ipath)
22 | tokens.update(tokens_new)
23 | model.update(1, tokens.size())
24 | model = model.to(device)
25 |
26 | # train function embedding
27 | model = asm2vec.utils.train(
28 | functions,
29 | tokens,
30 | model=model,
31 | epochs=epochs,
32 | neg_sample_num=neg_sample_num,
33 | device=device,
34 | mode='test',
35 | learning_rate=lr
36 | )
37 |
38 | # show predicted probability results
39 | x, y = asm2vec.utils.preprocess(functions, tokens)
40 | probs = model.predict(x.to(device), y.to(device))
41 | asm2vec.utils.show_probs(x, y, probs, tokens, limit=limit, pretty=pretty)
42 |
43 | if __name__ == '__main__':
44 | cli()
45 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import click
3 | import asm2vec
4 |
5 | @click.command()
6 | @click.option('-i', '--input', 'ipath', help='training data folder', required=True)
7 | @click.option('-o', '--output', 'opath', default='model.pt', help='output model path', show_default=True)
8 | @click.option('-m', '--model', 'mpath', help='load previous trained model path', type=str)
9 | @click.option('-l', '--limit', help='limit the number of functions to be loaded', show_default=True, type=int)
10 | @click.option('-d', '--ebedding-dimension', 'embedding_size', default=100, help='embedding dimension', show_default=True)
11 | @click.option('-b', '--batch-size', 'batch_size', default=1024, help='batch size', show_default=True)
12 | @click.option('-e', '--epochs', default=10, help='training epochs', show_default=True)
13 | @click.option('-n', '--neg-sample-num', 'neg_sample_num', default=25, help='negative sampling amount', show_default=True)
14 | @click.option('-a', '--calculate-accuracy', 'calc_acc', help='whether calculate accuracy ( will be significantly slower )', is_flag=True)
15 | @click.option('-c', '--device', default='auto', help='hardware device to be used: cpu / cuda / auto', show_default=True)
16 | @click.option('-lr', '--learning-rate', 'lr', default=0.02, help="learning rate", show_default=True)
17 | def cli(ipath, opath, mpath, limit, embedding_size, batch_size, epochs, neg_sample_num, calc_acc, device, lr):
18 | if device == 'auto':
19 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
20 |
21 | if mpath:
22 | model, tokens = asm2vec.utils.load_model(mpath, device=device)
23 | functions, tokens_new = asm2vec.utils.load_data(ipath, limit=limit)
24 | tokens.update(tokens_new)
25 | model.update(len(functions), tokens.size())
26 | else:
27 | model = None
28 | functions, tokens = asm2vec.utils.load_data(ipath, limit=limit)
29 |
30 | def callback(context):
31 | progress = f'{context["epoch"]} | time = {context["time"]:.2f}, loss = {context["loss"]:.4f}'
32 | if context["accuracy"]:
33 | progress += f', accuracy = {context["accuracy"]:.4f}'
34 | print(progress)
35 | asm2vec.utils.save_model(opath, context["model"], context["tokens"])
36 |
37 | model = asm2vec.utils.train(
38 | functions,
39 | tokens,
40 | model=model,
41 | embedding_size=embedding_size,
42 | batch_size=batch_size,
43 | epochs=epochs,
44 | neg_sample_num=neg_sample_num,
45 | calc_acc=calc_acc,
46 | device=device,
47 | callback=callback,
48 | learning_rate=lr
49 | )
50 |
51 | if __name__ == '__main__':
52 | cli()
53 |
--------------------------------------------------------------------------------
/asm2vec/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | bce, sigmoid, softmax = nn.BCELoss(), nn.Sigmoid(), nn.Softmax(dim=1)
5 |
6 | class ASM2VEC(nn.Module):
7 | def __init__(self, vocab_size, function_size, embedding_size):
8 | super(ASM2VEC, self).__init__()
9 | self.embeddings = nn.Embedding(vocab_size, embedding_size, _weight=torch.zeros(vocab_size, embedding_size))
10 | self.embeddings_f = nn.Embedding(function_size, 2 * embedding_size, _weight=(torch.rand(function_size, 2 * embedding_size)-0.5)/embedding_size/2)
11 | self.embeddings_r = nn.Embedding(vocab_size, 2 * embedding_size, _weight=(torch.rand(vocab_size, 2 * embedding_size)-0.5)/embedding_size/2)
12 |
13 | def update(self, function_size_new, vocab_size_new):
14 | device = self.embeddings.weight.device
15 | vocab_size, function_size, embedding_size = self.embeddings.num_embeddings, self.embeddings_f.num_embeddings, self.embeddings.embedding_dim
16 | if vocab_size_new != vocab_size:
17 | weight = torch.cat([self.embeddings.weight, torch.zeros(vocab_size_new - vocab_size, embedding_size).to(device)])
18 | self.embeddings = nn.Embedding(vocab_size_new, embedding_size, _weight=weight)
19 | weight_r = torch.cat([self.embeddings_r.weight, ((torch.rand(vocab_size_new - vocab_size, 2 * embedding_size)-0.5)/embedding_size/2).to(device)])
20 | self.embeddings_r = nn.Embedding(vocab_size_new, 2 * embedding_size, _weight=weight_r)
21 | self.embeddings_f = nn.Embedding(function_size_new, 2 * embedding_size, _weight=((torch.rand(function_size_new, 2 * embedding_size)-0.5)/embedding_size/2).to(device))
22 |
23 | def v(self, inp):
24 | e = self.embeddings(inp[:,1:])
25 | v_f = self.embeddings_f(inp[:,0])
26 | v_prev = torch.cat([e[:,0], (e[:,1] + e[:,2]) / 2], dim=1)
27 | v_next = torch.cat([e[:,3], (e[:,4] + e[:,5]) / 2], dim=1)
28 | v = ((v_f + v_prev + v_next) / 3).unsqueeze(2)
29 | return v
30 |
31 | def forward(self, inp, pos, neg):
32 | device, batch_size = inp.device, inp.shape[0]
33 | v = self.v(inp)
34 | # negative sampling loss
35 | pred = torch.bmm(self.embeddings_r(torch.cat([pos, neg], dim=1)), v).squeeze()
36 | label = torch.cat([torch.ones(batch_size, 3), torch.zeros(batch_size, neg.shape[1])], dim=1).to(device)
37 | return bce(sigmoid(pred), label)
38 |
39 | def predict(self, inp, pos):
40 | device, batch_size = inp.device, inp.shape[0]
41 | v = self.v(inp)
42 | probs = torch.bmm(self.embeddings_r(torch.arange(self.embeddings_r.num_embeddings).repeat(batch_size, 1).to(device)), v).squeeze(dim=2)
43 | return softmax(probs)
44 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/vim,python
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=vim,python
3 |
4 | ### Python ###
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | pytestdebug.log
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 | doc/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 | pythonenv*
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
137 |
138 | # pytype static type analyzer
139 | .pytype/
140 |
141 | # profiling data
142 | .prof
143 |
144 | ### Vim ###
145 | # Swap
146 | [._]*.s[a-v][a-z]
147 | !*.svg # comment out if you don't need vector files
148 | [._]*.sw[a-p]
149 | [._]s[a-rt-v][a-z]
150 | [._]ss[a-gi-z]
151 | [._]sw[a-p]
152 |
153 | # Session
154 | Session.vim
155 | Sessionx.vim
156 |
157 | # Temporary
158 | .netrwhist
159 | *~
160 | # Auto-generated tag files
161 | tags
162 | # Persistent undo
163 | [._]*.un~
164 |
165 | # End of https://www.toptal.com/developers/gitignore/api/vim,python
166 |
--------------------------------------------------------------------------------
/scripts/bin2asm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import re
3 | import os
4 | import click
5 | import r2pipe
6 | import hashlib
7 | from pathlib import Path
8 |
9 | def sha3(data):
10 | return hashlib.sha3_256(data.encode()).hexdigest()
11 |
12 | def validEXE(filename):
13 | magics = [bytes.fromhex('7f454c46')]
14 | with open(filename, 'rb') as f:
15 | header = f.read(4)
16 | return header in magics
17 |
18 | def normalize(opcode):
19 | opcode = opcode.replace(' - ', ' + ')
20 | opcode = re.sub(r'0x[0-9a-f]+', 'CONST', opcode)
21 | opcode = re.sub(r'\*[0-9]', '*CONST', opcode)
22 | opcode = re.sub(r' [0-9]', ' CONST', opcode)
23 | return opcode
24 |
25 | def fn2asm(pdf, minlen):
26 | # check
27 | if pdf is None:
28 | return
29 | if len(pdf['ops']) < minlen:
30 | return
31 | if 'invalid' in [op['type'] for op in pdf['ops']]:
32 | return
33 |
34 | ops = pdf['ops']
35 |
36 | # set label
37 | labels, scope = {}, [op['offset'] for op in ops]
38 | assert(None not in scope)
39 | for i, op in enumerate(ops):
40 | if op.get('jump') in scope:
41 | labels.setdefault(op.get('jump'), i)
42 |
43 | # dump output
44 | output = ''
45 | for op in ops:
46 | # add label
47 | if labels.get(op.get('offset')) is not None:
48 | output += f'LABEL{labels[op["offset"]]}:\n'
49 | # add instruction
50 | if labels.get(op.get('jump')) is not None:
51 | output += f' {op["type"]} LABEL{labels[op["jump"]]}\n'
52 | else:
53 | output += f' {normalize(op["opcode"])}\n'
54 |
55 | return output
56 |
57 | def bin2asm(filename, opath, minlen):
58 | # check
59 | if not validEXE(filename):
60 | return 0
61 |
62 | r = r2pipe.open(str(filename))
63 | r.cmd('aaaa')
64 |
65 | count = 0
66 |
67 | for fn in r.cmdj('aflj'):
68 | r.cmd(f's {fn["offset"]}')
69 | asm = fn2asm(r.cmdj('pdfj'), minlen)
70 | if asm:
71 | uid = sha3(asm)
72 | asm = f''' .name {fn["name"]}
73 | .offset {fn["offset"]:016x}
74 | .file {filename.name}
75 | ''' + asm
76 | with open(opath / uid, 'w') as f:
77 | f.write(asm)
78 | count += 1
79 |
80 | print(f'[+] {filename}')
81 |
82 | return count
83 |
84 | @click.command()
85 | @click.option('-i', '--input', 'ipath', help='input directory / file', required=True)
86 | @click.option('-o', '--output', 'opath', default='asm', help='output directory')
87 | @click.option('-l', '--len', 'minlen', default=10, help='ignore assembly code with instructions amount smaller than minlen')
88 | def cli(ipath, opath, minlen):
89 | '''
90 | Extract assembly functions from binary executable
91 | '''
92 | ipath = Path(ipath)
93 | opath = Path(opath)
94 |
95 | # create output directory
96 | if not os.path.exists(opath):
97 | os.mkdir(opath)
98 |
99 | fcount, bcount = 0, 0
100 |
101 | # directory
102 | if os.path.isdir(ipath):
103 | for f in os.listdir(ipath):
104 | if not os.path.islink(ipath / f) and not os.path.isdir(ipath / f):
105 | fcount += bin2asm(ipath / f, opath, minlen)
106 | bcount += 1
107 | # file
108 | elif os.path.exists(ipath):
109 | fcount += bin2asm(ipath, opath, minlen)
110 | bcount += 1
111 | else:
112 | print(f'[Error] No such file or directory: {ipath}')
113 |
114 | print(f'[+] Total scan binary: {bcount} => Total generated assembly functions: {fcount}')
115 |
116 | if __name__ == '__main__':
117 | cli()
118 |
--------------------------------------------------------------------------------
/asm2vec/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | from torch.utils.data import DataLoader, Dataset
5 | from pathlib import Path
6 | from .datatype import Tokens, Function, Instruction
7 | from .model import ASM2VEC
8 |
9 | class AsmDataset(Dataset):
10 | def __init__(self, x, y):
11 | self.x = x
12 | self.y = y
13 | def __len__(self):
14 | return len(self.x)
15 | def __getitem__(self, index):
16 | return self.x[index], self.y[index]
17 |
18 | def load_data(paths, limit=None):
19 | if type(paths) is not list:
20 | paths = [paths]
21 |
22 | filenames = []
23 | for path in paths:
24 | if os.path.isdir(path):
25 | filenames += [Path(path) / filename for filename in sorted(os.listdir(path)) if os.path.isfile(Path(path) / filename)]
26 | else:
27 | filenames += [Path(path)]
28 |
29 | functions, tokens = [], Tokens()
30 | for i, filename in enumerate(filenames):
31 | if limit and i >= limit:
32 | break
33 | with open(filename) as f:
34 | fn = Function.load(f.read())
35 | functions.append(fn)
36 | tokens.add(fn.tokens())
37 |
38 | return functions, tokens
39 |
40 | def preprocess(functions, tokens):
41 | x, y = [], []
42 | for i, fn in enumerate(functions):
43 | for seq in fn.random_walk():
44 | for j in range(1, len(seq) - 1):
45 | x.append([i] + [tokens[token].index for token in seq[j-1].tokens() + seq[j+1].tokens()])
46 | y.append([tokens[token].index for token in seq[j].tokens()])
47 | return torch.tensor(x), torch.tensor(y)
48 |
49 | def train(
50 | functions,
51 | tokens,
52 | model=None,
53 | embedding_size=100,
54 | batch_size=1024,
55 | epochs=10,
56 | neg_sample_num=25,
57 | calc_acc=False,
58 | device='cpu',
59 | mode='train',
60 | callback=None,
61 | learning_rate=0.02
62 | ):
63 | if mode == 'train':
64 | if model is None:
65 | model = ASM2VEC(tokens.size(), function_size=len(functions), embedding_size=embedding_size).to(device)
66 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
67 | elif mode == 'test':
68 | if model is None:
69 | raise ValueError("test mode required pretrained model")
70 | optimizer = torch.optim.Adam(model.embeddings_f.parameters(), lr=learning_rate)
71 | else:
72 | raise ValueError("Unknown mode")
73 |
74 | loader = DataLoader(AsmDataset(*preprocess(functions, tokens)), batch_size=batch_size, shuffle=True)
75 | for epoch in range(epochs):
76 | start = time.time()
77 | loss_sum, loss_count, accs = 0.0, 0, []
78 |
79 | model.train()
80 | for i, (inp, pos) in enumerate(loader):
81 | neg = tokens.sample(inp.shape[0], neg_sample_num)
82 | loss = model(inp.to(device), pos.to(device), neg.to(device))
83 | loss_sum, loss_count = loss_sum + loss, loss_count + 1
84 |
85 | optimizer.zero_grad()
86 | loss.backward()
87 | optimizer.step()
88 |
89 | if i == 0 and calc_acc:
90 | probs = model.predict(inp.to(device), pos.to(device))
91 | accs.append(accuracy(pos, probs))
92 |
93 | if callback:
94 | callback({
95 | 'model': model,
96 | 'tokens': tokens,
97 | 'epoch': epoch,
98 | 'time': time.time() - start,
99 | 'loss': loss_sum / loss_count,
100 | 'accuracy': torch.tensor(accs).mean() if calc_acc else None
101 | })
102 |
103 | return model
104 |
105 | def save_model(path, model, tokens):
106 | torch.save({
107 | 'model_params': (
108 | model.embeddings.num_embeddings,
109 | model.embeddings_f.num_embeddings,
110 | model.embeddings.embedding_dim
111 | ),
112 | 'model': model.state_dict(),
113 | 'tokens': tokens.state_dict(),
114 | }, path)
115 |
116 | def load_model(path, device='cpu'):
117 | checkpoint = torch.load(path, map_location=device)
118 | tokens = Tokens()
119 | tokens.load_state_dict(checkpoint['tokens'])
120 | model = ASM2VEC(*checkpoint['model_params'])
121 | model.load_state_dict(checkpoint['model'])
122 | model = model.to(device)
123 | return model, tokens
124 |
125 | def show_probs(x, y, probs, tokens, limit=None, pretty=False):
126 | if pretty:
127 | TL, TR, BL, BR = '┌', '┐', '└', '┘'
128 | LM, RM, TM, BM = '├', '┤', '┬', '┴'
129 | H, V = '─', '│'
130 | arrow = ' ➔'
131 | else:
132 | TL = TR = BL = BR = '+'
133 | LM = RM = TM = BM = '+'
134 | H, V = '-', '|'
135 | arrow = '->'
136 | top = probs.topk(5)
137 | for i, (xi, yi) in enumerate(zip(x, y)):
138 | if limit and i >= limit:
139 | break
140 | xi, yi = xi.tolist(), yi.tolist()
141 | print(TL + H * 42 + TR)
142 | print(f'{V} {str(Instruction(tokens[xi[1]], tokens[xi[2:4]])):37} {V}')
143 | print(f'{V} {arrow} {str(Instruction(tokens[yi[0]], tokens[yi[1:3]])):37} {V}')
144 | print(f'{V} {str(Instruction(tokens[xi[4]], tokens[xi[5:7]])):37} {V}')
145 | print(LM + H * 8 + TM + H * 33 + RM)
146 | for value, index in zip(top.values[i], top.indices[i]):
147 | if index in yi:
148 | colorbegin, colorclear = '\033[92m', '\033[0m'
149 | else:
150 | colorbegin, colorclear = '', ''
151 | print(f'{V} {colorbegin}{value*100:05.2f}%{colorclear} {V} {colorbegin}{tokens[index.item()].name:31}{colorclear} {V}')
152 | print(BL + H * 8 + BM + H * 33 + BR)
153 |
154 | def accuracy(y, probs):
155 | return torch.mean(torch.tensor([torch.sum(probs[i][yi]) for i, yi in enumerate(y)]))
156 |
157 |
--------------------------------------------------------------------------------
/asm2vec/datatype.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import warnings
4 |
5 | class Token:
6 | def __init__(self, name, index):
7 | self.name = name
8 | self.index = index
9 | self.count = 1
10 | def __str__(self):
11 | return self.name
12 |
13 | class Tokens:
14 | def __init__(self, name_to_index=None, tokens=None):
15 | self.name_to_index = name_to_index or {}
16 | self.tokens = tokens or []
17 | self._weights = None
18 | def __getitem__(self, key):
19 | if type(key) is str:
20 | if self.name_to_index.get(key) is None:
21 | warnings.warn("Unknown token in training dataset")
22 | return self.tokens[self.name_to_index[""]]
23 | return self.tokens[self.name_to_index[key]]
24 | elif type(key) is int:
25 | return self.tokens[key]
26 | else:
27 | try:
28 | return [self[k] for k in key]
29 | except:
30 | raise ValueError
31 | def load_state_dict(self, sd):
32 | self.name_to_index = sd['name_to_index']
33 | self.tokens = sd['tokens']
34 | def state_dict(self):
35 | return {'name_to_index': self.name_to_index, 'tokens': self.tokens}
36 | def size(self):
37 | return len(self.tokens)
38 | def add(self, names):
39 | self._weights = None
40 | if type(names) is not list:
41 | names = [names]
42 | for name in names:
43 | if name not in self.name_to_index:
44 | token = Token(name, len(self.tokens))
45 | self.name_to_index[name] = token.index
46 | self.tokens.append(token)
47 | else:
48 | self.tokens[self.name_to_index[name]].count += 1
49 | def update(self, tokens_new):
50 | for token in tokens_new:
51 | if token.name not in self.name_to_index:
52 | token.index = len(self.tokens)
53 | self.name_to_index[token.name] = token.index
54 | self.tokens.append(token)
55 | else:
56 | self.tokens[self.name_to_index[token.name]].count += token.count
57 | def weights(self):
58 | # if no cache, calculate
59 | if self._weights is None:
60 | total = sum([token.count for token in self.tokens])
61 | self._weights = torch.zeros(len(self.tokens))
62 | for token in self.tokens:
63 | self._weights[token.index] = (token.count / total) ** 0.75
64 | return self._weights
65 | def sample(self, batch_size, num=5):
66 | return torch.multinomial(self.weights(), num * batch_size, replacement=True).view(batch_size, num)
67 |
68 | class Function:
69 | def __init__(self, insts, blocks, meta):
70 | self.insts = insts
71 | self.blocks = blocks
72 | self.meta = meta
73 | @classmethod
74 | def load(cls, text):
75 | '''
76 | gcc -S format compatiable
77 | '''
78 | label, labels, insts, blocks, meta = None, {}, [], [], {}
79 | for line in text.strip('\n').split('\n'):
80 | if line[0] in [' ', '\t']:
81 | line = line.strip()
82 | # meta data
83 | if line[0] == '.':
84 | key, _, value = line[1:].strip().partition(' ')
85 | meta[key] = value
86 | # instruction
87 | else:
88 | inst = Instruction.load(line)
89 | insts.append(inst)
90 | if len(blocks) == 0 or blocks[-1].end():
91 | blocks.append(BasicBlock())
92 | # link prev and next block
93 | if len(blocks) > 1:
94 | blocks[-2].successors.add(blocks[-1])
95 | if label:
96 | labels[label], label = blocks[-1], None
97 | blocks[-1].add(inst)
98 | # label
99 | else:
100 | label = line.partition(':')[0]
101 | # link label
102 | for block in blocks:
103 | inst = block.insts[-1]
104 | if inst.is_jmp() and labels.get(inst.args[0]):
105 | block.successors.add(labels[inst.args[0]])
106 | # replace label with CONST
107 | for inst in insts:
108 | for i, arg in enumerate(inst.args):
109 | if labels.get(arg):
110 | inst.args[i] = 'CONST'
111 | return cls(insts, blocks, meta)
112 | def tokens(self):
113 | return [token for inst in self.insts for token in inst.tokens()]
114 | def random_walk(self, num=3):
115 | return [self._random_walk() for _ in range(num)]
116 | def _random_walk(self):
117 | current, visited, seq = self.blocks[0], [], []
118 | while current not in visited:
119 | visited.append(current)
120 | seq += current.insts
121 | # no following block / hit return
122 | if len(current.successors) == 0 or current.insts[-1].op == 'ret':
123 | break
124 | current = random.choice(list(current.successors))
125 | return seq
126 |
127 | class BasicBlock:
128 | def __init__(self):
129 | self.insts = []
130 | self.successors = set()
131 | def add(self, inst):
132 | self.insts.append(inst)
133 | def end(self):
134 | inst = self.insts[-1]
135 | return inst.is_jmp() or inst.op == 'ret'
136 |
137 | class Instruction:
138 | def __init__(self, op, args):
139 | self.op = op
140 | self.args = args
141 | def __str__(self):
142 | return f'{self.op} {", ".join([str(arg) for arg in self.args if str(arg)])}'
143 | @classmethod
144 | def load(cls, text):
145 | text = text.strip().strip('bnd').strip() # get rid of BND prefix
146 | op, _, args = text.strip().partition(' ')
147 | if args:
148 | args = [arg.strip() for arg in args.split(',')]
149 | else:
150 | args = []
151 | args = (args + ['', ''])[:2]
152 | return cls(op, args)
153 | def tokens(self):
154 | return [self.op] + self.args
155 | def is_jmp(self):
156 | return 'jmp' in self.op or self.op[0] == 'j'
157 | def is_call(self):
158 | return self.op == 'call'
159 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # asm2vec-pytorch
2 |
3 |
4 |
5 |
6 |
7 | Unofficial implementation of `asm2vec` using pytorch ( with GPU acceleration )
8 | The details of the model can be found in the original paper: [(sp'19) Asm2Vec: Boosting Static Representation Robustness for Binary Clone Search against Code Obfuscation and Compiler Optimization](https://www.computer.org/csdl/proceedings-article/sp/2019/666000a038/19skfc3ZfKo)
9 |
10 | ## Requirements
11 |
12 | python >= 3.6
13 |
14 | | packages | for |
15 | | --- | --- |
16 | | r2pipe | `scripts/bin2asm.py` |
17 | | click | `scripts/*` |
18 | | torch | almost all code need it |
19 |
20 | You also need to install `radare2` to run `scripts/bin2asm.py`. `r2pipe` is just the python interface to `radare2`
21 |
22 | If you only want to use the library code, you just need to install `torch`
23 |
24 | ## Install
25 |
26 | ```
27 | python setup.py install
28 | ```
29 |
30 | or
31 |
32 | ```
33 | pip install git+https://github.com/oalieno/asm2vec-pytorch.git
34 | ```
35 |
36 | ## Benchmark
37 |
38 | An implementation already exists here: [Lancern/asm2vec](https://github.com/Lancern/asm2vec)
39 | Following is the benchmark of training 1000 functions in 1 epoch.
40 |
41 | | Implementation | Time (s) |
42 | | :-: | :-: |
43 | | [Lancern/asm2vec](https://github.com/Lancern/asm2vec) | 202.23 |
44 | | [oalieno/asm2vec-pytorch](https://github.com/oalieno/asm2vec-pytorch) (with CPU) | 9.11 |
45 | | [oalieno/asm2vec-pytorch](https://github.com/oalieno/asm2vec-pytorch) (with GPU) | 0.97 |
46 |
47 | ## Get Started
48 |
49 | ```bash
50 | python scripts/bin2asm.py -i /bin/ -o asm/
51 | ```
52 |
53 | First generate asm files from binarys under `/bin/`.
54 | You can hit `Ctrl+C` anytime when there is enough data.
55 |
56 | ```bash
57 | python scripts/train.py -i asm/ -l 100 -o model.pt --epochs 100
58 | ```
59 |
60 | Try to train the model using only 100 functions and 100 epochs for a taste.
61 | Then you can use more data if you want.
62 |
63 | ```bash
64 | python scripts/test.py -i asm/123456 -m model.pt
65 | ```
66 |
67 | After you train your model, try to grab an assembly function and see the result.
68 | This script will show you how the model perform.
69 | Once you satisfied, you can take out the embedding vector of the function and do whatever you want with it.
70 |
71 | ## Usage
72 |
73 | ### bin2asm.py
74 |
75 | ```
76 | Usage: bin2asm.py [OPTIONS]
77 |
78 | Extract assembly functions from binary executable
79 |
80 | Options:
81 | -i, --input TEXT input directory / file [required]
82 | -o, --output TEXT output directory
83 | -l, --len INTEGER ignore assembly code with instructions amount smaller
84 | than minlen
85 |
86 | --help Show this message and exit.
87 | ```
88 |
89 | ```bash
90 | # Example
91 | python bin2asm.py -i /bin/ -o asm/
92 | ```
93 |
94 | ### train.py
95 |
96 | ```
97 | Usage: train.py [OPTIONS]
98 |
99 | Options:
100 | -i, --input TEXT training data folder [required]
101 | -o, --output TEXT output model path [default: model.pt]
102 | -m, --model TEXT load previous trained model path
103 | -l, --limit INTEGER limit the number of functions to be loaded
104 | -d, --ebedding-dimension INTEGER
105 | embedding dimension [default: 100]
106 | -b, --batch-size INTEGER batch size [default: 1024]
107 | -e, --epochs INTEGER training epochs [default: 10]
108 | -n, --neg-sample-num INTEGER negative sampling amount [default: 25]
109 | -a, --calculate-accuracy whether calculate accuracy ( will be
110 | significantly slower )
111 |
112 | -c, --device TEXT hardware device to be used: cpu / cuda /
113 | auto [default: auto]
114 |
115 | -lr, --learning-rate FLOAT learning rate [default: 0.02]
116 | --help Show this message and exit.
117 | ```
118 |
119 | ```bash
120 | # Example
121 | python train.py -i asm/ -o model.pt --epochs 100
122 | ```
123 |
124 | ### test.py
125 |
126 | ```
127 | Usage: test.py [OPTIONS]
128 |
129 | Options:
130 | -i, --input TEXT target function [required]
131 | -m, --model TEXT model path [required]
132 | -e, --epochs INTEGER training epochs [default: 10]
133 | -n, --neg-sample-num INTEGER negative sampling amount [default: 25]
134 | -l, --limit INTEGER limit the amount of output probability result
135 | -c, --device TEXT hardware device to be used: cpu / cuda / auto
136 | [default: auto]
137 |
138 | -lr, --learning-rate FLOAT learning rate [default: 0.02]
139 | -p, --pretty pretty print table [default: False]
140 | --help Show this message and exit.
141 | ```
142 |
143 | ```bash
144 | # Example
145 | python test.py -i asm/123456 -m model.pt
146 | ```
147 |
148 | ```
149 | ┌──────────────────────────────────────────┐
150 | │ endbr64 │
151 | │ ➔ push r15 │
152 | │ push r14 │
153 | ├────────┬─────────────────────────────────┤
154 | │ 34.68% │ [rdx + rsi*CONST + CONST] │
155 | │ 20.29% │ push │
156 | │ 16.22% │ r15 │
157 | │ 04.36% │ r14 │
158 | │ 03.55% │ r11d │
159 | └────────┴─────────────────────────────────┘
160 | ```
161 |
162 | ### compare.py
163 |
164 | ```
165 | Usage: compare.py [OPTIONS]
166 |
167 | Options:
168 | -i1, --input1 TEXT target function 1 [required]
169 | -i2, --input2 TEXT target function 2 [required]
170 | -m, --model TEXT model path [required]
171 | -e, --epochs INTEGER training epochs [default: 10]
172 | -c, --device TEXT hardware device to be used: cpu / cuda / auto
173 | [default: auto]
174 |
175 | -lr, --learning-rate FLOAT learning rate [default: 0.02]
176 | --help Show this message and exit.
177 | ```
178 |
179 | ```bash
180 | # Example
181 | python compare.py -i1 asm/123456 -i2 asm/654321 -m model.pt -e 30
182 | ```
183 |
184 | ```
185 | cosine similarity : 0.873684
186 | ```
187 |
--------------------------------------------------------------------------------