├── __init__.py ├── utils ├── __init__.py ├── __pycache__ │ ├── utils.cpython-36.pyc │ └── __init__.cpython-36.pyc └── utils.py ├── asm_embedding ├── __init__.py ├── __init__.pyc ├── FunctionAnalyzerRadare.pyc ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── FunctionNormalizer.cpython-36.pyc │ ├── InstructionsConverter.cpython-36.pyc │ └── FunctionAnalyzerRadare.cpython-36.pyc ├── DocumentManipulation.py ├── InstructionsConverter.py ├── FunctionNormalizer.py └── FunctionAnalyzerRadare.py ├── dataset_creation ├── __init__.py ├── FunctionsEmbedder.py ├── convertDB.py ├── ExperimentUtil.py ├── DataSplitter.py └── DatabaseFactory.py ├── function_search ├── __init__.py ├── fromJsonSearchToPlot.py ├── EvaluateSearchEngine.py └── FunctionSearchEngine.py ├── neural_network ├── __init__.py ├── __pycache__ │ ├── SAFE_model.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── parameters.cpython-36.pyc │ ├── PairFactory.cpython-36.pyc │ ├── SAFEEmbedder.cpython-36.pyc │ └── SiameseSAFE.cpython-36.pyc ├── freeze_graph.sh ├── train.sh ├── SAFEEmbedder.py ├── train.py ├── parameters.py ├── PairFactory.py ├── SiameseSAFE.py └── SAFE_model.py ├── helloworld.o ├── img ├── 1.jpeg ├── 2.jpeg ├── 3.jpeg ├── 4.jpeg ├── 5.jpeg ├── metric.png └── safe2.jpg ├── requirements.txt ├── download_model.sh ├── LICENSE ├── helloworld.c ├── 404.html ├── Gemfile ├── _config.yml ├── godown.pl ├── safe.py ├── index.md ├── README.md ├── downloader.py └── Gemfile.lock /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asm_embedding/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset_creation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /function_search/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /neural_network/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /helloworld.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/helloworld.o -------------------------------------------------------------------------------- /img/1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/img/1.jpeg -------------------------------------------------------------------------------- /img/2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/img/2.jpeg -------------------------------------------------------------------------------- /img/3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/img/3.jpeg -------------------------------------------------------------------------------- /img/4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/img/4.jpeg -------------------------------------------------------------------------------- /img/5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/img/5.jpeg -------------------------------------------------------------------------------- /img/metric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/img/metric.png -------------------------------------------------------------------------------- /img/safe2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/img/safe2.jpg -------------------------------------------------------------------------------- /asm_embedding/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/asm_embedding/__init__.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow 2 | sklearn 3 | numpy 4 | scipy 5 | matplotlib 6 | tqdm 7 | r2pipe 8 | pyfiglet -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /asm_embedding/FunctionAnalyzerRadare.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/asm_embedding/FunctionAnalyzerRadare.pyc -------------------------------------------------------------------------------- /download_model.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 downloader.py -b 4 | echo 'Model downloaded and, hopefully, ready to run' 5 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 2 | 3 | -------------------------------------------------------------------------------- /asm_embedding/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/asm_embedding/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /neural_network/__pycache__/SAFE_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/neural_network/__pycache__/SAFE_model.cpython-36.pyc -------------------------------------------------------------------------------- /neural_network/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/neural_network/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /neural_network/__pycache__/parameters.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/neural_network/__pycache__/parameters.cpython-36.pyc -------------------------------------------------------------------------------- /neural_network/__pycache__/PairFactory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/neural_network/__pycache__/PairFactory.cpython-36.pyc -------------------------------------------------------------------------------- /neural_network/__pycache__/SAFEEmbedder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/neural_network/__pycache__/SAFEEmbedder.cpython-36.pyc -------------------------------------------------------------------------------- /neural_network/__pycache__/SiameseSAFE.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/neural_network/__pycache__/SiameseSAFE.cpython-36.pyc -------------------------------------------------------------------------------- /helloworld.c: -------------------------------------------------------------------------------- 1 | #include "stdio.h" 2 | 3 | 4 | int main(){ 5 | printf("hello world"); 6 | int a=10; 7 | int b=20; 8 | printf("%d",a+b); 9 | } 10 | -------------------------------------------------------------------------------- /asm_embedding/__pycache__/FunctionNormalizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/asm_embedding/__pycache__/FunctionNormalizer.cpython-36.pyc -------------------------------------------------------------------------------- /asm_embedding/__pycache__/InstructionsConverter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/asm_embedding/__pycache__/InstructionsConverter.cpython-36.pyc -------------------------------------------------------------------------------- /asm_embedding/__pycache__/FunctionAnalyzerRadare.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gadiluna/SAFE/HEAD/asm_embedding/__pycache__/FunctionAnalyzerRadare.cpython-36.pyc -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from pyfiglet import figlet_format 2 | 3 | 4 | def print_safe(): 5 | a = figlet_format('SAFE', font='starwars') 6 | print(a) 7 | print("By Massarelli L., Di Luna G. A., Petroni F., Querzoni L., Baldoni R.") 8 | print("Please cite: http://arxiv.org/abs/1811.05296 \n") -------------------------------------------------------------------------------- /neural_network/freeze_graph.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo "usage: ./freeze_graph MODEL_DIR FREEZED_NAME" 4 | 5 | MODEL_DIR=$0 6 | FREEZED_NAME=$1 7 | 8 | freeze_graph --input_meta_graph $MODELDIR/checkpoints/model.meta 9 | --output_graph FREEZED_NAME 10 | --output_node_names Embedding1/dense/BiasAdd 11 | --input_bin 12 | --input_checkpoint $MODEL_DIR/checkpoints/model 13 | -------------------------------------------------------------------------------- /neural_network/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | BASE_PATH="/home/luca/work/binary_similarity_data/" 5 | 6 | DATA_PATH=$BASE_PATH/experiments/arith_mean_openSSL_no_dropout_no_shuffle_no_regeneration_emb_random_trainable 7 | OUT_PATH=$DATA_PATH/out 8 | 9 | DB_PATH=$BASE_PATH/databases/openSSL_data.db 10 | 11 | EMBEDDER=$BASE_PATH/word2vec/filtered_100_embeddings/ 12 | 13 | RANDOM="" 14 | TRAINABLE_EMBEDD="" 15 | 16 | python3 train.py $RANDOM $TRAINABLE_EMBEDD --o $OUT_PATH -n $DB_PATH -e $EMBEDDER 17 | -------------------------------------------------------------------------------- /404.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | --- 4 | 5 | 18 | 19 |
20 |

404

21 | 22 |

Page not found :(

23 |

The requested page could not be found.

24 |
25 | -------------------------------------------------------------------------------- /asm_embedding/DocumentManipulation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | 5 | def list_to_str(li): 6 | i='' 7 | for x in li: 8 | i=i+' '+x 9 | i=i+' endfun'*5 10 | return i 11 | 12 | def document_append(strin): 13 | with open('/Users/giuseppe/docuent_X86','a') as f: 14 | f.write(strin) 15 | 16 | ciro=set() 17 | cantina=[] 18 | num_total=0 19 | num_filtered=0 20 | with open('/Users/giuseppe/dump.x86.linux.json') as f: 21 | l=f.readline() 22 | print('loaded') 23 | r = re.split('(\[.*?\])(?= *\[)', l) 24 | del l 25 | for x in r: 26 | if '[' in x: 27 | gennaro=json.loads(x) 28 | for materdomini in gennaro: 29 | num_total=num_total+1 30 | if materdomini[0] not in ciro: 31 | ciro.add(materdomini[0]) 32 | num_filtered=num_filtered+1 33 | a=list_to_str(materdomini[1]) 34 | document_append(a) 35 | del x 36 | print(num_total) 37 | print(num_filtered) -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | 3 | # Hello! This is where you manage which Jekyll version is used to run. 4 | # When you want to use a different version, change it below, save the 5 | # file and run `bundle install`. Run Jekyll with `bundle exec`, like so: 6 | # 7 | # bundle exec jekyll serve 8 | # 9 | # This will help ensure the proper Jekyll version is running. 10 | # Happy Jekylling! 11 | gem "jekyll", "~> 3.7.4" 12 | 13 | # This is the default theme for new Jekyll sites. You may change this to anything you like. 14 | gem "minima", "~> 2.0" 15 | 16 | # If you want to use GitHub Pages, remove the "gem "jekyll"" above and 17 | # uncomment the line below. To upgrade, run `bundle update github-pages`. 18 | # gem "github-pages", group: :jekyll_plugins 19 | #gem "github-pages", group: :jekyll_plugins 20 | 21 | # If you have any plugins, put them here! 22 | group :jekyll_plugins do 23 | gem "jekyll-feed", "~> 0.6" 24 | end 25 | 26 | # Windows does not include zoneinfo files, so bundle the tzinfo-data gem 27 | gem "tzinfo-data", platforms: [:mingw, :mswin, :x64_mingw, :jruby] 28 | 29 | # Performance-booster for watching directories on Windows 30 | gem "wdm", "~> 0.1.0" if Gem.win_platform? 31 | 32 | -------------------------------------------------------------------------------- /asm_embedding/InstructionsConverter.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | import json 5 | 6 | 7 | class InstructionsConverter: 8 | 9 | def __init__(self, json_i2id): 10 | f = open(json_i2id, 'r') 11 | self.i2id = json.load(f) 12 | f.close() 13 | 14 | def convert_to_ids(self, instructions_list): 15 | ret_array = [] 16 | # For each instruction we add +1 to its ID because the first 17 | # element of the embedding matrix is zero 18 | for x in instructions_list: 19 | if x in self.i2id: 20 | ret_array.append(self.i2id[x] + 1) 21 | elif 'X_' in x: 22 | # print(str(x) + " is not a known x86 instruction") 23 | ret_array.append(self.i2id['X_UNK'] + 1) 24 | elif 'A_' in x: 25 | # print(str(x) + " is not a known arm instruction") 26 | ret_array.append(self.i2id['A_UNK'] + 1) 27 | else: 28 | # print("There is a problem " + str(x) + " does not appear to be an asm or arm instruction") 29 | ret_array.append(self.i2id['X_UNK'] + 1) 30 | return ret_array 31 | 32 | 33 | -------------------------------------------------------------------------------- /asm_embedding/FunctionNormalizer.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | import numpy as np 5 | 6 | 7 | class FunctionNormalizer: 8 | 9 | def __init__(self, max_instruction): 10 | self.max_instructions = max_instruction 11 | 12 | def normalize(self, f): 13 | f = np.asarray(f[0:self.max_instructions]) 14 | length = f.shape[0] 15 | if f.shape[0] < self.max_instructions: 16 | f = np.pad(f, (0, self.max_instructions - f.shape[0]), mode='constant') 17 | return f, length 18 | 19 | def normalize_function_pairs(self, pairs): 20 | lengths = [] 21 | new_pairs = [] 22 | for x in pairs: 23 | f0, len0 = self.normalize(x[0]) 24 | f1, len1 = self.normalize(x[1]) 25 | lengths.append((len0, len1)) 26 | new_pairs.append((f0, f1)) 27 | return new_pairs, lengths 28 | 29 | def normalize_functions(self, functions): 30 | lengths = [] 31 | new_functions = [] 32 | for f in functions: 33 | f, length = self.normalize(f) 34 | lengths.append(length) 35 | new_functions.append(f) 36 | return new_functions, lengths 37 | -------------------------------------------------------------------------------- /neural_network/SAFEEmbedder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | # SAFE TEAM 3 | # distributed under license: GPL 3 License http://www.gnu.org/licenses/ 4 | 5 | class SAFEEmbedder: 6 | 7 | def __init__(self, model_file): 8 | self.model_file = model_file 9 | self.session = None 10 | self.x_1 = None 11 | self.adj_1 = None 12 | self.len_1 = None 13 | self.emb = None 14 | 15 | def loadmodel(self): 16 | with tf.gfile.GFile(self.model_file, "rb") as f: 17 | graph_def = tf.GraphDef() 18 | graph_def.ParseFromString(f.read()) 19 | 20 | with tf.Graph().as_default() as graph: 21 | tf.import_graph_def(graph_def) 22 | 23 | sess = tf.Session(graph=graph) 24 | self.session = sess 25 | 26 | return sess 27 | 28 | def get_tensor(self): 29 | self.x_1 = self.session.graph.get_tensor_by_name("import/x_1:0") 30 | self.len_1 = self.session.graph.get_tensor_by_name("import/lengths_1:0") 31 | self.emb = tf.nn.l2_normalize(self.session.graph.get_tensor_by_name('import/Embedding1/dense/BiasAdd:0'), axis=1) 32 | 33 | def embedd(self, nodi_input, lengths_input): 34 | 35 | out_embedding= self.session.run(self.emb, feed_dict = { 36 | self.x_1: nodi_input, 37 | self.len_1: lengths_input}) 38 | 39 | return out_embedding 40 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | # Welcome to Jekyll! 2 | # 3 | # This config file is meant for settings that affect your whole blog, values 4 | # which you are expected to set up once and rarely edit after that. If you find 5 | # yourself editing this file very often, consider using Jekyll's data files 6 | # feature for the data you need to update frequently. 7 | # 8 | # For technical reasons, this file is *NOT* reloaded automatically when you use 9 | # 'bundle exec jekyll serve'. If you change this file, please restart the server process. 10 | 11 | # Site settings 12 | # These are used to personalize your new site. If you look in the HTML files, 13 | # you will see them accessed via {{ site.title }}, {{ site.email }}, and so on. 14 | # You can create any custom variable you would like, and they will be accessible 15 | # in the templates via {{ site.myvariable }}. 16 | title: 'SAFE: Self-Attentive Function Embeddings' 17 | email: safeteam@gmail.com 18 | description: >- # this means to ignore newlines until "baseurl:" 19 | Self-Attentive Function Embeddings for binary similarity. 20 | https://arxiv.org/abs/1811.05296 21 | baseurl: "" # the subpath of your site, e.g. /blog 22 | url: "" # the base hostname & protocol for your site, e.g. http://example.com 23 | twitter_username: 24 | github_username: 25 | 26 | # Build settings 27 | markdown: kramdown 28 | theme: minima 29 | #theme: jekyll-theme-midnight 30 | plugins: 31 | - jekyll-feed 32 | 33 | # Exclude from processing. 34 | # The following items will not be processed, by default. Create a custom list 35 | # to override the default setting. 36 | # exclude: 37 | # - Gemfile 38 | # - Gemfile.lock 39 | # - node_modules 40 | # - vendor/bundle/ 41 | # - vendor/cache/ 42 | # - vendor/gems/ 43 | # - vendor/ruby/ 44 | -------------------------------------------------------------------------------- /godown.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # Google Drive direct download of big files 4 | # ./gdown.pl 'gdrive file url' ['desired file name'] 5 | # 6 | # v1.0 by circulosmeos 04-2014. 7 | # v1.1 by circulosmeos 01-2017. 8 | # http://circulosmeos.wordpress.com/2014/04/12/google-drive-direct-download-of-big-files 9 | # Distributed under GPL 3 (http://www.gnu.org/licenses/gpl-3.0.html) 10 | # 11 | use strict; 12 | 13 | my $TEMP='gdown.cookie.temp'; 14 | my $COMMAND; 15 | my $confirm; 16 | my $check; 17 | sub execute_command(); 18 | 19 | my $URL=shift; 20 | die "\n./gdown.pl 'gdrive file url' [desired file name]\n\n" if $URL eq ''; 21 | 22 | my $FILENAME=shift; 23 | $FILENAME='gdown' if $FILENAME eq ''; 24 | 25 | if ($URL=~m#^https?://drive.google.com/file/d/([^/]+)#) { 26 | $URL="https://docs.google.com/uc?id=$1&export=download"; 27 | } 28 | 29 | execute_command(); 30 | 31 | while (-s $FILENAME < 100000) { # only if the file isn't the download yet 32 | open fFILENAME, '<', $FILENAME; 33 | $check=0; 34 | foreach () { 35 | if (/href="(\/uc\?export=download[^"]+)/) { 36 | $URL='https://docs.google.com'.$1; 37 | $URL=~s/&/&/g; 38 | $confirm=''; 39 | $check=1; 40 | last; 41 | } 42 | if (/confirm=([^;&]+)/) { 43 | $confirm=$1; 44 | $check=1; 45 | last; 46 | } 47 | if (/"downloadUrl":"([^"]+)/) { 48 | $URL=$1; 49 | $URL=~s/\\u003d/=/g; 50 | $URL=~s/\\u0026/&/g; 51 | $confirm=''; 52 | $check=1; 53 | last; 54 | } 55 | } 56 | close fFILENAME; 57 | die "Couldn't download the file :-(\n" if ($check==0); 58 | $URL=~s/confirm=([^;&]+)/confirm=$confirm/ if $confirm ne ''; 59 | 60 | execute_command(); 61 | } 62 | 63 | unlink $TEMP; 64 | 65 | sub execute_command() { 66 | $COMMAND="wget --no-check-certificate --load-cookie $TEMP --save-cookie $TEMP \"$URL\""; 67 | $COMMAND.=" -O \"$FILENAME\"" if $FILENAME ne ''; 68 | `$COMMAND`; 69 | return 1; 70 | } 71 | -------------------------------------------------------------------------------- /dataset_creation/FunctionsEmbedder.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | from asm_embedding.FunctionNormalizer import FunctionNormalizer 5 | import json 6 | from neural_network.SAFEEmbedder import SAFEEmbedder 7 | import numpy as np 8 | import sqlite3 9 | from tqdm import tqdm 10 | 11 | 12 | class FunctionsEmbedder: 13 | 14 | def __init__(self, model, batch_size, max_instruction): 15 | self.batch_size = batch_size 16 | self.normalizer = FunctionNormalizer(max_instruction) 17 | self.safe = SAFEEmbedder(model) 18 | self.safe.loadmodel() 19 | self.safe.get_tensor() 20 | 21 | def compute_embeddings(self, functions): 22 | functions, lenghts = self.normalizer.normalize_functions(functions) 23 | embeddings = self.safe.embedd(functions, lenghts) 24 | return embeddings 25 | 26 | @staticmethod 27 | def create_table(db_name, table_name): 28 | conn = sqlite3.connect(db_name) 29 | c = conn.cursor() 30 | c.execute("CREATE TABLE IF NOT EXISTS {} (id INTEGER PRIMARY KEY, {} TEXT)".format(table_name, table_name)) 31 | conn.commit() 32 | conn.close() 33 | 34 | def compute_and_save_embeddings_from_db(self, db_name, table_name): 35 | FunctionsEmbedder.create_table(db_name, table_name) 36 | conn = sqlite3.connect(db_name) 37 | cur = conn.cursor() 38 | q = cur.execute("SELECT id FROM functions WHERE id not in (SELECT id from {})".format(table_name)) 39 | ids = q.fetchall() 40 | 41 | for i in tqdm(range(0, len(ids), self.batch_size)): 42 | functions = [] 43 | batch_ids = ids[i:i+self.batch_size] 44 | for my_id in batch_ids: 45 | q = cur.execute("SELECT instructions_list FROM filtered_functions where id=?", my_id) 46 | functions.append(json.loads(q.fetchone()[0])) 47 | embeddings = self.compute_embeddings(functions) 48 | 49 | for l, id in enumerate(batch_ids): 50 | cur.execute("INSERT INTO {} VALUES (?,?)".format(table_name), (id[0], np.array2string(embeddings[l]))) 51 | conn.commit() 52 | -------------------------------------------------------------------------------- /safe.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | 5 | from asm_embedding.FunctionAnalyzerRadare import RadareFunctionAnalyzer 6 | from argparse import ArgumentParser 7 | from asm_embedding.FunctionNormalizer import FunctionNormalizer 8 | from asm_embedding.InstructionsConverter import InstructionsConverter 9 | from neural_network.SAFEEmbedder import SAFEEmbedder 10 | from utils import utils 11 | 12 | class SAFE: 13 | 14 | def __init__(self, model): 15 | self.converter = InstructionsConverter("data/i2v/word2id.json") 16 | self.normalizer = FunctionNormalizer(max_instruction=150) 17 | self.embedder = SAFEEmbedder(model) 18 | self.embedder.loadmodel() 19 | self.embedder.get_tensor() 20 | 21 | def embedd_function(self, filename, address): 22 | analyzer = RadareFunctionAnalyzer(filename, use_symbol=False, depth=0) 23 | functions = analyzer.analyze() 24 | instructions_list = None 25 | for function in functions: 26 | if functions[function]['address'] == address: 27 | instructions_list = functions[function]['filtered_instructions'] 28 | break 29 | if instructions_list is None: 30 | print("Function not found") 31 | return None 32 | converted_instructions = self.converter.convert_to_ids(instructions_list) 33 | instructions, length = self.normalizer.normalize_functions([converted_instructions]) 34 | embedding = self.embedder.embedd(instructions, length) 35 | return embedding 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | utils.print_safe() 41 | 42 | parser = ArgumentParser(description="Safe Embedder") 43 | 44 | parser.add_argument("-m", "--model", help="Safe trained model to generate function embeddings") 45 | parser.add_argument("-i", "--input", help="Input executable that contains the function to embedd") 46 | parser.add_argument("-a", "--address", help="Hexadecimal address of the function to embedd") 47 | 48 | args = parser.parse_args() 49 | 50 | address = int(args.address, 16) 51 | safe = SAFE(args.model) 52 | embedding = safe.embedd_function(args.input, address) 53 | print(embedding[0]) 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /neural_network/train.py: -------------------------------------------------------------------------------- 1 | from SAFE_model import modelSAFE 2 | from parameters import Flags 3 | import sys 4 | import os 5 | import numpy as np 6 | from utils import utils 7 | import traceback 8 | 9 | 10 | def load_embedding_matrix(embedder_folder): 11 | matrix_file='embedding_matrix.npy' 12 | matrix_path=os.path.join(embedder_folder,matrix_file) 13 | if os.path.isfile(matrix_path): 14 | try: 15 | print('Loading embedding matrix....') 16 | with open(matrix_path,'rb') as f: 17 | return np.float32(np.load(f)) 18 | except Exception as e: 19 | print("Exception handling file:"+str(matrix_path)) 20 | print("Embedding matrix cannot be load") 21 | print(str(e)) 22 | sys.exit(-1) 23 | 24 | else: 25 | print('Embedding matrix not found at path:'+str(matrix_path)) 26 | sys.exit(-1) 27 | 28 | 29 | def run_test(): 30 | flags = Flags() 31 | flags.logger.info("\n{}\n".format(flags)) 32 | 33 | print(str(flags)) 34 | 35 | embedding_matrix = load_embedding_matrix(flags.embedder_folder) 36 | if flags.random_embedding: 37 | embedding_matrix = np.random.rand(*np.shape(embedding_matrix)).astype(np.float32) 38 | embedding_matrix[0, :] = np.zeros(np.shape(embedding_matrix)[1]).astype(np.float32) 39 | 40 | if flags.cross_val: 41 | print("STARTING CROSS VALIDATION") 42 | res = [] 43 | mean = 0 44 | for i in range(0, flags.cross_val_fold): 45 | print("CROSS VALIDATION STARTING FOLD: " + str(i)) 46 | if i > 0: 47 | flags.close_log() 48 | flags.reset_logdir() 49 | del flags 50 | flags = Flags() 51 | flags.logger.info("\n{}\n".format(flags)) 52 | 53 | flags.logger.info("Starting cross validation fold: {}".format(i)) 54 | 55 | flags.db_name = flags.db_name + "_val_" + str(i+1) + ".db" 56 | flags.logger.info("Cross validation db name: {}".format(flags.db_name)) 57 | 58 | trainer = modelSAFE(flags, embedding_matrix) 59 | best_val_auc = trainer.train() 60 | 61 | mean += best_val_auc 62 | res.append(best_val_auc) 63 | 64 | flags.logger.info("Cross validation fold {} finished best auc: {}".format(i, best_val_auc)) 65 | print("FINISH FOLD: " + str(i) + " BEST VAL AUC: " + str(best_val_auc)) 66 | 67 | print("CROSS VALIDATION ENDED") 68 | print("Result: " + str(res)) 69 | print("") 70 | 71 | flags.logger.info("Cross validation finished results: {}".format(res)) 72 | flags.logger.info(" mean: {}".format(mean / flags.cross_val_fold)) 73 | flags.close_log() 74 | 75 | else: 76 | trainer = modelSAFE(flags, embedding_matrix) 77 | trainer.train() 78 | flags.close_log() 79 | 80 | 81 | if __name__ == '__main__': 82 | utils.print_safe() 83 | print('-Trainer for SAFE-') 84 | run_test() 85 | -------------------------------------------------------------------------------- /index.md: -------------------------------------------------------------------------------- 1 | --- 2 | # Feel free to add content and custom Front Matter to this file. 3 | # To modify the layout, see https://jekyllrb.com/docs/themes/#overriding-theme-defaults 4 | 5 | layout: home 6 | 7 | --- 8 | 9 |
10 | 11 | What is SAf(E)? 12 | ------------- 13 | 14 | **SAFE** is a **S**elf-**A**ttentive neural network that takes as input a binary **F**unction and creates an **E**mbedding. 15 | 16 | What is an embedding? 17 | ------------- 18 | An embedding is vector of real numbers. The nice feature of SAFE embeddings is that two similar binary functions should generate two embeddings 19 | that are close in the metric space. 20 | 21 |
22 | 23 | I want to know all the details! 24 | ------------- 25 | Good, read our paper on [arXiv](https://arxiv.org/abs/1811.05296). 26 | 27 | The paper is slightly amusing! How do I get SAFE? 28 | ------------- 29 | SAFE is available in our [GitHub](https://github.com/gadiluna/SAFE) repository. Keep in mind that SAFE has been developed as a research project. We only provide a minimal working proof-of-concept, 30 | with the code and data to replicate our experiments. We are not responsible for any self-harm episode correlated with reading our (sometimes badly written) code. 31 | 32 | How I can get involved with SAFE? 33 | ------------- 34 | If you are interested in this project write us an email. 35 | 36 | 37 | ------------- 38 | SAFE has been designed and developed by: 39 |
40 | * [Luca Massarelli](https://scholar.google.it/citations?user=mJ_QjZIAAAAJ&hl=it) (development and research) 41 |
42 | * [Giuseppe Antonio Di Luna](https://scholar.google.it/citations?hl=it&user=RgAfuVgAAAAJ&view_op=list_works&sortby=pubdate) (development and research) 43 |
44 | * [Fabio Petroni](https://scholar.google.it/citations?user=vxQc2L4AAAAJ&hl=it) (development and research) 45 |
46 | * [Leonardo Querzoni](https://scholar.google.it/citations?user=-_WFIJIAAAAJ&hl=it) (research) 47 |
48 | * [Roberto Baldoni](https://scholar.google.it/citations?user=82tR6VoAAAAJ&hl=it) (research) 49 | 50 | 51 | 52 | 53 | #### **Acknowledgments**: 54 | We are in debt with Google for providing free access to its cloud computing platform through the Education Program. Moreover, the authors would like to thank NVIDIA Corporation for partially supporting this work through the donation of a GPGPU card used during prototype development. 55 | This work is supported by a grant of the Italian Presidency of the Council of Ministers and by the CINI (Consorzio Interuniversitario Nazionale Informatica) National Laboratory of Cyber Security. 56 | Finally, we thank Davide Italiano for the insightful discussions. 57 | 58 | SAFE License. 59 | ------- 60 | # SAFE TEAM 61 | # GPL 3 License http://www.gnu.org/licenses/ -------------------------------------------------------------------------------- /function_search/fromJsonSearchToPlot.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | import matplotlib.pyplot as plt 5 | import json 6 | import math 7 | import numpy as np 8 | from multiprocessing import Pool 9 | 10 | 11 | def find_dcg(element_list): 12 | dcg_score = 0.0 13 | for j, sim in enumerate(element_list): 14 | dcg_score += float(sim) / math.log(j + 2) 15 | return dcg_score 16 | 17 | 18 | def count_ones(element_list): 19 | return len([x for x in element_list if x == 1]) 20 | 21 | 22 | def extract_info(file_1): 23 | with open(file_1, 'r') as f: 24 | data1 = json.load(f) 25 | 26 | performance1 = [] 27 | 28 | average_recall_k1 = [] 29 | precision_at_k1 = [] 30 | 31 | for f_index in range(0, len(data1)): 32 | 33 | f1 = data1[f_index][0] 34 | pf1 = data1[f_index][1] 35 | 36 | tp1 = [] 37 | 38 | recall_p1 = [] 39 | precision_p1 = [] 40 | # we start from 1 to remove ourselves 41 | for k in range(1, 200): 42 | cut1 = f1[0:k] 43 | dcg1 = find_dcg(cut1) 44 | ideal1 = find_dcg(([1] * (pf1) + [0] * (k - pf1))[0:k]) 45 | 46 | p1k = float(count_ones(cut1)) 47 | 48 | tp1.append(dcg1 / ideal1) 49 | recall_p1.append(p1k / pf1) 50 | precision_p1.append(p1k / k) 51 | 52 | performance1.append(tp1) 53 | average_recall_k1.append(recall_p1) 54 | precision_at_k1.append(precision_p1) 55 | 56 | avg_p1 = np.average(performance1, axis=0) 57 | avg_p10 = np.average(average_recall_k1, axis=0) 58 | average_precision = np.average(precision_at_k1, axis=0) 59 | return avg_p1, avg_p10, average_precision 60 | 61 | 62 | def print_graph(info1, file_name, label_y, title_1, p): 63 | fig, ax = plt.subplots() 64 | ax.plot(range(0, len(info1)), info1, color='b', label=title_1) 65 | ax.legend(loc=p, shadow=True, fontsize='x-large') 66 | plt.xlabel("Number of Nearest Results") 67 | plt.ylabel(label_y) 68 | fname = file_name 69 | plt.savefig(fname) 70 | plt.close(fname) 71 | 72 | 73 | def compare_and_print(file): 74 | filename = file.split('_')[0] + '_' + file.split('_')[1] 75 | t_short = filename 76 | label_1 = t_short + '_' + file.split('_')[3] 77 | 78 | avg_p1, recall_p1, precision1 = extract_info(file) 79 | 80 | fname = filename + '_nDCG.pdf' 81 | print_graph(avg_p1, fname, 'nDCG', label_1, 'upper right') 82 | 83 | fname = filename + '_recall.pdf' 84 | print_graph(recall_p1, fname, 'Recall', label_1, 'lower right') 85 | 86 | fname = filename + '_precision.pdf' 87 | print_graph(precision1, fname, 'Precision', label_1, 'upper right') 88 | 89 | return avg_p1, recall_p1, precision1 90 | 91 | 92 | e1 = 'embeddings_safe' 93 | 94 | opt = ['O0', 'O1', 'O2', 'O3'] 95 | compilers = ['gcc-7', 'gcc-4.8', 'clang-6.0', 'clang-4.0'] 96 | values = [] 97 | for o in opt: 98 | for c in compilers: 99 | f0 = '' + c + '_' + o + '_' + e1 + '_top200.json' 100 | values.append(f0) 101 | 102 | p = Pool(4) 103 | result = p.map(compare_and_print, values) 104 | 105 | avg_p1 = [] 106 | recal_p1 = [] 107 | pre_p1 = [] 108 | 109 | avg_p2 = [] 110 | recal_p2 = [] 111 | pre_p2 = [] 112 | 113 | for t in result: 114 | avg_p1.append(t[0]) 115 | recal_p1.append(t[1]) 116 | pre_p1.append(t[2]) 117 | 118 | avg_p1 = np.average(avg_p1, axis=0) 119 | recal_p1 = np.average(recal_p1, axis=0) 120 | pre_p1 = np.average(pre_p1, axis=0) 121 | 122 | print_graph(avg_p1[0:20], 'nDCG.pdf', 'normalized DCG', 'SAFE', 'upper right') 123 | print_graph(recal_p1, 'recall.pdf', 'recall', 'SAFE', 'lower right') 124 | print_graph(pre_p1[0:20], 'precision.pdf', 'precision', 'SAFE', 'upper right') 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAFE : Self Attentive Function Embedding 2 | 3 | Paper 4 | --- 5 | This software is the outcome of our accademic research. See our arXiv paper: [arxiv](https://arxiv.org/abs/1811.05296) 6 | 7 | If you use this code, please cite our accademic paper as: 8 | 9 | ```bibtex 10 | @inproceedings{massarelli2018safe, 11 | title={SAFE: Self-Attentive Function Embeddings for Binary Similarity}, 12 | author={Massarelli, Luca and Di Luna, Giuseppe Antonio and Petroni, Fabio and Querzoni, Leonardo and Baldoni, Roberto}, 13 | booktitle={Proceedings of 16th Conference on Detection of Intrusions and Malware & Vulnerability Assessment (DIMVA)}, 14 | year={2019} 15 | } 16 | ``` 17 | 18 | What you need 19 | ----- 20 | You need [radare2](https://github.com/radare/radare2) installed in your system. 21 | 22 | Quickstart 23 | ----- 24 | To create the embedding of a function: 25 | ``` 26 | git clone https://github.com/gadiluna/SAFE.git 27 | pip install -r requirements 28 | chmod +x download_model.sh 29 | ./download_model.sh 30 | python safe.py -m data/safe.pb -i helloworld.o -a 100000F30 31 | ``` 32 | #### What to do with an embedding? 33 | Once you have two embeddings ```embedding_x``` and ```embedding_y``` you can compute the similarity of the corresponding functions as: 34 | ``` 35 | from sklearn.metrics.pairwise import cosine_similarity 36 | 37 | sim=cosine_similarity(embedding_x, embedding_y) 38 | 39 | ``` 40 | 41 | 42 | Data Needed 43 | ----- 44 | SAFE needs few information to work. Two are essentials, a model that tells safe how to 45 | convert assembly instructions in vectors (i2v model) and a model that tells safe how 46 | to convert an binary function into a vector. 47 | Both models can be downloaded by using the command 48 | ``` 49 | ./download_model.sh 50 | ``` 51 | the downloader downloads the model and place them in the directory data. 52 | The directory tree after the download should be. 53 | ``` 54 | safe/-- githubcode 55 | \ 56 | \--data/-----safe.pb 57 | \ 58 | \---i2v/ 59 | 60 | ``` 61 | The safe.pb file contains the safe-model used to convert binary function to vectors. 62 | The i2v folder contains the i2v model. 63 | 64 | 65 | Hardcore Details 66 | ---- 67 | This section contains details that are needed to replicate our experiments, if you are an user of safe you can skip 68 | it. 69 | 70 | ### Safe.pb 71 | This is the freezed tensorflow trained model for AMD64 architecture. You can import it in your project using: 72 | 73 | ``` 74 | import tensorflow as tf 75 | 76 | with tf.gfile.GFile("safe.pb", "rb") as f: 77 | graph_def = tf.GraphDef() 78 | graph_def.ParseFromString(f.read()) 79 | 80 | with tf.Graph().as_default() as graph: 81 | tf.import_graph_def(graph_def) 82 | 83 | sess = tf.Session(graph=graph) 84 | ``` 85 | 86 | see file: neural_network/SAFEEmbedder.py 87 | 88 | ### i2v 89 | The i2v folder contains two files. 90 | A Matrix where each row is the embedding of an asm instruction. 91 | A json file that contains a dictonary mapping asm instructions into row numbers of the matrix above. 92 | see file: asm_embedding/InstructionsConverter.py 93 | 94 | 95 | 96 | ## Train the model 97 | If you want to train the model using our datasets you have to first use: 98 | ``` 99 | python3 downloader.py -td 100 | ``` 101 | This will download the datasets into data folder. Note that the datasets are compressed so you have to decompress them yourself. 102 | This data will be an sqlite databases. 103 | To start the train use neural_network/train.sh. 104 | The db can be selected by changing the parameter into train.sh. 105 | If you want information on the dataset see our paper. 106 | 107 | ## Create your own dataset 108 | If you want to create your own dataset you can use the script ExperimentUtil into the folder 109 | dataset creation. 110 | 111 | ## Create a functions knowledge base 112 | If you want to use SAFE binary code search engine you can use the script ExperimentUtil to create 113 | the knowledge base. 114 | Then you can search through it using the script into function_search 115 | 116 | 117 | Related Projects 118 | --- 119 | 120 | * YARASAFE: Automatic Binary Function Similarity Checks with Yara (https://github.com/lucamassarelli/yarasafe) 121 | * SAFEtorch: Pytorch implemenation of the SAFE neural network (https://github.com/facebookresearch/SAFEtorch) 122 | 123 | Thanks 124 | --- 125 | In our code we use [godown](https://github.com/circulosmeos/gdown.pl) to download data from Google drive. We thank 126 | circulosmeos, the creator of godown. 127 | 128 | We thank Davide Italiano for the useful discussions. 129 | -------------------------------------------------------------------------------- /dataset_creation/convertDB.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | import sqlite3 5 | import json 6 | from networkx.readwrite import json_graph 7 | import logging 8 | from tqdm import tqdm 9 | from asm_embedding.InstructionsConverter import InstructionsConverter 10 | 11 | 12 | # Create the db where data are stored 13 | def create_db(db_name): 14 | print('Database creation...') 15 | conn = sqlite3.connect(db_name) 16 | conn.execute(''' CREATE TABLE IF NOT EXISTS functions (id INTEGER PRIMARY KEY, 17 | project text, 18 | compiler text, 19 | optimization text, 20 | file_name text, 21 | function_name text, 22 | asm text, 23 | num_instructions INTEGER) 24 | ''') 25 | conn.execute('''CREATE TABLE IF NOT EXISTS filtered_functions (id INTEGER PRIMARY KEY, 26 | instructions_list text) 27 | ''') 28 | conn.commit() 29 | conn.close() 30 | 31 | 32 | def reverse_graph(cfg, lstm_cfg): 33 | instructions = [] 34 | asm = "" 35 | node_addr = list(cfg.nodes()) 36 | node_addr.sort() 37 | nodes = cfg.nodes(data=True) 38 | lstm_nodes = lstm_cfg.nodes(data=True) 39 | for addr in node_addr: 40 | a = nodes[addr]["asm"] 41 | if a is not None: 42 | asm += a 43 | instructions.extend(lstm_nodes[addr]['features']) 44 | return instructions, asm 45 | 46 | 47 | def copy_split(old_cur, new_cur, table): 48 | q = old_cur.execute("SELECT id FROM {}".format(table)) 49 | iii = q.fetchall() 50 | print("Copying table {}".format(table)) 51 | for ii in tqdm(iii): 52 | new_cur.execute("INSERT INTO {} VALUES (?)".format(table), ii) 53 | 54 | 55 | def copy_table(old_cur, new_cur, table_old, table_new): 56 | q = old_cur.execute("SELECT * FROM {}".format(table_old)) 57 | iii = q.fetchall() 58 | print("Copying table {} to {}".format(table_old, table_new)) 59 | for ii in tqdm(iii): 60 | new_cur.execute("INSERT INTO {} VALUES (?,?,?)".format(table_new), ii) 61 | 62 | logger = logging.getLogger() 63 | logger.setLevel(logging.DEBUG) 64 | 65 | db = "/home/lucamassarelli/binary_similarity_data/databases/big_dataset_X86.db" 66 | new_db = "/home/lucamassarelli/binary_similarity_data/new_databases/big_dataset_X86_new.db" 67 | 68 | create_db(new_db) 69 | 70 | conn_old = sqlite3.connect(db) 71 | conn_new = sqlite3.connect(new_db) 72 | 73 | 74 | cur_old = conn_old.cursor() 75 | cur_new = conn_new.cursor() 76 | 77 | 78 | q = cur_old.execute("SELECT id FROM functions") 79 | ids = q.fetchall() 80 | converter = InstructionsConverter() 81 | 82 | for my_id in tqdm(ids): 83 | 84 | q0 = cur_old.execute("SELECT id, project, compiler, optimization, file_name, function_name, cfg FROM functions WHERE id=?", my_id) 85 | meta = q.fetchone() 86 | 87 | q1 = cur_old.execute("SELECT lstm_cfg FROM lstm_cfg WHERE id=?", my_id) 88 | cfg = json_graph.adjacency_graph(json.loads(meta[6])) 89 | lstm_cfg = json_graph.adjacency_graph(json.loads(q1.fetchone()[0])) 90 | instructions, asm = reverse_graph(cfg, lstm_cfg) 91 | values = meta[0:6] + (asm, len(instructions)) 92 | q_n = cur_new.execute("INSERT INTO functions VALUES (?,?,?,?,?,?,?,?)", values) 93 | converted_instruction = json.dumps(converter.convert_to_ids(instructions)) 94 | q_n = cur_new.execute("INSERT INTO filtered_functions VALUES (?,?)", (my_id[0], converted_instruction)) 95 | 96 | conn_new.commit() 97 | 98 | cur_new.execute("CREATE TABLE train (id INTEGER PRIMARY KEY) ") 99 | cur_new.execute("CREATE TABLE validation (id INTEGER PRIMARY KEY) ") 100 | cur_new.execute("CREATE TABLE test (id INTEGER PRIMARY KEY) ") 101 | conn_new.commit() 102 | 103 | copy_split(cur_old, cur_new, "train") 104 | conn_new.commit() 105 | copy_split(cur_old, cur_new, "validation") 106 | conn_new.commit() 107 | copy_split(cur_old, cur_new, "test") 108 | conn_new.commit() 109 | 110 | cur_new.execute("CREATE TABLE train_pairs (id INTEGER PRIMARY KEY, true_pair TEXT, false_pair TEXT)") 111 | cur_new.execute("CREATE TABLE validation_pairs (id INTEGER PRIMARY KEY, true_pair TEXT, false_pair TEXT)") 112 | cur_new.execute("CREATE TABLE test_pairs (id INTEGER PRIMARY KEY, true_pair TEXT, false_pair TEXT)") 113 | conn_new.commit() 114 | 115 | copy_table(cur_old, cur_new, "train_couples", "train_pairs") 116 | conn_new.commit() 117 | copy_table(cur_old, cur_new, "validation_couples", "validation_pairs") 118 | conn_new.commit() 119 | copy_table(cur_old, cur_new, "test_couples", "test_pairs") 120 | conn_new.commit() 121 | 122 | conn_new.close() -------------------------------------------------------------------------------- /function_search/EvaluateSearchEngine.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | 5 | from FunctionSearchEngine import FunctionSearchEngine 6 | from sklearn import metrics 7 | import sqlite3 8 | 9 | from multiprocessing import Process 10 | import math 11 | 12 | import warnings 13 | import random 14 | import json 15 | 16 | class SearchEngineEvaluator: 17 | 18 | def __init__(self, db_name, table, limit=None,k=None): 19 | self.tables = table 20 | self.db_name = db_name 21 | self.SE = FunctionSearchEngine(db_name, table, limit=limit) 22 | self.k=k 23 | self.number_similar={} 24 | 25 | def do_search(self, target_db_name, target_fcn_ids): 26 | self.SE.load_target(target_db_name, target_fcn_ids) 27 | self.SE.pp_search(50) 28 | 29 | def calc_auc(self, target_db_name, target_fcn_ids): 30 | self.SE.load_target(target_db_name, target_fcn_ids) 31 | result = self.SE.auc() 32 | print(result) 33 | 34 | # 35 | # This methods searches for all target function in the DB, in our test we take num functions compiled with compiler and opt 36 | # moreover it populates the self.number_similar dictionary, that contains the number of similar function for each target 37 | # 38 | def find_target_fcn(self, compiler, opt, num): 39 | conn = sqlite3.connect(self.db_name) 40 | cur = conn.cursor() 41 | q = cur.execute("SELECT id, project, file_name, function_name FROM functions WHERE compiler=? AND optimization=?", (compiler, opt)) 42 | res = q.fetchall() 43 | ids = [i[0] for i in res] 44 | true_labels = [l[1]+"/"+l[2]+"/"+l[3] for l in res] 45 | n_ids = [] 46 | n_true_labels = [] 47 | num = min(num, len(ids)) 48 | 49 | for i in range(0, num): 50 | index = random.randrange(len(ids)) 51 | n_ids.append(ids[index]) 52 | n_true_labels.append(true_labels[index]) 53 | f_name=true_labels[index].split('/')[2] 54 | fi_name=true_labels[index].split('/')[1] 55 | q = cur.execute("SELECT num FROM count_func WHERE file_name='{}' and function_name='{}'".format(fi_name,f_name)) 56 | f = q.fetchone() 57 | if f is not None: 58 | num=int(f[0]) 59 | else: 60 | num = 0 61 | self.number_similar[true_labels[index]]=num 62 | 63 | return n_ids, n_true_labels 64 | 65 | @staticmethod 66 | def functions_ground_truth(labels, indices, values, true_label): 67 | y_true = [] 68 | y_score = [] 69 | for i, e in enumerate(indices): 70 | y_score.append(float(values[i])) 71 | l = labels[e] 72 | if l == true_label: 73 | y_true.append(1) 74 | else: 75 | y_true.append(0) 76 | return y_true, y_score 77 | 78 | # this methos execute the test 79 | # it select the targets functions and it looks up for the targets in the entire db 80 | # the outcome is json file containing the top 200 similar for each target function. 81 | # the json file is an array and such array contains an entry for each target function 82 | # each entry is a triple (t0,t1,t2) 83 | # t0: an array that contains 1 at entry j if the entry j is similar to the target 0 otherwise 84 | # t1: the number of similar functions to the target in the whole db 85 | # t2: an array that at entry j contains the similarity score of the j-th most similar function to the target. 86 | # 87 | # 88 | def evaluate_precision_on_all_functions(self, compiler, opt): 89 | target_fcn_ids, true_labels = self.find_target_fcn(compiler, opt, 10000) 90 | batch = 1000 91 | labels = self.SE.trunc_labels 92 | 93 | info=[] 94 | 95 | for i in range(0, len(target_fcn_ids), batch): 96 | if i + batch > len(target_fcn_ids): 97 | batch = len(target_fcn_ids) - i 98 | target = self.SE.load_target(self.db_name, target_fcn_ids[i:i+batch]) 99 | top_k = self.SE.top_k(target, self.k) 100 | 101 | for j in range(0, batch): 102 | a, b = SearchEngineEvaluator.functions_ground_truth(labels, top_k.indices[j, :], top_k.values[j, :], true_labels[i+j]) 103 | 104 | info.append((a,self.number_similar[true_labels[i + j]],b)) 105 | 106 | with open(compiler+'_'+opt+'_'+self.tables+'_top200.json', 'w') as outfile: 107 | json.dump(info, outfile) 108 | 109 | 110 | def test(dbName, table, opt,x,k): 111 | 112 | print("k:{} - Table: {} - Opt: {}".format(k,table, opt)) 113 | 114 | SEV = SearchEngineEvaluator(dbName, table, limit=2000000,k=k) 115 | SEV.evaluate_precision_on_all_functions(x, opt) 116 | 117 | print("-------------------------------------") 118 | 119 | 120 | if __name__ == '__main__': 121 | 122 | random.seed(12345) 123 | 124 | dbName = '../data/AMD64PostgreSQL.db' 125 | table = ['safe_embeddings'] 126 | opt = ["O0", "O1", "O2", "O3"] 127 | for x in ['gcc-4.8',"clang-4.0",'gcc-7','clang-6.0']: 128 | for t in table: 129 | for o in opt: 130 | p = Process(target=test, args=(dbName, t, o,x,200)) 131 | p.start() 132 | p.join() 133 | -------------------------------------------------------------------------------- /downloader.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | 5 | import argparse 6 | import os 7 | import sys 8 | from subprocess import call 9 | 10 | class Downloader: 11 | 12 | def __init__(self): 13 | parser = argparse.ArgumentParser(description='SAFE downloader') 14 | 15 | parser.add_argument("-m", "--model", dest="model", help="Download the trained SAFE model for x86", 16 | action="store_true", 17 | required=False) 18 | 19 | parser.add_argument("-i2v", "--i2v", dest="i2v", help="Download the i2v dictionary and embedding matrix", 20 | action="store_true", 21 | required=False) 22 | 23 | parser.add_argument("-b", "--bundle", dest="bundle", 24 | help="Download all files necessary to run the model", 25 | action="store_true", 26 | required=False) 27 | 28 | parser.add_argument("-td", "--train_data", dest="train_data", 29 | help="Download the files necessary to train the model (It takes a lot of space!)", 30 | action="store_true", 31 | required=False) 32 | 33 | args = parser.parse_args() 34 | 35 | self.download_model = (args.model or args.bundle) 36 | self.download_i2v = (args.i2v or args.bundle) 37 | self.download_train = args.train_data 38 | 39 | if not (self.download_model or self.download_i2v or self.download_train): 40 | parser.print_help(sys.__stdout__) 41 | 42 | self.url_model = "https://drive.google.com/file/d/1Kwl8Jy-g9DXe1AUjUZDhJpjRlDkB4NBs/view?usp=sharing" 43 | self.url_i2v = "https://drive.google.com/file/d/1CqJVGYbLDEuJmJV6KH4Dzzhy-G12GjGP" 44 | self.url_train = ['https://drive.google.com/file/d/1sNahtLTfZY5cxPaYDUjqkPTK0naZ45SH/view?usp=sharing','https://drive.google.com/file/d/16D5AVDux_Q8pCVIyvaMuiL2cw2V6gtLc/view?usp=sharing','https://drive.google.com/file/d/1cBRda8fYdqHtzLwstViuwK6U5IVHad1N/view?usp=sharing'] 45 | self.train_name = ['AMD64ARMOpenSSL.tar.bz2','AMD64multipleCompilers.tar.bz2','AMD64PostgreSQL.tar.bz2'] 46 | self.base_path = "data" 47 | self.path_i2v = os.path.join(self.base_path, "") 48 | self.path_model = os.path.join(self.base_path, "") 49 | self.path_train_data = os.path.join(self.base_path, "") 50 | self.i2v_compress_name='i2v.tar.bz2' 51 | self.model_compress_name='model.tar.bz2' 52 | self.datasets_compress_name='safe.pb' 53 | 54 | @staticmethod 55 | def download_file(id,path): 56 | try: 57 | print("Downloading from "+ str(id) +" into "+str(path)) 58 | call(['./godown.pl',id,path]) 59 | except Exception as e: 60 | print("Error downloading file at url:" + str(id)) 61 | print(e) 62 | 63 | @staticmethod 64 | def decompress_file(file_src,file_path): 65 | try: 66 | call(['tar','-xvf',file_src,'-C',file_path]) 67 | except Exception as e: 68 | print("Error decompressing file:" + str(file_src)) 69 | print('you need tar command e b2zip support') 70 | print(e) 71 | 72 | def download(self): 73 | print('Making the godown.pl script executable, thanks:'+str('https://github.com/circulosmeos/gdown.pl')) 74 | call(['chmod', '+x','godown.pl']) 75 | print("SAFE --- downloading models") 76 | 77 | if self.download_i2v: 78 | print("Downloading i2v model.... in the folder data/i2v/") 79 | if not os.path.exists(self.path_i2v): 80 | os.makedirs(self.path_i2v) 81 | Downloader.download_file(self.url_i2v, os.path.join(self.path_i2v,self.i2v_compress_name)) 82 | print("Decompressing i2v model and placing in" + str(self.path_i2v)) 83 | Downloader.decompress_file(os.path.join(self.path_i2v,self.i2v_compress_name),self.path_i2v) 84 | 85 | if self.download_model: 86 | print("Downloading the SAFE model... in the folder data") 87 | if not os.path.exists(self.path_model): 88 | os.makedirs(self.path_model) 89 | Downloader.download_file(self.url_model, os.path.join(self.path_model,self.datasets_compress_name)) 90 | #print("Decompressing SAFE model and placing in" + str(self.path_model)) 91 | #Downloader.decompress_file(os.path.join(self.path_model,self.model_compress_name),self.path_model) 92 | 93 | if self.download_train: 94 | print("Downloading the train data.... in the folder data") 95 | if not os.path.exists(self.path_train_data): 96 | os.makedirs(self.path_train_data) 97 | for i,x in enumerate(self.url_train): 98 | print("Downloading dataset "+str(self.train_name[i])) 99 | Downloader.download_file(x, os.path.join(self.path_train_data,self.train_name[i])) 100 | #print("Decompressing the train data and placing in" + str(self.path_train_data)) 101 | #Downloader.decompress_file(os.path.join(self.path_train_data,self.datasets_compress_name),self.path_train_data) 102 | 103 | if __name__=='__main__': 104 | a=Downloader() 105 | a.download() -------------------------------------------------------------------------------- /dataset_creation/ExperimentUtil.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | import argparse 5 | from dataset_creation import DatabaseFactory, DataSplitter, FunctionsEmbedder 6 | from utils.utils import print_safe 7 | 8 | 9 | def debug_msg(): 10 | msg = "SAFE DATABASE UTILITY" 11 | msg += "-------------------------------------------------\n" 12 | msg += "This program is an utility to save data into an sqlite database with SAFE \n\n" 13 | msg += "There are three main command: \n" 14 | msg += "BUILD: It create a db with two tables: functions, filtered_functions. \n" 15 | msg += " In the first table there are all the functions extracted from the executable with their hex code.\n" 16 | msg += " In the second table functions are converted to i2v representation. \n" 17 | msg += "SPLIT: Data are splitted into train validation and test set. " \ 18 | " Then it generate the pairs for the training of the network.\n" 19 | msg += "EMBEDD: Generate the embeddings of each function in the database using a trained SAFE model\n\n" 20 | msg += "If you want to train the network use build + split" 21 | msg += "If you want to create a knowledge base for the binary code search engine use build + embedd" 22 | msg += "This program has been written by the SAFE team.\n" 23 | msg += "-------------------------------------------------" 24 | return msg 25 | 26 | 27 | def build_configuration(db_name, root_dir, use_symbols, callee_depth): 28 | msg = "Database creation options: \n" 29 | msg += " - Database Name: {} \n".format(db_name) 30 | msg += " - Root dir: {} \n".format(root_dir) 31 | msg += " - Use symbols: {} \n".format(use_symbols) 32 | msg += " - Callee depth: {} \n".format(callee_depth) 33 | return msg 34 | 35 | 36 | def split_configuration(db_name, val_split, test_split, epochs): 37 | msg = "Splitting options: \n" 38 | msg += " - Database Name: {} \n".format(db_name) 39 | msg += " - Validation Size: {} \n".format(val_split) 40 | msg += " - Test Size: {} \n".format(test_split) 41 | msg += " - Epochs: {} \n".format(epochs) 42 | return msg 43 | 44 | 45 | def embedd_configuration(db_name, model, batch_size, max_instruction, embeddings_table): 46 | msg = "Embedding options: \n" 47 | msg += " - Database Name: {} \n".format(db_name) 48 | msg += " - Model: {} \n".format(model) 49 | msg += " - Batch Size: {} \n".format(batch_size) 50 | msg += " - Max Instruction per function: {} \n".format(max_instruction) 51 | msg += " - Table for saving embeddings: {}.".format(embeddings_table) 52 | return msg 53 | 54 | 55 | if __name__ == '__main__': 56 | 57 | print_safe() 58 | 59 | parser = argparse.ArgumentParser(description=debug_msg) 60 | 61 | parser.add_argument("-db", "--db", help="Name of the database to create", required=True) 62 | 63 | parser.add_argument("-b", "--build", help="Build db disassebling executables", action="store_true") 64 | parser.add_argument("-s", "--split", help="Perform data splitting for training", action="store_true") 65 | parser.add_argument("-e", "--embed", help="Compute functions embedding", action="store_true") 66 | 67 | parser.add_argument("-dir", "--dir", help="Root path of the directory to scan") 68 | parser.add_argument("-sym", "--symbols", help="Use it if you want to use symbols", action="store_true") 69 | parser.add_argument("-dep", "--depth", help="Recursive depth for analysis", default=0, type=int) 70 | 71 | parser.add_argument("-test", "--test_size", help="Test set size [0-1]", type=float, default=0.2) 72 | parser.add_argument("-val", "--val_size", help="Validation set size [0-1]", type=float, default=0.2) 73 | parser.add_argument("-epo", "--epochs", help="# Epochs to generate pairs for", type=int, default=25) 74 | 75 | parser.add_argument("-mod", "--model", help="Model for embedding generation") 76 | parser.add_argument("-bat", "--batch_size", help="Batch size for function embeddings", type=int, default=500) 77 | parser.add_argument("-max", "--max_instruction", help="Maximum instruction per function", type=int, default=150) 78 | parser.add_argument("-etb", "--embeddings_table", help="Name for the table that contains embeddings", 79 | default="safe_embeddings") 80 | 81 | try: 82 | args = parser.parse_args() 83 | except: 84 | parser.print_help() 85 | print(debug_msg()) 86 | exit(0) 87 | 88 | if args.build: 89 | print("Disassemblying files and creating dataset") 90 | print(build_configuration(args.db, args.dir, args.symbols, args.depth)) 91 | factory = DatabaseFactory.DatabaseFactory(args.db, args.dir) 92 | factory.build_db(args.symbols, args.depth) 93 | 94 | if args.split: 95 | print("Splitting data and generating epoch pairs") 96 | print(split_configuration(args.db, args.val_size, args.test_size, args.epochs)) 97 | splitter = DataSplitter.DataSplitter(args.db) 98 | splitter.split_data(args.val_size, args.test_size) 99 | splitter.create_pairs(args.epochs) 100 | 101 | if args.embed: 102 | print("Computing embeddings for function in db") 103 | print(embedd_configuration(args.db, args.model, args.batch_size, args.max_instruction, args.embeddings_table)) 104 | embedder = FunctionsEmbedder.FunctionsEmbedder(args.model, args.batch_size, args.max_instruction) 105 | embedder.compute_and_save_embeddings_from_db(args.db, args.embeddings_table) 106 | 107 | exit(0) 108 | -------------------------------------------------------------------------------- /function_search/FunctionSearchEngine.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | import sys 5 | import numpy as np 6 | import sqlite3 7 | import pandas as pd 8 | import tqdm 9 | import tensorflow as tf 10 | 11 | if sys.version_info >= (3, 0): 12 | from functools import reduce 13 | 14 | 15 | pd.set_option('display.max_column',None) 16 | pd.set_option('display.max_rows',None) 17 | pd.set_option('display.max_seq_items',None) 18 | pd.set_option('display.max_colwidth', 500) 19 | pd.set_option('expand_frame_repr', True) 20 | 21 | class TopK: 22 | 23 | # 24 | # This class computes the similarities between the targets and the list of functions on which we are searching. 25 | # This is done by using matrices multiplication and top_k of tensorflow 26 | def __init__(self): 27 | self.graph=tf.Graph() 28 | nop=0 29 | 30 | def loads_embeddings_SE(self, lista_embeddings): 31 | with self.graph.as_default(): 32 | tf.set_random_seed(1234) 33 | dim = lista_embeddings[0].shape[0] 34 | ll = np.asarray(lista_embeddings) 35 | self.matrix = tf.constant(ll, name='matrix_embeddings', dtype=tf.float32) 36 | self.target = tf.placeholder("float", [None, dim], name='target_embedding') 37 | self.sim = tf.matmul(self.target, self.matrix, transpose_b=True, name="embeddings_similarities") 38 | self.k = tf.placeholder(tf.int32, shape=(), name='k') 39 | self.top_k = tf.nn.top_k(self.sim, self.k, sorted=True) 40 | self.session = tf.Session() 41 | 42 | def topK(self, k, target): 43 | with self.graph.as_default(): 44 | tf.set_random_seed(1234) 45 | return self.session.run(self.top_k, {self.target: target, self.k: int(k)}) 46 | 47 | class FunctionSearchEngine: 48 | 49 | def __init__(self, db_name, table_name, limit=None): 50 | self.s2v = TopK() 51 | self.db_name = db_name 52 | self.table_name = table_name 53 | self.labels = [] 54 | self.trunc_labels = [] 55 | self.lista_embedding = [] 56 | self.ids = [] 57 | self.n_similar=[] 58 | self.ret = {} 59 | self.precision = None 60 | 61 | print("Query for ids") 62 | conn = sqlite3.connect(db_name) 63 | cur = conn.cursor() 64 | if limit is None: 65 | q = cur.execute("SELECT id, project, compiler, optimization, file_name, function_name FROM functions") 66 | res = q.fetchall() 67 | else: 68 | q = cur.execute("SELECT id, project, compiler, optimization, file_name, function_name FROM functions LIMIT {}".format(limit)) 69 | res = q.fetchall() 70 | 71 | for item in tqdm.tqdm(res, total=len(res)): 72 | q = cur.execute("SELECT " + self.table_name + " FROM " + self.table_name + " WHERE id=?", (item[0],)) 73 | e = q.fetchone() 74 | if e is None: 75 | continue 76 | 77 | self.lista_embedding.append(self.embeddingToNp(e[0])) 78 | 79 | element = "{}/{}/{}".format(item[1], item[4], item[5]) 80 | self.trunc_labels.append(element) 81 | 82 | element = "{}@{}/{}/{}/{}".format(item[5], item[1], item[2], item[3], item[4]) 83 | self.labels.append(element) 84 | self.ids.append(item[0]) 85 | 86 | conn.close() 87 | 88 | self.s2v.loads_embeddings_SE(self.lista_embedding) 89 | self.num_funcs = len(self.lista_embedding) 90 | 91 | def load_target(self, target_db_name, target_fcn_ids, calc_mean=False): 92 | conn = sqlite3.connect(target_db_name) 93 | cur = conn.cursor() 94 | mean = None 95 | for id in target_fcn_ids: 96 | 97 | if target_db_name == self.db_name and id in self.ids: 98 | idx = self.ids.index(id) 99 | e = self.lista_embedding[idx] 100 | else: 101 | q = cur.execute("SELECT " + self.table_name + " FROM " + self.table_name + " WHERE id=?", (id,)) 102 | e = q.fetchone() 103 | e = self.embeddingToNp(e[0]) 104 | 105 | 106 | if mean is None: 107 | mean = e.reshape([e.shape[0], 1]) 108 | else: 109 | mean = np.hstack((mean, e.reshape(e.shape[0], 1))) 110 | 111 | if calc_mean: 112 | target = [np.mean(mean, axis=1)] 113 | else: 114 | target = mean.T 115 | return target 116 | 117 | def embeddingToNp(self, e): 118 | e = e.replace('\n', '') 119 | e = e.replace('[', '') 120 | e = e.replace(']', '') 121 | emb = np.fromstring(e, dtype=float, sep=' ') 122 | return emb 123 | 124 | def top_k(self, target, k=None): 125 | if k is not None: 126 | top_k = self.s2v.topK(k, target) 127 | else: 128 | top_k = self.s2v.topK(len(self.lista_embedding), target) 129 | return top_k 130 | 131 | def pp_search(self, k): 132 | result = pd.DataFrame(columns=['Id', 'Name', 'Score']) 133 | top_k = self.s2v.topK(k) 134 | for i, e in enumerate(top_k.indices[0]): 135 | result = result.append({'Id': self.ids[e], 'Name': self.labels[e], 'Score': top_k.values[0][i]}, ignore_index=True) 136 | print(result) 137 | 138 | def search(self, k): 139 | result = [] 140 | top_k = self.s2v.topK(k) 141 | for i, e in enumerate(top_k.indices[0]): 142 | result = result.append({'Id': self.ids[e], 'Name': self.labels[e], 'Score': top_k.values[0][i]}) 143 | return result 144 | -------------------------------------------------------------------------------- /neural_network/parameters.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # distributed under license: GPL 3 License http://www.gnu.org/licenses/ 3 | 4 | import argparse 5 | import time 6 | import sys, os 7 | import logging 8 | 9 | 10 | # 11 | # Parameters File for the SAFE network. 12 | # 13 | # Authors: SAFE team 14 | 15 | 16 | def getLogger(logfile): 17 | logger = logging.getLogger(__name__) 18 | hdlr = logging.FileHandler(logfile) 19 | formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') 20 | hdlr.setFormatter(formatter) 21 | logger.addHandler(hdlr) 22 | logger.setLevel(logging.INFO) 23 | return logger, hdlr 24 | 25 | 26 | class Flags: 27 | 28 | def __init__(self): 29 | parser = argparse.ArgumentParser(description='SAFE') 30 | 31 | parser.add_argument("-o", "--output", dest="output_file", help="output directory for logging and models", 32 | required=False) 33 | parser.add_argument("-e", "--embedder", dest="embedder_folder", 34 | help="file with the embedding matrix and dictionary for asm instructions", required=False) 35 | parser.add_argument("-n", "--dbName", dest="db_name", help="Name of the database", required=False) 36 | parser.add_argument("-ld", "--load_dir", dest="load_dir", help="Load the model from directory load_dir", 37 | required=False) 38 | parser.add_argument("-r", "--random", help="if present the network use random embedder", default=False, 39 | action="store_true", dest="random_embedding", required=False) 40 | parser.add_argument("-te", "--trainable_embedding", 41 | help="if present the network consider the embedding as trainable", action="store_true", 42 | dest="trainable_embeddings", default=False) 43 | parser.add_argument("-cv", "--cross_val", help="if present the training is done with cross validiation", 44 | default=False, action="store_true", dest="cross_val") 45 | 46 | args = parser.parse_args() 47 | 48 | # mode = mean_field 49 | self.batch_size = 250 # minibatch size (-1 = whole dataset) 50 | self.num_epochs = 50 # number of epochs 51 | self.embedding_size = 100 # dimension of the function embedding 52 | self.learning_rate = 0.001 # init learning_rate 53 | self.l2_reg_lambda = 0 # 0.002 #0.002 # regularization coefficient 54 | self.num_checkpoints = 1 # max number of checkpoints 55 | self.out_dir = args.output_file # directory for logging 56 | self.rnn_state_size = 50 # dimesion of the rnn state 57 | self.db_name = args.db_name 58 | self.load_dir = str(args.load_dir) 59 | self.random_embedding = args.random_embedding 60 | self.trainable_embeddings = args.trainable_embeddings 61 | self.cross_val = args.cross_val 62 | self.cross_val_fold = 5 63 | 64 | # 65 | ## 66 | ## RNN PARAMETERS, these parameters are only used for RNN model. 67 | # 68 | self.rnn_depth = 1 # depth of the rnn 69 | self.max_instructions = 150 # number of instructions 70 | 71 | ## ATTENTION PARAMETERS 72 | self.attention_hops = 10 73 | self.attention_depth = 250 74 | 75 | # RNN SINGLE PARAMETER 76 | self.dense_layer_size = 2000 77 | 78 | self.seed = 2 # random seed 79 | 80 | # create logdir and logger 81 | self.reset_logdir() 82 | 83 | self.embedder_folder = args.embedder_folder 84 | 85 | def reset_logdir(self): 86 | # create logdir 87 | timestamp = str(int(time.time())) 88 | self.logdir = os.path.abspath(os.path.join(self.out_dir, "runs", timestamp)) 89 | os.makedirs(self.logdir, exist_ok=True) 90 | 91 | # create logger 92 | self.log_file = str(self.logdir) + '/console.log' 93 | self.logger, self.hdlr = getLogger(self.log_file) 94 | 95 | # create symlink for last_run 96 | sym_path_logdir = str(self.out_dir) + "/last_run" 97 | try: 98 | os.unlink(sym_path_logdir) 99 | except: 100 | pass 101 | try: 102 | os.symlink(self.logdir, sym_path_logdir) 103 | except: 104 | print("\nfailed to create symlink!\n") 105 | 106 | def close_log(self): 107 | self.hdlr.close() 108 | self.logger.removeHandler(self.hdlr) 109 | handlers = self.logger.handlers[:] 110 | for handler in handlers: 111 | handler.close() 112 | self.logger.removeHandler(handler) 113 | 114 | def __str__(self): 115 | msg = "" 116 | msg += "\nParameters:\n" 117 | msg += "\tRandom embedding: {}\n".format(self.random_embedding) 118 | msg += "\tTrainable embedding: {}\n".format(self.trainable_embeddings) 119 | msg += "\tlogdir: {}\n".format(self.logdir) 120 | msg += "\tbatch_size: {}\n".format(self.batch_size) 121 | msg += "\tnum_epochs: {}\n".format(self.num_epochs) 122 | msg += "\tembedding_size: {}\n".format(self.embedding_size) 123 | msg += "\trnn_state_size: {}\n".format(self.rnn_state_size) 124 | msg += "\tattention depth: {}\n".format(self.attention_depth) 125 | msg += "\tattention hops: {}\n".format(self.attention_hops) 126 | msg += "\tdense layer e: {}\n".format(self.dense_layer_size) 127 | 128 | msg += "\tlearning_rate: {}\n".format(self.learning_rate) 129 | msg += "\tl2_reg_lambda: {}\n".format(self.l2_reg_lambda) 130 | msg += "\tnum_checkpoints: {}\n".format(self.num_checkpoints) 131 | 132 | 133 | msg += "\tseed: {}\n".format(self.seed) 134 | msg += "\tMax Instructions per functions: {}\n".format(self.max_instructions) 135 | return msg 136 | -------------------------------------------------------------------------------- /dataset_creation/DataSplitter.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | import json 5 | import random 6 | import sqlite3 7 | from tqdm import tqdm 8 | 9 | 10 | class DataSplitter: 11 | 12 | def __init__(self, db_name): 13 | self.db_name = db_name 14 | 15 | def create_pair_table(self, table_name): 16 | conn = sqlite3.connect(self.db_name) 17 | c = conn.cursor() 18 | c.executescript("DROP TABLE IF EXISTS {} ".format(table_name)) 19 | c.execute("CREATE TABLE {} (id INTEGER PRIMARY KEY, true_pair TEXT, false_pair TEXT)".format(table_name)) 20 | conn.commit() 21 | conn.close() 22 | 23 | def get_ids(self, set_type): 24 | conn = sqlite3.connect(self.db_name) 25 | cur = conn.cursor() 26 | q = cur.execute("SELECT id FROM {}".format(set_type)) 27 | ids = q.fetchall() 28 | conn.close() 29 | return ids 30 | 31 | @staticmethod 32 | def select_similar_cfg(id, provenance, ids, cursor): 33 | q1 = cursor.execute('SELECT id FROM functions WHERE project=? AND file_name=? and function_name=?', provenance) 34 | candidates = [i[0] for i in q1.fetchall() if (i[0] != id and i[0] in ids)] 35 | if len(candidates) == 0: 36 | return None 37 | id_similar = random.choice(candidates) 38 | return id_similar 39 | 40 | @staticmethod 41 | def select_dissimilar_cfg(ids, provenance, cursor): 42 | while True: 43 | id_dissimilar = random.choice(ids) 44 | q2 = cursor.execute('SELECT project, file_name, function_name FROM functions WHERE id=?', id_dissimilar) 45 | res = q2.fetchone() 46 | if res != provenance: 47 | break 48 | return id_dissimilar 49 | 50 | def create_epoch_pairs(self, epoch_number, pairs_table,id_table): 51 | random.seed = epoch_number 52 | 53 | conn = sqlite3.connect(self.db_name) 54 | cur = conn.cursor() 55 | ids = cur.execute("SELECT id FROM "+id_table).fetchall() 56 | id_set=set(ids) 57 | true_pair = [] 58 | false_pair = [] 59 | 60 | for my_id in tqdm(ids): 61 | q = cur.execute('SELECT project, file_name, function_name FROM functions WHERE id =?', my_id) 62 | cfg_0_provenance = q.fetchone() 63 | id_sim = DataSplitter.select_similar_cfg(my_id, cfg_0_provenance, id_set, cur) 64 | id_dissim = DataSplitter.select_dissimilar_cfg(ids, cfg_0_provenance, cur) 65 | if id_sim is not None and id_dissim is not None: 66 | true_pair.append((my_id, id_sim)) 67 | false_pair.append((my_id, id_dissim)) 68 | 69 | true_pair = str(json.dumps(true_pair)) 70 | false_pair = str(json.dumps(false_pair)) 71 | 72 | cur.execute("INSERT INTO {} VALUES (?,?,?)".format(pairs_table), (epoch_number, true_pair, false_pair)) 73 | conn.commit() 74 | conn.close() 75 | 76 | def create_pairs(self, total_epochs): 77 | 78 | self.create_pair_table('train_pairs') 79 | self.create_pair_table('validation_pairs') 80 | self.create_pair_table('test_pairs') 81 | 82 | for i in range(0, total_epochs): 83 | print("Creating training pairs for epoch {} of {}".format(i, total_epochs)) 84 | self.create_epoch_pairs(i, 'train_pairs','train') 85 | 86 | print("Creating validation pairs") 87 | self.create_epoch_pairs(0, 'validation_pairs','validation') 88 | 89 | print("Creating test pairs") 90 | self.create_epoch_pairs(0, "test_pairs",'test') 91 | 92 | 93 | @staticmethod 94 | def prepare_set(data_to_include, table_name, file_list, cur): 95 | i = 0 96 | while i < data_to_include and len(file_list) > 0: 97 | choice = random.choice(file_list) 98 | file_list.remove(choice) 99 | q = cur.execute("SELECT id FROM functions where project=? AND file_name=?", choice) 100 | data = q.fetchall() 101 | cur.executemany("INSERT INTO {} VALUES (?)".format(table_name), data) 102 | i += len(data) 103 | return file_list, i 104 | 105 | def split_data(self, validation_dim, test_dim): 106 | random.seed = 12345 107 | conn = sqlite3.connect(self.db_name) 108 | c = conn.cursor() 109 | 110 | q = c.execute('''SELECT project, file_name FROM functions ''') 111 | data = q.fetchall() 112 | conn.commit() 113 | 114 | num_data = len(data) 115 | num_test = int(num_data * test_dim) 116 | num_validation = int(num_data * validation_dim) 117 | 118 | filename = list(set(data)) 119 | 120 | c.execute("DROP TABLE IF EXISTS train") 121 | c.execute("DROP TABLE IF EXISTS test") 122 | c.execute("DROP TABLE IF EXISTS validation") 123 | 124 | c.execute("CREATE TABLE IF NOT EXISTS train (id INTEGER PRIMARY KEY)") 125 | c.execute("CREATE TABLE IF NOT EXISTS validation (id INTEGER PRIMARY KEY)") 126 | c.execute("CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY)") 127 | 128 | c.execute('''CREATE INDEX IF NOT EXISTS my_index ON functions(project, file_name, function_name)''') 129 | c.execute('''CREATE INDEX IF NOT EXISTS my_index_2 ON functions(project, file_name)''') 130 | 131 | filename, test_num = DataSplitter.prepare_set(num_test, 'test', filename, conn.cursor()) 132 | conn.commit() 133 | assert len(filename) > 0 134 | filename, val_num = self.prepare_set(num_validation, 'validation', filename, conn.cursor()) 135 | conn.commit() 136 | assert len(filename) > 0 137 | _, train_num = self.prepare_set(num_data - num_test - num_validation, 'train', filename, conn.cursor()) 138 | conn.commit() 139 | 140 | print("Train Size: {}".format(train_num)) 141 | print("Validation Size: {}".format(val_num)) 142 | print("Test Size: {}".format(test_num)) 143 | -------------------------------------------------------------------------------- /asm_embedding/FunctionAnalyzerRadare.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | import json 5 | import r2pipe 6 | 7 | 8 | class RadareFunctionAnalyzer: 9 | 10 | def __init__(self, filename, use_symbol, depth): 11 | self.r2 = r2pipe.open(filename, flags=['-2']) 12 | self.filename = filename 13 | self.arch, _ = self.get_arch() 14 | self.top_depth = depth 15 | self.use_symbol = use_symbol 16 | 17 | def __enter__(self): 18 | return self 19 | 20 | @staticmethod 21 | def filter_reg(op): 22 | return op["value"] 23 | 24 | @staticmethod 25 | def filter_imm(op): 26 | imm = int(op["value"]) 27 | if -int(5000) <= imm <= int(5000): 28 | ret = str(hex(op["value"])) 29 | else: 30 | ret = str('HIMM') 31 | return ret 32 | 33 | @staticmethod 34 | def filter_mem(op): 35 | if "base" not in op: 36 | op["base"] = 0 37 | 38 | if op["base"] == 0: 39 | r = "[" + "MEM" + "]" 40 | else: 41 | reg_base = str(op["base"]) 42 | disp = str(op["disp"]) 43 | scale = str(op["scale"]) 44 | r = '[' + reg_base + "*" + scale + "+" + disp + ']' 45 | return r 46 | 47 | @staticmethod 48 | def filter_memory_references(i): 49 | inst = "" + i["mnemonic"] 50 | 51 | for op in i["operands"]: 52 | if op["type"] == 'reg': 53 | inst += " " + RadareFunctionAnalyzer.filter_reg(op) 54 | elif op["type"] == 'imm': 55 | inst += " " + RadareFunctionAnalyzer.filter_imm(op) 56 | elif op["type"] == 'mem': 57 | inst += " " + RadareFunctionAnalyzer.filter_mem(op) 58 | if len(i["operands"]) > 1: 59 | inst = inst + "," 60 | 61 | if "," in inst: 62 | inst = inst[:-1] 63 | inst = inst.replace(" ", "_") 64 | 65 | return str(inst) 66 | 67 | @staticmethod 68 | def get_callref(my_function, depth): 69 | calls = {} 70 | if 'callrefs' in my_function and depth > 0: 71 | for cc in my_function['callrefs']: 72 | if cc["type"] == "C": 73 | calls[cc['at']] = cc['addr'] 74 | return calls 75 | 76 | def get_instruction(self): 77 | instruction = json.loads(self.r2.cmd("aoj 1")) 78 | if len(instruction) > 0: 79 | instruction = instruction[0] 80 | else: 81 | return None 82 | 83 | operands = [] 84 | if 'opex' not in instruction: 85 | return None 86 | 87 | for op in instruction['opex']['operands']: 88 | operands.append(op) 89 | instruction['operands'] = operands 90 | return instruction 91 | 92 | def function_to_inst(self, functions_dict, my_function, depth): 93 | instructions = [] 94 | asm = "" 95 | 96 | if self.use_symbol: 97 | s = my_function['vaddr'] 98 | else: 99 | s = my_function['offset'] 100 | calls = RadareFunctionAnalyzer.get_callref(my_function, depth) 101 | self.r2.cmd('s ' + str(s)) 102 | 103 | if self.use_symbol: 104 | end_address = s + my_function["size"] 105 | else: 106 | end_address = s + my_function["realsz"] 107 | 108 | while s < end_address: 109 | instruction = self.get_instruction() 110 | asm += instruction["bytes"] 111 | if self.arch == 'x86': 112 | filtered_instruction = "X_" + RadareFunctionAnalyzer.filter_memory_references(instruction) 113 | elif self.arch == 'arm': 114 | filtered_instruction = "A_" + RadareFunctionAnalyzer.filter_memory_references(instruction) 115 | 116 | instructions.append(filtered_instruction) 117 | 118 | if s in calls and depth > 0: 119 | if calls[s] in functions_dict: 120 | ii, aa = self.function_to_inst(functions_dict, functions_dict[calls[s]], depth-1) 121 | instructions.extend(ii) 122 | asm += aa 123 | self.r2.cmd("s " + str(s)) 124 | 125 | self.r2.cmd("so 1") 126 | s = int(self.r2.cmd("s"), 16) 127 | 128 | return instructions, asm 129 | 130 | def get_arch(self): 131 | try: 132 | info = json.loads(self.r2.cmd('ij')) 133 | if 'bin' in info: 134 | arch = info['bin']['arch'] 135 | bits = info['bin']['bits'] 136 | except: 137 | print("Error loading file") 138 | arch = None 139 | bits = None 140 | return arch, bits 141 | 142 | def find_functions(self): 143 | self.r2.cmd('aaa') 144 | try: 145 | function_list = json.loads(self.r2.cmd('aflj')) 146 | except: 147 | function_list = [] 148 | return function_list 149 | 150 | def find_functions_by_symbols(self): 151 | self.r2.cmd('aa') 152 | try: 153 | symbols = json.loads(self.r2.cmd('isj')) 154 | fcn_symb = [s for s in symbols if s['type'] == 'FUNC'] 155 | except: 156 | fcn_symb = [] 157 | return fcn_symb 158 | 159 | def analyze(self): 160 | if self.use_symbol: 161 | function_list = self.find_functions_by_symbols() 162 | else: 163 | function_list = self.find_functions() 164 | 165 | functions_dict = {} 166 | if self.top_depth > 0: 167 | for my_function in function_list: 168 | if self.use_symbol: 169 | functions_dict[my_function['vaddr']] = my_function 170 | else: 171 | functions_dict[my_function['offset']] = my_function 172 | 173 | result = {} 174 | for my_function in function_list: 175 | if self.use_symbol: 176 | address = my_function['vaddr'] 177 | else: 178 | address = my_function['offset'] 179 | 180 | try: 181 | instructions, asm = self.function_to_inst(functions_dict, my_function, self.top_depth) 182 | result[my_function['name']] = {'filtered_instructions': instructions, "asm": asm, "address": address} 183 | except: 184 | print("Error in functions: {} from {}".format(my_function['name'], self.filename)) 185 | pass 186 | return result 187 | 188 | def close(self): 189 | self.r2.quit() 190 | 191 | def __exit__(self, exc_type, exc_value, traceback): 192 | self.r2.quit() 193 | 194 | 195 | 196 | -------------------------------------------------------------------------------- /neural_network/PairFactory.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # distributed under license: GPL 3 License http://www.gnu.org/licenses/ 3 | import sqlite3 4 | 5 | import json 6 | import numpy as np 7 | 8 | from multiprocessing import Queue 9 | from multiprocessing import Process 10 | from asm_embedding.FunctionNormalizer import FunctionNormalizer 11 | 12 | # 13 | # PairFactory class, used for training the SAFE network. 14 | # This class generates the pairs for training, test and validation 15 | # 16 | # 17 | # Authors: SAFE team 18 | 19 | 20 | class PairFactory: 21 | 22 | def __init__(self, db_name, dataset_type, batch_size, max_instructions, shuffle=True): 23 | self.db_name = db_name 24 | self.dataset_type = dataset_type 25 | self.max_instructions = max_instructions 26 | self.batch_dim = 0 27 | self.num_pairs = 0 28 | self.num_batches = 0 29 | self.batch_size = batch_size 30 | conn = sqlite3.connect(self.db_name) 31 | cur = conn.cursor() 32 | q = cur.execute("SELECT true_pair from " + self.dataset_type + " WHERE id=?", (0,)) 33 | self.num_pairs=len(json.loads(q.fetchone()[0]))*2 34 | n_chunk = int(self.num_pairs / self.batch_size) - 1 35 | conn.close() 36 | self.num_batches = n_chunk 37 | self.shuffle = shuffle 38 | 39 | @staticmethod 40 | def split( a, n): 41 | return [a[i::n] for i in range(n)] 42 | 43 | @staticmethod 44 | def truncate_and_compute_lengths(pairs, max_instructions): 45 | lenghts = [] 46 | new_pairs=[] 47 | for x in pairs: 48 | f0 = np.asarray(x[0][0:max_instructions]) 49 | f1 = np.asarray(x[1][0:max_instructions]) 50 | lenghts.append((f0.shape[0], f1.shape[0])) 51 | if f0.shape[0] < max_instructions: 52 | f0 = np.pad(f0, (0, max_instructions - f0.shape[0]), mode='constant') 53 | if f1.shape[0] < max_instructions: 54 | f1 = np.pad(f1, (0, max_instructions - f1.shape[0]), mode='constant') 55 | 56 | new_pairs.append((f0, f1)) 57 | return new_pairs, lenghts 58 | 59 | def async_chunker(self, epoch): 60 | 61 | conn = sqlite3.connect(self.db_name) 62 | cur = conn.cursor() 63 | query_string = "SELECT true_pair,false_pair from {} where id=?".format(self.dataset_type) 64 | q = cur.execute(query_string, (int(epoch),)) 65 | true_pairs_id, false_pairs_id = q.fetchone() 66 | true_pairs_id = json.loads(true_pairs_id) 67 | false_pairs_id = json.loads(false_pairs_id) 68 | 69 | assert len(true_pairs_id) == len(false_pairs_id) 70 | data_len = len(true_pairs_id) 71 | 72 | # print("Data Len: " + str(data_len)) 73 | conn.close() 74 | 75 | n_chunk = int(data_len / (self.batch_size / 2)) - 1 76 | lista_chunk = range(0, n_chunk) 77 | coda = Queue(maxsize=50) 78 | n_proc = 8 # modify this to increase the parallelism for the db loading, from our thest 8-10 is the sweet spot on a 16 cores machine with K80 79 | listone = PairFactory.split(lista_chunk, n_proc) 80 | 81 | # this ugly workaround is somehow needed, Pool is working oddly when TF is loaded. 82 | for i in range(0, n_proc): 83 | p = Process(target=self.async_create_couple, args=((epoch, listone[i], coda))) 84 | p.start() 85 | 86 | for i in range(0, n_chunk): 87 | yield self.async_get_dataset(coda) 88 | 89 | def get_pair_fromdb(self, id_1, id_2): 90 | conn = sqlite3.connect(self.db_name) 91 | cur = conn.cursor() 92 | q0 = cur.execute("SELECT instructions_list FROM filtered_functions WHERE id=?", (id_1,)) 93 | f0 = json.loads(q0.fetchone()[0]) 94 | 95 | q1 = cur.execute("SELECT instructions_list FROM filtered_functions WHERE id=?", (id_2,)) 96 | f1 = json.loads(q1.fetchone()[0]) 97 | conn.close() 98 | return f0, f1 99 | 100 | def get_couple_from_db(self, epoch_number, chunk): 101 | 102 | conn = sqlite3.connect(self.db_name) 103 | cur = conn.cursor() 104 | 105 | pairs = [] 106 | labels = [] 107 | 108 | q = cur.execute("SELECT true_pair, false_pair from " + self.dataset_type + " WHERE id=?", (int(epoch_number),)) 109 | true_pairs_id, false_pairs_id = q.fetchone() 110 | 111 | true_pairs_id = json.loads(true_pairs_id) 112 | false_pairs_id = json.loads(false_pairs_id) 113 | conn.close() 114 | data_len = len(true_pairs_id) 115 | 116 | i = 0 117 | 118 | normalizer = FunctionNormalizer(self.max_instructions) 119 | 120 | while i < self.batch_size: 121 | if chunk * int(self.batch_size / 2) + i > data_len: 122 | break 123 | 124 | p = true_pairs_id[chunk * int(self.batch_size / 2) + i] 125 | f0, f1 = self.get_pair_fromdb(p[0], p[1]) 126 | pairs.append((f0, f1)) 127 | labels.append(+1) 128 | 129 | p = false_pairs_id[chunk * int(self.batch_size / 2) + i] 130 | f0, f1 = self.get_pair_fromdb(p[0], p[1]) 131 | pairs.append((f0, f1)) 132 | labels.append(-1) 133 | 134 | i += 2 135 | 136 | pairs, lengths = normalizer.normalize_function_pairs(pairs) 137 | 138 | function1, function2 = zip(*pairs) 139 | len1, len2 = zip(*lengths) 140 | n_samples = len(pairs) 141 | 142 | if self.shuffle: 143 | shuffle_indices = np.random.permutation(np.arange(n_samples)) 144 | 145 | function1 = np.array(function1)[shuffle_indices] 146 | 147 | function2 = np.array(function2)[shuffle_indices] 148 | len1 = np.array(len1)[shuffle_indices] 149 | len2 = np.array(len2)[shuffle_indices] 150 | labels = np.array(labels)[shuffle_indices] 151 | else: 152 | function1=np.array(function1) 153 | function2=np.array(function2) 154 | len1=np.array(len1) 155 | len2=np.array(len2) 156 | labels=np.array(labels) 157 | 158 | upper_bound = min(self.batch_size, n_samples) 159 | len1 = len1[0:upper_bound] 160 | len2 = len2[0:upper_bound] 161 | function1 = function1[0:upper_bound] 162 | function2 = function2[0:upper_bound] 163 | y_ = labels[0:upper_bound] 164 | return function1, function2, len1, len2, y_ 165 | 166 | def async_create_couple(self, epoch,n_chunk,q): 167 | for i in n_chunk: 168 | function1, function2, len1, len2, y_ = self.get_couple_from_db(epoch, i) 169 | q.put((function1, function2, len1, len2, y_), block=True) 170 | 171 | def async_get_dataset(self, q): 172 | 173 | item = q.get() 174 | function1 = item[0] 175 | function2 = item[1] 176 | len1 = item[2] 177 | len2 = item[3] 178 | y_ = item[4] 179 | 180 | assert (len(function1) == len(y_)) 181 | n_samples = len(function1) 182 | self.batch_dim = n_samples 183 | #self.num_pairs += n_samples 184 | 185 | return function1, function2, len1, len2, y_ 186 | 187 | -------------------------------------------------------------------------------- /dataset_creation/DatabaseFactory.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # Copyright (C) 2019 Luca Massarelli, Giuseppe Antonio Di Luna, Fabio Petroni, Leonardo Querzoni, Roberto Baldoni 3 | 4 | from asm_embedding.InstructionsConverter import InstructionsConverter 5 | from asm_embedding.FunctionAnalyzerRadare import RadareFunctionAnalyzer 6 | import json 7 | import multiprocessing 8 | from multiprocessing import Pool 9 | from multiprocessing.dummy import Pool as ThreadPool 10 | import os 11 | import random 12 | import signal 13 | import sqlite3 14 | from tqdm import tqdm 15 | 16 | 17 | class DatabaseFactory: 18 | 19 | def __init__(self, db_name, root_path): 20 | self.db_name = db_name 21 | self.root_path = root_path 22 | 23 | @staticmethod 24 | def worker(item): 25 | DatabaseFactory.analyze_file(item) 26 | return 0 27 | 28 | @staticmethod 29 | def extract_function(graph_analyzer): 30 | return graph_analyzer.extractAll() 31 | 32 | 33 | @staticmethod 34 | def insert_in_db(db_name, pool_sem, func, filename, function_name, instruction_converter): 35 | path = filename.split(os.sep) 36 | if len(path) < 4: 37 | return 38 | asm = func["asm"] 39 | instructions_list = func["filtered_instructions"] 40 | instruction_ids = json.dumps(instruction_converter.convert_to_ids(instructions_list)) 41 | pool_sem.acquire() 42 | conn = sqlite3.connect(db_name) 43 | cur = conn.cursor() 44 | cur.execute('''INSERT INTO functions VALUES (?,?,?,?,?,?,?,?)''', (None, # id 45 | path[-4], # project 46 | path[-3], # compiler 47 | path[-2], # optimization 48 | path[-1], # file_name 49 | function_name, # function_name 50 | asm, # asm 51 | len(instructions_list)) # num of instructions 52 | ) 53 | inserted_id = cur.lastrowid 54 | cur.execute('''INSERT INTO filtered_functions VALUES (?,?)''', (inserted_id, 55 | instruction_ids) 56 | ) 57 | conn.commit() 58 | conn.close() 59 | pool_sem.release() 60 | 61 | @staticmethod 62 | def analyze_file(item): 63 | global pool_sem 64 | os.setpgrp() 65 | 66 | filename = item[0] 67 | db = item[1] 68 | use_symbol = item[2] 69 | depth = item[3] 70 | instruction_converter = item[4] 71 | 72 | analyzer = RadareFunctionAnalyzer(filename, use_symbol, depth) 73 | p = ThreadPool(1) 74 | res = p.apply_async(analyzer.analyze) 75 | 76 | try: 77 | result = res.get(120) 78 | except multiprocessing.TimeoutError: 79 | print("Aborting due to timeout:" + str(filename)) 80 | print('Try to modify the timeout value in DatabaseFactory instruction result = res.get(TIMEOUT)') 81 | os.killpg(0, signal.SIGKILL) 82 | except Exception: 83 | print("Aborting due to error:" + str(filename)) 84 | os.killpg(0, signal.SIGKILL) 85 | 86 | for func in result: 87 | DatabaseFactory.insert_in_db(db, pool_sem, result[func], filename, func, instruction_converter) 88 | 89 | analyzer.close() 90 | 91 | return 0 92 | 93 | # Create the db where data are stored 94 | def create_db(self): 95 | print('Database creation...') 96 | conn = sqlite3.connect(self.db_name) 97 | conn.execute(''' CREATE TABLE IF NOT EXISTS functions (id INTEGER PRIMARY KEY, 98 | project text, 99 | compiler text, 100 | optimization text, 101 | file_name text, 102 | function_name text, 103 | asm text, 104 | num_instructions INTEGER) 105 | ''') 106 | conn.execute('''CREATE TABLE IF NOT EXISTS filtered_functions (id INTEGER PRIMARY KEY, 107 | instructions_list text) 108 | ''') 109 | conn.commit() 110 | conn.close() 111 | 112 | # Scan the root directory to find all the file to analyze, 113 | # query also the db for already analyzed files. 114 | def scan_for_file(self, start): 115 | file_list = [] 116 | # Scan recursively all the subdirectory 117 | directories = os.listdir(start) 118 | for item in directories: 119 | item = os.path.join(start,item) 120 | if os.path.isdir(item): 121 | file_list.extend(self.scan_for_file(item + os.sep)) 122 | elif os.path.isfile(item) and item.endswith('.o'): 123 | file_list.append(item) 124 | return file_list 125 | 126 | # Looks for already existing files in the database 127 | # It returns a list of files that are not in the database 128 | def remove_override(self, file_list): 129 | conn = sqlite3.connect(self.db_name) 130 | cur = conn.cursor() 131 | q = cur.execute('''SELECT project, compiler, optimization, file_name FROM functions''') 132 | names = q.fetchall() 133 | names = [os.path.join(self.root_path, n[0], n[1], n[2], n[3]) for n in names] 134 | names = set(names) 135 | # If some files is already in the db remove it from the file list 136 | if len(names) > 0: 137 | print(str(len(names)) + ' Already in the database') 138 | cleaned_file_list = [] 139 | for f in file_list: 140 | if not(f in names): 141 | cleaned_file_list.append(f) 142 | 143 | return cleaned_file_list 144 | 145 | # root function to create the db 146 | def build_db(self, use_symbol, depth): 147 | global pool_sem 148 | 149 | pool_sem = multiprocessing.BoundedSemaphore(value=1) 150 | 151 | instruction_converter = InstructionsConverter("data/i2v/word2id.json") 152 | self.create_db() 153 | file_list = self.scan_for_file(self.root_path) 154 | 155 | print('Found ' + str(len(file_list)) + ' during the scan') 156 | file_list = self.remove_override(file_list) 157 | print('Find ' + str(len(file_list)) + ' files to analyze') 158 | random.shuffle(file_list) 159 | 160 | t_args = [(f, self.db_name, use_symbol, depth, instruction_converter) for f in file_list] 161 | 162 | # Start a parallel pool to analyze files 163 | p = Pool(processes=None, maxtasksperchild=20) 164 | for _ in tqdm(p.imap_unordered(DatabaseFactory.worker, t_args), total=len(file_list)): 165 | pass 166 | 167 | p.close() 168 | p.join() 169 | 170 | 171 | -------------------------------------------------------------------------------- /Gemfile.lock: -------------------------------------------------------------------------------- 1 | GEM 2 | remote: https://rubygems.org/ 3 | specs: 4 | activesupport (4.2.10) 5 | i18n (~> 0.7) 6 | minitest (~> 5.1) 7 | thread_safe (~> 0.3, >= 0.3.4) 8 | tzinfo (~> 1.1) 9 | addressable (2.5.2) 10 | public_suffix (>= 2.0.2, < 4.0) 11 | coffee-script (2.4.1) 12 | coffee-script-source 13 | execjs 14 | coffee-script-source (1.11.1) 15 | colorator (1.1.0) 16 | commonmarker (0.17.13) 17 | ruby-enum (~> 0.5) 18 | concurrent-ruby (1.1.3) 19 | dnsruby (1.61.2) 20 | addressable (~> 2.5) 21 | em-websocket (0.5.1) 22 | eventmachine (>= 0.12.9) 23 | http_parser.rb (~> 0.6.0) 24 | ethon (0.11.0) 25 | ffi (>= 1.3.0) 26 | eventmachine (1.2.7) 27 | execjs (2.7.0) 28 | faraday (0.15.3) 29 | multipart-post (>= 1.2, < 3) 30 | ffi (1.9.25) 31 | forwardable-extended (2.6.0) 32 | gemoji (3.0.0) 33 | github-pages (193) 34 | activesupport (= 4.2.10) 35 | github-pages-health-check (= 1.8.1) 36 | jekyll (= 3.7.4) 37 | jekyll-avatar (= 0.6.0) 38 | jekyll-coffeescript (= 1.1.1) 39 | jekyll-commonmark-ghpages (= 0.1.5) 40 | jekyll-default-layout (= 0.1.4) 41 | jekyll-feed (= 0.11.0) 42 | jekyll-gist (= 1.5.0) 43 | jekyll-github-metadata (= 2.9.4) 44 | jekyll-mentions (= 1.4.1) 45 | jekyll-optional-front-matter (= 0.3.0) 46 | jekyll-paginate (= 1.1.0) 47 | jekyll-readme-index (= 0.2.0) 48 | jekyll-redirect-from (= 0.14.0) 49 | jekyll-relative-links (= 0.5.3) 50 | jekyll-remote-theme (= 0.3.1) 51 | jekyll-sass-converter (= 1.5.2) 52 | jekyll-seo-tag (= 2.5.0) 53 | jekyll-sitemap (= 1.2.0) 54 | jekyll-swiss (= 0.4.0) 55 | jekyll-theme-architect (= 0.1.1) 56 | jekyll-theme-cayman (= 0.1.1) 57 | jekyll-theme-dinky (= 0.1.1) 58 | jekyll-theme-hacker (= 0.1.1) 59 | jekyll-theme-leap-day (= 0.1.1) 60 | jekyll-theme-merlot (= 0.1.1) 61 | jekyll-theme-midnight (= 0.1.1) 62 | jekyll-theme-minimal (= 0.1.1) 63 | jekyll-theme-modernist (= 0.1.1) 64 | jekyll-theme-primer (= 0.5.3) 65 | jekyll-theme-slate (= 0.1.1) 66 | jekyll-theme-tactile (= 0.1.1) 67 | jekyll-theme-time-machine (= 0.1.1) 68 | jekyll-titles-from-headings (= 0.5.1) 69 | jemoji (= 0.10.1) 70 | kramdown (= 1.17.0) 71 | liquid (= 4.0.0) 72 | listen (= 3.1.5) 73 | mercenary (~> 0.3) 74 | minima (= 2.5.0) 75 | nokogiri (>= 1.8.2, < 2.0) 76 | rouge (= 2.2.1) 77 | terminal-table (~> 1.4) 78 | github-pages-health-check (1.8.1) 79 | addressable (~> 2.3) 80 | dnsruby (~> 1.60) 81 | octokit (~> 4.0) 82 | public_suffix (~> 2.0) 83 | typhoeus (~> 1.3) 84 | html-pipeline (2.9.1) 85 | activesupport (>= 2) 86 | nokogiri (>= 1.4) 87 | http_parser.rb (0.6.0) 88 | i18n (0.9.5) 89 | concurrent-ruby (~> 1.0) 90 | jekyll (3.7.4) 91 | addressable (~> 2.4) 92 | colorator (~> 1.0) 93 | em-websocket (~> 0.5) 94 | i18n (~> 0.7) 95 | jekyll-sass-converter (~> 1.0) 96 | jekyll-watch (~> 2.0) 97 | kramdown (~> 1.14) 98 | liquid (~> 4.0) 99 | mercenary (~> 0.3.3) 100 | pathutil (~> 0.9) 101 | rouge (>= 1.7, < 4) 102 | safe_yaml (~> 1.0) 103 | jekyll-avatar (0.6.0) 104 | jekyll (~> 3.0) 105 | jekyll-coffeescript (1.1.1) 106 | coffee-script (~> 2.2) 107 | coffee-script-source (~> 1.11.1) 108 | jekyll-commonmark (1.2.0) 109 | commonmarker (~> 0.14) 110 | jekyll (>= 3.0, < 4.0) 111 | jekyll-commonmark-ghpages (0.1.5) 112 | commonmarker (~> 0.17.6) 113 | jekyll-commonmark (~> 1) 114 | rouge (~> 2) 115 | jekyll-default-layout (0.1.4) 116 | jekyll (~> 3.0) 117 | jekyll-feed (0.11.0) 118 | jekyll (~> 3.3) 119 | jekyll-gist (1.5.0) 120 | octokit (~> 4.2) 121 | jekyll-github-metadata (2.9.4) 122 | jekyll (~> 3.1) 123 | octokit (~> 4.0, != 4.4.0) 124 | jekyll-mentions (1.4.1) 125 | html-pipeline (~> 2.3) 126 | jekyll (~> 3.0) 127 | jekyll-optional-front-matter (0.3.0) 128 | jekyll (~> 3.0) 129 | jekyll-paginate (1.1.0) 130 | jekyll-readme-index (0.2.0) 131 | jekyll (~> 3.0) 132 | jekyll-redirect-from (0.14.0) 133 | jekyll (~> 3.3) 134 | jekyll-relative-links (0.5.3) 135 | jekyll (~> 3.3) 136 | jekyll-remote-theme (0.3.1) 137 | jekyll (~> 3.5) 138 | rubyzip (>= 1.2.1, < 3.0) 139 | jekyll-sass-converter (1.5.2) 140 | sass (~> 3.4) 141 | jekyll-seo-tag (2.5.0) 142 | jekyll (~> 3.3) 143 | jekyll-sitemap (1.2.0) 144 | jekyll (~> 3.3) 145 | jekyll-swiss (0.4.0) 146 | jekyll-theme-architect (0.1.1) 147 | jekyll (~> 3.5) 148 | jekyll-seo-tag (~> 2.0) 149 | jekyll-theme-cayman (0.1.1) 150 | jekyll (~> 3.5) 151 | jekyll-seo-tag (~> 2.0) 152 | jekyll-theme-dinky (0.1.1) 153 | jekyll (~> 3.5) 154 | jekyll-seo-tag (~> 2.0) 155 | jekyll-theme-hacker (0.1.1) 156 | jekyll (~> 3.5) 157 | jekyll-seo-tag (~> 2.0) 158 | jekyll-theme-leap-day (0.1.1) 159 | jekyll (~> 3.5) 160 | jekyll-seo-tag (~> 2.0) 161 | jekyll-theme-merlot (0.1.1) 162 | jekyll (~> 3.5) 163 | jekyll-seo-tag (~> 2.0) 164 | jekyll-theme-midnight (0.1.1) 165 | jekyll (~> 3.5) 166 | jekyll-seo-tag (~> 2.0) 167 | jekyll-theme-minimal (0.1.1) 168 | jekyll (~> 3.5) 169 | jekyll-seo-tag (~> 2.0) 170 | jekyll-theme-modernist (0.1.1) 171 | jekyll (~> 3.5) 172 | jekyll-seo-tag (~> 2.0) 173 | jekyll-theme-primer (0.5.3) 174 | jekyll (~> 3.5) 175 | jekyll-github-metadata (~> 2.9) 176 | jekyll-seo-tag (~> 2.0) 177 | jekyll-theme-slate (0.1.1) 178 | jekyll (~> 3.5) 179 | jekyll-seo-tag (~> 2.0) 180 | jekyll-theme-tactile (0.1.1) 181 | jekyll (~> 3.5) 182 | jekyll-seo-tag (~> 2.0) 183 | jekyll-theme-time-machine (0.1.1) 184 | jekyll (~> 3.5) 185 | jekyll-seo-tag (~> 2.0) 186 | jekyll-titles-from-headings (0.5.1) 187 | jekyll (~> 3.3) 188 | jekyll-watch (2.1.2) 189 | listen (~> 3.0) 190 | jemoji (0.10.1) 191 | gemoji (~> 3.0) 192 | html-pipeline (~> 2.2) 193 | jekyll (~> 3.0) 194 | kramdown (1.17.0) 195 | liquid (4.0.0) 196 | listen (3.1.5) 197 | rb-fsevent (~> 0.9, >= 0.9.4) 198 | rb-inotify (~> 0.9, >= 0.9.7) 199 | ruby_dep (~> 1.2) 200 | mercenary (0.3.6) 201 | mini_portile2 (2.3.0) 202 | minima (2.5.0) 203 | jekyll (~> 3.5) 204 | jekyll-feed (~> 0.9) 205 | jekyll-seo-tag (~> 2.1) 206 | minitest (5.11.3) 207 | multipart-post (2.0.0) 208 | nokogiri (1.8.5) 209 | mini_portile2 (~> 2.3.0) 210 | octokit (4.13.0) 211 | sawyer (~> 0.8.0, >= 0.5.3) 212 | pathutil (0.16.2) 213 | forwardable-extended (~> 2.6) 214 | public_suffix (2.0.5) 215 | rb-fsevent (0.10.3) 216 | rb-inotify (0.9.10) 217 | ffi (>= 0.5.0, < 2) 218 | rouge (2.2.1) 219 | ruby-enum (0.7.2) 220 | i18n 221 | ruby_dep (1.5.0) 222 | rubyzip (1.2.2) 223 | safe_yaml (1.0.4) 224 | sass (3.7.2) 225 | sass-listen (~> 4.0.0) 226 | sass-listen (4.0.0) 227 | rb-fsevent (~> 0.9, >= 0.9.4) 228 | rb-inotify (~> 0.9, >= 0.9.7) 229 | sawyer (0.8.1) 230 | addressable (>= 2.3.5, < 2.6) 231 | faraday (~> 0.8, < 1.0) 232 | terminal-table (1.8.0) 233 | unicode-display_width (~> 1.1, >= 1.1.1) 234 | thread_safe (0.3.6) 235 | typhoeus (1.3.1) 236 | ethon (>= 0.9.0) 237 | tzinfo (1.2.5) 238 | thread_safe (~> 0.1) 239 | unicode-display_width (1.4.0) 240 | 241 | PLATFORMS 242 | ruby 243 | 244 | DEPENDENCIES 245 | github-pages 246 | jekyll (~> 3.7.4) 247 | jekyll-feed (~> 0.6) 248 | minima (~> 2.0) 249 | tzinfo-data 250 | 251 | BUNDLED WITH 252 | 1.17.1 253 | -------------------------------------------------------------------------------- /neural_network/SiameseSAFE.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | # SAFE TEAM 3 | # 4 | # 5 | # distributed under license: CC BY-NC-SA 4.0 (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.txt) 6 | # 7 | 8 | # Siamese Self-Attentive Network for Binary Similarity: 9 | # 10 | # arXiv Nostro. 11 | # 12 | # based on the self attentive network:arXiv:1703.03130 Z. Lin at al. “A structured self-attentive sentence embedding'' 13 | # 14 | # Authors: SAFE team 15 | 16 | class SiameseSelfAttentive: 17 | 18 | def __init__(self, 19 | rnn_state_size, # Dimension of the RNN State 20 | learning_rate, # Learning rate 21 | l2_reg_lambda, 22 | batch_size, 23 | max_instructions, 24 | embedding_matrix, # Matrix containg the embeddings for each asm instruction 25 | trainable_embeddings, 26 | # if this value is True, the embeddings of the asm instruction are modified by the training. 27 | attention_hops, # attention hops parameter r of [1] 28 | attention_depth, # attention detph parameter d_a of [1] 29 | dense_layer_size, # parameter e of [1] 30 | embedding_size, # size of the final function embedding, in our test this is twice the rnn_state_size 31 | ): 32 | self.rnn_depth = 1 # if this value is modified then the RNN becames a multilayer network. In our tests we fix it to 1 feel free to be adventurous. 33 | self.learning_rate = learning_rate 34 | self.l2_reg_lambda = l2_reg_lambda 35 | self.rnn_state_size = rnn_state_size 36 | self.batch_size = batch_size 37 | self.max_instructions = max_instructions 38 | self.embedding_matrix = embedding_matrix 39 | self.trainable_embeddings = trainable_embeddings 40 | self.attention_hops = attention_hops 41 | self.attention_depth = attention_depth 42 | self.dense_layer_size = dense_layer_size 43 | self.embedding_size = embedding_size 44 | 45 | # self.generate_new_safe() 46 | 47 | def restore_model(self, old_session): 48 | graph = old_session.graph 49 | 50 | self.x_1 = graph.get_tensor_by_name("x_1:0") 51 | self.x_2 = graph.get_tensor_by_name("x_2:0") 52 | self.len_1 = graph.get_tensor_by_name("lengths_1:0") 53 | self.len_2 = graph.get_tensor_by_name("lengths_2:0") 54 | self.y = graph.get_tensor_by_name('y_:0') 55 | self.cos_similarity = graph.get_tensor_by_name("siamese_layer/cosSimilarity:0") 56 | self.loss = graph.get_tensor_by_name("Loss/loss:0") 57 | self.train_step = graph.get_operation_by_name("Train_Step/Adam") 58 | 59 | return 60 | 61 | def self_attentive_network(self, input_x, lengths): 62 | # each functions is a list of embeddings id (an id is an index in the embedding matrix) 63 | # with this we transform it in a list of embeddings vectors. 64 | embbedded_functions = tf.nn.embedding_lookup(self.instructions_embeddings_t, input_x) 65 | 66 | # We create the GRU RNN 67 | (output_fw, output_bw), _ = tf.nn.bidirectional_dynamic_rnn(self.cell_fw, self.cell_bw, embbedded_functions, 68 | sequence_length=lengths, dtype=tf.float32, 69 | time_major=False) 70 | 71 | # We create the matrix H 72 | H = tf.concat([output_fw, output_bw], axis=2) 73 | 74 | # We do a tile to account for training batches 75 | ws1_tiled = tf.tile(tf.expand_dims(self.WS1, 0), [tf.shape(H)[0], 1, 1], name="WS1_tiled") 76 | ws2_tile = tf.tile(tf.expand_dims(self.WS2, 0), [tf.shape(H)[0], 1, 1], name="WS2_tiled") 77 | 78 | # we compute the matrix A 79 | self.A = tf.nn.softmax(tf.matmul(ws2_tile, tf.nn.tanh(tf.matmul(ws1_tiled, tf.transpose(H, perm=[0, 2, 1])))), 80 | name="Attention_Matrix") 81 | # embedding matrix M 82 | M = tf.identity(tf.matmul(self.A, H), name="Attention_Embedding") 83 | 84 | # we create the flattened version of M 85 | flattened_M = tf.reshape(M, [tf.shape(M)[0], self.attention_hops * self.rnn_state_size * 2]) 86 | 87 | return flattened_M 88 | 89 | def generate_new_safe(self): 90 | self.instructions_embeddings_t = tf.Variable(initial_value=tf.constant(self.embedding_matrix), 91 | trainable=self.trainable_embeddings, 92 | name="instructions_embeddings", dtype=tf.float32) 93 | 94 | self.x_1 = tf.placeholder(tf.int32, [None, self.max_instructions], 95 | name="x_1") # List of instructions for Function 1 96 | self.lengths_1 = tf.placeholder(tf.int32, [None], name='lengths_1') # List of lengths for Function 1 97 | # example x_1=[[mov,add,padding,padding],[mov,mov,mov,padding]] 98 | # lenghts_1=[2,3] 99 | 100 | self.x_2 = tf.placeholder(tf.int32, [None, self.max_instructions], 101 | name="x_2") # List of instructions for Function 2 102 | self.lengths_2 = tf.placeholder(tf.int32, [None], name='lengths_2') # List of lengths for Function 2 103 | self.y = tf.placeholder(tf.float32, [None], name='y_') # Real label of the pairs, +1 similar, -1 dissimilar. 104 | 105 | # Euclidean norms; p = 2 106 | self.norms = [] 107 | 108 | # Keeping track of l2 regularization loss (optional) 109 | l2_loss = tf.constant(0.0) 110 | 111 | with tf.name_scope('parameters_Attention'): 112 | self.WS1 = tf.Variable(tf.truncated_normal([self.attention_depth, 2 * self.rnn_state_size], stddev=0.1), 113 | name="WS1") 114 | self.WS2 = tf.Variable(tf.truncated_normal([self.attention_hops, self.attention_depth], stddev=0.1), 115 | name="WS2") 116 | 117 | rnn_layers_fw = [tf.nn.rnn_cell.GRUCell(size) for size in ([self.rnn_state_size] * self.rnn_depth)] 118 | rnn_layers_bw = [tf.nn.rnn_cell.GRUCell(size) for size in ([self.rnn_state_size] * self.rnn_depth)] 119 | 120 | self.cell_fw = tf.nn.rnn_cell.MultiRNNCell(rnn_layers_fw) 121 | self.cell_bw = tf.nn.rnn_cell.MultiRNNCell(rnn_layers_bw) 122 | 123 | with tf.name_scope('Self-Attentive1'): 124 | self.function_1 = self.self_attentive_network(self.x_1, self.lengths_1) 125 | with tf.name_scope('Self-Attentive2'): 126 | self.function_2 = self.self_attentive_network(self.x_2, self.lengths_2) 127 | 128 | self.dense_1 = tf.nn.relu(tf.layers.dense(self.function_1, self.dense_layer_size)) 129 | self.dense_2 = tf.nn.relu(tf.layers.dense(self.function_2, self.dense_layer_size)) 130 | 131 | with tf.name_scope('Embedding1'): 132 | self.function_embedding_1 = tf.layers.dense(self.dense_1, self.embedding_size) 133 | with tf.name_scope('Embedding2'): 134 | self.function_embedding_2 = tf.layers.dense(self.dense_2, self.embedding_size) 135 | 136 | with tf.name_scope('siamese_layer'): 137 | self.cos_similarity = tf.reduce_sum(tf.multiply(self.function_embedding_1, self.function_embedding_2), 138 | axis=1, 139 | name="cosSimilarity") 140 | 141 | # CalculateMean cross-entropy loss 142 | with tf.name_scope("Loss"): 143 | A_square = tf.matmul(self.A, tf.transpose(self.A, perm=[0, 2, 1])) 144 | 145 | I = tf.eye(tf.shape(A_square)[1]) 146 | I_tiled = tf.tile(tf.expand_dims(I, 0), [tf.shape(A_square)[0], 1, 1], name="I_tiled") 147 | self.A_pen = tf.norm(A_square - I_tiled) 148 | 149 | self.loss = tf.reduce_sum(tf.squared_difference(self.cos_similarity, self.y), name="loss") 150 | self.regularized_loss = self.loss + self.l2_reg_lambda * l2_loss + self.A_pen 151 | 152 | # Train step 153 | with tf.name_scope("Train_Step"): 154 | self.train_step = tf.train.AdamOptimizer(self.learning_rate).minimize(self.regularized_loss) 155 | -------------------------------------------------------------------------------- /neural_network/SAFE_model.py: -------------------------------------------------------------------------------- 1 | # SAFE TEAM 2 | # distributed under license: GPL 3 License http://www.gnu.org/licenses/ 3 | 4 | from SiameseSAFE import SiameseSelfAttentive 5 | from PairFactory import PairFactory 6 | import tensorflow as tf 7 | import random 8 | import sys, os 9 | import numpy as np 10 | from sklearn import metrics 11 | import matplotlib 12 | import tqdm 13 | 14 | matplotlib.use('Agg') 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | class modelSAFE: 19 | 20 | def __init__(self, flags, embedding_matrix): 21 | self.embedding_size = flags.embedding_size 22 | self.num_epochs = flags.num_epochs 23 | self.learning_rate = flags.learning_rate 24 | self.l2_reg_lambda = flags.l2_reg_lambda 25 | self.num_checkpoints = flags.num_checkpoints 26 | self.logdir = flags.logdir 27 | self.logger = flags.logger 28 | self.seed = flags.seed 29 | self.batch_size = flags.batch_size 30 | self.max_instructions = flags.max_instructions 31 | self.embeddings_matrix = embedding_matrix 32 | self.session = None 33 | self.db_name = flags.db_name 34 | self.trainable_embeddings = flags.trainable_embeddings 35 | self.cross_val = flags.cross_val 36 | self.attention_hops = flags.attention_hops 37 | self.attention_depth = flags.attention_depth 38 | self.dense_layer_size = flags.dense_layer_size 39 | self.rnn_state_size = flags.rnn_state_size 40 | 41 | random.seed(self.seed) 42 | np.random.seed(self.seed) 43 | 44 | print(self.db_name) 45 | 46 | # loads an usable model 47 | # returns the network and a tensorflow session in which the network can be used. 48 | @staticmethod 49 | def load_model(path): 50 | session = tf.Session() 51 | checkpoint_dir = os.path.abspath(os.path.join(path, "checkpoints")) 52 | saver = tf.train.import_meta_graph(os.path.join(checkpoint_dir, "model.meta")) 53 | tf.global_variables_initializer().run(session=session) 54 | saver.restore(session, os.path.join(checkpoint_dir, "model")) 55 | network = SiameseSelfAttentive( 56 | rnn_state_size=1, 57 | learning_rate=1, 58 | l2_reg_lambda=1, 59 | batch_size=1, 60 | max_instructions=1, 61 | embedding_matrix=1, 62 | trainable_embeddings=1, 63 | attention_hops=1, 64 | attention_depth=1, 65 | dense_layer_size=1, 66 | embedding_size=1 67 | ) 68 | network.restore_model(session) 69 | return session, network 70 | 71 | def create_network(self): 72 | self.network = SiameseSelfAttentive( 73 | rnn_state_size=self.rnn_state_size, 74 | learning_rate=self.learning_rate, 75 | l2_reg_lambda=self.l2_reg_lambda, 76 | batch_size=self.batch_size, 77 | max_instructions=self.max_instructions, 78 | embedding_matrix=self.embeddings_matrix, 79 | trainable_embeddings=self.trainable_embeddings, 80 | attention_hops=self.attention_hops, 81 | attention_depth=self.attention_depth, 82 | dense_layer_size=self.dense_layer_size, 83 | embedding_size=self.embedding_size 84 | ) 85 | 86 | def train(self): 87 | tf.reset_default_graph() 88 | with tf.Graph().as_default() as g: 89 | session_conf = tf.ConfigProto( 90 | allow_soft_placement=True, 91 | log_device_placement=False 92 | ) 93 | sess = tf.Session(config=session_conf) 94 | 95 | # Sets the graph-level random seed. 96 | tf.set_random_seed(self.seed) 97 | 98 | self.create_network() 99 | self.network.generate_new_safe() 100 | # --tbrtr 101 | 102 | # Initialize all variables 103 | sess.run(tf.global_variables_initializer()) 104 | 105 | # TensorBoard 106 | # Summaries for loss and accuracy 107 | loss_summary = tf.summary.scalar("loss", self.network.loss) 108 | 109 | # Train Summaries 110 | train_summary_op = tf.summary.merge([loss_summary]) 111 | train_summary_dir = os.path.join(self.logdir, "summaries", "train") 112 | train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) 113 | 114 | # Validation summaries 115 | val_summary_op = tf.summary.merge([loss_summary]) 116 | val_summary_dir = os.path.join(self.logdir, "summaries", "validation") 117 | val_summary_writer = tf.summary.FileWriter(val_summary_dir, sess.graph) 118 | 119 | # Test summaries 120 | test_summary_op = tf.summary.merge([loss_summary]) 121 | test_summary_dir = os.path.join(self.logdir, "summaries", "test") 122 | test_summary_writer = tf.summary.FileWriter(test_summary_dir, sess.graph) 123 | 124 | # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it 125 | checkpoint_dir = os.path.abspath(os.path.join(self.logdir, "checkpoints")) 126 | checkpoint_prefix = os.path.join(checkpoint_dir, "model") 127 | if not os.path.exists(checkpoint_dir): 128 | os.makedirs(checkpoint_dir) 129 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=self.num_checkpoints) 130 | 131 | best_val_auc = 0 132 | stat_file = open(str(self.logdir) + "/epoch_stats.tsv", "w") 133 | stat_file.write("#epoch\ttrain_loss\tval_loss\tval_auc\ttest_loss\ttest_auc\n") 134 | 135 | p_train = PairFactory(self.db_name, 'train_pairs', self.batch_size, self.max_instructions) 136 | p_validation = PairFactory(self.db_name, 'validation_pairs', self.batch_size, self.max_instructions, False) 137 | p_test = PairFactory(self.db_name, 'test_pairs', self.batch_size, self.max_instructions, False) 138 | 139 | step = 0 140 | for epoch in range(0, self.num_epochs): 141 | epoch_msg = "" 142 | epoch_msg += " epoch: {}\n".format(epoch) 143 | 144 | epoch_loss = 0 145 | 146 | # ----------------------# 147 | # TRAIN # 148 | # ----------------------# 149 | n_batch = 0 150 | for function1_batch, function2_batch, len1_batch, len2_batch, y_batch in tqdm.tqdm( 151 | p_train.async_chunker(epoch % 25), total=p_train.num_batches): 152 | feed_dict = { 153 | self.network.x_1: function1_batch, 154 | self.network.x_2: function2_batch, 155 | self.network.lengths_1: len1_batch, 156 | self.network.lengths_2: len2_batch, 157 | self.network.y: y_batch, 158 | } 159 | 160 | summaries, _, loss, norms, cs = sess.run( 161 | [train_summary_op, self.network.train_step, self.network.loss, self.network.norms, 162 | self.network.cos_similarity], 163 | feed_dict=feed_dict) 164 | 165 | train_summary_writer.add_summary(summaries, step) 166 | epoch_loss += loss * p_train.batch_dim # ??? 167 | step += 1 168 | # recap epoch 169 | epoch_loss /= p_train.num_pairs 170 | epoch_msg += "\ttrain_loss: {}\n".format(epoch_loss) 171 | 172 | # ----------------------# 173 | # VALIDATION # 174 | # ----------------------# 175 | val_loss = 0 176 | epoch_msg += "\n" 177 | val_y = [] 178 | val_pred = [] 179 | for function1_batch, function2_batch, len1_batch, len2_batch, y_batch in tqdm.tqdm( 180 | p_validation.async_chunker(0), total=p_validation.num_batches): 181 | feed_dict = { 182 | self.network.x_1: function1_batch, 183 | self.network.x_2: function2_batch, 184 | self.network.lengths_1: len1_batch, 185 | self.network.lengths_2: len2_batch, 186 | self.network.y: y_batch, 187 | } 188 | 189 | summaries, loss, similarities = sess.run( 190 | [val_summary_op, self.network.loss, self.network.cos_similarity], feed_dict=feed_dict) 191 | val_loss += loss * p_validation.batch_dim 192 | val_summary_writer.add_summary(summaries, step) 193 | val_y.extend(y_batch) 194 | val_pred.extend(similarities.tolist()) 195 | 196 | val_loss /= p_validation.num_pairs 197 | 198 | if np.isnan(val_pred).any(): 199 | print("Validation: carefull there is NaN in some ouput values, I am fixing it but be aware...") 200 | val_pred = np.nan_to_num(val_pred) 201 | 202 | val_fpr, val_tpr, val_thresholds = metrics.roc_curve(val_y, val_pred, pos_label=1) 203 | val_auc = metrics.auc(val_fpr, val_tpr) 204 | epoch_msg += "\tval_loss : {}\n\tval_auc : {}\n".format(val_loss, val_auc) 205 | 206 | sys.stdout.write( 207 | "\r\tepoch {} / {}, loss {:g}, val_auc {:g}, norms {}".format(epoch, self.num_epochs, epoch_loss, 208 | val_auc, norms)) 209 | sys.stdout.flush() 210 | 211 | # execute test only if validation auc increased 212 | test_loss = "-" 213 | test_auc = "-" 214 | 215 | # in case of cross validation we do not need to evaluate on a test split that is effectively missing 216 | if val_auc > best_val_auc and self.cross_val: 217 | # 218 | ##-- --## 219 | # 220 | best_val_auc = val_auc 221 | saver.save(sess, checkpoint_prefix) 222 | print("\nNEW BEST_VAL_AUC: {} !\n".format(best_val_auc)) 223 | # write ROC raw data 224 | with open(str(self.logdir) + "/best_val_roc.tsv", "w") as the_file: 225 | the_file.write("#thresholds\ttpr\tfpr\n") 226 | for t, tpr, fpr in zip(val_thresholds, val_tpr, val_fpr): 227 | the_file.write("{}\t{}\t{}\n".format(t, tpr, fpr)) 228 | 229 | # in case we are not cross validating we expect to have a test split. 230 | if val_auc > best_val_auc and not self.cross_val: 231 | 232 | best_val_auc = val_auc 233 | epoch_msg += "\tNEW BEST_VAL_AUC: {} !\n".format(best_val_auc) 234 | 235 | # save best model 236 | saver.save(sess, checkpoint_prefix) 237 | 238 | # ----------------------# 239 | # TEST # 240 | # ----------------------# 241 | 242 | # TEST 243 | test_loss = 0 244 | epoch_msg += "\n" 245 | test_y = [] 246 | test_pred = [] 247 | 248 | for function1_batch, function2_batch, len1_batch, len2_batch, y_batch in tqdm.tqdm( 249 | p_test.async_chunker(0), total=p_test.num_batches): 250 | feed_dict = { 251 | self.network.x_1: function1_batch, 252 | self.network.x_2: function2_batch, 253 | self.network.lengths_1: len1_batch, 254 | self.network.lengths_2: len2_batch, 255 | self.network.y: y_batch, 256 | } 257 | summaries, loss, similarities = sess.run( 258 | [test_summary_op, self.network.loss, self.network.cos_similarity], feed_dict=feed_dict) 259 | test_loss += loss * p_test.batch_dim 260 | test_summary_writer.add_summary(summaries, step) 261 | test_y.extend(y_batch) 262 | test_pred.extend(similarities.tolist()) 263 | 264 | test_loss /= p_test.num_pairs 265 | if np.isnan(test_pred).any(): 266 | print("Test: carefull there is NaN in some ouput values, I am fixing it but be aware...") 267 | test_pred = np.nan_to_num(test_pred) 268 | 269 | test_fpr, test_tpr, test_thresholds = metrics.roc_curve(test_y, test_pred, pos_label=1) 270 | 271 | # write ROC raw data 272 | with open(str(self.logdir) + "/best_test_roc.tsv", "w") as the_file: 273 | the_file.write("#thresholds\ttpr\tfpr\n") 274 | for t, tpr, fpr in zip(test_thresholds, test_tpr, test_fpr): 275 | the_file.write("{}\t{}\t{}\n".format(t, tpr, fpr)) 276 | 277 | test_auc = metrics.auc(test_fpr, test_tpr) 278 | epoch_msg += "\ttest_loss : {}\n\ttest_auc : {}\n".format(test_loss, test_auc) 279 | fig = plt.figure() 280 | plt.title('Receiver Operating Characteristic') 281 | plt.plot(test_fpr, test_tpr, 'b', 282 | label='AUC = %0.2f' % test_auc) 283 | fig.savefig(str(self.logdir) + "/best_test_roc.png") 284 | print( 285 | "\nNEW BEST_VAL_AUC: {} !\n\ttest_loss : {}\n\ttest_auc : {}\n".format(best_val_auc, test_loss, 286 | test_auc)) 287 | plt.close(fig) 288 | 289 | stat_file.write( 290 | "{}\t{}\t{}\t{}\t{}\t{}\n".format(epoch, epoch_loss, val_loss, val_auc, test_loss, test_auc)) 291 | self.logger.info("\n{}\n".format(epoch_msg)) 292 | stat_file.close() 293 | sess.close() 294 | return best_val_auc 295 | --------------------------------------------------------------------------------